当前位置: 代码迷 >> 综合 >> TORCH.NN.FUNCTIONAL.ONE_HOT
  详细解决方案

TORCH.NN.FUNCTIONAL.ONE_HOT

热度:14   发布时间:2024-01-09 05:53:23.0

文章目录

  • 1.独热编码
  • 2.不带参数独热编码,仅只输入一个张量 one_hot
    • 2.1 代码
    • 2.2 输出结果
  • 3. 带参数独热编码,一个参数为输入张量,另外一个为类别数
    • 3.2 结果

1.独热编码

定义:torch.nn.functional.one_hot(tensor, num_classes=- 1) → LongTensor
描述:Takes LongTensor with index values of shape () and returns a tensor of shape (, num_classes) that have zeros everywhere except where the index of last dimension matches the corresponding value of the input tensor, in which case it will be 1.
译文:
使用index值为shape()的LongTensor,并返回一个shape (, num_classes)的张量,除了最后一个维度的索引与输入张量的对应值相匹配的地方外,其他地方都是零,在这种情况下,它将是1。
说人话: 就是在你给定一个张量的时候,可以对你给的张量进行编码,这里分两种情况

2.不带参数独热编码,仅只输入一个张量 one_hot

2.1 代码

import torch
from torch.nn import functional as Fx = torch.tensor([1, 2, 3, 8, 5])
# 定义一个张量输入,因为此时有 5 个数值,且最大值为8,
# 所以我们可以得到 y 的输出结果的形状应该为 shape=(5,9);5行9列
y = F.one_hot(x) # 只有一个参数张量x
print(f'x = {
      x}') # 输出 x
print(f'x_shape = {
      x.shape}') # 查看 x 的形状
print(f'y = {
      y}') # 输出 y
print(f'y_shape = {
      y.shape}') # 输出 y 的形状

2.2 输出结果

我们可以看出来,所得的结果为 X 中每个张量里面的值为 Y 结果中的序号为 1 的地方;
比如: X 中第 4 个值表示为 8 的值,可以看到 Y 中第 4 行的 8 个(下标从 0 开始)

x = tensor([1, 2, 3, 8, 5])
x_shape = torch.Size([5])
y = tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0],[0, 0, 1, 0, 0, 0, 0, 0, 0],[0, 0, 0, 1, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 1],[0, 0, 0, 0, 0, 1, 0, 0, 0]])
y_shape = torch.Size([5, 9])

3. 带参数独热编码,一个参数为输入张量,另外一个为类别数

import torch
from torch.nn import functional as Fx = torch.tensor([1, 2, 3, 8, 5])
# 定义一个张量输入,因为此时有 5 个数值,且最大值为8,且 类别数为 12
# 所以我们可以得到 y 的输出结果的形状应该为 shape=(5,12);5行12列
y = F.one_hot(x, 12)  # 一个参数张量x, 12 为类别数,其中 12 > max{x}
print(f'x = {
      x}')  # 输出 x
print(f'x_shape = {
      x.shape}')  # 查看 x 的形状
print(f'y = {
      y}')  # 输出 y
print(f'y_shape = {
      y.shape}')  # 输出 y 的形状

3.2 结果

x = tensor([1, 2, 3, 8, 5])
x_shape = torch.Size([5])
y = tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]])
y_shape = torch.Size([5, 12])
  相关解决方案