【论文笔记】2019-WWW-Multiple Treatment Effect Estimation using Deep Generative Model with Task Embedding

背景

这篇文章考虑了一个新的causal inference设定:treatment不是简单的二元变量 { 0 , 1 } \left\{0,1\right\} {0,1},而是二元变量的组合 { 0 , 1 } k \left\{0,1\right\}^k {0,1}k。这个设定也比较好理解,还用医生治病的例子来说,通常医生使用的是多种药的组合。如果总共涉及到三种药物,而病人使用了第一种和第三种,则对应的 k = 3 k=3 k=3,treatment就是 [ 1 , 0 , 1 ] [1,0,1] [1,0,1]

挑战

这个设定的挑战在于如何设计针对多个treatment的网络结构。在经典的TARnet和Dragonnet中,作者针对 p ( y ∣ t = 0 , x ) p(y|t=0,x) p(yt=0,x) p ( y ∣ t = 1 , x ) p(y|t=1,x) p(yt=1,x)都设计了不同的网络,如果本文也沿用这个方法,就会出现网络结构冗余的问题。比如例子中涉及到3个treatment的组合,那就要相应设计 2 3 = 8 2^3=8 23=8个网络,非常不高效,还会出现因为数据分布不均匀网络训练不准确的问题。

方法

整体的框架还是套用的CEVAE(可以参见笔者写的上一篇文章),创新之处在于引入了一个可学习的embedding matrix。

Encoder

网络结构如下图所示:
在这里插入图片描述
前向传播:首先输入 x x x会经过网络 g 1 g_1 g1得到 q ( t ∣ x ) = ∏ i = 1 k B e r n ( q t , i ) q(t|x)=\prod_{i=1}^k Bern(q_{t,i}) q(tx)=i=1kBern(qt,i),然后从 q ( t ∣ x ) q(t|x) q(tx)中采样得到 t ′ t' t(这里有个问题就是怎么反向传播?采样得到 t ′ t' t没法反向传播吧),接下来 t ′ t' t会和一个embedding matrix W W W相乘得到新的表示 τ = W ⋅ t ′ \tau=W\cdot t' τ=Wt。新表示 τ \tau τ经过网络 g 2 g_2 g2得到 q ( y ∣ t , x ) = N ( g 2 , 1 ) q(y|t,x)=N(g_2,1) q(yt,x)=N(g2,1),这里方差设为1也是为了简单防止过拟合吧,避免网络中要学习太多变量。之后,作者把 τ , x , g 2 \tau, x, g_2 τ,x,g2concatenate到一起得到 g 3 g_3 g3 g 4 g_4 g4的输入, g 3 g_3 g3 g 4 g_4 g4的输出恰好是 q ( z ∣ x , t , y ) q(z|x,t,y) q(zx,t,y)的均值和方差。

Decoder

网络结构如下图所示:
在这里插入图片描述前向传播:这里作者没写清楚decoder的输入 z z z怎么来的(吐槽一句,作者有很多细节都没写清楚),我猜测就是从encoder的输出采样得到。接下来先看下面四个网络 f 1 , f 2 , f 3 , f 4 f_1,f_2,f_3,f_4 f1,f2,f3,f4,其实是针对 x x x的三种可能情形:二元变量、目录变量、连续变量,这里只以连续变量为例进行说明。 f 1 f_1 f1 f 2 f_2 f2的输出分别是 p ( x ∣ z ) p(x|z) p(xz)的均值和方差。 f 5 f_5 f5的设计和 g 1 g_1 g1基本一致,输出就是 p ( t ∣ z ) = ∏ i = 1 k B e r n ( p t , i ) p(t|z)=\prod_{i=1}^k Bern(p_{t,i}) p(tz)=i=1kBern(pt,i),然后继续采样得到 t ~ \widetilde{t} t t ~ \widetilde{t} t 再与embedding matrix相乘得到 τ ~ = W ⋅ t ~ \widetilde{\tau}=W \cdot \widetilde{t} τ =Wt 。之后作者在文章里说把 τ ~ , x , z \widetilde{\tau},x,z τ ,x,z concatanate到一起作为 f 6 f_6 f6的输入,但根据流程图似乎没有 x x x?(这个作者写作有点不认真啊,文章居然和图对不上)
作者没具体写出训练的目标函数(很迷,这么重要的东西居然文章里没有明确写出来),只是说利用和VAE类似的变分推断的方法,估计是和CEVAE差不多,先验分布也是标准正态分布。

总结

文章的亮点在于提出了multiple treatment的范式和embedding的解决思路,缺点在于作者写作实在太不严谨了,很多细节没交代清楚(当然也可能是我读的还不够细),类似于采样 t t t怎么反向传播、目标函数之类的都没有具体写出来。