定义
tf.concat(values, concat_dim)
作用是将待操作的张量进行连接/合并,
函数中第一个参数是待操作的张量,用[]包起来;第二个参数指定某个维度,
本文的concat函数例子是因为在tf的工程中,静态图的tensor具体维度常常未知,所以需要用tf.palceholder()表示,非一般的已定维度的常量concat。
那种其实很好理解,无非叠加在一起罢了。
例子1
import tensorflow as tf
import numpy as np
i1 = tf.placeholder(tf.float32,[2,None,2])
i2 = tf.placeholder(tf.float32,[2,None,None])con_output = tf.concat([i1,i2],axis=0)num = np.ones([2,2,2]) # 作为真正的输入
# num = tf.random_normal([2,2,2]) # 会报类型错误with tf.Session() as sess:print(sess.run(con_output, feed_dict={
i1:num,i2:num}))
通过shape可以看到最终的shape是(4,?,2),即将i1和i2的第一个维度合并/连接,同时由于第二个维度未定,所以为?;第三个维度因为i1的缘故,为2。
例子2
这里仅修改输入和concat的中的维度
i1 = tf.placeholder(tf.float32,[None,None,2])
i2 = tf.placeholder(tf.float32,[None,None,None])
num = np.ones([2,2,2]) # 作为真正的输入
con_output = tf.concat([i1,i2],axis=0)
得到的shape是:(?, ?, 2),并不是因为num的第一个维度为2,shape就为4,这里只要i1,i2中的某个维度一同为None,最终的shape也为None
待操作的张量不能是tf.Tensor类型,会报错:
TypeError: The value of a feed cannot be a tf.Tensor object. Acceptable feed values include Python scalars, strings, lists, numpy ndarrays, or TensorHandles.