当前位置: 代码迷 >> 综合 >> Vision Transformer详解(附代码)
  详细解决方案

Vision Transformer详解(附代码)

热度:20   发布时间:2023-12-17 22:36:15.0

1 引言

?Transformer\mathrm{Transformer}TransformerNLP\mathrm{NLP}NLP中大获成功,VisionTransformer\mathrm{Vision\text{ }Transformer}Vision Transformer则将Transformer\mathrm{Transformer}Transformer模型架构扩展到计算机视觉的领域中,并且它可以很好的地取代卷积操作,在不依赖卷积的情况下,依然可以在图像分类任务上达到很好的效果。卷积操作只能考虑到局部的特征信息,而Transformer\mathrm{Transformer}Transformer中的注意力机制可以综合考量全局的特征信息。VisionTransformer\mathrm{Vision\text{ }Transformer}Vision Transformer尽力做到在不改变Transformer\mathrm{Transformer}TransformerEncoder\mathrm{Encoder}Encoder架构的前提下,直接将其从NLP\mathrm{NLP}NLP领域迁移到计算机视觉领域中,目的是让原始的Transformer\mathrm{Transformer}Transformer模型开箱即用。如果想要了解Transformer\mathrm{Transformer}Transformer原理详细的介绍可以看我的上一篇文章《Transformer详解(附代码)》。

2 注意力机制应用

?在正式详细介绍VisionTransformer\mathrm{Vision\text{ }Transformer}Vision Transformer之前,先介绍两个注意力机制在计算机视觉中应用的例子。VisionTransformer\mathrm{Vision\text{ }Transformer}Vision Transformer并不是第一个将注意力机制应用到计算机视觉的领域中去的,其中SAGAN\mathrm{SAGAN}SAGANAttnGAN\mathrm{AttnGAN}AttnGAN就早已经在GAN\mathrm{GAN}GAN的框架中引入了注意力机制,并且它们大大提高了图像生成的质量。

2.1 Self-Attention GAN

?SAGAN\mathrm{SAGAN}SAGANGAN\mathrm{GAN}GAN的框架中利用自注意力机制来捕获图像特征的长距离依赖关系,使得合成的图像中考量了所有的图像特征信息。SAGAN\mathrm{SAGAN}SAGAN中自注意力机制的操作原理如下图所示。
?给定一个333通道的输入特征图X=(X1,X2,X3)∈R3×3×3X=(X^1,X^2,X^3)\in \mathbb{R}^{3\times 3\times 3}X=(X1,X2,X3)R3×3×3,其中Xi∈R3×3X^{i}\in \mathbb{R}^{3\times 3}XiR3×3i∈{1,2,3}i\in\{1,2,3\}i{ 1,2,3}。将XXX分别输入到三个不同的1×11\times 11×1的卷积层中,并生成query\mathrm{query}query特征图Q∈R3×3×3Q\in \mathbb{R}^{3\times 3\times 3}QR3×3×3key\mathrm{key}key特征图K∈R3×3×3K\in \mathbb{R}^{3\times 3\times 3}KR3×3×3value\mathrm{value}value特征图V∈R3×3×3V\in \mathbb{R}^{3\times 3\times 3}VR3×3×3。生成QQQ具体的计算过程为,给定三个卷积核Wq1W^{q1}Wq1Wq2W^{q2}Wq2Wq3∈R1×1×3W^{q3}\in\mathbb{R}^{1\times1\times3}Wq3R1×1×3,并用这三个卷积核分别与XXX做卷积运算得到Q1Q^1Q1Q2Q^2Q2Q3∈R3×3Q^3\in \mathbb{R}^{3 \times 3}Q3R3×3,即{Q1=X?Wq1Q2=X?Wq2Q3=X?Wq3\left\{\begin{aligned}Q^1&=X * W^{q1}\\Q^2&=X * W^{q2}\\Q^3&=X*W^{q3}\end{aligned}\right.??????Q1Q2Q3?=X?Wq1=X?Wq2=X?Wq3?其中?*?表示卷积运算符号。同理生成KKKVVV的计算过程与QQQ的计算过程类似。然后再利用QQQKKK进行注意力分数的计算得到矩阵A∈R3×3A\in \mathbb{R}^{3 \times 3}AR3×3,其中矩阵AAA的元素amla_{ml}aml?的计算公式为aml=Qm?Kl,m∈{1,2,3},l∈{1,2,3}a_{ml}=Q^m * K^l,\quad m \in \{1,2,3\},l\in \{1,2,3\}aml?=Qm?Kl,m{ 1,2,3},l{ 1,2,3}再对矩阵AAA利用softmax\mathrm{softmax}softmax函数进行注意力分布的计算得到注意力分布矩阵S∈R3×3S\in \mathbb{R}^{3\times 3}SR3×3,其中矩阵SSS的元素smls_{ml}sml?的计算公式为sml=exp?(aml)∑i=j3exp?(amj),m∈{1,2,3},l∈{1,2,3}s_{ml}=\frac{\exp(a_{ml})}{\sum\limits_{i=j}^{3}\exp(a_{mj})},\quad m \in \{1,2,3\},l\in\{1,2,3\}sml?=i=j3?exp(amj?)exp(aml?)?,m{ 1,2,3},l{ 1,2,3}最后利用注意力分布矩阵SSSvalue\mathrm{value}value特征图VVV得到最后的输出O=(O1,O2,O3)∈R3×3×3O=(O^1,O^2,O^3)\in \mathbb{R}^{3\times 3\times 3}O=(O1,O2,O3)R3×3×3,即{O1=s11?V1+s12?V2+s13?V3O2=s21?V1+s22?V2+s23?V3O3=s31?V1+s32?V2+s33?V3\left\{\begin{aligned}O^1&=s_{11}\cdot V^1+s_{12}\cdot V^2+s_{13}\cdot V^3\\O^2&=s_{21}\cdot V^1+s_{22}\cdot V^2+s_{23}\cdot V^3\\O^3&=s_{31}\cdot V^1+s_{32}\cdot V^2+s_{33}\cdot V^3\end{aligned}\right.??????O1O2O3?=s11??V1+s12??V2+s13??V3=s21??V1+s22??V2+s23??V3=s31??V1+s32??V2+s33??V3?

2.2 AttnGAN

?AttnGAN\mathrm{AttnGAN}AttnGAN通过利用注意力机制来实现多阶段细颗粒度的文本到图像的生成,它可以通过关注自然语言中的一些重要单词来对图像的不同子区域进行合成。比如通过文本“一只鸟有黄色的羽毛和黑色的眼睛”来生成图像时,会对关键词“鸟”,“羽毛”,“眼睛”,“黄色”,“黑色”给予不同的生成权重,并根据这些关键词的引导在图像的不同的子区域中进行细节的丰富。AttnGAN\mathrm{AttnGAN}AttnGAN中注意力机制的操作原理如下图所示。
?给定输入图像特征向量h=(h1,h2,h3,h4)∈RD^×4h=(h^1,h^2,h^3,h^4)\in\mathbb{R}^{\hat{D}\times 4}h=(h1,h2,h3,h4)RD^×4和词特征向量e=(e1,e2,e3,e4)e=(e^1,e^2,e^3,e^4)e=(e1,e2,e3,e4),其中hi∈RD^×1h^i\in \mathbb{R}^{\hat{D}\times 1}hiRD^×1ei∈RD×1e^i\in \mathbb{R}^{D\times 1}eiRD×1i∈{1,2,3,4}i\in \{1,2,3,4\}i{ 1,2,3,4}。首先利用矩阵WWW进行线性变换将词特征空间RD\mathbb{R}^{D}RD的向量转换成图像特征空间RD^\mathbb{R}^{\hat{D}}RD^的向量,则有e^=W?e=(e^1,e^2,e^3,e^4)∈RD^×4\hat{e}=W\cdot e=(\hat{e}^1,\hat{e}^2,\hat{e}^3,\hat{e}^4)\in \mathbb{R}^{\hat{D}\times 4}e^=W?e=(e^1,e^2,e^3,e^4)RD^×4然后再利用转换后的词特征e^\hat{e}e^与图像特征hhh进行注意力分数的计算得到注意力分数矩阵SSS,其中的分量sijs_{ij}sij?的计算公式为sij=(hi)??e^j,i∈{1,2,3,4},j∈{1,2,3,4}s_{ij}=(h^i)^{\top}\cdot \hat{e}^j,\quad i\in \{1,2,3,4\},j\in\{1,2,3,4\}sij?=(hi)??e^j,i{ 1,2,3,4},j{ 1,2,3,4} 再对矩阵SSS利用softmax\mathrm{softmax}softmax函数进行注意力分布的计算得到注意力分布矩阵β∈R4×4\beta\in \mathbb{R}^{4\times 4}βR4×4,其中矩阵β\betaβ的元素βij\beta_{ij}βij?的计算公式为βij=exp?(sij)∑k=13exp?(sik),i∈{1,2,3,4},l∈{1,2,3,4}\beta_{ij}=\frac{\exp(s_{ij})}{\sum\limits_{k=1}^{3}\exp(s_{ik})},\quad i \in \{1,2,3,4\},l\in\{1,2,3,4\}βij?=k=13?exp(sik?)exp(sij?)?,i{ 1,2,3,4},l{ 1,2,3,4}最后利用注意力分布矩阵β\betaβ和图像特征hhh得到最后的输出o=(o1,o2,o3,o4)∈RD^×4o=(o^1,o^2,o^3,o^4)\in \mathbb{R}^{\hat{D}\times 4}o=(o1,o2,o3,o4)RD^×4,即{o1=β11?h1+β12?h2+β13?h3+β14?h4o2=β21?h1+β22?h2+β23?h3+β24?h4o3=β31?h1+β32?h2+β33?h3+β34?h4o4=β41?h1+β42?h2+β43?h3+β44?h4\left\{\begin{aligned}o^1&=\beta_{11}\cdot h^1+\beta_{12}\cdot h^2+\beta_{13}\cdot h^3+\beta_{14}\cdot h^4\\o^2&=\beta_{21}\cdot h^1+\beta_{22}\cdot h^2+\beta_{23}\cdot h^3+\beta_{24}\cdot h^4\\o^3&=\beta_{31}\cdot h^1+\beta_{32}\cdot h^2+\beta_{33}\cdot h^3+\beta_{34}\cdot h^4\\o^4&=\beta_{41}\cdot h^1+\beta_{42}\cdot h^2+\beta_{43}\cdot h^3+\beta_{44}\cdot h^4\end{aligned}\right.????????????o1o2o3o4?=β11??h1+β12??h2+β13??h3+β14??h4=β21??h1+β22??h2+β23??h3+β24??h4=β31??h1+β32??h2+β33??h3+β34??h4=β41??h1+β42??h2+β43??h3+β44??h4?

3 Vision Transformer

?本节主要详细介绍VisionTransformer\mathrm{Vision\text{ }Transformer}Vision Transformer的工作原理,3.1节是关于VisionTransformer\mathrm{Vision\text{ }Transformer}Vision Transformer的整体框架,3.2节是关于TransformerEncoder\mathrm{Transformer\text{ }Encoder}Transformer Encoder的内部操作细节。对于TransformerEncoder\mathrm{Transformer\text{ }Encoder}Transformer EncoderMulti\mathrm{Multi}Multi-HeadAttention\mathrm{Head\text{ }Attention}Head Attention的原理本文不会赘述,具体想了解的可以参考上一篇文章《Transformer详解(附代码)》中相关原理的介绍。不难发现,不管是自然语言处理中的Transformer\mathrm{Transformer}Transformer,还是计算机视觉中图像生成的SAGAN\mathrm{SAGAN}SAGAN,以及文本生成图像的AttnGAN\mathrm{AttnGAN}AttnGAN,它们核心模块中注意力机制的主要目的就是求出注意力分布。

3.1 Vision Transformer整体框架

?如果下图所示为VisionTransformer\mathrm{Vision\text{ }Transformer}Vision Transformer的整体框架以及相应的训练流程

  • 给定一张图片X∈R3n×3nX\in \mathbb{R}^{3n\times 3n}XR3n×3n,并将它分割成999patch\mathrm{patch}patch分别为x1,?,x9∈Rn×nx^1,\cdots,x^9\in\mathbb{R}^{n\times n}x1,?,x9Rn×n。然后再将这个999patch\mathrm{patch}patch拉平,则有x1,?,x9∈Rn2x^1,\cdots,x^9\in\mathbb{R}^{n^2}x1,?,x9Rn2
  • 利用矩阵W∈Rl×n2W\in \mathbb{R}^{l \times n^2}WRl×n2将拉平后的向量xi∈Rn2,i∈{1,?,9}x^i\in\mathbb{R}^{n^2},i\in\{1,\cdots,9\}xiRn2,i{ 1,?,9}经过线性变换得到图像编码向量zi∈Rl,i∈{1,?,9}z^i\in \mathbb{R}^{l},i\in\{1,\cdots,9\}ziRl,i{ 1,?,9},具体的计算公式为zi=W?xi,i∈{1,?9}z^i = W\cdot x^i,\quad i\in\{1,\cdots9\}zi=W?xi,i{ 1,?9}
  • 然后将图像编码向量zi,i∈{1,?,9}z^{i},i\in\{1,\cdot,9\}zi,i{ 1,?,9}和类编码向量z0z^0z0分别与对应的位置编进行加和得到输入编码向量,则有zi+pi∈Rl,i∈{0,?9}z^{i}+p^{i}\in\mathbb{R}^l,\quad i\in\{0,\cdots 9\}zi+piRl,i{ 0,?9}
  • 接着将输入编码向量输入到VisionTransformerEncoder\mathrm{Vision\text{ }Transformer\text{ }Encoder}Vision Transformer Encoder中得到对应的输出oi∈Rl,i∈{0,?,9}o^i\in \mathbb{R}^l,i\in\{0,\cdots,9\}oiRl,i{ 0,?,9}
  • 最后将类编码向量o0o^0o0输入全连接神经网络中MLP\mathrm{MLP}MLP得到类别预测向量y^∈Rc\hat{y}\in\mathbb{R}^cy^?Rc,并与真实类别向量y∈Rcy\in\mathbb{R}^cyRc计算交叉熵损失得到损失值losslossloss,利用优化算法更新模型的权重参数

注意事项: 看到这里可能会有一个疑问为什么预测类别的时候只用到了类别编码向量o0o^0o0VisionTransformerEncoder\mathrm{Vision\text{ }Transformer\text{ }Encoder}Vision Transformer Encoder其它的输出为什么没有输入到MLP\mathrm{MLP}MLP中?为了回答这个问题,我们令函数f0(?)f_0(\cdot)f0?(?)VisionTransformerEncoder\mathrm{Vision\text{ }Transformer\text{ }Encoder}Vision Transformer Encoder,则类编码向量o0o^{0}o0可以表示为o0=f0(z0+p0,?,z9+p9)o^0=f_0(z^0+p^0,\cdots,z^9+p^9)o0=f0?(z0+p0,?,z9+p9)由上公式可以发现,类编码向量o0o^{0}o0是属于高层特征,其实它综合了所有的图像编码信息,所以可以用它来进行分类,这个可以类比在卷积神经网络中最后的类别输出向量其实就是一层层卷积得到的高层特征。

3.2 Transformer Encoder操作原理

?如下图所示分别为VisionTransformerEncoder\mathrm{Vision\text{ }Transformer\text{ }Encoder}Vision Transformer Encoder模型结构图和原始TransformerEncoder\mathrm{Transformer\text{ }Encoder}Transformer Encoder的模型结构图。可以直观的发现VisionTransformerEncoder\mathrm{Vision\text{ }Transformer\text{ }Encoder}Vision Transformer EncoderTransformerEncoder\mathrm{Transformer\text{ }Encoder}Transformer Encoder都有层归一化,多头注意力机制,残差连接和线性变换这四个操作,只是在操作顺序有所不同。在以下的Transformer\mathrm{ \text{ }Transformer} Transformer代码实例中,将以下两种Encoder\mathrm{Encoder}Encoder网络结构都进行了实现,可以发现两种网络结构都可以进行很好的训练。
?下图左半部分VisionTransformerEncoder\mathrm{Vision\text{ }Transformer\text{ }Encoder}Vision Transformer Encoder具体的操作流程为

  • 给定输入编码矩阵Z∈Rl×nZ\in\mathbb{R}^{l\times n}ZRl×n,首先将其进行层归一化得到Z′∈Rl×nZ^{\prime}\in\mathbb{R}^{l \times n}ZRl×n
  • 利用矩阵Wq,Wk,Wv∈Rl×lW^{q},W^{k},W^{v}\in \mathbb{R}^{l\times l}Wq,Wk,WvRl×lZ′Z^{\prime}Z进行线性变换得到矩阵Q,K,W∈Rl×nQ,K,W\in\mathbb{R}^{l\times n}Q,K,WRl×n具体的计算过程为{Q=Wq?Z′K=Wk?Z′V=Wv?Z′\left\{\begin{aligned}Q &= W^{q}\cdot Z^{\prime}\\K&=W^{k}\cdot Z^{\prime}\\V&=W^v \cdot Z^{\prime}\end{aligned}\right.??????QKV?=Wq?Z=Wk?Z=Wv?Z?再将这三个矩阵输入到Multi\mathrm{Multi}Multi-HeadAttention\mathrm{Head\text{ }Attention}Head Attention(该原理参考《Transformer详解(附代码)》)中得到矩阵Z′′∈Rl×nZ^{\prime\prime}\in \mathbb{R}^{l \times n}ZRl×n将最原始的输入矩阵ZZZZ′′Z^{\prime\prime}Z进行残差计算得到Z+Z′′∈Rl×nZ+Z^{\prime\prime}\in \mathbb{R}^{l\times n}Z+ZRl×n
  • Z+Z′′Z+Z^{\prime\prime}Z+Z进行第二次层归一化得到Z′′′∈Rl×nZ^{\prime\prime\prime}\in\mathbb{R}^{l\times n}ZRl×n,然后再将Z′′′Z^{\prime\prime\prime}Z输入到全连接神经网络中进行线性变换得到Z′′′′∈Rl×nZ^{\prime\prime\prime\prime}\in\mathbb{R}^{l\times n}ZRl×n。最后将Z+Z′′Z+Z^{\prime\prime}Z+ZZ′′′′Z^{\prime\prime\prime\prime}Z进行残差操作得到该Block\mathrm{Block}Block的输出Z+Z′′+Z′′′′∈Rl×nZ+Z^{\prime\prime}+Z^{\prime\prime\prime\prime}\in\mathbb{R}^{l\times n}Z+Z+ZRl×n。一个Encoder\mathrm{Encoder}Encoder可以将NNNBlock\mathrm{Block}Block进行堆叠,最后得到的输出为O∈Rl×nO\in\mathbb{R}^{l\times n}ORl×n

4 程序代码

?VisionTransformer\mathrm{Vision\text{ }Transformer}Vision Transformer的代码示例如下所示。该代码是由上一篇《Transformer详解(附代码)》的代码的基础上改编而来。VisionTransformer\mathrm{Vision\text{ }Transformer}Vision Transformer的作者的本意就是想让在NLP\mathrm{NLP}NLP中的Transformer\mathrm{Transformer}Transformer模型架构做尽可能少的修改可以直接迁移到CV\mathrm{CV}CV中,所以以下程序尽可能保持作者的原意,并在代码实现了两种Encoder\mathrm{Encoder}Encoder的网络结构,即3.2节图片所示的两个网络结构,一种是最原始的Encoder\mathrm{Encoder}Encoder网络结构,一种是VisionTransformer\mathrm{Vision\text{ }Transformer}Vision Transformer论文里的Encoder\mathrm{Encoder}Encoder的网络结构。这里需要注意的是,VisionTransformer\mathrm{Vision\text{ }Transformer}Vision Transformer里并能没有Decoder\mathrm{Decoder}Decoder模块,所以不需要计算Encoder\mathrm{Encoder}EncoderDecoder\mathrm{Decoder}Decoder的交叉注意力分布,这就进一步给VisionTransformer\mathrm{Vision\text{ }Transformer}Vision Transformer的编程带来了简便。VisionTransformer\mathrm{Vision\text{ }Transformer}Vision Transformer的开源代码的网址为https://github.com/lucidrains/vit-pytorch/tree/main/vit_pytorch

import torch
import torch.nn as nn
import os
from einops import rearrange
from einops import repeat
from einops.layers.torch import Rearrangedef inputs_deal(inputs):return inputs if isinstance(inputs, tuple) else(inputs, inputs)class SelfAttention(nn.Module):def __init__(self, embed_size, heads):super(SelfAttention, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // headsassert (self.head_dim * heads == embed_size), "Embed size needs to be div by heads"self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)self.fc_out = nn.Linear(heads * self.head_dim, embed_size)def forward(self, values, keys, query):N =query.shape[0]value_len , key_len , query_len = values.shape[1], keys.shape[1], query.shape[1]# split embedding into self.heads piecesvalues = values.reshape(N, value_len, self.heads, self.head_dim)keys = keys.reshape(N, key_len, self.heads, self.head_dim)queries = query.reshape(N, query_len, self.heads, self.head_dim)values = self.values(values)keys = self.keys(keys)queries = self.queries(queries)energy = torch.einsum("nqhd,nkhd->nhqk", queries, keys)# queries shape: (N, query_len, heads, heads_dim)# keys shape : (N, key_len, heads, heads_dim)# energy shape: (N, heads, query_len, key_len)attention = torch.softmax(energy/ (self.embed_size ** (1/2)), dim=3)out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim)# attention shape: (N, heads, query_len, key_len)# values shape: (N, value_len, heads, heads_dim)# (N, query_len, heads, head_dim)out = self.fc_out(out)return outclass TransformerBlock(nn.Module):def __init__(self, embed_size, heads, dropout, forward_expansion):super(TransformerBlock, self).__init__()self.attention = SelfAttention(embed_size, heads)self.norm = nn.LayerNorm(embed_size)self.feed_forward = nn.Sequential(nn.Linear(embed_size, forward_expansion*embed_size),nn.ReLU(),nn.Linear(forward_expansion*embed_size, embed_size))self.dropout = nn.Dropout(dropout)def forward(self, value, key, query, x, type_mode):if type_mode == 'original':attention = self.attention(value, key, query)x = self.dropout(self.norm(attention + x))forward = self.feed_forward(x)out = self.dropout(self.norm(forward + x))return outelse:attention = self.attention(self.norm(value), self.norm(key), self.norm(query))x =self.dropout(attention + x)forward = self.feed_forward(self.norm(x))out = self.dropout(forward + x)return outclass TransformerEncoder(nn.Module):def __init__(self,embed_size,num_layers,heads,forward_expansion,dropout = 0,type_mode = 'original'):super(TransformerEncoder, self).__init__()self.embed_size = embed_sizeself.type_mode = type_modeself.Query_Key_Value = nn.Linear(embed_size, embed_size * 3, bias = False)self.layers = nn.ModuleList([TransformerBlock(embed_size,heads,dropout=dropout,forward_expansion=forward_expansion,)for _ in range(num_layers)])self.dropout = nn.Dropout(dropout)def forward(self, x):for layer in self.layers:QKV_list = self.Query_Key_Value(x).chunk(3, dim = -1)x = layer(QKV_list[0], QKV_list[1], QKV_list[2], x, self.type_mode)return xclass VisionTransformer(nn.Module):def __init__(self, image_size, patch_size, num_classes, embed_size, num_layers, heads, mlp_dim, pool = 'cls',channels = 3,dropout = 0,emb_dropout = 0.1,type_mode = 'vit'):super(VisionTransformer, self).__init__()img_h, img_w = inputs_deal(image_size)patch_h, patch_w = inputs_deal(patch_size)assert img_h % patch_h == 0 and img_w % patch_w == 0, 'Img dimensions can be divisible by the patch dimensions'num_patches = (img_h // patch_h) * (img_w // patch_w)patch_size = channels * patch_h * patch_wself.patch_embedding = nn.Sequential(Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_h, p2=patch_w),nn.Linear(patch_size, embed_size, bias=False))self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embed_size))self.cls_token = nn.Parameter(torch.randn(1, 1, embed_size))self.dropout = nn.Dropout(emb_dropout)self.transformer = TransformerEncoder(embed_size, num_layers, heads, mlp_dim,dropout)self.pool = poolself.to_latent = nn.Identity()self.mlp_head = nn.Sequential(nn.LayerNorm(embed_size),nn.Linear(embed_size, num_classes))def forward(self, img):x = self.patch_embedding(img)b, n, _ = x.shapecls_tokens = repeat(self.cls_token, '() n d ->b n d', b = b)x = torch.cat((cls_tokens, x), dim = 1)x += self.pos_embedding[:, :(n + 1)]x = self.dropout(x)x = self.transformer(x)x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]x = self.to_latent(x)return self.mlp_head(x)if __name__ == '__main__':vit = VisionTransformer(image_size = 256,patch_size = 16,num_classes = 10,embed_size = 256,num_layers = 6,heads = 8,mlp_dim = 512,dropout = 0.1,emb_dropout = 0.1)img = torch.randn(3, 3, 256, 256)pred = vit(img)print(pred)

?以下代码是利用VisionTransformer\mathrm{Vision \text{ }Transformer}Vision Transformer网络结构训练一个分类mnist\mathrm{mnist}mnist数据集的主程序代码。

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import VIT
import osdef train():batch_size = 4device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')epoches = 20mnist_train = datasets.MNIST("mnist-data", train=True, download=True, transform=transforms.ToTensor())train_loader = torch.utils.data.DataLoader(mnist_train, batch_size= batch_size, shuffle=True)mnist_model = VIT.VisionTransformer(image_size = 28,patch_size = 7,num_classes = 10,channels = 1,embed_size = 512,num_layers = 1,heads = 2,mlp_dim =1024,dropout = 0,emb_dropout = 0)loss_fn = nn.CrossEntropyLoss()mnist_model = mnist_model.to(device)opitimizer = optim.Adam(mnist_model.parameters(), lr=0.00001)mnist_model.train()for epoch in range(epoches):total_loss = 0 corrects = 0 num = 0for batch_X, batch_Y in train_loader:batch_X, batch_Y = batch_X.to(device), batch_Y.to(device)opitimizer.zero_grad()outputs = mnist_model(batch_X)_, pred = torch.max(outputs.data, 1)loss = loss_fn(outputs, batch_Y)loss.backward()opitimizer.step()total_loss += loss.item()corrects = torch.sum(pred == batch_Y.data)num += batch_sizeprint(epoch, total_loss/float(num), corrects.item()/float(batch_size))if __name__ == '__main__':train()

?训练的过程如下所示,可以发现损失函数可以稳定下降。但是训练一个VisionTransformer\mathrm{Vision \text{ }Transformer}Vision Transformer模型真的是很烧硬件,跟训练一个普通的CNN\mathrm{CNN}CNN模型相比,训练一个VisionTransformer\mathrm{Vision \text{ }Transformer}Vision Transformer模型更加耗时耗力。

  相关解决方案