tf.app.run
其常出现的场景为
if __name__ == "__main__":tf.app.run()
tf.app.run()
会调用main
函数,并传递参数。因此,必须在main
函数中设置一个参数位置。如果想要更换main
名字,只需要在tf.app.run
中传入一个指定的函数名即可。
def test(args):passif __name__ == "__main__":tf.app.run(test)
因为tf.app.run()
要传递参数给调用的函数,所以其首先要加载flags
的参数项。那么参数的定义和获取就涉及到tf.app.flags.FLAGS
。
tf.app.flags
tf.app.flags
主要用于处理从终端传入的命令行参数,相当于对python中的命令行参数模块optargs
做了一层封装。
在optpars中,参数类型都是通过参数“type=xxx”定义的,tf中每一个合法数据类型都有对应的DEFIN_xxx
函数。常用:
- tf.app.flags.DEFIN_string():定义一个用于接受string类型的变量;
- tf.app.flags.DEFIN_integer():定义一个用于接受int类型的变量
- tf.app.flags.DEFIN_float():定义一个用于接受float类型的变量
- tf.app.flags.DEFIN_boolean():定义一个用于接受bool类型的变量
tf.app.flags.DEFIN_xxx()
该函数中带3个参数,分别是变量名称、默认值、用法描述,例如:
tf.app.flags.DEFIN_string('ckpt_path','model/model.ckpt-1000','checkpoint directory to restore')
上述代码,定义了一个名称是ckpt_path的变量,默认值是model/model.ckpt-1000,描述信息表明这是一个用于保存节点信息的路径。
tf.app.flags.FLAGS
该函数用于实例化这个类,后续需要获取命令行参数的值直接通过实例获取即可
import tensorflow as tf tf.app.flags.DEFINE_string('data_dir', '/tmp/mnist', 'Directory with the MNIST data.')
tf.app.flags.DEFINE_integer('batch_size', 5, 'Batch size.')
tf.app.flags.DEFINE_integer('num_evals', 1000, 'Number of batches to evaluate.')
FLAGS = tf.app.flags.FLAGSprint(FLAGS.data_dir, FLAGS.batch_size, FLAGS.num_evals)