pytorch报错
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 1 and 3 in dimension 1 at /pytorch/aten/src/TH/generic/THTensorMath.cpp:3616
原因分析
使用DataLoader加载图像,这些图像中的一些具有3个通道(彩色图像),而其他图像可能具有单个通道(BW图像),由于dim1的尺寸不同,因此无法将它们连接成批次。
解决方法
将img = img.convert(‘RGB’)添加到数据集图像读取中
img = Image.open(img_path)
img = img.convert('RGB')