文章目录
- 1.前言
- 2.tensor 的存储方式
-
- 2.1.基本知识
-
- 2.1.1官方文档
- 2.2.tensor 的stride() 和 storage_offset() 属性
-
- 2.2.1 stride()
- 2.2.2 storage_offset()
- 3.view(),reshape(),resize_()之间的关系
-
- 3.1.view()
- 3.2.reshape()
-
- 3.2.1.tensor的连续性
- 3.3.resize_()
-
- 3.3.1.数据多的时候
- 3.3.2.数据少的时候
- 3.3.3.处理不连续数据
- 4.总结
1.前言
在使用 expand() 函数时,查看官方文档,文档中说返回的是原 tensor
的一个 view
,不太理解 view
的意思。遂查找了解。后续会更新 expand()
函数的用法。
2.tensor 的存储方式
2.1.基本知识
tensor
的存储,分为两个部分(一个 tensor
占用两个内存位置)
- 一个位置存储了真正的数据,我们称为存储区(Storage)
- 一个位置存储
tensor
的 形状(size
),步长(stride
),索引等信息。我们称为头信息(Tensor)
假如我们有两个 tensor: A, B
。我们利用 =
号,将 A
赋值给 B
,做的其实是浅拷贝。也就是说,A
与B
共享数据(存储部分),不同的只是头信息
头信息是对存储区的一种表现形式,这决定了我们是以什么排列方式看到真实数据的。其实这就是 tensor
的视图,tensor.view()
函数就是通过改变头信息,来使数据以不同的形式展示(真实数据并没有改变)。我们在下面会提到。
我们利用代码来说明一下。函数 tensor.storage().data_ptr()
是用于获取 tensor
存储区地址的。
import torch
a = torch.tensor([1, 2, 3])
b = a
b[0] = 100
print(f'a : {
a}')
print(f'b : {
b}')
print(f'storage address of a: {
a.storage().data_ptr()}')
print(f'storage address of b: {
b.storage().data_ptr()}')
>>a : tensor([100, 2, 3])
>>b : tensor([100, 2, 3])
>>storage address of a: 94784411983360
>>storage address of b: 94784411983360
操作:我们使用 =
将 a
赋值给 b
,然后修改了 b
。
从结果中我们可以看出,a , b
发生了改变。二者存储区的地址相同。这说明二者共用存储区。
2.1.1官方文档
其实这里官方文档是有提到的
意思就是,PyTorch
支持一个 tensor
是一个现存 tensor
的 View
的。(现存 tensor
称为 base tensor
,另一个称为 view tensor
)。二者共享内存。
- 这种操作可以避免一些数据复制,使得我们能够快速,并且节省内存得进行
reshape
,切片,和一些基于元素的操作
2.2.tensor 的stride() 和 storage_offset() 属性
tensor
为了节约内存,很多操作都是在更改头信息区。头信息区包含了,如何组织数据,以及从哪里开始组织数据。其中两个重要的属性是 stride()
和 storage_offset()
2.2.1 stride()
在指定维度 dim
上,从一个元素跳到下一个元素所必须的步长(在存储区中经过的元素的个数)
a = torch.randn(3, 2)
print(a.stride())
>>(2, 1)
其实不难理解,在第0
维,想要跳到下一个元素,比如从 a[0][0] -> a[1][0]
,需要经过两个元素,步长是 2。在第 1
维,想跳到下一个元素,从a[0][0] -> a[0][1]
,需要经过一个元素,步长是 1。
2.2.2 storage_offset()
表示 tensor
的第 0 个元素与真实存储区的第 0 个元素的偏移量
a = torch.tensor([1, 2, 3, 4, 5])
b = a[1:]
c = a[3:]
print(b.storage_offset())
print(c.storage_offset())
>>1
>>3
可见,b
的第 0 个元素与 a
的第 0 个元素之间的偏移量是 1,c
与 a
的偏移量是 3
3.view(),reshape(),resize_()之间的关系
3.1.view()
view
从字面意思上就是 视图 的意思。因此,就是将数据以某种排列方式展示给我们,不改变存储区的真实数据,只改变头信息区。
a = torch.tensor([1, 2, 3, 4, 5, 6])
b = a.view(2, 3)
print(f'a : {
a}')
print(f'b : {
b}')
print(f'storage address of a: {
a.storage().data_ptr()}')
print(f'storage address of b: {
b.storage().data_ptr()}')
>>a : tensor([1, 2, 3, 4, 5, 6])
>>b : tensor([[1, 2, 3],[4, 5, 6]])
>>storage address of a: 93972187307840
>>storage address of b: 93972187307840
可见,二者共享存储区
print(f'storage of a: {
a.storage()}')
print(f'storage of b: {
b.storage()}')
>>[torch.LongStorage of size 6]
storage of a: 123456
>>[torch.LongStorage of size 6]
storage of b: 123456
存储区的数据并没有发生改变
print(f'stride of a : {
a.stride()}')
print(f'stride of b : {
b.stride()}')
>>(1,)
>>(3, 1)
可见,stride
发生改变,也就是头信息区发生改变
3.2.reshape()
3.2.1.tensor的连续性
tensor
的连续性说的其实是 stride()
属性 和 size()
之间的关系
连续性条件: s t r i d e [ i ] = s t r i d e [ i + 1 ] ? s i z e [ i + 1 ] stride[i] = stride[i + 1]*size[i+1] stride[i]=stride[i+1]?size[i+1]
意思就是,第 i
维跳到下一个元素走的步数,是 i + 1
维走到下一维的步数,乘以 i + 1
维数的个数。
比如二维数组中 s t r i d e [ 0 ] = s t r i d e [ 1 ] ? s i z e [ 1 ] stride[0] = stride[1]*size[1] stride[0]=stride[1]?size[1],代表的就是第 0
维走到下一个数,需要走完这一行。
比如上面的例子中
a = torch.tensor([1, 2, 3, 4, 5, 6])
b = a.view(2, 3)
对 b
来说:stride[0] = 3
,stride[1] = 1
,size[1] = 3
。满足上面的条件
直观来说,就是:在存储区的真实数据中,在我旁边的数,现在还在我旁边,就叫连续
有些操作会改变连续性,比如转置
a = torch.tensor([1, 2, 3, 4, 5, 6]).view(2, 3)
b = a.t()
print(f'a : {
a}')
print(f'b : {
b}')
print(a.stride())
print(b.stride())
>>a : tensor([[1, 2, 3],[4, 5, 6]])
>>b : tensor([[1, 4],[2, 5],[3, 6]])
>>(3, 1)
>>(1, 3)
b
是 a
的转置。==在第 0
维走到下一个元素,1 -> 2
,步长是 1,因为在存储区中 1 和 2 是相邻的。==这就不满足上面的式子了,因此是不连续的。
再次强调:stride
是在当前维走到下一个元素,在存储区中需要经过的元素的个数
直观理解:在第 0
维,如果数据是连续的,走到下一个元素,应该把这一行走完,步长是 2。现在 4 的邻居是 1 和 2 了,实际上应该是 3 和 5。这就说明数据不连续了
不连续是不能使用 view()
方法的。那有什么办法可以让 b
使用 view()
呢?就是将其连续化(b.contiguous()
)
a = torch.tensor([1, 2, 3, 4, 5, 6]).view(2, 3)
b = a.t()
c = b.contiguous()
print(f'a : {
a}')
print(f'b : {
b}')
print(f'c : {
c}')
print(f'stride of a : {
a.stride()}')
print(f'stride of b : {
b.stride()}')
print(f'stride of c : {
c.stride()}')
print(f'storage address of a: {
a.storage().data_ptr()}')
print(f'storage address of b: {
b.storage().data_ptr()}')
print(f'storage address of c: {
c.storage().data_ptr()}')
>>a : tensor([[1, 2, 3],[4, 5, 6]])
>>b : tensor([[1, 4],[2, 5],[3, 6]])
>>c : tensor([[1, 4],[2, 5],[3, 6]])
>>stride of a : (3, 1)
>>stride of b : (1, 3)
>>stride of c : (2, 1)
>>storage address of a: 94382097256256
>>storage address of b: 94382097256256
>>storage address of c: 94382053572096
我们看到,c
的数据恢复了连续性,且其存储区的地址与 a, b
不同了。
contiguous()
函数其实就是创造了一个全新的 tensor
。在存储区中,将 b
中的数据按顺序存放,得到 c
这样我们可以说明 reshape()
和 view()
的区别了
- 当
tensor
满足连续性要求时,reshape() = view()
,和原来tensor
共用存储区 - 当
tensor
不满足连续性要求时,reshape() = **contiguous() + view()
,会产生新的存储区的tensor
,与原来tensor
不共用存储区
3.3.resize_()
前面说到的 reshape
和 view
都必须要用到全部的原始数据,比如你的原始数据只有12个,无论你怎么变形都必须要用到12个数字,不能多不能少。因此你就不能把只有12个数字的 tensor
强行 reshap
成 2*5
的维度的tensor
。但是 resize_()
可以做到,无论你存储区原始有多少个数字,我都能变成你想要的维度,数字不够怎么办?随机产生凑!数字多了怎么办?就取我需要的部分!
3.3.1.数据多的时候
a = torch.tensor([1, 2, 3, 4, 5, 6, 7])
b = a.resize_(2, 3)
print(f'a : {
a}')
print(f'b : {
b}')
print(f'stride of a : {
a.stride()}')
print(f'stride of b : {
b.stride()}')
print(f'storage address of a: {
a.storage().data_ptr()}')
print(f'storage address of b: {
b.storage().data_ptr()}')
>>a : tensor([[1, 2, 3],[4, 5, 6]])
>>b : tensor([[1, 2, 3],[4, 5, 6]])
>>stride of a : (3, 1)
>>stride of b : (3, 1)
>>storage address of a: 94579423708416
>>storage address of b: 94579423708416print(a.storage())
>> 1234567
可见,取的是前 6 个。
会改变 a
,但是并没有改变存储区中的数据,a, b
共用存储区
3.3.2.数据少的时候
a = torch.tensor([1, 2, 3, 4, 5])
b = a.resize_(2, 3)
print(f'a : {
a}')
print(f'b : {
b}')
print(f'stride of a : {
a.stride()}')
print(f'stride of b : {
b.stride()}')
print(f'storage address of a: {
a.storage().data_ptr()}')
print(f'storage address of b: {
b.storage().data_ptr()}')
>>a : tensor([[ 1, 2, 3],[ 4, 5, 94673159007352]])
>>b : tensor([[ 1, 2, 3],[ 4, 5, 94673159007352]])
>>stride of a : (3, 1)
>>stride of b : (3, 1)
>>storage address of a: 94673159007296
>>storage address of b: 94673159007296print(a.storage())
>> 1234594673159007352
可见,补充了一个数
会改变 a
,且改变了存储区中的数据,a, b
共用存储区(但是已经不是刚刚那个存储区了,地址变了)
3.3.3.处理不连续数据
a = torch.arange(6).view(2, 3)
b = a.t()
c = b.resize_(3, 2)
print(f'a : {
a}')
print(f'b : {
b}')
print(f'c : {
c}')
print(f'stride of a : {
a.stride()}')
print(f'stride of b : {
b.stride()}')
print(f'stride of c : {
c.stride()}')
print(f'storage address of a: {
a.storage().data_ptr()}')
print(f'storage address of b: {
b.storage().data_ptr()}')
print(f'storage address of c: {
c.storage().data_ptr()}')
>>a : tensor([[0, 1, 2],[3, 4, 5]])
>>b : tensor([[0, 3],[1, 4],[2, 5]])
>>c : tensor([[0, 3],[1, 4],[2, 5]])
>>stride of a : (3, 1)
>>stride of b : (1, 3)
>>stride of c : (1, 3)
>>storage address of a: 94375435009664
>>storage address of b: 94375435009664
>>storage address of c: 94375435009664
可见,使用 resize_()
之后,数据仍然保持连续性。并且没有开辟新的 tensor
,与原 tensor
共享存储区
print(a.storage())
>> 012345
并且,没有改变存储区中的数。
也就是说,resize_()
只是改变了头信息,使得数据以我们想要的形式呈现,而并没有改变其他信息。
4.总结
最后总结一下 view()
、reshape()
、resize_()
三者的关系和区别。
view()
只能对满足连续性要求的tensor使用。- 当
tensor
满足连续性要求时,reshape() = view()
,和原来tensor
共用内存。 - 当
tensor
不满足连续性要求时,reshape() = **contiguous() + view()
,会产生新的存储区的tensor
,与原来tensor
不共用存储区。 resize_()
可以随意的获取任意维度的tensor
,不用在意真实数据的个数限制,但是不推荐使用。
参考:Pytorch——Tensor的储存机制以及view()、reshape()、reszie_()三者的关系和区别 - Circle_Wang - 博客园 (cnblogs.com)