当前位置: 代码迷 >> 综合 >> 【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search
  详细解决方案

【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search

热度:18   发布时间:2024-02-22 16:43:18.0

本文用了强化学习,在知识图谱上游走,寻找目标节点。

一、简介

大概意思就是,在知识图谱上,给出一个起始节点和查询(query),然后找到目标节点。

 图G包含节点和边。

如下图,给出起始节点Obama,query:citizenship,目标节点是USA。

 

 

我们要学习一个方法来预测。

我们我们将f作为强化学习的agent。他要学习搜索策略(search policy)

训练的时候,我们给出,让f自己学习路径,如果他走到,就给他一个正的reward(1分),其他时候是0分(没停或者停错地方都是0)。学完后只给出,预测。(这个reward只用在Q learning的时候了)

 

所以设计了一个神经网络的agent,叫M-walk。用RNN将历史路径转化为一个向量,用来学policy和Q function 。reward稀疏,所以用带蒙特卡洛树搜索的RNN,生成路径。

 

二、用马尔科夫决策过程来进行图的游走

(S,A,R,P) s是state,a是action,r是reward function,p是state transition probability

初始状态s0和下一个状态的表示,如上图所示。

\varepsilon _{n_{t}}是连接点nt的所有边,N_{n_{t}}是nt的所有邻居节点。

st包括1)到t时刻所有走过的节点(包括他们的邻居和邻边) 2)动作 3)初始query q构成。

 

集合S由所有可能出现的st构成。

在状态st,agent有以下动作可以选择:1)选择\varepsilon _{n_{t}}中的一条边,他连接到点n_{t+1} 2)选择STOP,则n_{t}就是要预测的n_{T}。通常是随着时间t而改变的。

t时刻的动作集合由下图表示,A是所有时刻的At的并集。

选择stop之后,输出

 

如果输出是n_{T}(即输出了正确的答案),则reward=1,否则为0.

这可以看出来,reward是非常稀疏的,只有走到正确的位置才有reward。但是由于图是已知静态确定的,所以如果确定了上一个状态和动作,那么下一个状态时确定的。(文中说这有助于解决reward稀疏。)

 

π是policy(给出状态s,选择动作a),Q是Q function(在状态s下选择动作a,它的Q value是多少,即之后的长期收益是多少)

 

三、M-walk agent

3.1 π和Q的神经网路结构

用RNN获得当前状态st的表达ht

ht分为三个部分:

1) 将上个时间的状态、动作、当前节点,综合。

2)综合了nt的邻居n'节点,以及nt和n'之间的边e,代表第n'个候选动作(不包括STOP动作)

3)  综合了\varepsilon _{n_{t}}N_{n_{t}},用来判断STOP的概率。

 

所以π和Q的计算。

u0是将hst,hAt通过一个full-connected neural network。(这里没说这两个h要怎么整合到一起,应该是拼接)

un'是hst和hn't做内积(即点乘,对应位相乘,求和)

u0(STOP的分数),un'(邻居的分数)都是一个数字

Q是对每个数字做sigmoid

(这里做sigmoid,将q value化到0-1,因为这个模型的分数只有0和1,q value=0代表在当前s采取a,预期的总reward是0,是找不到的,如果是1代表未来可能找到。)

π是做温度参数为τ的softmax

关于温度参数

 

3.2 训练算法

传统的使用蒙特卡罗方法的REINFORCE,需要sample一个完整的序列,sample的效率很低,而且reward稀疏。所以sample的时候使用PUCT算法的变体。

 

π是上面提到的策略分数(softmax算的),c和β用来控制探索的程度。N是visit count。W是走(s-a)这条边上的蒙特卡罗树的total action reward。

PUCT算法最开始倾向于选择在状态s下出现少的action(式子的前半部分), 后来倾向于选择分数高的(式子的后半部分)。

当PUCT算法选择了STOP,或者到达了最大探索数(应该是强行选择STOP),则停止。使用

用下面的式子,更新上一个式子中的N和W。γ是衰减因子(discount factor).

主要目标就是多生成reward为正的路径。

然后用DQN网络,寻找更好的π就是max Q

(由于Q和π共享参数,且算的时候只用了sigmoid和softmax这种没参数的函数,所以训一个就行)

莫烦python-DQN网络代码详解(pytorch)

 3.3预测算法

已知(ns,q)求nT。利用π在G上寻找nT。

一种方法是用训练好的π去寻找。然而这并没有用MDP的转移模型(?)(下方这个公式)

所以利用上面训练好的模型π、Q去生成蒙特卡罗树,就像训练时那样(Q stop作为上文提到的V进行更新)。但是可能有多路径到达同一个终止节点n。走不同路径,就有不同的叶子节点是n。

怎么比较选择哪个终点n(而且n需要综合多条路径),需要算一个分数,排序。

N是蒙特卡罗树的总模拟数量

综合叶子节点是n的情况,求n的分数。

在所有的候选节点中,我们选择score最大的。

 

3.4 RNN encoder

 qt约等于右边的式子

所以st大约可以写成

st由两部分组成 1)\varepsilon _{n_{t}}  N_{n_{t}}代表候选动作(包括STOP) 2)qt代表历史

所以用两个不同的神经网络去编码他们

前面说过,ht分为三个部分:

1) 将上个时间的状态、动作、当前节点,综合。

2)综合了nt的邻居n'节点,以及nt和n'之间的边e,代表第n'个候选动作(包括STOP动作)

3)  综合了\varepsilon _{n_{t}}N_{n_{t}},用来判断STOP的概率。

求 2)的方法很简单,就是边和点的表达通过full-connected neural network

求 3) 的方法,就是max 2)的结果,因为每一次的节点数可能都不一样,这样可以得到统一的结果

求1) 就是编码qt 使用gru的思想

可以看出,q就相当于rnn里的hidden,初始是query q,之后为qt,rnn的输入是[h_{A,t},h_{a_{t},t},n_{t+1}]

 

  相关解决方案