当前位置: 代码迷 >> 综合 >> tensor的存储方式 + view() reshape() resize_() 区别
  详细解决方案

tensor的存储方式 + view() reshape() resize_() 区别

热度:60   发布时间:2023-11-25 15:38:27.0

文章目录

  • 1.前言
  • 2.tensor 的存储方式
    • 2.1.基本知识
      • 2.1.1官方文档
    • 2.2.tensor 的stride() 和 storage_offset() 属性
      • 2.2.1 stride()
      • 2.2.2 storage_offset()
  • 3.view(),reshape(),resize_()之间的关系
    • 3.1.view()
    • 3.2.reshape()
      • 3.2.1.tensor的连续性
    • 3.3.resize_()
      • 3.3.1.数据多的时候
      • 3.3.2.数据少的时候
      • 3.3.3.处理不连续数据
  • 4.总结

1.前言

在使用 expand() 函数时,查看官方文档,文档中说返回的是原 tensor 的一个 view,不太理解 view 的意思。遂查找了解。后续会更新 expand() 函数的用法。

2.tensor 的存储方式

2.1.基本知识

tensor的存储,分为两个部分(一个 tensor 占用两个内存位置)

  • 一个位置存储了真正的数据,我们称为存储区(Storage)
  • 一个位置存储 tensor 的 形状(size),步长(stride),索引等信息。我们称为头信息(Tensor)

假如我们有两个 tensor: A, B。我们利用 = 号,将 A 赋值给 B,做的其实是浅拷贝。也就是说,AB共享数据(存储部分),不同的只是头信息

头信息是对存储区的一种表现形式,这决定了我们是以什么排列方式看到真实数据的。其实这就是 tensor 的视图,tensor.view() 函数就是通过改变头信息,来使数据以不同的形式展示(真实数据并没有改变)。我们在下面会提到。

我们利用代码来说明一下。函数 tensor.storage().data_ptr() 是用于获取 tensor 存储区地址的。

import torch
a = torch.tensor([1, 2, 3])
b = a
b[0] = 100
print(f'a : {
      a}')
print(f'b : {
      b}')
print(f'storage address of a: {
      a.storage().data_ptr()}')
print(f'storage address of b: {
      b.storage().data_ptr()}')
>>a : tensor([100,   2,   3])
>>b : tensor([100,   2,   3])
>>storage address of a: 94784411983360
>>storage address of b: 94784411983360

操作:我们使用 =a 赋值给 b ,然后修改了 b

从结果中我们可以看出,a , b 发生了改变。二者存储区的地址相同。这说明二者共用存储区。

2.1.1官方文档

其实这里官方文档是有提到的

在这里插入图片描述

意思就是,PyTorch 支持一个 tensor 是一个现存 tensorView 的。(现存 tensor 称为 base tensor,另一个称为 view tensor)。二者共享内存。

  • 这种操作可以避免一些数据复制,使得我们能够快速,并且节省内存得进行 reshape,切片,和一些基于元素的操作

2.2.tensor 的stride() 和 storage_offset() 属性

tensor 为了节约内存,很多操作都是在更改头信息区。头信息区包含了,如何组织数据,以及从哪里开始组织数据。其中两个重要的属性是 stride()storage_offset()

2.2.1 stride()

在指定维度 dim 上,从一个元素跳到下一个元素所必须的步长(在存储区中经过的元素的个数

a = torch.randn(3, 2)
print(a.stride())
>>(2, 1)

在这里插入图片描述

其实不难理解,在第0 维,想要跳到下一个元素,比如从 a[0][0] -> a[1][0] ,需要经过两个元素,步长是 2。在第 1 维,想跳到下一个元素,从a[0][0] -> a[0][1],需要经过一个元素,步长是 1。

2.2.2 storage_offset()

表示 tensor 的第 0 个元素与真实存储区的第 0 个元素的偏移量

a = torch.tensor([1, 2, 3, 4, 5])
b = a[1:]
c = a[3:]
print(b.storage_offset())
print(c.storage_offset())
>>1
>>3

可见,b的第 0 个元素与 a 的第 0 个元素之间的偏移量是 1,ca 的偏移量是 3

3.view(),reshape(),resize_()之间的关系

3.1.view()

view 从字面意思上就是 视图 的意思。因此,就是将数据以某种排列方式展示给我们,不改变存储区的真实数据,只改变头信息区。

a = torch.tensor([1, 2, 3, 4, 5, 6])
b = a.view(2, 3)
print(f'a : {
      a}')
print(f'b : {
      b}')
print(f'storage address of a: {
      a.storage().data_ptr()}')
print(f'storage address of b: {
      b.storage().data_ptr()}')
>>a : tensor([1, 2, 3, 4, 5, 6])
>>b : tensor([[1, 2, 3],[4, 5, 6]])
>>storage address of a: 93972187307840
>>storage address of b: 93972187307840

可见,二者共享存储区

print(f'storage of a: {
      a.storage()}')
print(f'storage of b: {
      b.storage()}')
>>[torch.LongStorage of size 6]
storage  of a:  123456
>>[torch.LongStorage of size 6]
storage  of b:  123456

存储区的数据并没有发生改变

print(f'stride of a : {
      a.stride()}')
print(f'stride of b : {
      b.stride()}')
>>(1,)
>>(3, 1)

可见,stride 发生改变,也就是头信息区发生改变

3.2.reshape()

3.2.1.tensor的连续性

tensor 的连续性说的其实是 stride() 属性 和 size() 之间的关系

连续性条件: s t r i d e [ i ] = s t r i d e [ i + 1 ] ? s i z e [ i + 1 ] stride[i] = stride[i + 1]*size[i+1] stride[i]=stride[i+1]?size[i+1]

意思就是,第 i 维跳到下一个元素走的步数,是 i + 1 维走到下一维的步数,乘以 i + 1 维数的个数。

比如二维数组中 s t r i d e [ 0 ] = s t r i d e [ 1 ] ? s i z e [ 1 ] stride[0] = stride[1]*size[1] stride[0]=stride[1]?size[1],代表的就是第 0 维走到下一个数,需要走完这一行。

比如上面的例子中

a = torch.tensor([1, 2, 3, 4, 5, 6])
b = a.view(2, 3)

b 来说:stride[0] = 3stride[1] = 1size[1] = 3。满足上面的条件

直观来说,就是:在存储区的真实数据中,在我旁边的数,现在还在我旁边,就叫连续

有些操作会改变连续性,比如转置

a = torch.tensor([1, 2, 3, 4, 5, 6]).view(2, 3)
b = a.t()
print(f'a : {
      a}')
print(f'b : {
      b}')
print(a.stride())
print(b.stride())
>>a : tensor([[1, 2, 3],[4, 5, 6]])
>>b : tensor([[1, 4],[2, 5],[3, 6]])
>>(3, 1)
>>(1, 3)

ba 的转置。==在第 0 维走到下一个元素,1 -> 2,步长是 1,因为在存储区中 1 和 2 是相邻的。==这就不满足上面的式子了,因此是不连续的。

再次强调:stride 是在当前维走到下一个元素,在存储区中需要经过的元素的个数

直观理解:在第 0 维,如果数据是连续的,走到下一个元素,应该把这一行走完,步长是 2。现在 4 的邻居是 1 和 2 了,实际上应该是 3 和 5。这就说明数据不连续了

不连续是不能使用 view() 方法的。那有什么办法可以让 b 使用 view() 呢?就是将其连续化(b.contiguous()

a = torch.tensor([1, 2, 3, 4, 5, 6]).view(2, 3)
b = a.t()
c = b.contiguous()
print(f'a : {
      a}')
print(f'b : {
      b}')
print(f'c : {
      c}')
print(f'stride of a : {
      a.stride()}')
print(f'stride of b : {
      b.stride()}')
print(f'stride of c : {
      c.stride()}')
print(f'storage address of a: {
      a.storage().data_ptr()}')
print(f'storage address of b: {
      b.storage().data_ptr()}')
print(f'storage address of c: {
      c.storage().data_ptr()}')
>>a : tensor([[1, 2, 3],[4, 5, 6]])
>>b : tensor([[1, 4],[2, 5],[3, 6]])
>>c : tensor([[1, 4],[2, 5],[3, 6]])
>>stride of a : (3, 1)
>>stride of b : (1, 3)
>>stride of c : (2, 1)
>>storage address of a: 94382097256256
>>storage address of b: 94382097256256
>>storage address of c: 94382053572096

我们看到,c 的数据恢复了连续性,且其存储区的地址与 a, b 不同了。

contiguous() 函数其实就是创造了一个全新的 tensor。在存储区中,将 b 中的数据按顺序存放,得到 c

这样我们可以说明 reshape()view() 的区别了

  • tensor 满足连续性要求时,reshape() = view(),和原来 tensor 共用存储区
  • tensor 不满足连续性要求时,reshape() = **contiguous() + view()会产生新的存储区tensor,与原来 tensor 不共用存储区

3.3.resize_()

前面说到的 reshapeview 都必须要用到全部的原始数据,比如你的原始数据只有12个,无论你怎么变形都必须要用到12个数字,不能多不能少。因此你就不能把只有12个数字的 tensor 强行 reshap2*5 的维度的tensor。但是 resize_() 可以做到,无论你存储区原始有多少个数字,我都能变成你想要的维度,数字不够怎么办?随机产生凑!数字多了怎么办?就取我需要的部分!

3.3.1.数据多的时候

a = torch.tensor([1, 2, 3, 4, 5, 6, 7])
b = a.resize_(2, 3)
print(f'a : {
      a}')
print(f'b : {
      b}')
print(f'stride of a : {
      a.stride()}')
print(f'stride of b : {
      b.stride()}')
print(f'storage address of a: {
      a.storage().data_ptr()}')
print(f'storage address of b: {
      b.storage().data_ptr()}')
>>a : tensor([[1, 2, 3],[4, 5, 6]])
>>b : tensor([[1, 2, 3],[4, 5, 6]])
>>stride of a : (3, 1)
>>stride of b : (3, 1)
>>storage address of a: 94579423708416
>>storage address of b: 94579423708416print(a.storage())
>> 1234567

可见,取的是前 6 个。

会改变 a,但是并没有改变存储区中的数据a, b 共用存储区

3.3.2.数据少的时候

a = torch.tensor([1, 2, 3, 4, 5])
b = a.resize_(2, 3)
print(f'a : {
      a}')
print(f'b : {
      b}')
print(f'stride of a : {
      a.stride()}')
print(f'stride of b : {
      b.stride()}')
print(f'storage address of a: {
      a.storage().data_ptr()}')
print(f'storage address of b: {
      b.storage().data_ptr()}')
>>a : tensor([[             1,              2,              3],[             4,              5, 94673159007352]])
>>b : tensor([[             1,              2,              3],[             4,              5, 94673159007352]])
>>stride of a : (3, 1)
>>stride of b : (3, 1)
>>storage address of a: 94673159007296
>>storage address of b: 94673159007296print(a.storage())
>> 1234594673159007352

可见,补充了一个数

会改变 a且改变了存储区中的数据a, b 共用存储区(但是已经不是刚刚那个存储区了,地址变了)

3.3.3.处理不连续数据

a = torch.arange(6).view(2, 3)
b = a.t()
c = b.resize_(3, 2)
print(f'a : {
      a}')
print(f'b : {
      b}')
print(f'c : {
      c}')
print(f'stride of a : {
      a.stride()}')
print(f'stride of b : {
      b.stride()}')
print(f'stride of c : {
      c.stride()}')
print(f'storage address of a: {
      a.storage().data_ptr()}')
print(f'storage address of b: {
      b.storage().data_ptr()}')
print(f'storage address of c: {
      c.storage().data_ptr()}')
>>a : tensor([[0, 1, 2],[3, 4, 5]])
>>b : tensor([[0, 3],[1, 4],[2, 5]])
>>c : tensor([[0, 3],[1, 4],[2, 5]])
>>stride of a : (3, 1)
>>stride of b : (1, 3)
>>stride of c : (1, 3)
>>storage address of a: 94375435009664
>>storage address of b: 94375435009664
>>storage address of c: 94375435009664

可见,使用 resize_() 之后,数据仍然保持连续性并且没有开辟新的 tensor ,与原 tensor 共享存储区

print(a.storage())
>> 012345

并且,没有改变存储区中的数。

也就是说,resize_() 只是改变了头信息,使得数据以我们想要的形式呈现,而并没有改变其他信息

4.总结

最后总结一下 view()reshape()resize_()三者的关系和区别。

  • view() 只能对满足连续性要求的tensor使用。
  • tensor 满足连续性要求时,reshape() = view(),和原来 tensor 共用内存。
  • tensor 不满足连续性要求时,reshape() = **contiguous() + view()会产生新的存储区tensor,与原来 tensor 不共用存储区。
  • resize_() 可以随意的获取任意维度的 tensor,不用在意真实数据的个数限制,但是不推荐使用。

参考:Pytorch——Tensor的储存机制以及view()、reshape()、reszie_()三者的关系和区别 - Circle_Wang - 博客园 (cnblogs.com)

  相关解决方案