如果某一维度为batch_size
,代码会显示该维度为None
;
使用tf.squeeze()
时,如果不指定axis
,会出现压缩异常的情况。理论上,会压缩所有维度为1的维度,但是这里压缩后却为unknown
。
如果指定了axis
,则没有问题:
由于这里使用Flatten
也可以达到相同效果,所以最后代码为:
Flatten()(tf.concat(embedding_query, axis=-1)
但是Flatten
是将所有的维度拉成一维,这个也要注意;