当前位置: 代码迷 >> 综合 >> transforms.ToTensor()的作用就是变换图像数据为tensor并shape转化为[channel, h, w]
  详细解决方案

transforms.ToTensor()的作用就是变换图像数据为tensor并shape转化为[channel, h, w]

热度:22   发布时间:2023-12-17 05:42:01.0

图片数据封装时候往往需要用到transform, 这个方法一方面将图片数据shape成【channel, w, h】,另一方卖弄将图片array转化为tensor

class CatsAndDogsDataset(Dataset):def __init__(self, csv_file, root_dir, transform=None):self.annotations = pd.read_csv(csv_file)self.root_dir = root_dirself.transform = transformdef __len__(self):return len(self.annotations)def __getitem__(self, index):img_path = os.path.join(self.root_dir, self.annotations.iloc[index, 0])image = io.imread(img_path) #array [h,w,channel] y_label = torch.tensor(int(self.annotations.iloc[index, 1]))if self.transform:image = self.transform(image) #tensor [channel, h,w]return (image, y_label)

上面这段代码实例化数据集时候需要传入transform方法

dataset = CatsAndDogsDataset(csv_file='./custom_dataset/cats_dogs.csv',root_dir='./custom_dataset/cats_dogs_resized',transform=transforms.ToTensor(),
)

需要先导入import torchvision.transforms as transforms
其中image = io.imread(img_path) 读取的image格式是array, shape为[h,w,channel] , 但是网络训练时候一般要求是[channel, h, w], 所以image = self.transform(image)将其转化为tensor, shape为 [channel, h,w]格式

  相关解决方案