Review - Learning to Diagnose with LSTM Recurrent Neural Networks

最近在忙着搭框架和了解签证和硕士课程的问题,鸽了一段时间……没办法,没人疼的孩子什么都得一手抓,自己处理数据,自己写初始试验框架【在做的】,自己看论文,自己……好吧,也就这会儿觉得业界真是又好薪水又多,某商汤大牛今天刚刚把我秀得北都找不着。

回归正题———医学数据标记在近年来成为了不错的研究热点,尤其是近年精准医疗的发展使得医学大数据更加地受到人们的关注,学界的很多大牛以及研究机构开始和医院合作获得相关的数据并设计数据模型做出诊断。今天介绍的这篇《Learning to Diagnose with LSTM Recurrent Neural Networks》就是ICLR 2016上发表的一篇针对这个领域的应用型论文。原文链接:Learning to Diagnose with LSTM Recurrent Neural Networks

根据论文作者的Literature Review和作者目前实习所获取的知识,医学测量数据,尤其是时序上的数据,普遍具有以下几个特点:
1)数据集子集(episode)长度不固定。比如我们在使用多导心电图(Polysomnography, PSG)获得声音数据时,录制得到的病人不同睡眠周期的声音长度往往是不一样的;
2)数据本身的长相关性(Long Time/Range Dependency)。这个在上一篇Adadelta的介绍里提到过。这个问题在医学领域显得尤为严重,因为长相关性的测量数据会直接影响到诊断的效率。然而在采集数据时,尤其是对于某些呼吸类疾病(哮喘,呼吸暂停,呼吸急促等)及其他潜伏期较长的疾病,长相关性似乎无法避免———毕竟不停按别人肚子或者掐别人脖子得到的数据是不真实的……

然而一旦提到处理长相关性的数据,循环神经网络(RNN)自然不会沉默。尤其是带有LSTM单元的RNN,自从1997年被提出之后迅速取代传统RNN成为数据处理的一大杀器。在这篇文章里它也是整个算法的核心。下文将会主要介绍LSTM以及本文对于训练LSTM的新思路。

Algorithm: LSTM Recurrent Neural Network

相信任何一个对于神经网络有所研究的人都不可能不会接触这个架构。那么和上次一样,我们依然说话简单点,直接将LSTM计算单元的架构和算法双手奉上:

关于传统MLP,RNN和LSTM-RNN之间的区别Google大牛Alex Graves的博士论文Supervised Sequence Labelling with Recurrent Neural Networks已经解释的非常详细了,文中对于LSTM也有很详细的解读,不过我在这里依然用“不那么数学的方式”大致介绍一下。

LSTM-RNN的开发本身是为了解决1994年由Bengio提出的梯度消失(Vanishing Gradient)问题,通俗点说就是RNN容易“忘事”:在一定的时间过后,长时数据在梯度算法迭代的过程中对于某一模式的识别能力会逐渐下降。为了解决这个问题,1997年LSTM被提出取代传统的RNN cell。这个新架构几乎完美解决了VG问题,也使其在处理长时相关数据上非常强大。

从上图中可以看到,相比于传统的时序RNN(即RNN时序展开后的样子,其实就是一个针对时序数据的,所有模块之间共用一个激活函数的普通DNN),LSTM多出了三道“门”(gate):输入门(Input Gate),输出门(Output Gate)和忘记门(Forget Gate)。三个门的激活函数是一样的(本文使用sigmoid),也有各自的初始权重。这三个权重矩阵和一般的RNN反馈权重矩阵一样也会参与反向传播(Back-Propagation)更新。三个门分别控制着数据的输入比,输出比和传输比。值得一提的是,由于LSTM是在保留原有RNN cell的基础上建立的,上一层RNN cell以及这一层cell本身在上一时间点的state对于下一层的三个门也有影响,这从公式中能够很明显的看出。这三道门就像三个开关一样,可以决定信息是否保留。也因此,“忘记”的pattern也好,label也好,probability也好,都会被保留下来而不会逐渐随着梯度算法的迭代而“消失”。

目前被广泛应用的LSTM是2000年的加了忘记门的版本而并非1997年的I/O门版本。

在这篇论文中,LSTM-RNN主要被用于数据分类和标记。

Method for Training

这篇文章在RNN和LSTM的训练方法上做出了创新,主要着手在loss function和target generation上。Loss function方面,针对普通RNN中不同时序对于target value的access time的差别,文中提出在每一个时间点对于target value做一个带有权重的复制(target replication):

之后将最终的loss和每一个时间点的loss使用类似于凸组合(convex combination)的计算手段相加,得出当前时间点的loss。(这一步很像Boosting….?)

再次强调:LSTM在大架构上依然是RNN,因此这种优化方式适用。

文章对于RNN的训练方式也做出了一定的创新型调整。由于他们得到的数据对应的标记较多而所要求的最终数据上的标记没那么多,因此他们有效利用了多余的标记(注意在这里所有的标记都是已知的)———他们将这些标记当做target放入了网络当中用于辅助模型的进一步训练(auxiliary targets)。他们根据Literature Review和试验之后发现,将标记放入数据中不仅减小了整体的loss,还减轻了数据的过拟合(overfitting)问题。具体的idea见下图。

具体的实验部分由于数据太过庞大,请查看原文的试验结果及附录部分。

Personal View

这篇文章个人认为不仅对于医学数据处理,而是对于很多领域的数据相关建模提供了一个不错的范例。原因在于他没有花大量时间在参数的调整(深度学习目前的瓶颈也就在这…)和数学推导上,而是通过灵活运用Training Data改变RNN的训练架构,从结构上做出了创新并达到了不错的效果。这可以成为日后的一个不错的研究方向。而半监督学习也可以从这篇文章中得到不错的启发,毕竟半监督学习的核心数据里大部分是没有标记的,而这里恰恰反了过来:有大部分的标记不需要对应数据……不过笔者毕竟不是大(Luo)长(Fang)者(Hao),姿势水平有限,还是需要多多研读SSL方面的资源。

All pictures and formulas are cited from “Learning to Diagnose with LSTM Recurrent Neural Networks”, Zachary C.Lipton et al.