原文链接:https://blog.csdn.net/m0_37041325/article/details/77448971
总结起来的话,tf.app.flags.DEFINE_xxx()就是添加命令行的optional argument(可选参数),而tf.app.flags.FLAGS可以从对应的命令行参数取出参数。
举个栗子
新建test.py文件,并输入如下代码,代码的功能是创建几个命令行参数,然后把命令行参数输出显示
-
import tensorflow
as tf
-
FLAGS=tf.app.flags.FLAGS
-
tf.app.flags.DEFINE_float(
-
'flag_float',
0.01,
'input a float')
-
tf.app.flags.DEFINE_integer(
-
'flag_int',
400,
'input a int')
-
tf.app.flags.DEFINE_boolean(
-
'flag_bool',
True,
'input a bool')
-
tf.app.flags.DEFINE_string(
-
'flag_string',
'yes',
'input a string')
-
-
print(FLAGS.flag_float)
-
print(FLAGS.flag_int)
-
print(FLAGS.flag_bool)
-
print(FLAGS.flag_string)
1.在命令行中查看帮助信息,在命令行输入 python test.py -h
注意红色框中的信息,这个就是我们用DEFINE_XXX添加命令行参数时的第三个参数
2.直接运行test.py
因为没有给对应的命令行参数赋值,所以输出的是命令行参数的默认值。
3.带命令行参数的运行test.py文件
这里输出了我们赋给命令行参数的值
如果要套用的话,也很简单,参考下面这两段代码:
套用代码1:
import numpy as np
import tensorflow as tfimport pprint #python的另外一种输出模块
import os
import tensorflow as tftf.reset_default_graph()
flags = tf.app.flags #tensorflow里面一种定义的变量的手段,学习就好flags.DEFINE_integer("epoch", 15000, "Number of epoch [15000]")
flags.DEFINE_integer("batch_size", 128, "The size of batch images [128]")
flags.DEFINE_integer("image_size", 33, "The size of image to use [33]")
flags.DEFINE_integer("label_size", 21, "The size of label to produce [21]")
flags.DEFINE_float("learning_rate", 1e-4, "The learning rate of gradient descent algorithm [1e-4]")
flags.DEFINE_integer("c_dim", 1, "Dimension of image color. [1]")
flags.DEFINE_integer("scale", 3, "The size of scale factor for preprocessing input image [3]")
flags.DEFINE_integer("stride", 21, "The size of stride to apply input image [14]")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Name of checkpoint directory [checkpoint]")
flags.DEFINE_string("sample_dir", "sample", "Name of sample directory [sample]")
flags.DEFINE_boolean("is_train", False, "True for training, False for testing [True]")
FLAGS = flags.FLAGSpp = pprint.PrettyPrinter()def main(_):pp.pprint(flags.FLAGS.__flags) #这里面是用 该方法定义的全部参数if not os.path.exists(FLAGS.checkpoint_dir): #检查是否有这连个目录,没有就创建目录os.makedirs(FLAGS.checkpoint_dir)if not os.path.exists(FLAGS.sample_dir): #检查这个目录是否存在,没有就创建目录os.makedirs(FLAGS.sample_dir)with tf.Session() as sess:srcnn = SRCNN(sess, image_size=FLAGS.image_size, label_size=FLAGS.label_size, batch_size=FLAGS.batch_size,c_dim=FLAGS.c_dim, checkpoint_dir=FLAGS.checkpoint_dir, #检查点目录sample_dir=FLAGS.sample_dir)srcnn.train(FLAGS)if __name__ == '__main__':tf.app.run()
套用代码2:
#! /usr/bin/env pythonimport tensorflow as tf
import numpy as np
import os
import time
import datetime
import data_helpers
from text_cnn import TextCNN
from tensorflow.contrib import learn# Parameters
# ==================================================# Data loading params
tf.flags.DEFINE_float("dev_sample_percentage", .1, "Percentage of the training data to use for validation")
tf.flags.DEFINE_string("positive_data_file", "./data/rt-polaritydata/rt-polarity.pos", "Data source for the positive data.")
tf.flags.DEFINE_string("negative_data_file", "./data/rt-polaritydata/rt-polarity.neg", "Data source for the negative data.")# Model Hyperparameters
tf.flags.DEFINE_integer("embedding_dim", 128, "Dimensionality of character embedding (default: 128)")
tf.flags.DEFINE_string("filter_sizes", "3,4,5", "Comma-separated filter sizes (default: '3,4,5')")
tf.flags.DEFINE_integer("num_filters", 128, "Number of filters per filter size (default: 128)")
tf.flags.DEFINE_float("dropout_keep_prob", 0.5, "Dropout keep probability (default: 0.5)")
tf.flags.DEFINE_float("l2_reg_lambda", 0.0, "L2 regularization lambda (default: 0.0)")# Training parameters
tf.flags.DEFINE_integer("batch_size", 64, "Batch Size (default: 64)")
tf.flags.DEFINE_integer("num_epochs", 50, "Number of training epochs (default: 200)")
tf.flags.DEFINE_integer("evaluate_every", 100, "Evaluate model on dev set after this many steps (default: 100)")
tf.flags.DEFINE_integer("checkpoint_every", 100, "Save model after this many steps (default: 100)")
tf.flags.DEFINE_integer("num_checkpoints", 5, "Number of checkpoints to store (default: 5)")
# Misc Parameters
tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices")FLAGS = tf.flags.FLAGS #FLAGS保存命令行参数的数据
#FLAGS._parse_flags()
FLAGS.flag_values_dict() #将其解析成字典存储到FLAGS.__flags中
print("\nParameters:")
for attr, value in sorted(FLAGS.__flags.items()):print("{}={}".format(attr.upper(), value))
print("")# Data Preparation
# ==================================================# Load data
print("Loading data...")
x_text, y = data_helpers.load_data_and_labels(FLAGS.positive_data_file, FLAGS.negative_data_file)
..........