当前位置: 代码迷 >> 综合 >> TensorFlow: tf.flags.DEFINE_xxx()用法
  详细解决方案

TensorFlow: tf.flags.DEFINE_xxx()用法

热度:62   发布时间:2024-02-19 17:22:14.0

读代码的时候常常会遇到flag到处飞

比如:

FLAGS = flags.FLAGSflags.DEFINE_integer("data_reading_num_threads", 64,"The number of threads used to read the dataset.")
FLAGS = tf.flags.FLAGStf.flags.DEFINE_string('name', 'default', 'name of the model')
tf.flags.DEFINE_integer('num_seqs', 100, 'number of seqs in one batch')

tf 中定义了 tf.app.flags.FLAGS ,用于接受从终端传入的命令行参数,用起来的话很方便

用法:

在终端输入如下

python train.py \--input_file data/shakespeare.txt  \--name shakespeare \--num_steps 50 \--num_seqs 32 \--learning_rate 0.01 \--max_steps 20000

通过输入不同的文件名、参数,可以快速完成程序的调参和更换训练集的操作,不需要进入源码中更改。

描述:

DEFINE_xxx 函数带3个参数,分别是变量名称,默认值,用法描述。

DEFINE后面规定变量数据类型,例如常见的

tf.app.flags.DEFINE_string() :定义一个用于接收 string 类型数值的变量;

tf.app.flags.DEFINE_integer() : 定义一个用于接收 int 类型数值的变量;

tf.app.flags.DEFINE_float() :定义一个用于接收 float 类型数值的变量;

tf.app.flags.DEFINE_boolean() : 定义一个用于接收 bool 类型数值的变量;

  相关解决方案