图片数据封装时候往往需要用到transform, 这个方法一方面将图片数据shape成【channel, w, h】,另一方卖弄将图片array转化为tensor
class CatsAndDogsDataset(Dataset):def __init__(self, csv_file, root_dir, transform=None):self.annotations = pd.read_csv(csv_file)self.root_dir = root_dirself.transform = transformdef __len__(self):return len(self.annotations)def __getitem__(self, index):img_path = os.path.join(self.root_dir, self.annotations.iloc[index, 0])image = io.imread(img_path) #array [h,w,channel] y_label = torch.tensor(int(self.annotations.iloc[index, 1]))if self.transform:image = self.transform(image) #tensor [channel, h,w]return (image, y_label)
上面这段代码实例化数据集时候需要传入transform方法
dataset = CatsAndDogsDataset(csv_file='./custom_dataset/cats_dogs.csv',root_dir='./custom_dataset/cats_dogs_resized',transform=transforms.ToTensor(),
)
需要先导入import torchvision.transforms as transforms
其中image = io.imread(img_path)
读取的image格式是array, shape为[h,w,channel] , 但是网络训练时候一般要求是[channel, h, w], 所以image = self.transform(image)
将其转化为tensor, shape为 [channel, h,w]格式