Value passed to parameter ‘x’ has DataType int32 not in list of allowed values: bfloat16, float16, float32, float64, complex64, complex128
原代码是这样的:
dv=64
att = Lambda(lambda x: tf.tensordot(x[0],x[1],axes=[-1,-1])/tf.sqrt(dv),output_shape=(l, nv, nv))([q,k])
然后就出现了上面的错误,分析了原因,是我使用tf.sqrt(dv)时,传递的参数有问题,tf.sqrt()不接受int类型的参数,所以会报错。我将其改成numpy类型的,也就是
att = Lambda(lambda x: tf.tensordot(x[0],x[1],axes=[-1,-1])/np.sqrt(dv),output_shape=(l, nv, nv))([q,k])
修改后不再报错