y = torch.index_select(x, dim=1, index=index)
该函数是按照维度对原张量进行切分
index 是切分的索引,通常是一个数字型列表
x是待切分的张量
假设 x = torch.rand(3,2,1) 且 index = torch.LongTensor([0])
tensor([[[0.3048],[0.0140]],[[0.6699],[0.3395]],[[0.1088],[0.9452]]])tensor([0]) 注意这里的index 必须为longTensor类型
y = torch.index_select(x, dim=1, index=index)
将张量x 按照第二个维度(dim从0开始算)进行切分,搜索到第一个元素(index=0)
得到
tensor([[[0.3048]],[[0.6699]],[[0.1088]]])