当前位置: 代码迷 >> python >> Tensorflow占位符来自函数
  详细解决方案

Tensorflow占位符来自函数

热度:86   发布时间:2023-07-14 09:49:23.0

我想做这样的事情:

def f():
    place = tf.placeholder(tf.int32)
    return 2 * place

y = f()

with tf.Session() as sess:
    a = sess.run(y, feed_dict={ place: 5 })

当然,占位符位置在外面不可见。

NameError                                 Traceback (most recent call last)
<ipython-input-88-8b3c17d16dce> in <module>()
      1 with tf.Session() as sess:
----> 2     a = sess.run(y, feed_dict={ place: 5 })

NameError: name 'place' is not defined

而且,我可以这样解决这个问题:

def f():
    global place
    place = tf.placeholder(tf.int32)
    return 2 * place

但是,有没有人有更好的解决方案。 如何调用函数并将其返回值作为运算符传递给上面的例子,如何在函数内部创建占位符并向其外部提供值。

您可以通过其名称访问占位符:

import tensorflow as tf

def f():
    place = tf.placeholder(tf.int32,name='place')
    return 2 * place

y = f()

with tf.Session() as sess:
    place = tf.get_default_graph().get_tensor_by_name('place:0')
    a = sess.run(y, feed_dict={ place: 5 })

说明

无论何时定义占位符(或任何其他TensorFlow张量或操作),都会将其添加到计算图中,该图是位于后台并管理所有计算的对象。 每个占位符都有一个默认名称,但您也可以为其选择一个名称。 在这个例子中,我选择了名字place

现在,对于高级用例,您可能有多个计算图,但总有一个是默认计算图。 要获得默认值,我使用了tf.get_default_graph() 然后为了获得对占位符的引用,我使用了get_tensor_by_name('place:0') (我使用名称'place:0'而不是'place'因为当你定义一个占位符时,实际上你可以创建一个tf.Tensor ,你也可以创建一个执行喂食的操作。该操作将有名称'place'而实际张量的名称为'place:0' 。)

这纯粹是Python,但您可以简单地将占位符与使用它创建的图形一起返回。

像这样:

return place*2, place

然后像这样使用f():

y, place = f()

我认为在Tensorflow中的一个类中工作是一个好主意,并且能够轻松地在任何地方访问这些占位符。

编辑:当然,总是可以给占位符一个名称,然后在需要提供它时从图表中获取它。

  相关解决方案