当前位置: 代码迷 >> 综合 >> pytorch 代码:yield features.index_select(0,j), labels.index_select(0,j)
  详细解决方案

pytorch 代码:yield features.index_select(0,j), labels.index_select(0,j)

热度:51   发布时间:2023-12-21 02:52:40.0

pytorch 代码:yield features.index_select(0,j), labels.index_select(0,j)

yield features.index_select(0,j), labels.index_select(0,j)

yield 首先作用理解为return,它也可以返回一个或多个值,要想调用返回值就必须在循环中

index_select() 中第一个参数 0 表示以行为标准选择,例如j = tensor([1,2]),结果为选取features 第1,第2行数据

例子

import torcha = torch.randn(5, 5, dtype=torch.float32)
print(a)
i=1
indices = list(range(5))
j = torch.LongTensor(indices[i:min(i+2,5)])
print(j)
b = a.index_select(0,j)
print(b)

结果及其说明

# a 的值
tensor([[-1.2528, -0.3235,  0.2825, -0.5463,  0.0053],[-0.3129,  0.4375,  0.4789, -0.3872,  0.1995],[-0.9480,  0.3840,  1.1145,  0.8569, -0.8921],[ 1.0255,  0.0352, -0.0806, -1.2422,  0.4661],[-0.5092, -1.2760, -0.1923,  0.2986, -0.7680]])
# j 的值
tensor([1, 2])# b 的值,可以看到是来自 a 的 1,2 行
tensor([[-0.3129,  0.4375,  0.4789, -0.3872,  0.1995],[-0.9480,  0.3840,  1.1145,  0.8569, -0.8921]])
  相关解决方案