当前位置: 代码迷 >> 综合 >> Pytorch:关于epoch、batch_size和batch_idx(iteration )的一些理解(深度学习)
  详细解决方案

Pytorch:关于epoch、batch_size和batch_idx(iteration )的一些理解(深度学习)

热度:0   发布时间:2023-12-17 04:48:42.0

前言

在新手搭建神经网络时,常弄不清epoch、batch_size、iteration和batch_idx(iteration )的区别。
这里以torchvision自带的CIFAR10数据集来举例,通过代码操作来直观地对这几个概念进行理解。
声明,这里batch_idx==iteration

数据准备

首先加载数据集:

import torch
import torch.nn as nn
import torchvisiontrain_dataset = torchvision.datasets.CIFAR10(root="data/",train=True,download=False)
test_dataset = torchvision.datasets.CIFAR10(root="data/",train=False,download=False)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=100,shuffle=True,num_workers=0)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=50,shuffle=False,num_workers=0)

如果之前没有下载过此数据集,请将train_dataset中的download设置为True进行下载。
这里最关键的数据是代码loader中的batch_size,这里记住数值为50

数据简介

输入:

print(train_dataset)
print(train_dataset.data.shape)

得到:

ataset CIFAR10Number of datapoints: 50000Root location: data/Split: Train
(50000, 32, 32, 3)

train_dataset中含有50000个32*32的3通道图片数据,test_dataset含有10000个。


我们使用enumerate函数对已经载入的数据进行迭代,这是一个能够迭代式地从数据装载器中拿出数据的函数,看看能够迭代多少次。

for batch_idx,(inputs,labels) in enumerate(train_loader,0): # 从0开始迭代print(batch_idx)

得到batch_idx最终显示为1000。发现,存在等式50000=50*1000
让我们试试test_loader,此时的test的batch_size25。结果batch_idx为多少呢?400!满足10000=25*400
很直观地,存在等式:数据个数/长度 = batch_size * batch_idx。一般的,数据个数/长度是已知的,batch_size是程序员自行设计的,而batch_idx是根据等式得出的。

那究竟这些参数是什么意思呢?

  • batch_size : 代表每次从所有数据中取出一小筐子数据进行训练,类似挑十框石头,每次挑一筐,此时的batch_size=1。这个参数是由于深度学习中尝使用SGD(随机梯度下降)产生。batch_size的大小取值和GPU内存会有关系,数值越大一次性载入数据越多,占用的GPU内存越多。适当增加batch_size能够增加训练速度和训练精度(因为梯度下降时震动较小),过小会导致模型收敛困难。
  • batch_idx(iteration ) : 代表要进行多少次batch_size的迭代,十框石头一次挑一框,batch_idx即为10。而为什么会介绍这个参数呢? 是因为在深度学习训练中,需要print出一个个阶段的精度/loss,而这个参数就是用来作为pirnt精度/loss的index。实际上,这是一个代号参数,也可以写成iterationstepi或者其他自己能够清晰理解的名字。
  • epoch : 把所有的训练数据全部迭代遍历一遍(单次epoch),在上面的例子是把train_loader的50000个数据遍历一遍,举例的话是将十框石头全部搬一遍称为一个epoch。深度学习中常常训练十个甚至百个epoch,这是根据网络最终是否收敛决定的。
  • 数据个数/长度 = 1 epoch = batch_size * batch_idx
  相关解决方案