Lesson 2.张量的索引、分片、合并以及维度调整
转载自:https://www.bilibili.com/video/BV14X4y1A7KT?p=3
??张量作为有序的序列,也是具备数值索引的功能,并且基本索引方法和Python原生的列表、NumPy中的数组基本一致,当然,所有不同的是,PyTorch中还定义了一种采用函数来进行索引的方式。
??而作为PyTorch中基本数据类型,张量即具备了列表、数组的基本功能,同时还充当着向量、矩阵、甚至是数据框等重要数据结构,因此PyTorch中也设置了非常完备的张量合并与变换的操作。
import torch
import numpy as np
一、张量的符号索引
??张量也是有序序列,我们可以根据每个元素在系统内的顺序“编号”,来找出特定的元素,也就是索引。
1.一维张量索引
??一维张量的索引过程和Python原生对象类型的索引一致,基本格式遵循[start: end: step]
,索引的基本要点回顾如下。
t1 = torch.arange(1, 11)
t1
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
- 从左到右,从零开始
t1[0]
tensor(1)
注:张量索引出来的结果还是零维张量, 而不是单独的数。要转化成单独的数,需要使用item()方法。
- 冒号分隔,表示对某个区域进行索引,也就是所谓的切片
t1[1: 8] # 索引其中2-9号元素,并且左包含右不包含
tensor([2, 3, 4, 5, 6, 7, 8])
- 第二个冒号,表示索引的间隔
t1[1: 8: 2] # 索引其中2-9号元素,左包含右不包含,且隔两个数取一个
tensor([2, 4, 6, 8])
- 冒号前后没有值,表示索引这个区域
t1[1: : 2] # 从第二个元素开始索引,一直到结尾,并且每隔两个数取一个
tensor([ 2, 4, 6, 8, 10])
t1[: 8: 2] # 从第一个元素开始索引到第9个元素(不包含),并且每隔两个数取一个
tensor([1, 3, 5, 7])
在张量的索引中,step位必须大于0
t1[9: 1: -1]
---------------------------------------------------------------------------ValueError Traceback (most recent call last)<ipython-input-312-b82bb967c8e2> in <module>
----> 1 t1[9: 1: -1]ValueError: step must be greater than zero
2.二维张量索引
??二维张量的索引逻辑和一维张量的索引逻辑基本相同,二维张量可以视为两个一维张量组合而成,而在实际的索引过程中,需要用逗号进行分隔,分别表示对哪个一维张量进行索引、以及具体的一维张量的索引。
t2 = torch.arange(1, 10).reshape(3, 3)
t2
tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
t2[0, 1] # 表示索引第一行、第二个(第二列的)元素
tensor(2)
t2[0, ::2] # 表示索引第一行、每隔两个元素取一个
tensor([1, 3])
t2[0, [0, 2]] # 索引结果同上
tensor([1, 3])
t2[::2, ::2] # 表示每隔两行取一行、并且每一行中每隔两个元素取一个
tensor([[1, 3],[7, 9]])
t2[[0, 2], 1] # 索引第一行、第三行、第二列的元素
tensor([2, 8])
理解:对二维张量来说,基本可以视为是对矩阵的索引,并且行、列的索引遵照相同的索引规范,并用逗号进行分隔。
3.三维张量的索引
??在二维张量索引的基础上,三维张量拥有三个索引的维度。我们将三维张量视作矩阵组成的序列,则在实际索引过程中拥有三个维度,分别是索引矩阵、索引矩阵的行、索引矩阵的列。
t3 = torch.arange(1, 28).reshape(3, 3, 3)
t3
tensor([[[ 1, 2, 3],[ 4, 5, 6],[ 7, 8, 9]],[[10, 11, 12],[13, 14, 15],[16, 17, 18]],[[19, 20, 21],[22, 23, 24],[25, 26, 27]]])
t3[1, 1, 1] # 索引第二个矩阵中,第二行、第二个元素
tensor(14)
t3[1, ::2, ::2] # 索引第二个矩阵,行和列都是每隔两个取一个
tensor([[10, 12],[16, 18]])
t3[:: 2, :: 2, :: 2] # 每隔两个取一个矩阵,对于每个矩阵来说,行和列都是每隔两个取一个
tensor([[[ 1, 3],[ 7, 9]],[[19, 21],[25, 27]]])
理解:更为本质的角度去理解高维张量的索引,其实就是围绕张量的“形状”进行索引
t3.shape
torch.Size([3, 3, 3])
t3[1, 1, 1] # 与shape一一对应
tensor(14)
二、张量的函数索引
??在PyTorch中,我们还可以使用index_select函数,通过指定index来对张量进行索引。
t1
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
t1.ndim
1
indices = torch.tensor([1, 2])
indices
tensor([1, 2])
torch.index_select(t1, 0, indices)
tensor([2, 3])
在index_select函数中,第二个参数实际上代表的是索引的维度。对于t1这个一维向量来说,由于只有一个维度,因此第二个参数取值为0,就代表在第一个维度上进行索引
t2 = torch.arange(12).reshape(4, 3)
t2
tensor([[ 0, 1, 2],[ 3, 4, 5],[ 6, 7, 8],[ 9, 10, 11]])
t2.shape
torch.Size([4, 3])
indices
tensor([1, 2])
torch.index_select(t2, 0, indices)
tensor([[3, 4, 5],[6, 7, 8]])
dim参数取值为0,代表在shape的第一个维度上索引
torch.index_select(t2, 1, indices)
tensor([[ 1, 2],[ 4, 5],[ 7, 8],[10, 11]])
dim参数取值为1,代表在shape的第二个维度上索引
三、tensor.view()方法
??在正式介绍张量的切分方法之前,需要首先介绍PyTorch中的.view()方法。该方法会返回一个类似视图的结果,该结果和原张量对象共享一块数据存储空间,并且通过.view()方法,还可以改变对象结构,生成一个不同结构,但共享一个存储空间的张量。当然,共享一个存储空间,也就代表二者是“浅拷贝”的关系,修改其中一个,另一个也会同步进行更改。
t = torch.arange(6).reshape(2, 3)
t
tensor([[0, 1, 2],[3, 4, 5]])
te = t.view(3, 2) # 构建一个数据相同,但形状不同的“视图”
te
tensor([[0, 1],[2, 3],[4, 5]])
t
tensor([[0, 1, 2],[3, 4, 5]])
t[0] = 1 # 对t进行修改
t
tensor([[1, 1, 1],[3, 4, 5]])
te # te同步变化
tensor([[1, 1],[1, 3],[4, 5]])
tr = t.view(1, 2, 3) # 维度也可以修改
tr
tensor([[[1, 1, 1],[3, 4, 5]]])
“视图”的作用就是节省空间,而值得注意的是,在接下来介绍的很多切分张量的方法中,返回结果都是“视图”,而不是新生成一个对象。
三、张量的分片函数
1.分块:chunk函数
??chunk函数能够按照某维度,对张量进行均匀切分,并且返回结果是原张量的视图。
t2 = torch.arange(12).reshape(4, 3)
t2
tensor([[ 0, 1, 2],[ 3, 4, 5],[ 6, 7, 8],[ 9, 10, 11]])
tc = torch.chunk(t2, 4, dim=0) # 在第零个维度上(按行),进行四等分
tc
(tensor([[0, 1, 2]]),tensor([[3, 4, 5]]),tensor([[6, 7, 8]]),tensor([[ 9, 10, 11]]))
注意:chunk返回结果是一个视图,不是新生成了一个对象
tc[0][0]
tensor([0, 1, 2])
tc[0][0][0] = 1 # 修改tc中的值
tc
(tensor([[1, 1, 2]]),tensor([[3, 4, 5]]),tensor([[6, 7, 8]]),tensor([[ 9, 10, 11]]))
t2 # 原张量也会对应发生变化
tensor([[ 1, 1, 2],[ 3, 4, 5],[ 6, 7, 8],[ 9, 10, 11]])
当原张量不能均分时,chunk不会报错,但会返回其他均分的结果
torch.chunk(t2, 3, dim=0) # 次一级均分结果
(tensor([[1, 1, 2],[3, 4, 5]]),tensor([[ 6, 7, 8],[ 9, 10, 11]]))
len(torch.chunk(t2, 3, dim=0))
2
torch.chunk(t2, 5, dim=0) # 次一级均分结果
(tensor([[1, 1, 2]]),tensor([[3, 4, 5]]),tensor([[6, 7, 8]]),tensor([[ 9, 10, 11]]))
2.拆分:split函数
??split既能进行均分,也能进行自定义切分。当然,需要注意的是,和chunk函数一样,split返回结果也是view。
t2 = torch.arange(12).reshape(4, 3)
t2
tensor([[ 0, 1, 2],[ 3, 4, 5],[ 6, 7, 8],[ 9, 10, 11]])
均分情况
torch.split(t2, 2, 0) # 第二个参数只输入一个数值时表示均分,第三个参数表示切分的维度
(tensor([[0, 1, 2],[3, 4, 5]]),tensor([[ 6, 7, 8],[ 9, 10, 11]]))
按照索引切分
torch.split(t2, [1, 3], 0) # 第二个参数输入一个序列时,表示按照序列数值进行切分,也就是1/3分
(tensor([[0, 1, 2]]),tensor([[ 3, 4, 5],[ 6, 7, 8],[ 9, 10, 11]]))
注意,当第二个参数位输入一个序列时,序列的各数值的和必须等于对应维度下形状分量的取值。例如,上述代码中,是按照第一个维度进行切分,而t2总共有4行,因此序列的求和必须等于4,也就是1+3=4,而序列中每个分量的取值,则代表切块大小。
torch.split(t2, [1, 1, 1, 1], 0)
(tensor([[0, 1, 2]]),tensor([[3, 4, 5]]),tensor([[6, 7, 8]]),tensor([[ 9, 10, 11]]))
torch.split(t2, [1, 1, 2], 0)
(tensor([[0, 1, 2]]),tensor([[3, 4, 5]]),tensor([[ 6, 7, 8],[ 9, 10, 11]]))
ts = torch.split(t2, [1, 2], 1)
ts
(tensor([[0],[3],[6],[9]]),tensor([[ 1, 2],[ 4, 5],[ 7, 8],[10, 11]]))
ts[0][0] = 1 # view进行修改
t2 # 原对象同步改变
tensor([[ 1, 1, 2],[ 3, 4, 5],[ 6, 7, 8],[ 9, 10, 11]])
tensor的split方法和array的split方法有很大的区别,array的split方法是根据索引进行切分。
四、张量的合并操作
??张量的合并操作类似与列表的追加元素,可以拼接、也可以堆叠。
- 拼接函数:cat
PyTorch中,可以使用cat函数实现张量的拼接。
a = torch.zeros(2, 3)
a
tensor([[0., 0., 0.],[0., 0., 0.]])
b = torch.ones(2, 3)
b
tensor([[1., 1., 1.],[1., 1., 1.]])
c = torch.zeros(3, 3)
c
tensor([[0., 0., 0.],[0., 0., 0.],[0., 0., 0.]])
torch.cat([a, b]) # 按照行进行拼接,dim默认取值为0
tensor([[0., 0., 0.],[0., 0., 0.],[1., 1., 1.],[1., 1., 1.]])
torch.cat([a, b], 1) # 按照列进行拼接
tensor([[0., 0., 0., 1., 1., 1.],[0., 0., 0., 1., 1., 1.]])
torch.cat([a, c], 1) # 形状不匹配时将报错
---------------------------------------------------------------------------RuntimeError Traceback (most recent call last)<ipython-input-153-8bdd1a857266> in <module>
----> 1 torch.cat([a, c], 1) # 形状不匹配时将报错RuntimeError: Sizes of tensors must match except in dimension 1. Got 2 and 3 in dimension 0 (The offending index is 1)
注意理解,拼接的本质是实现元素的堆积,也就是构成a、b两个二维张量的各一维张量的堆积,最终还是构成二维向量。
- 堆叠函数:stack
??和拼接不同,堆叠不是将元素拆分重装,而是简单的将各参与堆叠的对象分装到一个更高维度的张量里。
a
tensor([[0., 0., 0.],[0., 0., 0.]])
b
tensor([[1., 1., 1.],[1., 1., 1.]])
torch.stack([a, b]) # 堆叠之后,生成一个三维张量
tensor([[[0., 0., 0.],[0., 0., 0.]],[[1., 1., 1.],[1., 1., 1.]]])
torch.stack([a, b]).shape
torch.Size([2, 2, 3])
torch.cat([a, b])
tensor([[0., 0., 0.],[0., 0., 0.],[1., 1., 1.],[1., 1., 1.]])
注意对比二者区别,拼接之后维度不变,堆叠之后维度升高。拼接是把一个个元素单独提取出来之后再放到二维张量中,而堆叠则是直接将两个二维张量封装到一个三维张量中,因此,堆叠的要求更高,参与堆叠的张量必须形状完全相同。
a
tensor([[0., 0., 0.],[0., 0., 0.]])
c
tensor([[0., 0., 0.],[0., 0., 0.],[0., 0., 0.]])
torch.cat([a, c]) # 横向拼接时,对行数没有一致性要求
tensor([[0., 0., 0.],[0., 0., 0.],[0., 0., 0.],[0., 0., 0.],[0., 0., 0.]])
torch.stack([a, c]) # 维度不匹配时也会报错
---------------------------------------------------------------------------RuntimeError Traceback (most recent call last)<ipython-input-167-0311d15e051e> in <module>
----> 1 torch.stack([a, c]) # 维度不匹配时也会报错RuntimeError: stack expects each tensor to be equal size, but got [2, 3] at entry 0 and [3, 3] at entry 1
五、张量维度变换
??此前我们介绍过,通过reshape方法,能够灵活调整张量的形状。而在实际操作张量进行计算时,往往需要另外进行降维和升维的操作,当我们需要除去不必要的维度时,可以使用squeeze函数,而需要手动升维时,则可采用unsqueeze函数。
- squeeze函数:删除不必要的维度
t = torch.zeros(1, 1, 3, 1)
t
tensor([[[[0.],[0.],[0.]]]])
t.shape
torch.Size([1, 1, 3, 1])
t张量解释:一个包含一个三维的四维张量,三维张量只包含一个三行一列的二维张量。
torch.squeeze(t)
tensor([0., 0., 0.])
torch.squeeze(t).shape
torch.Size([3])
转化后生成了一个一维张量
t1 = torch.zeros(1, 1, 3, 2, 1, 2)
t1.shape
torch.Size([1, 1, 3, 2, 1, 2])
torch.squeeze(t1)
tensor([[[0., 0.],[0., 0.]],[[0., 0.],[0., 0.]],[[0., 0.],[0., 0.]]])
torch.squeeze(t1).shape
torch.Size([3, 2, 2])
简单理解,squeeze就相当于提出了shape返回结果中的1
- unsqeeze函数:手动升维
t = torch.zeros(1, 2, 1, 2)
t.shape
torch.Size([1, 2, 1, 2])
torch.unsqueeze(t, dim = 0) # 在第1个维度索引上升高1个维度
tensor([[[[[0., 0.]],[[0., 0.]]]]])
torch.unsqueeze(t, dim = 0).shape
torch.Size([1, 1, 2, 1, 2])
torch.unsqueeze(t, dim = 2).shape # 在第3个维度索引上升高1个维度
torch.Size([1, 2, 1, 1, 2])
torch.unsqueeze(t, dim = 4).shape # 在第5个维度索引上升高1个维度
torch.Size([1, 2, 1, 2, 1])
注意理解维度和shape返回结果一一对应的关系,shape返回的序列有几个元素,张量就有多少维度。