当前位置: 代码迷 >> 综合 >> tensorflow2 按层加载模型参数 finetune
  详细解决方案

tensorflow2 按层加载模型参数 finetune

热度:70   发布时间:2023-12-15 15:55:30.0

tensorflow 2.3.1
python 3.6.13

model = ArcfaceModel()
# 加载原始权重文件
ckpt_path = './checkpoints/arc_mbv2' # 该文件夹下存放了4个文件:checkpoint, e_8_b_40000.ckpt.data-00000-of-00002, e_8_b_40000.ckpt.data-00001-of-00002,e_8_b_40000.ckpt.index
ckpt_path = tf.train.latest_checkpoint(ckpt_path)
model.load_weights(ckpt_path)# 保存为pb模型
model.save('./checkpoints/tf_model')# 加载保存的pb模型参数
pretrain_model_weights = tf.saved_model.load('./checkpoints/tf_model')
params_dict = {
    }
for v in pretrain_model_weights.trainable_variables:params_dict[v.name] = v.read_value()# 加载 除最后一层外 的预训练参数到模型,以便下一步finetune
for idx,layer in enumerate(model.variables[:-1]):layer.assign(pretrain_model_weights.variables[idx])
  相关解决方案