当前位置: 代码迷 >> 综合 >> Pytorch强化一 || 自定义Dataset实现torchvision.datasets.ImageFolder的相同功能属性,支持批量读取高通道维度的tif格式图片
  详细解决方案

Pytorch强化一 || 自定义Dataset实现torchvision.datasets.ImageFolder的相同功能属性,支持批量读取高通道维度的tif格式图片

热度:91   发布时间:2023-11-30 19:57:51.0

1.问题描述

  1. 普通的torchvision.datasets.ImageFolder()函数读取4通道的tif格式时,输出的tensor向量还是三通道的,因为其底层就是使用PIL读取图片,无法读入高维度图片,解决方案是重写torch底层,采用skimage读取图片

  2. 将重写的代码命名为loadTifImage.py,存放于lib文件夹内,使用如下(与使用torch自带的ImageFolder()一样):

from lib import loadTifImagedata_transform = transforms.Compose([transforms.ToTensor()])train_dataset = loadTifImage.DatasetFolder(root='路径',transform=data_transform)
  1. 使用时图片文件夹的目录结构,train下有两个子文件夹,每个子文件夹内的图片是同一类,代码自动对其打上label标签,我们传入的路径精确到 ../../train 即可
    请添加图片描述

2.重写代码的读取图片方法之处

def loadTifImage(path):image = io.imread(path)# print('image.shape=>',image.shape)image = transform.resize(image, (224, 224))     # 修改尺寸,仅能在此处修改image = image/255.0             # 归一化# print(image)im = np.array(image, dtype=np.float32)return im

3.完整代码 loadTifImage.py

import os
import numpy as np
import sys
from torch.utils.data import Dataset
from skimage import transform,io# 支持的图片格式
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', 'webp']def has_file_allowed_extension(filename, extensions):"""查看文件是否是支持的可扩展类型Args:filename (string): 文件路径extensions (iterable of strings): 可扩展类型列表,即能接受的图像文件类型Returns:bool: True if the filename ends with one of given extensions"""filename_lower = filename.lower()return any(filename_lower.endswith(ext) for ext in extensions) # 返回True或False列表def make_dataset(dir, class_to_idx, extensions):"""返回形如[(图像路径, 该图像对应的类别索引值),(),...]"""images = []dir = os.path.expanduser(dir)for target in sorted(class_to_idx.keys()):d = os.path.join(dir, target)if not os.path.isdir(d):continuefor root, _, fnames in sorted(os.walk(d)): #层层遍历文件夹,返回当前文件夹路径,存在的所有文件夹名,存在的所有文件名for fname in sorted(fnames):if has_file_allowed_extension(fname, extensions): #查看文件是否是支持的可扩展类型,是则继续path = os.path.join(root, fname)item = (path, class_to_idx[target])images.append(item)return imagesdef loadTifImage(path):image = io.imread(path)# print('image.shape=>',image.shape)image = transform.resize(image, (224, 224))     # 修改尺寸,仅能在此处修改image = image/255.0             # 归一化# print(image)im = np.array(image, dtype=np.float32)return imclass DatasetFolder(Dataset):"""Args:root (string): 根目录路径loader (callable): 根据给定的路径来加载样本的可调用函数extensions (list[string]): 可扩展类型列表,即能接受的图像文件类型.transform (callable, optional): 用于样本的transform函数,然后返回样本transform后的版本E.g, ``transforms.RandomCrop`` for images.target_transform (callable, optional): 用于样本标签的transform函数Attributes:classes (list): 类别名列表class_to_idx (dict): 项目(class_name, class_index)字典,如{'cat': 0, 'dog': 1}samples (list): (sample path, class_index) 元组列表,即(样本路径, 类别索引)targets (list): 在数据集中每张图片的类索引值,为列表"""def __init__(self, root, loader=loadTifImage, extensions=IMG_EXTENSIONS, transform=None, target_transform=None):classes, class_to_idx = self._find_classes(root)    # 得到类名和类索引,如['cat', 'dog']和{'cat': 0, 'dog': 1}# 返回形如[(图像路径, 该图像对应的类别索引值),(),...],即对每个图像进行标记samples = make_dataset(root, class_to_idx, extensions)if len(samples) == 0:raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n""Supported extensions are: " + ",".join(extensions)))self.root = rootself.loader = loaderself.extensions = extensionsself.classes = classesself.class_to_idx = class_to_idxself.samples = samplesself.targets = [s[1] for s in samples]  # 所有图像的类索引值组成的列表self.transform = transformself.target_transform = target_transformdef _find_classes(self, dir):"""在数据集中查找类文件夹。Args:dir (string): 根目录路径Returns:返回元组: (classes, class_to_idx)即(类名, 类索引),其中classes即相应的目录名,如['cat', 'dog'];class_to_idx为形如{类名:类索引}的字典,如{'cat': 0, 'dog': 1}.Ensures:保证没有类名是另一个类目录的子目录"""if sys.version_info >= (3, 5):# Faster and available in Python 3.5 and aboveclasses = [d.name for d in os.scandir(dir) if d.is_dir()]   # 获得根目录dir的所有第一层子目录名else:classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]   # 效果和上面的一样,只是版本不同方法不同classes.sort() #然后对类名进行排序class_to_idx = {
    classes[i]: i for i in range(len(classes))}     # 然后将类名和索引值一一对应的到相应字典,如{'cat': 0, 'dog': 1}return classes, class_to_idx    # 然后返回类名和类索引def __getitem__(self, index):"""Args:index (int): IndexReturns:tuple: (sample, target) where target is class_index of the target class."""path, target = self.samples[index]sample = self.loader(path)  # 加载图片函数,可自定义为opencv,默认为PILif self.transform is not None:sample = self.transform(sample)if self.target_transform is not None:target = self.target_transform(target)return sample, targetdef __len__(self):return len(self.samples)def __repr__(self):fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())fmt_str += ' Root Location: {}\n'.format(self.root)tmp = ' Transforms (if any): 'fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))tmp = ' Target Transforms (if any): 'fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))return fmt_str
  相关解决方案