当前位置: 代码迷 >> 综合 >> 【主动学习论文】Bayesian Generative Active Deep Learning,ICML 2019
  详细解决方案

【主动学习论文】Bayesian Generative Active Deep Learning,ICML 2019

热度:48   发布时间:2024-02-02 04:55:32.0

Bayesian Generative Active Deep Learning

  • 一、解决的问题
  • 二、模型的框架:BALD挑选策略+VAE-ACGAN生成网络
    • (1)Bayesian Active Learning by Disagreement
    • (2)变分自编码器 (VAE, Variational Auto Encoder)
    • (3)ACGAN(Auxiliary Classifier Generative Adversarial Network)
    • (4)Bayesian Generative Active Deep Learning
    • (5)损失函数
  • 三、实验与结果
    • (1) 对比的模型
    • (2) 使用的数据集
    • (3) 使用的超参数
    • (4) 使用的分类器
    • (5) 结果
  • 四、参考文献
  • 五、其他链接

原文传送链接

  • author: Toan Tran1 Thanh-Toan Do2 Ian Reid1 Gustavo Carneiro1
  • 1University of Adelaide, Australia; 2University of Liverpool.
  • Correspondence to: Toan Tran toan.m.tran@adelaide.edu.au.
  • 2019, ICML

一、解决的问题

当标记数据量较少的时候,如何应用深度学习框架解决多分类任务?基于此,作者提出了Bayesian Generative Active Deep Learning这一框架,融合AL和DA方法,生成有价值的(informative)数据,扩大标注数据集,从而达到提高模型分类正确率的效果。

二、模型的框架:BALD挑选策略+VAE-ACGAN生成网络

(1)Bayesian Active Learning by Disagreement

  • 利用Shannon Entropy(香农熵,即信息熵;不同于交叉熵损失)度量样本的信息量;信息熵的差值代表当某一样本的信息熵和总体的平均信息熵之间的差值,这一差值越大,越能说明该样本相对于平均水平包含更多的信息量。

x ? = a r g m a x x D p o o l α ( x , M ) = a r g m a x x D p o o l H [ Y X , D ] ? E θ   p ( θ D ) [ H [ y x , θ ] ] x^{*}=argmax_{x \in D_{pool}} \alpha(x,M)\\=argmax_{x \in D_{pool}}{H[Y|X,D]-E_{\theta~p(\theta|D)}[H[y|x,\theta]]}

  • 这一公式是计算样本x在分类模型M下的函数值
  • H [ Y X , D ] H[Y|X,D] 是计算x预测值 p ( y x , D ) p(y|x,D) 的信息熵
  • E θ   p ( θ D ) [ H [ y x , θ ] ] E_{\theta~p(\theta|D)}[H[y|x,\theta]] 是计算分布 p ( y x , θ ) p(y|x,\theta) 的信息熵,其中 θ \theta 是网络M的参数

α ( x , M ) ? c 1 T t p ^ c t ) ? l o g ( 1 T t p ^ c t ) + 1 T c , t p ^ c t ? l o g p ^ c t \alpha(x,M)\approx-\sum_{c}(\frac{1}{T}\sum_t \hat p_{c}^{t}) ·log(\frac{1}{T}\sum_t \hat p_{c}^{t}) )+\frac{1}{T}\sum_{c,t} \hat p_{c}^{t}·log\hat p_{c}^{t}

  • 蒙特卡洛模拟过程(由于无法直接计算上式的期望,故通过模拟的方法得到其近似值):

    • 在不修改现有深度网络模型的基础上,只要模型带有dropout层,便可以完成模拟过程。
    • 训练的时候MC dropout的表现形式和dropout没有区别,按照正常方式训练即可
    • 但是在测试的时候,在前向传递过程中,dropout是不关闭的。
    • 为了得到上述函数的一个近似值,我们对同一个输入进行多次前向传递过程,这相当于在蒙特卡洛方法的帮助下我们得到为了不同网络结构的输出,然后对这些输出去平均值,就可以得到模型的预测结果。
    • 同时这个多次前向传递过程是可以并行的(进行一次完整的前向传递,然后随机选择一定比例的模型参数),因此在时间上相当于是一次前向传播。
  • 参数解读

    • t是指第t次dropout,t=0,1,…,T;c是指分类类别, c { 1 , 2 , . . . , C } c\in\{1,2,...,C\}
    • p ^ t = [ p ^ 1 t , p ^ 2 t , . . . , p ^ C t ] = s o f t m a x ( f ( x ; θ t ) ) \hat p^{t}=[\hat p_{1}^{t},\hat p_{2}^{t},...,\hat p_{C}^{t}]=softmax(f(x;\theta^t)) f f 是指参数为 θ t \theta^t 的网络M, θ t \theta^t 是指第t次未被dropout的参数
  • 近似函数意义解读:这相当于一个琴生不等式

    • 根据琴生不等式:这里的主函数是 f ( x ) = ? x ? l o g x f(x)=-x·logx ,由于二阶偏导 ? 2 f ( x ) ? x 2 = ? 1 x < 0 \frac{\partial^2 f(x)}{\partial x^2}=-\frac1x < 0 ,期望的函数值要大于函数值的期望, f ( E ( x ) ) > E ( f ( x ) ) f(E(x))>E(f(x))
    • α ( x , M ) ? ( E x ) ? l o g ( E x ) ? ( ? E ( x ? l o g x ) ) \alpha(x,M)\approx-(Ex)·log(Ex)-(-E(x·logx))
    • 具体计算过程如图所示,最终结果为 2.78722-2.64644
      在这里插入图片描述
  • 实际意义解读:度量的是 模型预测结果的波动程度

    • 考虑一个二分类问题,此时x是一个二维列向量:XOY平面是列向量的取值空间,Z是对应列向量的交叉熵的取值;所以这是一个交叉熵函数图像。
    • 当样本的预测值波动越大时,T个列向量之间的差距越大,T个点(淡橘色)会分散地分布于不同的坡面上;此时样本熵的均值(深橘色)和样本均值的熵(蓝色)之间的差距会比较大
    • 当样本的预测值波动越小时,T个列向量之间的差距越小,T个点(淡橘色)会分散地分布于相同的坡面上;此时样本熵的均值(深橘色)和样本均值的熵(蓝色)之间的差距会比较小
      模型预测结果的波动程度
    • 按照作者的理解,BALD寻找的是位于分类边界的样本。
      在这里插入图片描述

(2)变分自编码器 (VAE, Variational Auto Encoder)

  • 变分自编码器由一个编码器和一个解码器构成-

    • 编码器:通过真实样本,学习到真实样本在latent space当中的后验分布的均值和方差
    • 解码器:通过在latent space中采样,生成一个新的样本
      [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-kCnCgKql-1595943497623)(https://s3-us-west-2.amazonaws.com/secure.notion-static.com/95173f16-6d27-4007-80b5-89b7b9aba359/Untitled.png)]
  • 损失函数:最小化近似后验分布和真实后验分布之间的KL散度

    m i n { D K L ( q θ ( z x ) p ( z x ) ) } min\{D_{KL}(q_{\theta}(z|x)||p(z|x))\}

  • 该损失函数无法直接计算,通过等价变形得到近似的损失函数计算方法

    m i n { D K L ( q ? ( z L x L p ( z ) ) ) ? E [ l o g ( p θ ( x L z L ) ) ] } min \{D_{KL}(q_{\phi}(z_{L}|x_{L}||p(z))) - E[log (p_{\theta}(x_{L}|z_{L}))]\}

    第一部分是近似后验分布和先验分布之间的KL散度,第二部分是样本的重构损失

  • 等价损失函数的推导过程
    ![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-h47eMKPz-159594

(3)ACGAN(Auxiliary Classifier Generative Adversarial Network)

  • ACGAN由一个生成网络和判别网络构成

    • 生成网络:通过输入的噪音(Z)和类别信息(C)生成具有标签的假数据(X_fake)
    • 判别网络:完成两个分类任务:1)区别真假数据‘;2)将数据正确分类’
  • 对抗体现在真假数据区分,而对数据做多类别分类任务时不产生对抗

  • 损失函数:生成器最小化 L S ? L C L_{S}-L_{C} ,判别器最小化 L C + L S L_{C}+L_{S}

    L C = ? E [ l o g P ( C = c X r e a l ) ] ? E [ l o g P ( C = c X f a k e ) ] L_{C}=-E[log P(C=c|X_{real})]-E[log P(C=c|X_{fake})]

    L S = ? E [ l o g P ( S = r e a l X r e a l ) ] ? E [ l o g P ( S = f a k e X f a k e ) ] L_{S} =-E[log P(S=real|X_{real})]-E[log P(S=fake|X_{fake})]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-IZft7p9n-1595943497627)(https://s3-us-west-2.amazonaws.com/secure.notion-static.com/1c4a6128-45ee-4f9d-bcdd-ecd1ad58c208/Untitled.png)]

(4)Bayesian Generative Active Deep Learning

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-D1f82FOn-1595943497629)(https://s3-us-west-2.amazonaws.com/secure.notion-static.com/f687a9e9-483a-4326-b0f4-282cdab3c7ed/Untitled.png)]

  • 根据BALD策略,选择信息量最大的样本 x ? x^* ,然后将其送到Oracle处打上标签 y ? y^* ,获得 ( x ? , y ? ) (x^*,y^*)

  • 然后控制生成模型g,生成信息量大的样本 x , y ? (x', y^*)

    • 重构损失最小化保证了样本之间的相似性,从而也保证了生成样本具备一定的信息量
  • 然后通过判别器D使得生成的数据逼近真实数据,且其标签 y ? y^* 是合理的

  • 因此,作者通过生成“真实的”带有给定标签的数据,完成数据增强的任务。其中“真实”是通过D的区别真假数据的对抗获得的,数据带有给定标签是通过ACGAN的判别器的分类器以及VAE的输入标签信息获得的。VAE-ACGAN解决了传统生成网络生成的数据缺乏真实性、模式崩塌的问题,获得了更加真实的、具备多样性的数据。

(5)损失函数

  • 优化网络 e ( x ; θ E ) e(x;\theta_{E}) L = L r e c ( x , g ( e ( x ; θ E ) ) ; θ G ) + D K L ( q ( z x ) p ( z ) ) L=L_{rec}(x,g(e(x;\theta_E));\theta_G)+D_{KL}(q(z|x)||p(z))

  • L A C G A N = l o g d ( ( x ; θ D ) ) + l o g ( 1 ? d ( g ( z ; θ G ) ; θ D ) ) + l o g ( 1 ? d ( g ( u ; θ G ) ; θ D ) ) + l o g ( s o f t m a x ( c ( x ; θ C ) ) + l o g ( s o f t m a x ( c ( g ( z ; θ G ) ; θ C ) ) + l o g ( s o f t m a x ( c ( g ( u ; θ G ) ; θ C ) ) L_{ACGAN}=logd((x;\theta_D))+log(1-d(g(z;\theta_G);\theta_D))+log(1-d(g(u;\theta_G);\theta_D))\\+log(softmax(c(x;\theta_C))+log(softmax(c(g(z;\theta_G);\theta_C))\\+log(softmax(c(g(u;\theta_G);\theta_C))

    z = e ( x ; θ E ) z=e(x;\theta_E) ,为了优化VAE

    u ? N 0 , 1 u\sim N(0,1) ,为了优化ACGAN

  • 优化网络 g ( z ; θ G ) g(z;\theta_G) L = γ L r e c ( x , g ( e ( x ; θ E ) ) ; θ G ) ? L A C G A N L=\gamma L_{rec}(x,g(e(x;\theta_E));\theta_G)-L_{ACGAN}

    重构损失、欺骗判别器、正确分类

  • 优化网络 d ( x ; θ D ) d(x;\theta_D) L = L A C G A N L=L_{ACGAN}

  • 优化网络 c ( x ; θ C ) c(x;\theta_C) L = L A C G A N L=L_{ACGAN}

三、实验与结果

(1) 对比的模型

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-8Il1by9d-1595943497631)(https://s3-us-west-2.amazonaws.com/secure.notion-static.com/cce4efda-d53a-4712-8ebe-698de08e950d/Untitled.png)]

(2) 使用的数据集

数据集信息

(3) 使用的超参数

  • C使用SGD,lr=0.01,momentum=0.9
  • E、G、D使用Adam,lr=0.0002, β 1 = 0.5 β 2 = 0.999 \beta_1=0.5,\beta_2=0.999
  • mini-batch=100
  • 三次正确率的均值

(4) 使用的分类器

  • ResNet18

  • ResNet18pa

(5) 结果

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bgl4ZPEb-1595943497632)(https://s3-us-west-2.amazonaws.com/secure.notion-static.com/2f9874ae-d6ea-435f-b1ed-df7b5add15b3/Untitled.png)]

  • 使用生成模型生成的结果作数据加强的精度普遍比不使用的高
  • AL+VAEACGAN的精度基本接近使用全部数据作十倍数据扩大的精度
  • 同时对比VAAL的实验结果,发现在CIFAR10/100上精度基本高15%
  • 一个大胆的猜测:挑选策略不是很重要,生成式的数据增强能够大幅度提高精度

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-2tkkisTk-1595943497634)(https://s3-us-west-2.amazonaws.com/secure.notion-static.com/e5e43c20-d073-441b-bfc5-db110a89f180/Untitled.png)]

  • 使用VAEACGAN生成的数据更具有价值,信息量更高

四、参考文献

[1] Tran T, Do T T, Reid I, et al. Bayesian generative active deep learning[J]. arXiv preprint arXiv:1904.11643, 2019.

五、其他链接

b站相关视频:https://www.bilibili.com/video/BV1aT4y1E7yU
本期作者相关主页:https://me.csdn.net/StatsChenXiaoshu

  相关解决方案