当前位置: 代码迷 >> 综合 >> pytorch 《Writing Custom Datasets, Dataloaders and Transforms》官方指导 笔记
  详细解决方案

pytorch 《Writing Custom Datasets, Dataloaders and Transforms》官方指导 笔记

热度:13   发布时间:2023-12-21 06:41:12.0

源码:

from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils# Ignore warnings
import warnings
warnings.filterwarnings("ignore")plt.ion()   # interactive modedef show_landmarks(image, landmarks):"""Show image with landmarks"""plt.imshow(image)plt.scatter(landmarks[:, 0], landmarks[:, 1], s=100, marker='.', c='r')class FaceLandmarksDataset(Dataset):"""Face Landmarks dataset."""# 初始化函数# 输入:csv文档对象路径、根目录路径、转换器对象def __init__(self, csv_file, root_dir, transform=None):"""Args:csv_file (string): Path to the csv file with annotations.root_dir (string): Directory with all the images.transform (callable, optional): Optional transform to be appliedon a sample."""self.landmarks_frame = pd.read_csv(csv_file)   # 把 csv 文档读到 Dataset 中self.root_dir = root_dir   # 在 Dataset 中保存一个根目录路径,根目录路径在 getitem() 的时候会用到self.transform = transformdef __len__(self):return len(self.landmarks_frame)def __getitem__(self, idx):   # dataloader取数据的时候会直接给idx传进来一个int型数if torch.is_tensor(idx):   # 为什么要加这步判断?是iloc只支持list吗?idx = idx.tolist()img_name = os.path.join(self.root_dir, self.landmarks_frame.iloc[idx, 0])   # 把图像的路径拼接好 ?这一行单独拿出来会报错,见draftimage = io.imread(img_name)   # 把图像读进来landmarks = self.landmarks_frame.iloc[idx, 1:]landmarks = np.asarray(landmarks)landmarks = landmarks.astype('float').reshape(-1, 2)sample = {
    'image': image, 'landmarks': landmarks}   # 把样本图像、landmarks 组合成一个字典型对象if self.transform:sample = self.transform(sample)   # 默认 transform 对象实现了 __call__(),变成了可调用的对象return sample# # 创建 FaceLandmarksDataset 的一个实例并迭代样本,试着打印前4个样本的 size 并显示它们的 landmarks
face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv', root_dir='data/faces/')class Rescale(object):"""Rescale the image in a sample to a given size.Args:output_size (tuple or int): Desired output size. If tuple, output ismatched to output_size. If int, smaller of image edges is matchedto output_size keeping aspect ratio the same."""def __init__(self, output_size):assert isinstance(output_size, (int, tuple))self.output_size = output_sizedef __call__(self, sample):image, landmarks = sample['image'], sample['landmarks']h, w = image.shape[:2]if isinstance(self.output_size, int):if h > w:new_h, new_w = self.output_size * h / w, self.output_sizeelse:new_h, new_w = self.output_size, self.output_size * w / helse:new_h, new_w = self.output_sizenew_h, new_w = int(new_h), int(new_w)img = transform.resize(image, (new_h, new_w))# h and w are swapped for landmarks because for images,# x and y axes are axis 1 and 0 respectivelylandmarks = landmarks * [new_w / w, new_h / h]return {
    'image': img, 'landmarks': landmarks}class RandomCrop(object):"""Crop randomly the image in a sample.Args:output_size (tuple or int): Desired output size. If int, square cropis made."""def __init__(self, output_size):assert isinstance(output_size, (int, tuple))if isinstance(output_size, int):self.output_size = (output_size, output_size)else:assert len(output_size) == 2self.output_size = output_sizedef __call__(self, sample):image, landmarks = sample['image'], sample['landmarks']h, w = image.shape[:2]new_h, new_w = self.output_sizetop = np.random.randint(0, h - new_h)left = np.random.randint(0, w - new_w)image = image[top: top + new_h, left: left + new_w]landmarks = landmarks - [left, top]   # 一定要注意 image 更新为剪裁后的图片以后必须更新 landmarksreturn {
    'image': image, 'landmarks': landmarks}class ToTensor(object):"""Convert ndarrays in sample to Tensors."""def __call__(self, sample):image, landmarks = sample['image'], sample['landmarks']# swap color axis because# numpy image: H x W x C# torch image: C X H X Wimage = image.transpose((2, 0, 1))return {
    'image': torch.from_numpy(image),'landmarks': torch.from_numpy(landmarks)}transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv', root_dir='data/faces/', transform=transforms.Compose([Rescale(256), RandomCrop(224), ToTensor()]))dataloader = DataLoader(transformed_dataset, batch_size=4, shuffle=True, num_workers=0)# Helper function to show a batch
def show_landmarks_batch(sample_batched):"""Show image with landmarks for a batch of samples."""images_batch, landmarks_batch = sample_batched['image'], sample_batched['landmarks']batch_size = len(images_batch)im_size = images_batch.size(2)grid_border_size = 2grid = utils.make_grid(images_batch)plt.imshow(grid.numpy().transpose((1, 2, 0)))for i in range(batch_size):plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i + 1) * grid_border_size,landmarks_batch[i, :, 1].numpy() + grid_border_size, s=10, marker='.', c='r')plt.title('Batch from dataloader')for i_batch, sample_batched in enumerate(dataloader):# 自定义的dataset对象里面,__getitem__()返回的是一个字典,# 这里4个样本组合起来用sample_batched变量接收到的还是一个字典,# 只不过字典里面的image键是4个图像数组和一起的一个4x3x224x224的大tensor,# landmarks键也是4个图像的landmark和一起的一个大tensor,sample_batched用起来很符合我们自定义dataset的定义。print(i_batch, sample_batched['image'].size(), sample_batched['landmarks'].size())# observe 4th batch and stop.if i_batch == 3:plt.figure()show_landmarks_batch(sample_batched)plt.axis('off')plt.ioff()plt.show()break

用到的API等:

路径连接:

os.path.join()

用路径读图像为一个3维数组:

skimage.io.imread("图像路径")

如果output_size是int型或tuple型则继续往下执行,否则抛出异常中断(assert关键字):

assert isinstance(output_size, (int, tuple))

调整图像大小以符合一定的尺寸:

skimage.transform.resize(图像, (,))

把对象1、对象2、…(我们自定义的transform类对象)组合成一个大transform,用于传给自定义dataset里做transform成员变量:

torchvision.transforms.Compose([对象1, 对象2, ...])

Make a grid of images:

torchvision.utils.make_grid(images_batch)
  相关解决方案