当前位置: 代码迷 >> 综合 >> tf.train.batch 和 tf.train.shuffle_batch学习
  详细解决方案

tf.train.batch 和 tf.train.shuffle_batch学习

热度:44   发布时间:2023-11-25 04:03:44.0

转自 http://blog.csdn.net/wuguangbin1230/article/details/72810706

    tf.train.batch([example, label], batch_size=batch_size, capacity=capacity):[example, label]表示样本和样本标签,这个可以是一个样本和一个样本标签,batch_size是返回的一个batch样本集的样本个数。capacity是队列中的容量。这主要是按顺序组合成一个batch。 tf.train.shuffle_batch([example, label], batch_size=batch_size, capacity=capacity, min_after_dequeue)。这里面的参数和上面的一样的意思。不一样的是这个参数min_after_dequeue,一定要保证这参数小于capacity参数的值,否则会出错。这个代表队列中的元素大于它的时候就输出乱的顺序的batch。也就是说这个函数的输出结果是一个乱序的样本排列的batch,不是按照顺序排列的。

上面的函数返回值都是一个batch的样本和样本标签,只是一个是按照顺序,另外一个是随机的。

下文转自 http://blog.csdn.net/ying86615791/article/details/73864381

capacity是队列的长度
min_after_dequeue是出队后,队列至少剩下min_after_dequeue个数据
假设现在有个test.tfrecord文件,里面按从小到大顺序存放整数0~100

  1. tf.train.batch是按顺序读取数据,队列中的数据始终是一个有序的队列,
    比如队列的capacity=20,开始队列内容为0,1,…,19=>读取10条记录后,队列剩下10,11,…,19,然后又补充10条变成=>10,11,…,29,
    队头一直按顺序补充,队尾一直按顺序出队,到了第100条记录后,又重头开始补充0,1,2…

  2. tf.train.shuffle_batch是将队列中数据打乱后,再读取出来,因此队列中剩下的数据也是乱序的,队头也是一直在补充(我猜也是按顺序补充),
    比如batch_size=5,capacity=10,min_after_dequeue=5,
    初始是有序的0,1,…,9(10条记录),
    然后打乱8,2,6,4,3,7,9,2,0,1(10条记录),
    队尾取出5条,剩下7,9,2,0,1(5条记录),
    然后又按顺序补充进来,变成7,9,2,0,1,10,11,12,13,14(10条记录),
    再打乱13,10,2,7,0,12…1(10条记录),

再出队…

capacity可以看成是局部数据的范围,读取的数据是基于这个范围的,

在这个范围内,min_after_dequeue越大,数据越乱

这样按batch读取的话,最后会自动在前面添加一个维度,比如数据的维度是[1],batch_size是10,那么读取出来的shape就是[10,1]

  相关解决方案