本文用了强化学习,在知识图谱上游走,寻找目标节点。
一、简介
大概意思就是,在知识图谱上,给出一个起始节点和查询(query),然后找到目标节点。
图G包含节点和边。
如下图,给出起始节点Obama,query:citizenship,目标节点是USA。
我们要学习一个方法来预测。
我们我们将f作为强化学习的agent。他要学习搜索策略(search policy)
训练的时候,我们给出,让f自己学习路径,如果他走到,就给他一个正的reward(1分),其他时候是0分(没停或者停错地方都是0)。学完后只给出,预测。(这个reward只用在Q learning的时候了)
所以设计了一个神经网络的agent,叫M-walk。用RNN将历史路径转化为一个向量,用来学policy和Q function 。reward稀疏,所以用带蒙特卡洛树搜索的RNN,生成路径。
二、用马尔科夫决策过程来进行图的游走
(S,A,R,P) s是state,a是action,r是reward function,p是state transition probability
初始状态s0和下一个状态的表示,如上图所示。
是连接点nt的所有边,是nt的所有邻居节点。
st包括1)到t时刻所有走过的节点(包括他们的邻居和邻边) 2)动作 3)初始query q构成。
集合S由所有可能出现的st构成。
在状态st,agent有以下动作可以选择:1)选择中的一条边,他连接到点 2)选择STOP,则就是要预测的。通常是随着时间t而改变的。
t时刻的动作集合由下图表示,A是所有时刻的At的并集。
选择stop之后,输出
如果输出是(即输出了正确的答案),则reward=1,否则为0.
这可以看出来,reward是非常稀疏的,只有走到正确的位置才有reward。但是由于图是已知静态确定的,所以如果确定了上一个状态和动作,那么下一个状态时确定的。(文中说这有助于解决reward稀疏。)
π是policy(给出状态s,选择动作a),Q是Q function(在状态s下选择动作a,它的Q value是多少,即之后的长期收益是多少)
三、M-walk agent
3.1 π和Q的神经网路结构
用RNN获得当前状态st的表达ht
ht分为三个部分:
1) 将上个时间的状态、动作、当前节点,综合。
2)综合了nt的邻居n'节点,以及nt和n'之间的边e,代表第n'个候选动作(不包括STOP动作)
3) 综合了和,用来判断STOP的概率。
所以π和Q的计算。
u0是将hst,hAt通过一个full-connected neural network。(这里没说这两个h要怎么整合到一起,应该是拼接)
un'是hst和hn't做内积(即点乘,对应位相乘,求和)
u0(STOP的分数),un'(邻居的分数)都是一个数字
Q是对每个数字做sigmoid
(这里做sigmoid,将q value化到0-1,因为这个模型的分数只有0和1,q value=0代表在当前s采取a,预期的总reward是0,是找不到的,如果是1代表未来可能找到。)
π是做温度参数为τ的softmax
关于温度参数
3.2 训练算法
传统的使用蒙特卡罗方法的REINFORCE,需要sample一个完整的序列,sample的效率很低,而且reward稀疏。所以sample的时候使用PUCT算法的变体。
π是上面提到的策略分数(softmax算的),c和β用来控制探索的程度。N是visit count。W是走(s-a)这条边上的蒙特卡罗树的total action reward。
PUCT算法最开始倾向于选择在状态s下出现少的action(式子的前半部分), 后来倾向于选择分数高的(式子的后半部分)。
当PUCT算法选择了STOP,或者到达了最大探索数(应该是强行选择STOP),则停止。使用
用下面的式子,更新上一个式子中的N和W。γ是衰减因子(discount factor).
主要目标就是多生成reward为正的路径。
然后用DQN网络,寻找更好的π就是max Q
(由于Q和π共享参数,且算的时候只用了sigmoid和softmax这种没参数的函数,所以训一个就行)
莫烦python-DQN网络代码详解(pytorch)
3.3预测算法
已知(ns,q)求nT。利用π在G上寻找nT。
一种方法是用训练好的π去寻找。然而这并没有用MDP的转移模型(?)(下方这个公式)
所以利用上面训练好的模型π、Q去生成蒙特卡罗树,就像训练时那样(Q stop作为上文提到的V进行更新)。但是可能有多路径到达同一个终止节点n。走不同路径,就有不同的叶子节点是n。
怎么比较选择哪个终点n(而且n需要综合多条路径),需要算一个分数,排序。
N是蒙特卡罗树的总模拟数量
综合叶子节点是n的情况,求n的分数。
在所有的候选节点中,我们选择score最大的。
3.4 RNN encoder
qt约等于右边的式子
所以st大约可以写成
st由两部分组成 1) 代表候选动作(包括STOP) 2)qt代表历史
所以用两个不同的神经网络去编码他们
前面说过,ht分为三个部分:
1) 将上个时间的状态、动作、当前节点,综合。
2)综合了nt的邻居n'节点,以及nt和n'之间的边e,代表第n'个候选动作(包括STOP动作)
3) 综合了和,用来判断STOP的概率。
求 2)的方法很简单,就是边和点的表达通过full-connected neural network
求 3) 的方法,就是max 2)的结果,因为每一次的节点数可能都不一样,这样可以得到统一的结果
求1) 就是编码qt 使用gru的思想
可以看出,q就相当于rnn里的hidden,初始是query q,之后为qt,rnn的输入是