当前位置: 代码迷 >> 综合 >> RL策略梯度方法之(四): Asynchronous Advantage Actor-Critic(A3C)
  详细解决方案

RL策略梯度方法之(四): Asynchronous Advantage Actor-Critic(A3C)

热度:2   发布时间:2023-12-15 05:35:43.0

本专栏按照 https://lilianweng.github.io/lil-log/2018/04/08/policy-gradient-algorithms.html 顺序进行总结 。


文章目录

  • 原理解析
  • 算法实现
    • 总体流程
    • 代码实现


A3C\color{red}A3CA3C :[ paper | code ]


原理解析

在A3C中,critic 学习值函数,同时多个 actor 并行训练,并不时地与全局参数同步。因此,A3C可以很好地用于并行训练。

  • 服务器的每个核都是一个线程,也就是一个平行世界,同样的一个程序,在平行世界里同时运行,可以成倍数的提升运行速度。每个线程中的运行结果反馈给主网络,同时从主网络获取最新的参数更新,这样就将多个线程结合在一起,同时进一步减弱了事件的相关性,利于程序的收敛。

以 state-value 函数为例,状态值的损失函数是最小均方误差:Jv(w)=(Gt?Vw(s))2J_v(w) = (G_t - V_w(s))^2Jv?(w)=(Gt??Vw?(s))2,梯度下降可以用来寻找最优的 www。此状态值函数用作策略梯度更新的基线。

接下来,看一下A3C 的实现流程图:
在这里插入图片描述

可以看到,我们有一个主网络,还有许多Worker(主网络和每个线程中的网络结构相同,均为AC结构,唯一不同点在于主网络不需要进行训练,仅用于存储AC结构的参数。)。A3C主要有两个操作,一个是pull,一个是push:
pull:把主网络的参数直接赋予Worker中的网络
push:使用各Worker中的梯度,对主网络的参数进行更新

算法实现

总体流程

以下是算法的整体流程:

  1. 全局参数 θ\thetaθwww,相同的 特定线程的参数 θ′\theta'θw′w'w
  2. 初始化时间步:t=1t=1t=1
  3. T≤TMAXT \le T_{MAX}TTMAX?
    (1). 重置梯度 dθ=0,dw=0d\theta =0, dw=0dθ=0,dw=0
    (2). 使用全局参数同步特定线程的参数 θ′=θ\theta' = \thetaθ=θw′=ww' = ww=w
    (3). tstart=tt_{start}=ttstart?=t,采样一个初始状态 sts_tst?
    (4). st!=s_t !=st?!= TERMINAL &&t?tstart≤tmax\ \ \&\& \ \ \ t - t_\text{start} \leq t_\text{max}  &&   t?tstart?tmax?):
    \quad① 取动作:At?πθ′(At∣St)A_t \sim \pi_{\theta'}(A_t \vert S_t)At??πθ?(At?St?),从环境中获得奖励 RtR_tRt? 和 下一个状态 st+1s_{t+1}st+1?
    \quad② 更新 t=t+1t=t+1t=t+1T=T+1T=T+1T=T+1
    (5). 初始化 保存return估计 的变量
    \quadR={0if stis TERMINALVw′(st)otherwiseR = \begin{cases} 0 & \text{if } s_t \text{ is TERMINAL} \\ V_{w'}(s_t) & \text{otherwise} \end{cases}R={ 0Vw?(st?)?if st? is TERMINALotherwise?
    (6). 对于i=t?1,…,tstarti = t-1, \dots, t_\text{start}i=t?1,,tstart?
    \quadR←γR+RiR \leftarrow \gamma R + R_iRγR+Ri?,这里 RRRGiG_iGi? 的一个 MC 度量
    \quad② 累加梯度关于θ′\theta'θdθ←dθ+?θ′log?πθ′(ai∣si)(R?Vw′(si))d\theta \leftarrow d\theta + \nabla_{\theta'} \log \pi_{\theta'}(a_i \vert s_i)(R - V_{w'}(s_i))dθdθ+?θ?logπθ?(ai?si?)(R?Vw?(si?))
    \qquad累加梯度关于w′w'wdw←dw+2(R?Vw′(si))?w′(R?Vw′(si))dw \leftarrow dw + 2 (R - V_{w'}(s_i)) \nabla_{w'} (R - V_{w'}(s_i))dwdw+2(R?Vw?(si?))?w?(R?Vw?(si?))
    (7). 使用 dθ,dw\mathrm{d}\theta, \mathrm{d}wdθ,dw 异步更新 θ,w\theta, wθ,w

算法流程图如下:
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

A3C允许在多个agent的训练中并行进行。梯度积累步骤(6.2)可以看作是对基于mini-batch的随机梯度更新的并行改善:wwwθ\thetaθ 的值分别在每个训练线程的方向上进行一点独立的修正。

代码实现

https://github.com/MorvanZhou/pytorch-A3C

  相关解决方案