当前位置: 代码迷 >> 综合 >> tf.contrib.lookup.index_to_string_table_from_tensor和tf.nn.top_k()
  详细解决方案

tf.contrib.lookup.index_to_string_table_from_tensor和tf.nn.top_k()

热度:45   发布时间:2023-11-04 09:19:21.0
tf.contrib.lookup.index_to_string_table_from_tensor(mapping,default_value='UNK',name=None
)

Returns a lookup table that maps a Tensor of indices into strings.

1.一维情况

import tensorflow as tf
sess=tf.Session()
mapping_string = tf.constant(["emerson", "lake", "palmer"])
indices1 = tf.constant([0,1,2], tf.int64)
indices2 = tf.constant([0,2], tf.int64)
indices3 = tf.constant([1], tf.int64)
table = tf.contrib.lookup.index_to_string_table_from_tensor(mapping_string)
values1 = table.lookup(indices1)
values2 = table.lookup(indices2)
values3 = table.lookup(indices3)
sess.run(tf.tables_initializer())print(sess.run(mapping_string))print(sess.run(indices1))
print(sess.run(values1))
print(sess.run(indices2))
print(sess.run(values2))
print(sess.run(indices3))
print(sess.run(values3))

运行结果:
[b’emerson’ b’lake’ b’palmer’]
[0 1 2]
[b’emerson’ b’lake’ b’palmer’]
[0 2]
[b’emerson’ b’palmer’]
[1]
[b’lake’]

2.二维情况

import tensorflow as tf
sess=tf.Session()
mapping_string = tf.constant(["emerson", "lake", "palmer"])
indices1 = tf.constant([[0,1],[1,2],[2,3]], tf.int64)table = tf.contrib.lookup.index_to_string_table_from_tensor(mapping_string)
values1 = table.lookup(indices1)sess.run(tf.tables_initializer())print(sess.run(mapping_string))
print(sess.run(indices1))
print(sess.run(values1))

运行结果:
[b’emerson’ b’lake’ b’palmer’]
[[0 1]
[1 2]
[2 3]]
[[b’emerson’ b’lake’]
[b’lake’ b’palmer’]
[b’palmer’ b’UNK’]]

总结:按照张量indices的格式,返回indices中索引的值,格式不变.

3.tf.nn.top_k()
以下以一个5分类为例,一行表示一个样本属于5类的概率.按下面代码转化完了后,输出的每行第一个数字就代表该行样本所属的分类.

import tensorflow as tf
import numpy as np
sess=tf.Session()
y=np.array([[0.10,0.20,0.40,0.25,0.05],[0.45,0.15,0.10,0.05,0.25],[0.15,0.65,0.05,0.10,0.05]])values,indices=tf.nn.top_k(y,5)
table = tf.contrib.lookup.index_to_string_table_from_tensor(tf.constant([str(i) for i in range(5)]))
prediction_classes = table.lookup(tf.to_int64(indices))
sess.run(tf.tables_initializer())
print(sess.run(values))
print(sess.run(indices))
print(sess.run(prediction_classes))

运行结果:
[[0.4 0.25 0.2 0.1 0.05]
[0.45 0.25 0.15 0.1 0.05]
[0.65 0.15 0.1 0.05 0.05]]
[[2 3 1 0 4]
[0 4 1 2 3]
[1 0 3 2 4]]
[[b’2’ b’3’ b’1’ b’0’ b’4’]
[b’0’ b’4’ b’1’ b’2’ b’3’]
[b’1’ b’0’ b’3’ b’2’ b’4’]]
总结: 返回的是每行由大到小排序的结果以及对应的它们在原来数组中的行索引.例如:第一行最大值为0.4,它在原数组中行索引是2(该行第三个数字).

  相关解决方案