这篇论文同样作为将 Multi-task learning 应用于药物发现,较之 [1],其整理出一个新的数据集,做了大量的对比实验,更侧重于去验证 multi-task learning 的有效性,以及探究利用 MTL 后效果得到提升的潜在因素。
作者在实验后,得到了几个比较有意义的结论:
- 随着 task 和 data 的增加,模型的性能增益会衰减,但性能依旧会提高(至少在他的数据集中是这样);
- task 和 data 是两个影响模型的比较重要的因素;
- MTL模型中抽取出的特征,其具有 transferability;
- task之间存在的共有活性分子也对模型性能提升有一定的影响,而靶点类别并没有影响。
Experimental Section
这篇论文在模型设计上和 [1] 中类似,主要是通过大量的实验解释了如下几个问题:
- Do massively multitask networks provide a performance boost over simple machine learning? If do, what is the optimal architecture for massively multitask networks?
- How does the performance of a multitask network depend on the number of tasks? How does the performance depend on the total amount of data?
- Do massively multitask networks extract generalizable information about chemical space?
- When do datasets benefit from multitask training?
Experimental Exploration of Massively Multitask Networks
对于第一个问题中 optimal architecture 是 pyramidal multitask networks。这种结构的设计的目的,是出于为了解决 overfitting 的问题(一个主要的因素是没有强正则项),而作者考虑的是设计一个金字塔型的网络结构(pyramidal architecture),以此减少模型的参数,变相地消除 overfitting 问题。这种思路受启发于 GoogleNet [2] 中 1x1卷积核的使用。第一层是一个较宽的隐层(2000个nodes),第二层则采用一个较窄的隐层(100个nodes)。这样的好处在于,先用一个较宽的隐层抽取一个复杂的表征,再采用一个较窄的隐层,则是起到了一个降维的作用,减少了模型的参数(这个地方就类似于 1x1 卷积核在 GoogleNet中的使用,起到了一个降维的作用)。
Relationship between performance and number of tasks
对于第二个问题,作者实验上的设置是通过训练好几个 MTL 模型,不同点在于 task的数量,分别是 10, 20, 40, 80, 160, 249。实验的结果表明,随着模型中的 tasks 数量增加,模型的性能也在提高,至少在作者用的数据集中有249个 tasks,模型的性能提升依旧没达到瓶颈。
More tasks or more data?
对于第二个问题的后面那个问题,作者的实验设置则是训练样本数量不同的模型以及tasks数量不同的模型,进行对比得出的结论是,更多的data 和 tasks 对于模型性能的提升有一定的帮助。
Do massively multitask networks extract generalizable features?
针对第三个问题,作者则是去实验MTL模型中抽取出的特征能否做迁移学习(transfer learning),也就是用MTL模型中抽取的特征去初始化一个 single-task networks(训练集不包含在前者训练的数据集中),并做微调,与不使用迁移学习的 single-task networks进行对比。实验的结果证明,当采用 task 数量较小的 MTL模型,做迁移学习,其效果并不好,但是随着task数量的增江,其迁移学习的效果得到了提高。
总结
这篇论文并没有提出什么新的方法,但是其大量的对比实验得出的结论对于MTL应用于分子筛选比较有意义。如下是目前的一些思考:
一个是对于 overfitting 问题的处理。这个问题是DL应用于分子筛选任务中比较大的问题,因为对于一个单独的 task,其中的活性分子的样本是很少的(也就是 bias 的问题)。MTL以及作者提出的减少参数的方式是一个思路,但这个地方是不是可以借鉴最近的一些处理样本偏差问题的工作(好像这类问题在推荐系统任务中比较常见)去提升模型效果。
另外一个则是更多 tasks 和 data,这个可以去实验下,模型提升是否存在瓶颈。
还有一个利用MTL模型得到的表征,用作迁移学习的思路,本文中好像提到效果比不适用迁移学习的baseline效果要好,但并没有超过直接使用 MTL模型去做预测的结果来的好,但也算是一个不错的点子。
参考文献
[1] Unterthiner T, Mayr A, Klambauer G, et al. Deep learning as an opportunity in virtual screening[C]//Proceedings of the deep learning workshop at NIPS. 2014, 27: 1-9.
[2] Szegedy C, Liu W, Jia Y, et al. Going deeper with convolutions[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2015: 1-9.