当前位置: 代码迷 >> 综合 >> tf.gather()用法详解
  详细解决方案

tf.gather()用法详解

热度:57   发布时间:2024-01-19 11:17:12.0

tf.gather(params, indices, validate_indices=None, axis=None, batch_dims=0, name=None )

 

请注意,在CPU上,如果找到超出范围的索引,则会返回错误。在GPU上,如果找到越界索引,则将0存储在相应的输出值中。

另请参阅tf.gather_nd

 

import numpy as np
import tensorflow as tflogits = [0,2,2,2,2,3,4,5,6,7,8,9]
target_id = [0,6]
selected_logits = tf.gather(logits, target_id)
with tf.Session() as sess:print(sess.run(selected_logits))

 

selected_logits取logits得第0位和第六位