Causal Effect Inference with Deep Latent-Variable Models
笔者最近在做causal inference这个方向,因此会把日常读到的还(neng)不(kan)错(dong)的paper简单整理一下做个笔记,欢迎感兴趣的童鞋交流讨论~
背景
Causal inference涉及到的数据集通常由三个变量组成{X,T,Y}\left\{X,T,Y\right\}{ X,T,Y}。其中,XXX代表特征(covariate),例如病人的身体、经济状况,TTT代表某个操作(treatment),通常是0-1的,例如是否服用某种药物,YYY代表输出(outcome),例如病人一段时间后的血压血糖水平。简单来说,causal inference的任务是想在给定XXX的情况下估计TTT对YYY的影响。
本文作者考虑的一个问题是如何消除hidden confounder对causal inference的影响。简单来说,confounder ZZZ就是会对TTT和YYY都产生影响的变量,例如一个人的经济实力社会地位,这些会对他是否能够服用某种药物产生影响,但ZZZ又是一个很难准确观测的变量。这里,作者假设可观测到的XXX是ZZZ的代理变量,例如我们虽然很难准确度量一个人的社会地位ZZZ,但可以通过调查他的职业收入XXX侧面反映ZZZ。这里,作者构建了一个如下的因果图:
这个因果图理解起来也比较直观,深色的是可以观测到的,白色的是无法观测到的,XXX是ZZZ的一个noisy observation,因此Z→XZ\rightarrow XZ→X,其他几个箭头都是causal inference里的常用假设。
方法
其实看到这个因果图,熟悉VAE的童鞋可能已经猜到了作者的思路,就是把ZZZ当做隐空间表示,然后套用VAE的架构。
Encoder:文章里叫inference network,结构如下图:
这个结构作者参考的是TARnet网络,这是causal inference里一个非常经典的深度模型,会在之后的博客里介绍。q(t∣x)q(t|x)q(t∣x)是在计算propensity score(不过这个东西在原始TARnet并没有用到,估计是作者为了实验效果后加上去的一项),在学完共同特征表示之后,根据t=0/1t=0/1t=0/1接出两个分支。
Decoder:文章中叫model network,结构如下:
这个结构可以根据之前的因果图分解得到:
p(x,t,y,z)=p(z)p(t,x,y∣z)=p(z)p(t,y∣z)p(x∣t,y,z)=p(z)p(t,y∣z)p(x∣z)=p(z)p(t∣z)p(y∣t,z)p(x∣z)p(x,t,y,z)=p(z)p(t,x,y|z)=p(z)p(t,y|z)p(x|t,y,z)=p(z)p(t,y|z)p(x|z)=p(z)p(t|z)p(y|t,z)p(x|z)p(x,t,y,z)=p(z)p(t,x,y∣z)=p(z)p(t,y∣z)p(x∣t,y,z)=p(z)p(t,y∣z)p(x∣z)=p(z)p(t∣z)p(y∣t,z)p(x∣z)
目标函数的推导与VAE基本一致:
当然,就像笔者之前提到的,为了实验效果,作者又在原始VAE loss上加了新的两项:
结论
这应该是第一篇利用深度生成模型求解causal inference的文章,文章的motivation(解决hidden confounder)和构建因果图的方式(XXX是ZZZ的noisy observation)很让人信服,不过实验效果好像一般(hhh可能也是因为如此大家都喜欢把它当做baseline),套用VAE的框架也不算难,读起来也比较轻松。