DataLoader可通过collate_fn参数,对Dataset生成的mini-batch的可迭代数据进行后处理,如数据填充。
collate_fn应当是一个可调用对象,常见的可以是外部定义的函数或者lambda函数。
其接受DataLoader。在不设置collate_fn参数时,DataLoader的mini-batch样本序列样式取决于对应dataset参数的设置。而dataset只要是覆写了__init__、__getitem__和__len__方法的Dataset子类即可,所以其输出形式可以很多样化。而collate_fn对于上述的输出进行进一步的加工,可通过定义采样、切片、组合和数学操作等一系列操作,生成最终的训练数据。
下面的代码如果对x,y处设置调试断点,则第一步跳转就是执行MyDataset的gittiem获取一个mini-batch数据,然后跳转到unify_fn后处理函数处理一个批次数据,所以DataLoader的mini-batch输出是经过unify处理后的输出。
import torch
from torch.utils.data import Dataset, DataLoaderA = torch.randn(8, 3) # 待拼接的张量A
B = torch.randn(8, 2) # 待拼接的张量B
C = torch.randn(8, 2) # labels张量# 1. 利用Dataset封装数据集
class MyDataset(Dataset):def __init__(self, x1, x2, y):assert x1.size(0)==x2.size(0)==y.size(0)self.x1, self.x2, self.y = x1, x2, ydef __getitem__(self, idx):return (self.x1[idx], self.x2[idx], self.y[idx])def __len__(self):return self.x1.size(0)dataset = MyDataset(A, B, C)# 2. 定义后处理的collate_fn函数
# 需要特别注意的是:输入的mini-batch数据,返回的为张量,因此要注意对格式进行转换和统一
def unify_fn(batch_data):x_ = [x1.tolist()+x2.tolist() for x1, x2, y in batch_data]y_ = [y.tolist() for x1, x2, y in batch_data]return torch.tensor(x_), torch.tensor(y_)# 3. 利用DataLoader完成数据集的批量化
MyDataLoader = DataLoader(dataset=dataset, shuffle=True, batch_size=4, collate_fn=unify_fn)
for data_iter in MyDataLoader:x, y = data_iter
后处理函数另外一种定义方式是定义为一个类,然后重写call魔法方法把该对象当成函数使用,如现在有一种需求加一个为0的特征。
MyDataset定义不变,定义一个MyCollate类来完成这个处理,处理逻辑在call里写就行
class MyCollate:def __init__(self, pad_value):self.padding = pad_valuedef __call__(self, batch_data):x_ = [x1.tolist()+x2.tolist()+[self.padding]for x1, x2, y in batch_data]y_ = [y.tolist() for x1, x2, y in batch_data]return torch.tensor(x_), torch.tensor(y_)
调用接口修改为
MyDataLoader1 = DataLoader(dataset=dataset, shuffle=True, batch_size=2, collate_fn=MyCollate(0))
以上实例完整代码
import torch
from torch.utils.data import Dataset, DataLoaderA = torch.randn(8, 1) # 待拼接的张量A
B = torch.randn(8, 2) # 待拼接的张量B
C = torch.randn(8, 3) # labels张量# 1. 利用Dataset封装数据集
class MyDataset(Dataset):def __init__(self, x1, x2, y):assert x1.size(0)==x2.size(0)==y.size(0)self.x1, self.x2, self.y = x1, x2, ydef __getitem__(self, idx):return (self.x1[idx], self.x2[idx], self.y[idx])def __len__(self):return self.x1.size(0)dataset = MyDataset(A, B, C)# 2. 定义后处理的collate_fn函数
# 需要特别注意的是:输入的mini-batch数据,返回的为张量,因此要注意对格式进行转换和统一
def unify_fn(batch_data):print(batch_data) #batch_data是一个元组list,一个list是一个批量数据,一个数据是样本加目标x_ = [x1.tolist()+x2.tolist() for x1, x2, y in batch_data]y_ = [y.tolist() for x1, x2, y in batch_data]return torch.tensor(x_), torch.tensor(y_)class MyCollate:def __init__(self, pad_value):self.padding = pad_valuedef __call__(self, batch_data):x_ = [x1.tolist()+x2.tolist()+[self.padding]for x1, x2, y in batch_data]y_ = [y.tolist() for x1, x2, y in batch_data]return torch.tensor(x_), torch.tensor(y_)3. 利用DataLoader完成数据集的批量化
MyDataLoader = DataLoader(dataset=dataset, shuffle=True, batch_size=2, collate_fn=unify_fn)
for data_iter in MyDataLoader:x, y = data_iterMyDataLoader1 = DataLoader(dataset=dataset, shuffle=True, batch_size=2, collate_fn=MyCollate(0))
for data_iter in MyDataLoader1:x, y = data_iterprint(x)
参考:
https://blog.csdn.net/guofei_fly/article/details/104384497
https://blog.csdn.net/guofei_fly/article/details/104382583