当前位置: 代码迷 >> 综合 >> 记录一些Pytorch方便的函数<即插即用>
  详细解决方案

记录一些Pytorch方便的函数<即插即用>

热度:15   发布时间:2023-12-23 04:15:56.0

引言

Pytorch自己有一些函数可以实现很复杂的一些功能,自己以前想创建一个tensor,经常傻乎乎的创建一个空Tensor,然后再慢慢调整,不但不美观,而且有的时候时间复杂度很高。这个博客记录了一些Pyrotch的很方便的函数,想实现某个功能时,可以去查阅一下有没有一步到位的函数。


TORCH.FULL

torch.full(size, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) 

作用: 返回以size填充的大小的张量fill_value

关键参数:
size (int...) – a list, tuple, or torch.Size of integers defining the shape of the output tensor.
fill_value – the number to fill the output tensor with.
out (Tensor, optional) – the output tensor.
dtype (torch.dtype, optional) – the desired data type of returned tensor.

样例:

>>> torch.full((2, 3), 3.141592)
tensor([[ 3.1416,  3.1416,  3.1416],[ 3.1416,  3.1416,  3.1416]])

TORCH.WHERE

torch.where(condition, x, y) → Tensor

作用:返回从x或y中选择的元素的张量,具体取决于condition。该操作定义为:
在这里插入图片描述

关键参数:
condition (BoolTensor) – When True (nonzero), yield x, otherwise yield y
x (Tensor) – values selected at indices where condition is True
y (Tensor) – values selected at indices where condition is False

样例:

>>> x = torch.randn(3, 2)
>>> y = torch.ones(3, 2)
>>> x
tensor([[-0.4620,  0.3139],[ 0.3898, -0.7197],[ 0.0478, -0.1657]])
>>> torch.where(x > 0, x, y)
tensor([[ 1.0000,  0.3139],[ 0.3898,  1.0000],[ 0.0478,  1.0000]])

TORCH.NE

torch.ne(input, other, out=None) → Tensor

作用:逐元素计算input != output,第二个参数可以是数字或张量,其形状可 与第一个参数一起广播。

关键参数:
input (Tensor) – the tensor to compare
other (Tensor or float) – the tensor or value to compare
out (Tensor, optional) – the output tensor that must be a BoolTensor

样例:

>>> torch.ne(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]]))
tensor([[False, True], [True, False]])

TORCH.BMM

torch.bmm(input, mat2, deterministic=False, out=None) 

作用:执行存储在input 中的矩阵和矩阵矩mat2的乘积。input和mat2必须是3D张量,每个张量包含相同数量的矩阵。如果input是(b?n?m)(b*n*m)(b?n?m) 张量,mat2是 (b?m?p)(b *m*p)b?m?p 张量,out将是 (b?n?p)(b*n*p)(b?n?p) 张量。

关键元素:
input (Tensor) – the first batch of matrices to be multiplied
mat2 (Tensor) – the second batch of matrices to be multiplied
deterministic (bool, optional) – flag to choose between a faster non-deterministic calculation, or a slower deterministic calculation. This argument is only available for sparse-dense CUDA bmm. Default: False
out (Tensor, optional) – the output tensor.

样例:

>>> input = torch.randn(10, 3, 4)
>>> mat2 = torch.randn(10, 4, 5)
>>> res = torch.bmm(input, mat2)
>>> res.size()
torch.Size([10, 3, 5])

TORCH.CLAMP

torch.clamp(input, min, max, out=None) 

作用:将所有元素固定在[min,max][min,max][minmax]范围内,并返回结果张量:
在这里插入图片描述

关键元素:
input (Tensor) – the input tensor.
min (Number) – lower-bound of the range to be clamped to
max (Number) – upper-bound of the range to be clamped to
out (Tensor, optional) – the output tensor

样例:

>>> a = torch.randn(4)
>>> a
tensor([-1.7120,  0.1734, -0.0478, -0.0922])
>>> torch.clamp(a, min=-0.5, max=0.5)
tensor([-0.5000,  0.1734, -0.0478, -0.0922])