当前位置: 代码迷 >> 综合 >> TensorFlow 中 tf.app.flags.FLAGS
  详细解决方案

TensorFlow 中 tf.app.flags.FLAGS

热度:33   发布时间:2023-12-15 18:58:54.0

原文链接:https://blog.csdn.net/m0_37041325/article/details/77448971

总结起来的话,tf.app.flags.DEFINE_xxx()就是添加命令行的optional argument(可选参数),而tf.app.flags.FLAGS可以从对应的命令行参数取出参数。

举个栗子

新建test.py文件,并输入如下代码,代码的功能是创建几个命令行参数,然后把命令行参数输出显示


  
  1. import tensorflow as tf
  2. FLAGS=tf.app.flags.FLAGS
  3. tf.app.flags.DEFINE_float(
  4. 'flag_float', 0.01, 'input a float')
  5. tf.app.flags.DEFINE_integer(
  6. 'flag_int', 400, 'input a int')
  7. tf.app.flags.DEFINE_boolean(
  8. 'flag_bool', True, 'input a bool')
  9. tf.app.flags.DEFINE_string(
  10. 'flag_string', 'yes', 'input a string')
  11. print(FLAGS.flag_float)
  12. print(FLAGS.flag_int)
  13. print(FLAGS.flag_bool)
  14. print(FLAGS.flag_string)


1.在命令行中查看帮助信息,在命令行输入 python test.py -h



注意红色框中的信息,这个就是我们用DEFINE_XXX添加命令行参数时的第三个参数

2.直接运行test.py


因为没有给对应的命令行参数赋值,所以输出的是命令行参数的默认值。

3.带命令行参数的运行test.py文件


这里输出了我们赋给命令行参数的值

如果要套用的话,也很简单,参考下面这两段代码:
套用代码1:

import numpy as np
import tensorflow as tfimport pprint   #python的另外一种输出模块
import os
import tensorflow as tftf.reset_default_graph() 
flags = tf.app.flags              #tensorflow里面一种定义的变量的手段,学习就好flags.DEFINE_integer("epoch", 15000, "Number of epoch [15000]")
flags.DEFINE_integer("batch_size", 128, "The size of batch images [128]")
flags.DEFINE_integer("image_size", 33, "The size of image to use [33]")
flags.DEFINE_integer("label_size", 21, "The size of label to produce [21]")
flags.DEFINE_float("learning_rate", 1e-4, "The learning rate of gradient descent algorithm [1e-4]")
flags.DEFINE_integer("c_dim", 1, "Dimension of image color. [1]")
flags.DEFINE_integer("scale", 3, "The size of scale factor for preprocessing input image [3]")
flags.DEFINE_integer("stride", 21, "The size of stride to apply input image [14]")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Name of checkpoint directory [checkpoint]")
flags.DEFINE_string("sample_dir", "sample", "Name of sample directory [sample]")
flags.DEFINE_boolean("is_train", False, "True for training, False for testing [True]")
FLAGS = flags.FLAGSpp = pprint.PrettyPrinter()def main(_):pp.pprint(flags.FLAGS.__flags) #这里面是用 该方法定义的全部参数if not os.path.exists(FLAGS.checkpoint_dir):     #检查是否有这连个目录,没有就创建目录os.makedirs(FLAGS.checkpoint_dir)if not os.path.exists(FLAGS.sample_dir):         #检查这个目录是否存在,没有就创建目录os.makedirs(FLAGS.sample_dir)with tf.Session() as sess:srcnn = SRCNN(sess, image_size=FLAGS.image_size, label_size=FLAGS.label_size, batch_size=FLAGS.batch_size,c_dim=FLAGS.c_dim, checkpoint_dir=FLAGS.checkpoint_dir,       #检查点目录sample_dir=FLAGS.sample_dir)srcnn.train(FLAGS)if __name__ == '__main__':tf.app.run()

套用代码2:

#! /usr/bin/env pythonimport tensorflow as tf
import numpy as np
import os
import time
import datetime
import data_helpers
from text_cnn import TextCNN
from tensorflow.contrib import learn# Parameters
# ==================================================# Data loading params
tf.flags.DEFINE_float("dev_sample_percentage", .1, "Percentage of the training data to use for validation")
tf.flags.DEFINE_string("positive_data_file", "./data/rt-polaritydata/rt-polarity.pos", "Data source for the positive data.")
tf.flags.DEFINE_string("negative_data_file", "./data/rt-polaritydata/rt-polarity.neg", "Data source for the negative data.")# Model Hyperparameters
tf.flags.DEFINE_integer("embedding_dim", 128, "Dimensionality of character embedding (default: 128)")
tf.flags.DEFINE_string("filter_sizes", "3,4,5", "Comma-separated filter sizes (default: '3,4,5')")
tf.flags.DEFINE_integer("num_filters", 128, "Number of filters per filter size (default: 128)")
tf.flags.DEFINE_float("dropout_keep_prob", 0.5, "Dropout keep probability (default: 0.5)")
tf.flags.DEFINE_float("l2_reg_lambda", 0.0, "L2 regularization lambda (default: 0.0)")# Training parameters
tf.flags.DEFINE_integer("batch_size", 64, "Batch Size (default: 64)")
tf.flags.DEFINE_integer("num_epochs", 50, "Number of training epochs (default: 200)")
tf.flags.DEFINE_integer("evaluate_every", 100, "Evaluate model on dev set after this many steps (default: 100)")
tf.flags.DEFINE_integer("checkpoint_every", 100, "Save model after this many steps (default: 100)")
tf.flags.DEFINE_integer("num_checkpoints", 5, "Number of checkpoints to store (default: 5)")
# Misc Parameters
tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices")FLAGS = tf.flags.FLAGS  #FLAGS保存命令行参数的数据
#FLAGS._parse_flags()
FLAGS.flag_values_dict()    #将其解析成字典存储到FLAGS.__flags中
print("\nParameters:")
for attr, value in sorted(FLAGS.__flags.items()):print("{}={}".format(attr.upper(), value))
print("")# Data Preparation
# ==================================================# Load data
print("Loading data...")
x_text, y = data_helpers.load_data_and_labels(FLAGS.positive_data_file, FLAGS.negative_data_file)
..........
  相关解决方案