当前位置: 代码迷 >> 综合 >> torchvision 数据加载和可视化:ImageFolder、make_grid
  详细解决方案

torchvision 数据加载和可视化:ImageFolder、make_grid

热度:85   发布时间:2024-01-04 21:07:01.0

torchvision 是 pytorch 框架适配的相当好用的工具包,它封装了最流行的数据集(torchvision.datasets)、模型(torchvision.models)和常用于 CV 的图像转换组件(torchvision.transforms)和其它工具:

有时间一定要通读一遍官方文档 TORCHVISION,内容不多,简明易懂,有助于上手。


以 notebook 的方式实践 torchvision
# 导入必要的包
import torch
import torchvision
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import models, datasets, transforms%pylab inline  # 魔法方法用于显示 plt.show()
一、torchvision.transforms

transforms.Compose([ ... ]) 定义常用的图像转换流程,以字典的方式保存方便调用:

# 按照数据集的图像大小选择转换组件和参数
data_transforms = {
    # 训练数据集的转换组件'train': transforms.Compose([transforms.Resize(230),  # 图片自适应缩小(或放大)到最大边长为230的大小 == transforms.Scale(230)transforms.CenterCrop(224),  # 居中裁剪成 224×224的图transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转图像,图像被翻转的概率默认为 p=0.5transforms.ToTensor(),  # 转成 Tensor 格式,并归一化像素值到 [0., 1.]transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # 以均值0.5,标准差0.5,分别在3个通道上进行归一化]),# 测试数据集的转换组件'test': transforms.Compose([transforms.Resize(256),  # 保留测试图像更大一些,考验一下神经网络transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
}

官方文档摘取:

  • torchvision.transforms.ToTensor()

    ToTensor() 将 shape 为 (H, W, C) 的 nump.ndarray 或 PIL img 转为 shape 为 (C, H, W) 的 torch.Tensor(其中, C 为图像通道数),并将每一个数值归一化到 [0,1],其归一化方法比较简单,直接除以 255 即可。

  • torchvision.transforms.Normalize(mean, std, inplace=False)

    对于 n 个通道,给定均值数组:(mean [1],...,mean [n]) 和标准差数组:(std [1],..,std [n]),此变换将对输入的每个通道分别用对应值的进行归一化。

    归一化方式:output[channel] = (input[channel] - mean[channel]) / std[channel]

    它的目的是让图像的像素值接近 N ( 0 , 1 ) \mathcal N(0,1) N(01) 标准正态分布,即以 0 为均值,1 为标准差的分布,所以大部分的像素值最终在 [-1, 1] 区间。

二、torchvision.datasetstorch.utils.data.DataLoader

PyTorch 数据加载的核心是 torch.utils.data.DataLoader 类。它返回一个 Python 可迭代的数据集

DataLoader 的参数配置如下:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,batch_sampler=None, num_workers=0, collate_fn=None,pin_memory=False, drop_last=False, timeout=0,worker_init_fn=None, *, prefetch_factor=2,persistent_workers=False)

我们一般只需要指定 dataset、batch_size、shuffle 和 num_workers。

DataLoader 的第一个参数 dataset 接收一个数据集类的实例,该类必须继承 torch.utils.data.Dataset,并至少实现两个魔法方法:__getitem__(self, index)__len__(self)

torchvision.datasets 中的所有数据集类就都是这样的:
1
所以像 mnist_train = torchvision.datasets.MNIST(train=True),mnist_train 就可以直接放进 DataLoader:train_loader = DataLoader(dataset=mnist_train, batch_size=1) 并得到 MNIST 训练数据集的加载器 train_loader。

三、torchvision.datasets.ImageFolder

它可以把本地的文件夹和里面的文件作为一个数据集并自动按各自文件夹分类,它接收的参数为:ImageFolder(data_dir, data_transforms),data_dir 是数据集所在的文件夹,data_transforms 是上面定义的图像转换方式。

train_set = datasets.ImageFolder(os.path.join(data_dir, 'train'), data_transforms['train'])
test_set = datasets.ImageFolder(os.path.join(data_dir, 'test'), data_transforms['test'])train_loader = DataLoader(train_set, batch_size=5, shuffle=True, num_workers=4)
test_loader = DataLoader(train_set, batch_size=5, shuffle=True, num_workers=4)

官方文档摘取:

ImageFolder 默认以图片所在文件夹作为分类标签

  • ImageFolder 要求加载的文件夹里面是各类图像所在的各自的文件夹,并以文件夹顺序作为标签(0, 1, …, folder_n - 1)
  • Returns: (sample, target) where target is class_index of the target class.
四、torchvision.utils.make_gridplt.imshow

这个工具可以很方便地可视化数据集。这里 还有更多非常实用的 torchvision.utils 的可视化示例。

有了数据加载器 train_loader,我们可以很容易通过迭代来得到一个 batch_size 大小的图像和标签数据组:

imgs, labels = next(iter(train_loader))

定义一个便利的图像查看函数:

def imshow(imgs):imgs = imgs / 2 + 0.5  # 逆归一化,像素值从[-1, 1]回到[0, 1]imgs = imgs.numpy().transpose((1, 2, 0))  # 图像从(C, H, W)转回(H, W, C)的numpy矩阵plt.imshow(imgs)plt.show()

文档摘取:

plt.imshow() 的图像参数为:numpy array-like or PIL image,并且要求像素值为 [0,1] 区间的 float 类型或 [0, 255] 区间的 uint8 类型。

所以经过 transforms.ToTensor() 转换过的 PIL img 只需转回 img.numpy() 以及调整通道顺序即可,不必要乘以 255。

现在获取一个 image grid 图像栅格,便于成组展示图象:

img_grid = torchvision.utils.make_grid(imgs)

作画:

imshow(img_grid)

00