tf.gather(params, indices, validate_indices=None, axis=None, batch_dims=0, name=None )
请注意,在CPU上,如果找到超出范围的索引,则会返回错误。在GPU上,如果找到越界索引,则将0存储在相应的输出值中。
另请参阅tf.gather_nd
。
import numpy as np
import tensorflow as tflogits = [0,2,2,2,2,3,4,5,6,7,8,9]
target_id = [0,6]
selected_logits = tf.gather(logits, target_id)
with tf.Session() as sess:print(sess.run(selected_logits))
selected_logits取logits得第0位和第六位