train的时候没问题,inference时,出现了tensorflow的error: Input to reshape is a tensor with 41 values, but the requested shape requires a multiple of 16。
原因是在写模型时用到了shape(tensor)而不是tf.shape(tensor)。
原代码:
max_sentence_size = shape(text)[1]
new_a = tf.reshape(tensor=a,shape=(-1, max_sentence_size,3),name="logits_reshape")
改成:
max_sentence_size = tf.shape(text)[1]
new_a = tf.reshape(tensor=a,shape=(-1, max_sentence_size,3),name="logits_reshape")
就可以了。
得到了个教训, 写模型的时候要一直用tensor的操作。tensorflow-gpu的版本是2.2。