当前位置: 代码迷 >> 综合 >> Value passed to parameter ‘x‘ has DataType int32 not in list of allowed values: bfloat16, float16, f
  详细解决方案

Value passed to parameter ‘x‘ has DataType int32 not in list of allowed values: bfloat16, float16, f

热度:53   发布时间:2024-02-07 14:29:03.0

Value passed to parameter ‘x’ has DataType int32 not in list of allowed values: bfloat16, float16, float32, float64, complex64, complex128

原代码是这样的:

dv=64
att = Lambda(lambda x: tf.tensordot(x[0],x[1],axes=[-1,-1])/tf.sqrt(dv),output_shape=(l, nv, nv))([q,k])

然后就出现了上面的错误,分析了原因,是我使用tf.sqrt(dv)时,传递的参数有问题,tf.sqrt()不接受int类型的参数,所以会报错。我将其改成numpy类型的,也就是

att = Lambda(lambda x: tf.tensordot(x[0],x[1],axes=[-1,-1])/np.sqrt(dv),output_shape=(l, nv, nv))([q,k])

修改后不再报错

  相关解决方案