当前位置: 代码迷 >> 综合 >> tf.keras.preprocessing.image_dataset_from_directory() 简介
  详细解决方案

tf.keras.preprocessing.image_dataset_from_directory() 简介

热度:55   发布时间:2023-12-14 08:30:24.0
  • ? 运行环境:python3
  • ? 作者:K同学啊
  • ? 精选专栏:《深度学习100例》
  • ? 推荐专栏:《新手入门深度学习》
  • ? 选自专栏:《Matplotlib教程》
  • ? 优秀专栏:《Python入门100题》

函数原型

tf.keras.preprocessing.image_dataset_from_directory(directory,labels="inferred",label_mode="int",class_names=None,color_mode="rgb",batch_size=32,image_size=(256, 256),shuffle=True,seed=None,validation_split=None,subset=None,interpolation="bilinear",follow_links=False,
)

官网地址:https://tensorflow.google.cn/api_docs/python/tf/keras/preprocessing/image_dataset_from_directory

作用

将文件夹中的数据加载到tf.data.Dataset中,且加载的同时会打乱数据。

注: 如果你的目录结构是:

main_directory/
…class_a/
…a_image_1.jpg
…a_image_2.jpg
…class_b/
…b_image_1.jpg
…b_image_2.jpg

然后调用 image_dataset_from_directory(main_directory, labels=‘inferred’) 将返回一个tf.data.Dataset, 该数据集从子目录class_a和class_b生成批次图像,同时生成标签0和1(0对应class_a,1对应class_b),

支持的图像格式:jpeg, png, bmp, gif. 动图被截断到第一帧。

参数

  • directory: 数据所在目录。如果标签是inferred(默认),则它应该包含子目录,每个目录包含一个类的图像。否则,将忽略目录结构。
  • labels: inferred(标签从目录结构生成),或者是整数标签的列表/元组,其大小与目录中找到的图像文件的数量相同。标签应根据图像文件路径的字母顺序排序(通过Python中的os.walk(directory)获得)。
  • label_mode:
    • int标签将被编码成整数(使用的损失函数应为:sparse_categorical_crossentropy loss)
    • categorical标签将被编码为分类向量(使用的损失函数应为:categorical_crossentropy loss)
    • binary:意味着标签(只能有2个)被编码为值为0或1的float32标量(例如:binary_crossentropy)。
    • None:(无标签)。
  • class_names: 仅当labelsinferred时有效。这是类名称的明确列表(必须与子目录的名称匹配)。用于控制类的顺序(否则使用字母数字顺序)。
  • color_mode: grayscalergbrgba之一。默认值:rgb。图像将被转换为1、3或者4通道。
  • batch_size: 数据批次的大小。默认值:32
  • image_size: 从磁盘读取数据后将其重新调整大小。默认:(256,256)。由于管道处理的图像批次必须具有相同的大小,因此该参数必须提供。
  • shuffle: 是否打乱数据。默认值:True。如果设置为False,则按字母数字顺序对数据进行排序。
  • seed: 用于shuffle和转换的可选随机种子。
  • validation_split: 0和1之间的可选浮点数,可保留一部分数据用于验证。
  • subset: trainingvalidation之一。仅在设置validation_split时使用。
  • interpolation: 字符串,当调整图像大小时使用的插值方法。默认为:bilinear。支持bilinear, nearest, bicubic, area, lanczos3, lanczos5, gaussian, mitchellcubic
  • follow_links: 是否访问符号链接指向的子目录。默认:False

Returns

一个tf.data.Dataset对象。

  • 如果label_mode为None,它将生成float32张量,其shape为(batch_size, image_size[0], image_size(1), num_channels),并对图像进行编码。
  • 否则,将生成一个元组(images, labels),其中图像的shape为(batch_size, image_size[0], image_size(1), num_channels),并且labels遵循下面描述的格式。

关于labels格式规则:

  • 如果label_mode 是 int, labels是形状为(batch_size, )的int32张量
  • 如果label_mode 是 binary, labels是形状为(batch_size, 1)的1和0的float32张量。
  • 如果label_mode 是 categorial, labels是形状为(batch_size, num_classes)的float32张量,表示类索引的one-hot编码。

有关生成图像中通道数量的规则:

  • 如果color_mode 是 grayscale, 图像张量有1个通道。
  • 如果color_mode 是 rgb, 图像张量有3个通道。
  • 如果color_mode 是 rgba, 图像张量有4个通道。
  相关解决方案