当前位置: 代码迷 >> 综合 >> 3、DataSet 与 DataLoader
  详细解决方案

3、DataSet 与 DataLoader

热度:16   发布时间:2024-02-19 09:48:44.0
# -*- coding: utf-8 -*-
"""
# @file name  : train_lenet.py
# @author     : tingsongyu
# @date       : 2019-09-07 10:08:00
# @brief      : 人民币分类模型训练
"""
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
from model.lenet import LeNet
from tools.my_dataset import RMBDatasetdef set_seed(seed=1):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)set_seed()  # 设置随机种子
rmb_label = {"1": 0, "100": 1}# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1# ============================ step 1/5 数据 ============================split_dir = os.path.join("..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]train_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.RandomCrop(32, padding=4),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),
])valid_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),
])# 【1】构建MyDataset实例(必须是用户自己构建的 )--》ctrl+点击---》my_dataset.py
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)# 【2】有了Dataset就可以构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)# ============================ step 2/5 模型 ============================net = LeNet(classes=2)
net.initialize_weights()# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()                                                   # 选择损失函数# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)                        # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)     # 设置学习率下降策略# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()# 【3】训练是以epoch为周期,在每个epoch中会有多个Iteration的训练
for epoch in range(MAX_EPOCH):loss_mean = 0.correct = 0.total = 0.net.train()# 【4】数据的获取--》debug查看pytorch是如何获取数据的--》dataloader.pyfor i, data in enumerate(train_loader):# forwardinputs, labels = dataoutputs = net(inputs)# backward--》获取梯度optimizer.zero_grad()loss = criterion(outputs, labels)loss.backward()# update weightsoptimizer.step()# 统计分类情况_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).squeeze().sum().numpy()# 打印训练信息loss_mean += loss.item()train_curve.append(loss.item())if (i+1) % log_interval == 0:loss_mean = loss_mean / log_intervalprint("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))loss_mean = 0.scheduler.step()  # 更新学习率# validate the model,在每个epoch中会进行验证集的测试,通过验证集观察模型是否过拟合if (epoch+1) % val_interval == 0:correct_val = 0.total_val = 0.loss_val = 0.net.eval()with torch.no_grad():for j, data in enumerate(valid_loader):inputs, labels = dataoutputs = net(inputs)loss = criterion(outputs, labels)_, predicted = torch.max(outputs.data, 1)total_val += labels.size(0)correct_val += (predicted == labels).squeeze().sum().numpy()loss_val += loss.item()loss_val_epoch = loss_val / len(valid_loader)valid_curve.append(loss_val_epoch)# valid_curve.append(loss.item())    # 20191022改,记录整个epoch样本的loss,注意要取平均print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val_epoch, correct_val / total_val))train_x = range(len(train_curve))
train_y = train_curvetrain_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curveplt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()# ============================ inference ============================BASE_DIR = os.path.dirname(os.path.abspath(__file__))
test_dir = os.path.join(BASE_DIR, "test_data")test_data = RMBDataset(data_dir=test_dir, transform=valid_transform)
valid_loader = DataLoader(dataset=test_data, batch_size=1)for i, data in enumerate(valid_loader):# forwardinputs, labels = dataoutputs = net(inputs)_, predicted = torch.max(outputs.data, 1)rmb = 1 if predicted.numpy()[0] == 0 else 100print("模型获得{}元".format(rmb))
1、在dataloader.py中---》是用单进程/多进程,以单进程为例
def __iter__(self):if self.num_workers == 0:return _SingleProcessDataLoaderIter(self)else:return _MultiProcessingDataLoaderIter(self)2、step Into _SingleProcessDataLoaderIter(self)
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):def __init__(self, loader):super(_SingleProcessDataLoaderIter, self).__init__(loader)assert self.timeout == 0assert self.num_workers == 0self.dataset_fetcher = _DatasetKind.create_fetcher(self.dataset_kind, self.dataset, self.auto_collation, self.collate_fn, self.drop_last)def __next__(self):index = self._next_index()  # may raise StopIterationdata = self.dataset_fetcher.fetch(index)  # may raise StopIterationif self.pin_memory:data = _utils.pin_memory.pin_memory(data)return datanext = __next__  # Python 2 compatibility3、单进程中最主要的函数是def __next__(self):
这个函数中告诉我们从每个Iteration当中读取哪些数据4、在index = self._next_index()  # may raise StopIteration。--》 run to cusor
Step Into--》simpler.py: 就是一个采样器,用来告诉我们每个Iteration中,batchsize该读取那些数据def _next_index(self):return next(self.sampler_iter)  # may raise StopIteration----》step Intodef __iter__(self):batch = []for idx in self.sampler:batch.append(idx)if len(batch) == self.batch_size:yield batchbatch = []if len(batch) > 0 and not self.drop_last:yield batch5、step out返回到dataoader.py---》data = self.dataset_fetcher.fetch(index)
把index输入dataset获取data---》step Into ---》fetch.pyclass _MapDatasetFetcher(_BaseDatasetFetcher):def __init__(self, dataset, auto_collation, collate_fn, drop_last):super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)# 在这里正式实现数据读取def fetch(self, possibly_batched_index):if self.auto_collation:
# 把一系列data拼接成一个listdata = [self.dataset[idx] for idx in possibly_batched_index]else:data = self.dataset[possibly_batched_index]return self.collate_fn(data)在data = [self.dataset[idx] for idx in possibly_batched_index] ----》step Into---》mydataset.py--->def __getitem__(self, index):

 

  相关解决方案