当前位置: 代码迷 >> 综合 >> Pytorch 实现tf.gather()
  详细解决方案

Pytorch 实现tf.gather()

热度:70   发布时间:2023-12-19 03:20:59.0

1. 实现tf.gather

在pytorch中,实现 tf.gather 很简单,只需要使用 select。
select(dim, index) → Tensor

比如,

import numpy as np
a = np.array([[1],[2],[3],[4],[5]])
b = torch.from_numpy(a)
indices = [ 1, 2, 0]
b[indices]

Output:

tensor([[2],[3],[1]])

参考:

  1. How to implement an equivalent of tf.gather in pytorch