当前位置: 代码迷 >> 综合 >> embedding = torch.nn.Embedding(10, 3)
  详细解决方案

embedding = torch.nn.Embedding(10, 3)

热度:74   发布时间:2023-12-21 02:42:22.0

embedding = torch.nn.Embedding(10, 3)

通过 word embedding,就可以将自然语言所表示的单词或短语转换为计算机能够理解的由实数构成的向量或矩阵形式(比如,one-hot 就是一种简单的 word embedding 的方法)。

import torch
# 10个单词的向量矩阵,10*3的矩阵,会初始化
embedding = torch.nn.Embedding(10, 3)
# input是要取的单词的下标矩阵
input = torch.LongTensor([[2,2]])print(embedding(input))
print(embedding(torch.LongTensor([2])))

结果

tensor([[[ 1.2912,  0.9011, -0.7307],[ 1.2912,  0.9011, -0.7307]]], grad_fn=<SqueezeBackward1>)tensor([[ 1.2912,  0.9011, -0.7307]], grad_fn=<EmbeddingBackward>)

我遇到的情况

感觉他用 2位数组来取是多此一举

input = torch.LongTensor([[2],[3]])
print(embedding(input))
print(embedding(input).squeeze(dim=1))  # squeeze 是把维数压缩
# 2*3*1
tensor([[[-1.4740, -1.0795, -1.5344]],[[-0.3839, -0.8261,  1.4525]]], grad_fn=<EmbeddingBackward>)
# 2*3
tensor([[-1.4740, -1.0795, -1.5344],[-0.3839, -0.8261,  1.4525]], grad_fn=<SqueezeBackward1>)
  相关解决方案