1. 实现tf.gather
在pytorch中,实现 tf.gather 很简单,只需要使用 select。
select(dim, index) → Tensor
比如,
import numpy as np
a = np.array([[1],[2],[3],[4],[5]])
b = torch.from_numpy(a)
indices = [ 1, 2, 0]
b[indices]
Output:
tensor([[2],[3],[1]])
参考:
- How to implement an equivalent of tf.gather in pytorch