当前位置: 代码迷 >> 综合 >> tf.train.slice_input_producer,tf.train.string_input_producer两种队列批量读取方式
  详细解决方案

tf.train.slice_input_producer,tf.train.string_input_producer两种队列批量读取方式

热度:103   发布时间:2023-11-04 09:24:28.0

一.tf.train.slice_input_producer()

    tf.train.slice_input_producer([image,label],num_epochs=10),随机产生一个图片和标签,num_epochs=10,则表示把所有的数据过10遍,使用完所有的图片数据为一个epoch,这是重复使用10次。上面的用法表示你的数据集和标签已经全部加载到内存中了,如果数据集非常庞大,我们通过这个函数也可以只加载图片的路径,放入图片的path,注意path必须是一个list或者tensorlist.见下面代码实例


   
  1. # -*- coding: utf-8 -*-
  2. ”“”
  3. Created on Mon Mar 26 22:02:22 2018
  4. @author: Administrator
  5. “”“
  6. import tensorflow as tf
  7. import glob
  8. import matplotlib.pyplot as plt
  9. import time
  10. datapath= r’/media/wsw/文档/pythonfile_withpycharm/SVMLearning/faceLibrary/人脸库/ORL/’
  11. imgpath = glob.glob(datapath+ ‘*.bmp’)
  12. # 将路径转化成张量形式
  13. imgpath = tf.convert_to_tensor(imgpath)
  14. # 产生一个队列每次随机产生一张图片地址
  15. # 注意这里要放在数组里面
  16. image = tf.train.slice_input_producer([imgpath])
  17. # 得到一个batch的图片地址
  18. # 由于tf.train.slice_input_producer()函数默认是随机产生一个实例
  19. # 所以在这里直接使用tf.train.batch()直接获得一个batch的数据即可
  20. # 没有必要再去使用tf.trian.shuffle_batch() 速度会慢
  21. img_batch = tf.train.batch([image],batch_size= 20,capacity= 100)
  22. with tf.Session() as sess:
  23. coord = tf.train.Coordinator()
  24. thread = tf.train.start_queue_runners(sess,coord)
  25. i = 0
  26. try:
  27. while not coord.should_stop():
  28. imgs = sess.run(img_batch)
  29. print(imgs)
  30. fig = plt.figure()
  31. for i,path in enumerate(imgs):
  32. img = plt.imread(path[ 0].decode( ‘utf-8’))
  33. axes = fig.add_subplot( 5, 4,i+ 1)
  34. axes.imshow(img)
  35. axes.axis( ‘off’)
  36. plt.ion()
  37. plt.show()
  38. time.sleep( 1)
  39. plt.close()
  40. i+= 1
  41. if i% 10== 0:
  42. break
  43. except tf.errors.OutOfRangeError:
  44. pass
  45. finally:
  46. coord.request_stop()
  47. coord.join(thread)

注意路径此时被加载成二进制编码格式了。

二.批量读取图片数据

    使用tf.train.slice_input_producer([path]),也可以批量读取图片,得到每个图片的路径后,我们可以加载图片并解码成三维数组的形式(图像的深度必须是3通道或者4通道,笔者实验灰度图像,一直不成功)。当使用tf.train.slice_input_producer()时,加载图片数据的reader使用tf.read_file(filename),直接读取。注意图片记得resize().见下面代码:


   
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Tue Mar 27 14:18:34 2018
  5. @author: wsw
  6. """
  7. # 用于通过读取图片的path,然后解码成图片数组的形式,最后返回batch个图片数组
  8. import glob
  9. import tensorflow as tf
  10. import matplotlib.pyplot as plt
  11. path_list = r'/media/wsw/文档/pythonfile_withpycharm/SVMLearning/faceLibrary/人脸库/Yale2/'
  12. img_path = glob.glob(path_list+ '*.bmp')
  13. img_path = tf.convert_to_tensor(img_path,dtype=tf.string)
  14. # 这里img_path,不放在数组里面
  15. # num_epochs = 1,表示将文件下所有的图片都使用一次
  16. # num_epochs和tf.train.slice_input_producer()中是一样的
  17. # 此参数可以用来设置训练的 epochs
  18. image = tf.train.slice_input_producer([img_path],num_epochs= 1)
  19. # load one image and decode img
  20. def load_img(path_queue):
  21. # 创建一个队列读取器,然后解码成数组
  22. # reader = tf.WholeFileReader()
  23. # key,value = reader.read(path_queue)
  24. file_contents = tf.read_file(path_queue[ 0])
  25. img = tf.image.decode_bmp(file_contents,channels= 1)
  26. # 这里很有必要,否则会出错
  27. # 感觉这个地方貌似只能解码3通道以上的图片
  28. img = tf.image.resize_images(img,size=( 100, 100))
  29. # img = tf.reshape(img,shape=(50,50,4))
  30. return img
  31. img = load_img(image)
  32. print(img.shape)
  33. image_batch = tf.train.batch([img],batch_size= 20)
  34. with tf.Session() as sess:
  35. # initializer for num_epochs
  36. tf.local_variables_initializer().run()
  37. coord = tf.train.Coordinator()
  38. thread = tf.train.start_queue_runners(sess=sess,coord=coord)
  39. try:
  40. while not coord.should_stop():
  41. imgs = sess.run(image_batch)
  42. print(imgs.shape)
  43. except tf.errors.OutOfRangeError:
  44. print( 'done')
  45. finally:
  46. coord.request_stop()
  47. coord.join(thread)

三.使用tf.train.string_input_producer()

    tf.train.string_input_producer(path),传入路径时,不需要放入list中。然后加载图片的reader是tf.WholeFileReader(),其他地方和tf.train.slice_input_producer()函数用法基本类似。见代码:


   
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Tue Mar 27 14:18:34 2018
  5. @author: wsw
  6. """
  7. # 用于通过读取图片的path,然后解码成图片数组的形式,最后返回batch个图片数组
  8. import glob
  9. import tensorflow as tf
  10. path_list = r'/media/wsw/文档/pythonfile_withpycharm/SVMLearning/faceLibrary/人脸库/Yale2/'
  11. img_path = glob.glob(path_list+ '*.bmp')
  12. img_path = tf.convert_to_tensor(img_path,dtype=tf.string)
  13. # 这里img_path,不放在数组里面
  14. # num_epochs = 1,表示将文件下所有的图片都使用一次
  15. # num_epochs和tf.train.slice_input_producer()中是一样的
  16. # 此参数可以用来设置训练的 epochs
  17. image = tf.train.string_input_producer(img_path,num_epochs= 1)
  18. # load one image and decode img
  19. def load_img(path_queue):
  20. # 创建一个队列读取器,然后解码成数组
  21. reader = tf.WholeFileReader()
  22. key,value = reader.read(path_queue)
  23. img = tf.image.decode_bmp(value,channels= 3)
  24. # 这里很有必要,否则会出错
  25. # 感觉这个地方貌似只能解码3通道以上的图片
  26. # img = tf.image.resize_images(img,size=(100,100))
  27. img = tf.reshape(img,shape=( 224, 224, 3))
  28. return img
  29. img = load_img(image)
  30. print(img.shape)
  31. image_batch = tf.train.batch([img],batch_size= 20)
  32. with tf.Session() as sess:
  33. # initializer for num_epochs
  34. tf.local_variables_initializer().run()
  35. coord = tf.train.Coordinator()
  36. thread = tf.train.start_queue_runners(sess=sess,coord=coord)
  37. try:
  38. while not coord.should_stop():
  39. imgs = sess.run(image_batch)
  40. print(imgs.shape)
  41. except tf.errors.OutOfRangeError:
  42. print( 'done')
  43. finally:
  44. coord.request_stop()
  45. coord.join(thread)

  相关解决方案