当前位置: 代码迷 >> 综合 >> pytorch函数 torch.index_select( )学习
  详细解决方案

pytorch函数 torch.index_select( )学习

热度:127   发布时间:2023-09-27 14:04:19.0

y = torch.index_select(x, dim=1, index=index)

该函数是按照维度对原张量进行切分
index 是切分的索引,通常是一个数字型列表
x是待切分的张量

假设 x = torch.rand(3,2,1) 且 index = torch.LongTensor([0])

tensor([[[0.3048],[0.0140]],[[0.6699],[0.3395]],[[0.1088],[0.9452]]])tensor([0])   注意这里的index 必须为longTensor类型
y = torch.index_select(x, dim=1, index=index)

将张量x 按照第二个维度(dim从0开始算)进行切分,搜索到第一个元素(index=0)
得到

tensor([[[0.3048]],[[0.6699]],[[0.1088]]])