"""
设计步骤
1. Data generaor:训练模型时提供数据a. load vocabb. load image featurec.provide data for training2.Build image caption model
3.Trains the model
"""import os
import sys
import tensorflow as tf
from tensorflow import gfile
from tensorflow import logging
import pprint
import pickle
import numpy as np
import numpy
#输入描述文件
input_description_file=""
input_img_feature_dir=""
#生成的词表文件
input_vocab_file=""
output_dir=""if not gfile.Exists(output_dir):gfile.MakeDirs(output_dir)#定义模型所需参数
def get_default_params():return tf.contrib.training.HParams(#此表过滤参数,设置一个阈值num_vocab_word_threshold=3,#lstm的结构参数num_embedding_nodes=32,num_timesteps=10,num_lstm_nodes=[64,64],num_lstm_layers=2,num_fc_nodes=32,batch_size=80,cell_type="lstm",#梯度剪切clip_lstm_grads=1.0,learning_rate=0.001,keep_prob=0.8,#每隔多久打印一次log_frequent=100,#每隔多久保存一次save_frequent=1000,)hps=get_default_params()
#词表载入
class Vocab(object):def __init__(self,filename,word_num_threshold):self._id_to_word={}self._word_to_id={}self._unk=-1#句子结束符self._eos=-1self._word_num_threshold=word_num_thresholdself._read_dict()def _read_dict(self,filename):with gfile.GFile(filename,'r') as f:lines=f.readlines()for line in lines:word,occurrence=line.strip('\r\n').split('\t')occurrence=int(occurrence)if occurrence<self._word_num_threshold:continueidx=len(self._id_to_word)if word=='<UNK>':self._unk=idxelif word=='.':self._eos=idxif word in self._word_to_id or idx in self._id_to_word:raise Exception("")self._word_to_id[word]=idxself._id_to_word[idx]=word@propertydef unk(self):return self._unk@propertydef eos(self):return self._eosdef id_to_word(self,word_id):return self._id_to_word.get(word_id,'<UNK>')def word_to_id(self,word):return self._word_to_id.get(word,self.unk)def size(self):return len(self._id_to_word)def encode(self,sentence):return [self.word_to_id(word) for word in sentence.split(' ')]def decode(self,sentence_id):words= [self.id_to_word(word_id) for word_id in sentence_id]return ' '.join(words)vocab=Vocab(input_vocab_file,hps.num_vocab_word_threshold)
vocab_size=vocab.size()
print(vocab_size)
def parse_token_file(token_file):img_name_to_tokens={}with gfile.GFile(token_file,'r') as f:lines=f.readlines()for line in lines:img_id,description =line.strip('\r\n').split('\t')img_name,_=img_id.split('#')img_name_to_tokens.setdefault(img_name,[])img_name_to_tokens[img_name].append(description)return img_name_to_tokens#将每张图片的每个描述转换成id
def convert_token_to_id(img_name_to_tokens,vocab):img_name_to_tokens_id={}for img_name in img_name_to_tokens:img_name_to_tokens_id.setdefault(img_name,[])for description in img_name_to_tokens[img_name]:token_ids=vocab.encode(description)img_name_to_tokens_id[img_name].append(token_ids)return img_name_to_tokens_idimg_name_to_tokens=parse_token_file(input_description_file)
img_name_to_tokens_id=convert_token_to_id(img_name_to_tokens,vocab)logging.info("num of all images"%len(img_name_to_tokens))
#给模型提供数据
class ImageCaptionData(object):#初始化def __init__(self,img_name_to_tokens_id,img_feature_dir,num_timesteps,vocab,deterministic =False):self._vocab=vocabself._img_name_to_tokens_id=img_name_to_tokens_idself._num_timesteps=num_timestepsself._deterministic=deterministicself._indicator=0self._img_feature_filenames=[]self._img_feature_data=[]self._all_img_feature_filepaths=[]for filename in gfile.ListDirectory(img_feature_dir):self._all_img_feature_filepaths.append(os.path.join(img_feature_dir,filename))pprint.pprint(self._all_img_feature_filepaths)self._load_img_feature_pickle()if not self._deterministic:self._random_shuffle()#从pickle文件中加载数据def _load_img_feature_pickle(self):for filepath in self._all_img_feature_filepaths:logging.info("loading %s" % filepath)with gfile.GFile(filepath,'r') as f:filenames,features=pickle.load(f)#列表self._img_feature_filenames+=filenamesself._img_feature_data.append(features)self._img_feature_data=np.vstack(self._img_feature_data )origin_shape=self._img_feature_data.shapeself._img_feature_data=np.reshape(self._img_feature_data,(origin_shape[0],origin_shape[3]))self._img_feature_filenames=np.asarray(self._img_feature_filenames)print(self._img_feature_data,shape)print(self._img_feature_filenames.shape)def size(self):return len(self._img_feature_filenames)def img_feature_size(self):return self._img_feature_data.shape[1]def _random_shuffle(self):p=np.random.permutation(self.size())self._img_feature_filenames=self._img_feature_filenames[p]self._img_feature_data=self._img_feature_data[p]def _img_desc(self,batch_filenames):batch_sentence_ids=[]batch_weights=[]for filename in batch_filenames:token_ids_set=self._img_name_to_tokens_id[filename]chosen_token_ids=random.choice(token_ids_set)chosen_token_ids_length=len(chosen_token_ids)weight=[1 for i range(chosen_token_ids_length)]if chosen_token_ids_length>=self._num_timesteps:chosen_token_ids=chosen_token_ids[0:self._num_timesteps]weight=weight[0:self._num_timesteps]else:remaining_length=self._num_timesteps-chosen_token_ids_lengthchosen_token_ids+=[self._vocab.eos for i in range(remaining_length)]weight+=[0 for i in range(remaining_length)]batch_sentence_ids.append(chosen_token_ids)batch_weights.append(weight)batch_sentence_ids=np.asarray(batch_sentence_ids)batch_weights=np.asarray(batch_weights)return batch_sentence_ids,batch_weights#返回数据给模型def next_batch(self,batch_size):end_indicator=self.indicator+batch_sizeif end_indicator>self.size():if not self._deterministic:self._random_shuffle()self._indicator=0end_indicator=self._indicator+batch_sizeassert end_indicator<self.size()batch_filenames=self._img_feature_filenames[self._indicator:end_indicator]batch_img_features=self._img_feature_data[self._indicator:end_indicator]batch_sentence_ids,batch_weights=self._img_desc[batch_filenames]self._indicator=end_indicatorreturn batch_img_features,batch_sentence_ids,batch_weights,caption_data= ImageCaptionData(img_name_to_tokens_id,input_img_feature_dir,hps.num_timesteps,vocab)
img_feature_dim=caption_data.img_feature_size()
def create_rnn_cell(hidden_dim,cell_type):if cell_type =='lstm':return tf.contrib.rnn.BasicLSTMCell(hidden_dim,state_is_tuple=True)elif cell_type=='gru':return tf.contrib.rnn.GLSTMCell(hidden_dim)else:raise Exception("")def dropout(cell,keep_prob):return tf.contrib.rnn.DropoutWrapper(cell,output_keep_prob=keep_prob)#计算图实现
def get_train_model(hps,vocab_size,img_feature_dim):num_timesteps=hps.num_timestepsbatch_size=hps.batch_sizeimg_feature=tf.placeholder(tf.float32,(batch_size,img_feature_dim))sentence=tf.placeholder(tf.int32,(batch_size,num_timesteps))#第多少个是填充的mask=tf.placeholder(tf.int32,(batch_size,num_timesteps))keep_prob=tf.placeholder(tf.float32,name="keep_prob")global_step=tf.Variable(tf.zeros([],tf.int32),name="global_step",trainable=False)#predictoin process:#ground_truth:sentence:[a,b,c,d,e]#img_feature:[0.4,0.3,10,2]#img_feature->embedding_img->lstm->(a)#predict:a->embedding_word->lstm->(b)#...#Sets up embedding layerembedding_initiaizer=tf.random_uniform_initializer(-1.0,1.0)with tf.variable_scope('embedding',initializer=embedding_initiaizer):embeddings=tf.get_variable('embedding',[vocab_size,hps.num_embedding_nodes],tf.float32)# embed_token_ids:[batch_size,num_timestep-1,num_embedding_nodes]embed_token_ids= tf.nn.embedding_lookup(embeddings,sentence[:,0:num_timesteps-1])img_feature_embed_init=tf.uniform_unit_scaling_initializer(factor=1.0)with tf.variable_scope('img_feature_embed',initializer=img_feature_embed_init):#img_featre:[batch_size,img_fature_dim]# embed_img:[batch_size,num_embedding_nodes]embed_img=tf.layers.dense(img_feature,hps.num_embedding_nodes)# embed_img:[batch_size,1,num_embedding_nodes]embed_img=tf.expand_dims(embed_img,1)# embed_inputs:[batch_size,num_timesteps,num_embedding_nodes]embed_inputs=tf.concat([embed_img,embed_token_ids],axis=1)# Sets up rnn networkscale=1.0/math.sqrt(hps.num_embedding_nodes+hps.num_lstm_node)rnn_init=tf.random_uniform_initializer(-scale,scale)with tf.variable_scope('lstm_nn',initializer=rnn_init):#存储每一层的cellcells=[]for i in range(hps.num_lstm_layer):cell=create_rnn_cell(hps.num_lstm_node[i],hps.cell_type)cell=dropout(cell,keep_prob)cells.append(cell)cell=tf.contrib.rnn.MultiRNNCell(cells)#定义初始化的状态init_state=cell.zero_state(hps.batch_size,tf.float32)# rnn_outputs:[batch_size,num_timestep,hps.num_lstm_node[-1]]rnn_outputs,_=tf.nn.dynamic_rnn(cell,embed_inputs,init_state=init_state)#Sets up fully_connected layerfc_init=tf.uniform_unit_scaling_initializer(factor=1.0)with tf.variable_scope('fc',initializer=fc_init):rnn_otputs_2d=tf.reshape(rnn_outputs,[-1,hps.num_lstm_node[-1]])fc1=tf.layers.dense(rnn_otputs_2d,hps.num_fc_nodes,name="fc1")fc1_droupt=tf.contrib.layers.dropout(fc1,keep_prob)fc1_relu=tf.nn.relu(fc1_droupt)logits=tf.layers.dense(fc1_relu,vocab_size,name="logits")#损失函数with tf.variable_scope('loss'):sentence_flatten=tf.reshape(sentence,[-1])mask_flatten=tf.reshape(mask,[-1])mask_sum=tf.reduce_sum(mask_flatten)softmax_loss=tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,labels=sentence_flatten)weighted_softmax_loss=tf.multiply(softmax_loss,tf.cast(mask_flatten,tf.float32))loss=tf.reduce_sum(weighted_softmax_loss)/mask_sumprediction= tf.arg_max(logits,1,output_type=tf.int32)correct_prediction= tf.equal(prediction,sentence_flatten)weighted_coorect_prediction=tf.multiply(tf.cast(correct_prediction,tf.float32),mask_flatten)accuracy=tf.reduce_sum(weighted_coorect_prediction)/mask_sumtf.summary.scalar('loss',loss)#train opwith tf.variable_scope('train_op'):#获取所有可训练的变量tvars=tf.trainable_variables()for var in tvars:logging.info('variable name: %s' % var.name)grads,_=tf.clip_by_global_norm(tf.gradients(loss,tvars),hps.clip_lstm_grads)optimizer=tf.train.AdamOptimizer(hps.learning_rate)train_op=optimizer.apply_gradients(zip(grads,tvars),global_step=global_step)return ((img_feature,sentence,mask,keep_prob),(loss,accuracy,train_op),global_step)placeholder,metrics,global_step=get_train_model(hps,vocab_size,img_feature_dim)img_feature,sentence,mask,keep_prob=placeholderloss,accuracy,train_op=metricssummary_op=tf.summary.merge_all()init_op=tf.global_variables_initializer()saver=tf.train.Saver(max_to_keep=10)#模型训练
training_steps=1000with tf.Session() as sess:sess.run(init_op)writer=tf.summary.FileWriter(output_dir,sess.graph)for i in range(training_steps):(batch_img_features,batch_sentence_ids,batch_weights,_)=caption_data.next_batch(hps.batch_size)input_vals=(batch_img_features,batch_sentence_ids,batch_weights,hps.keep_prob)feed_dict=dict(zip(placeholder,input_vals))fetches=[global_step,loss,accuracy,train_op]#?????should_log=(i+1)% hps.log_frequent==0should_save=(i+1)% hps.save_frequent==0if should_log:fetches+=[summary_op]#输出outputs=sess.run(fetches,feed_dict=feed_dict)global_setp_val,loss_val,accuracy_val=outputs[0:3]if should_log:summary_str=outputs[-1]writer.add_summary(summary_str,global_setp_val)logging.info('step: %5d,loss%3.3f, accu:%3.3f'% (global_setp_val,loss_val,accuracy_val))if should_save:model_save_file=os.path.join(output_dir,"image_caption")logging.info('step: %5d, model saved' % global_setp_val)Saver.saver(sess,model_save_file,global_step=global_setp_val)
#模型训练
training_steps=1000with tf.Session() as sess:sess.run(init_op)