当前位置: 代码迷 >> 综合 >> OpenAI baseline GAIL代码讲解及其可视化
  详细解决方案

OpenAI baseline GAIL代码讲解及其可视化

热度:91   发布时间:2023-11-20 23:02:00.0

最近在研究关于强化学习的部分工作,首先从OpenAI的Baseline中的小型GAIL算法出发。

首先参考了大神的文章从《西部世界》到GAIL(Generative Adversarial Imitation Learning)算法。

原文链接:https://blog.csdn.net/jinzhuojun/article/details/85220327#commentBox

对大神写的文章做一些补充和细节解释。

在baseline 的文件夹中运行即可以进行模型的训练

python3 -m baselines.gail.run_mujoco

在run_mujoco.py代码中写到

   parser.add_argument('--task', type=str, choices=['train', 'evaluate', 'sample'], default='train')

 可以在命令行后面添加 --task 改变任务为train 和evaluate。evaluate后面要加上存储的模型的地址

# 假设训练模型放在/home/jzj/source/baselines/checkpoint/trpo_gail.transition_limitation_-1.Hopper.g_step_3.d_step_1.policy_entcoeff_0.adversary_entcoeff_0.001.seed_0/
python3 -m baselines.gail.run_mujoco --task=evaluate  --load_model_path=/home/jzj/source/baselines/checkpoint/trpo_gail.transition_limitation_-1.Hopper.g_step_3.d_step_1.policy_entcoeff_0.adversary_entcoeff_0.001.seed_0/trpo_gail.transition_limitation_-1.Hopper.g_step_3.d_step_1.policy_entcoeff_0.adversary_entcoeff_0.001.seed_0

在baseline 中使用tensorflow方式存储模型:  在trpo_mpi.py  232行。

        # Save modelif rank == 0 and iters_so_far % save_per_iter == 0 and ckpt_dir is not None:fname = os.path.join(ckpt_dir, task_name)#U.save_variables(fname)#print("the save path is ",fname)os.makedirs(os.path.dirname(fname), exist_ok=True)saver = tf.train.Saver()saver.save(tf.get_default_session(), fname)

所以在checkpoint中存储了可以用tensorflow方式读取模型的三个文件,而在运行评估模型时读取模型的方式采用的是baseline 中common自己定义的    U.load_variables(load_model_path)来读取文件,读取文件的类型是上面由tensorflow生成的文件的集合体。

    U.load_variables(load_model_path)

因此在存储模型的时候也应该采用common中的定义的save_variables来存储模型生成集成文件:

        if rank == 0 and iters_so_far % save_per_iter == 0 and ckpt_dir is not None:fname = os.path.join(ckpt_dir, task_name)U.save_variables(fname)print("the save path is ",fname)os.makedirs(os.path.dirname(fname), exist_ok=True)saver = tf.train.Saver()saver.save(tf.get_default_session(), fname)

 然后运行train的命令行,在训练100次迭代之后就可以在保存模型的文件夹中发现一个无.data/.index/.meta后缀的集成文件。

此时再运行evaluate命令行就可以出现对模型的评估返回数据

在run_mujoco.py中的traj_1_generator函数中的while函数中插入env.render()就可以渲染出模型可视化结果。

    while True:ac, vpred = pi.act(stochastic, ob)obs.append(ob)news.append(new)acs.append(ac)ob, rew, new, _ = env.step(ac)rews.append(rew)env.render()cur_ep_ret += rewcur_ep_len += 1if new or t >= horizon:breakt += 1

感谢大佬的分享,同时在遇到困难的时候还是要敢于挑战权威呀。

  相关解决方案