当前位置: 首页 > 编程日记 > 正文

让学生网络相互学习,为什么深度相互学习优于传统蒸馏模型?| 论文精读

640?wx_fmt=png

作者 | Ying Zhang,Tao Xiang等
译者 | 李杰
出品 | AI科技大本营(ID:rgznai100)

蒸馏模型是一种将知识从教师网络(teacher)传递到学生网络(student)的有效且广泛使用的技术。通常来说,蒸馏模型是从功能强大的大型网络或集成网络转移到结构简单,运行快速的小型网络。本文决定打破这种预先定义好的“强弱关系”,提出了一种深度相互学习策略(deep mutual learning, DML)。

在此策略中,一组学生网络在整个训练过程中相互学习、相互指导,而不是静态的预先定义好教师和学生之间的单向转换通路。作者通过在CIFAR-100和Market-1501数据集上的实验,表明DML网络在分类和任务重识别任务中的有效性。更重要的是,DML的成功揭示了没有强大的教师网络是可行的,相互学习的对象是由一个个简单的学生网络组成的集合。

简介

深度神经网络已经广泛应用到计算机视觉的各个任务中,并获得了很好的性能表现,但是,这种SOTA通常是依靠深度堆叠网络层数,增加网络宽度实现,这种结构设计会产生大量的参数,一方面,会拖慢运行速度和执行效率,另一方面,需要很大的存储空间进行存储。这两方面也限制了很多网络在实际应用中落地。

因此,如何在保证效果的情况下设计更小,更快速的网络,就成了我们关注的重点。基于这种思想,涌现了很多好的工作,蒸馏模型(model distillation)就是其中的代表,为了更好地学习小型网络,蒸馏方法从一个强大的(更深或更宽)教师网络开始,然后训练一个更小的学生网络来模仿教师网络。下图是蒸馏模型的一个结构表示:

640?wx_fmt=png
蒸馏模型

符号表示:

  • Big Model:复杂强大的教师网络

  • Small Model:轻巧简单的学生网络

  • soft targets:输入x经过教师网络后得到的softmax层输出

  • hard targets:输入数据对应的label标签

  • softmax公式表示:

640?wx_fmt=png


其中,qi是第i类的概率, Zi和Zj分别表示softmax层输出,T是温度系数,控制着输出概率的软化(soft)程度,T越大,不同类别输出概率在不改变相对大小关系的情况下,差值会越小,也就是更加soft。

定义好基本概念后,实现步骤可以表示为:

1.设置一个较大的 T,输入x训练一个教师网络,经过softmax层后生成soft targets。
2.使用步骤1得到的soft targets来训练学生网络。
3.最终模型的目标函数由soft targets和学生网络的输出数据的交叉熵,hard targets 和学生网络的输出数据的交叉熵两部分共同组成。

这些训练步骤,能够保证学生网络和教师网络的结果尽可能一致,也就代表学生网络学到了教师网络的知识;能够保证学生网络的结果和实际类别标签尽可能一致,也就代表学生网络的能力很强。

在本文中,作者剥离了教师,学生网络的概念,提出了一个与蒸馏模型不同但又相关的概念——相互学习(mutual learning)。通过上文的介绍我们知道,蒸馏模型从一个强大的、预先培训过的教师网络开始,然后将知识传递给一个小的、未经训练的学生网络,这种传递方式是一条单向的通路。与之相反,在相互学习中,从一组未经训练的学生网络开始,它们同时学习,共同解决任务。

在训练过程中,每个学生网络的损失函数由两部分组成:(1)传统的监督学习损失(2)模仿损失,使每个学生的预测类别与其他学生的类别预测概率保持一致。实验证明,在这种基于同伴教学(peer-teaching)的训练中,每个学生网络的学习效果都比在传统的监督学习场景中单独学习要好得多。此外,虽然传统的蒸馏模型需要一个比预期学生网络更大、更强大的教师网络,但事实证明,在许多情况下,几个大型网络的相互学习也比独立学习提高了性能。

你可能会有这样的疑问:为什么这种相互学习的策略会比蒸馏模型更有用?如果整个训练过程都是从一个小的且未经预训练的网络开始,那么网络中额外的知识从哪里产生?为什么它会收敛的好,而不是被群体思维所束缚,造成“瞎子带领瞎子"(theblind lead the blind)的局面?

针对这些疑问,作者给出了相应的解释:每个学生网络主要受传统的监督学习损失的指导,这意味着他们的表现通常会提高,而且也限制了他们作为一个群体任意地进行群体思维的能力。有了监督学习,所有的网络很快就可以为每个训练实例预测相同的标签,这些标签大多是正确且相同的。

但是由于每个网络从不同的初始条件开始,它们对下一个最有可能的类的概率的估计是不同的,而正是这些secondary信息,为蒸馏和相互学习提供了额外的知识。在相互学习网络中,每个学生网络有效地汇集了他们对下一个最有可能的类别的集体估计,根据每个训练实例找出并匹配其他最有可能的类会增加每个学生网络的后验熵,这有助于得到一个更健壮和泛化能力更强的网络。

综上所述,相互学习通过利用一组小的未经训练的网络协作进行训练,可以简洁而有效的提高网络的泛化能力。实验结果表明,与经过预训练的静态大型网络相比,同伴相互学习可以获得更好的性能。此外,作者认为相互学习还有以下几点优势:

1.网络效果随队列中网络的数量增加而增加;
2.相互学习适用于各种网络架构,以及由不同大小的混合网络组成的异构群组;
3.与独立训练相比,即使是在队列中相互训练的大型网络也能提高性能;
4.虽然作者的重点是获得一个单一有效的网络,但整个队列也可以整合为一个高效的集成模型。
深度相互学习
*为表述方便,本文以两个网络为例进行说明。

  • DML通用表示


如下图所示,本文提出的DML网络,在队列中有两个网络θ1,θ2。给定来自 M个类别的 N个样本,表示为:

640?wx_fmt=png

其对应的标签集合为:

640?wx_fmt=png

那么θ1网络中某个样本 xi属于类别 m的概率可以表示为:

640?wx_fmt=png

其中,640?wx_fmt=png是θ1网络中经过softmax层后输出的预测概率。

640?wx_fmt=png

对于多目标分类任务而言,θ1网络的目标函数可以用交叉熵表示:

640?wx_fmt=png

其中, 相当于一个指示函数,如下式所示,如果标签值和预测值相同,置为1,否则置为0:

640?wx_fmt=png

传统的监督损失训练能够帮助网络预测实例的正确标签,为了进一步提升网络θ1的泛化能力,DML引入了同伴网络θ2,θ2同样会产生一个预测概率p2,在这里引入KL散度的概念,相信了解过GAN网络的小伙伴对KL应该不会陌生,KL 散度是一种衡量两个概率分布的匹配程度的指标,两个分布差异越大,KL散度越大。作者采用KL散度,衡量这两个网络的预测p1和p2是否匹配。
p1和p2的KL散度距离计算公式为:

640?wx_fmt=png

综上,对于θ1网络来说,此时总的损失函数就由两部分构成:自身监督损失函数,来自θ2网络的匹配损失函数:

640?wx_fmt=png

同理,θ2可以表示为:

640?wx_fmt=png

  • 算法优化


DML在每次训练迭代中,都计算两个模型的预测,并根据另一个模型的预测更新两个网络的参数。θ1和θ2网络一直在迭代直至收敛,整个训练优化细节如下图所示:
640?wx_fmt=png
输入:
训练集 X,标签集 Y,学习率 γ1,γ2

初始化:
θ1,θ2不同初始化条件

步骤:
从训练集 X中随机抽样 x
1.根据上文中的概率计算公式p,分别计算两个网络的在当前batch的预测p1和p2,得到θ1的总损失函数 Lθ1
2.利用随机梯度下降,更新θ1参数:

640?wx_fmt=png

3. 根据上文中的概率计算公式p,分别计算两个网络的在当前batch的预测p1和p2,得到θ2的总损失函数Lθ2
4. 利用随机梯度下降,更新θ2参数:

640?wx_fmt=png

重复以上步骤直至网络收敛

  • 学生网络的扩展


前几节我们用两个网络θ1和θ2说明了DML的结构,算法。其实DML不仅在两个网络中有效,还可以扩展到多个网络中去。假定我们要训练一个有K(K>2)个学生网络的互相学习网络,那么对于其中的某个网络θk而言,总的损失函数可以表示为:

640?wx_fmt=png

该公式说明,每个学生网络都能够从另外的K-1个网络中学到知识,换而言之,对于一个学生网络,另外的K-1个网络都能作为该网络的教师网络。K=2就是该扩展网络的特例。注意,在上式中,对于其他网络的KL散度和,前面添加了权重系数1/(K-1),这是为了确保整个训练过程主要以监督学习的真正标签为指导。

对于两个以上的网络,除了DML训练策略外,在K个网络的训练中,对于一个学生网络,我们还可以将所有其他的k-1个网络集成作为一个单独的教师网络来提供综合平均的学习知识。这种思想与蒸馏模型类似,但是在参数更新上,在每个mini-batch上进行更新。基于这种思想,一个学生网络θk的损失函数可以表示为:

640?wx_fmt=png

实验
主要是两方面的实验,利用CIFAR-100和Market-1501两个数据集分别进行目标分类和人物重识别任务测试。

  • Results on CIFAR-100


在CIFAR-100上进行top-1指标测试。首先对只有两个网络的DML进行测试,。采用不同的网络结构,结果如下表所示,可以看到,相比独立的分类网络,基于任何组合方式的,添加DML策略的网络,表现都有所提升;体量较小的网络(如ResNet-32),从DML中提升更多;在大网络(如WRN-28-10)中添加DML策略,也会使得性能得到提升,与传统的蒸馏模型相比,可以看到一个大型的预培训的教师网络并非必要条件。

640?wx_fmt=png

  • Results on Market-1501


在Market-1501上进行mAP和rank-1指标测试。每个MobileNet在一个双网络队列中训练,并计算队列中两个网络的平均性能。如下表所示,与单独学习相比,DML显著的提升了MobileNet的性能,我们还可以看到,使用两个MobileNet训练的DML方法的性能显著优于先前最主流的方法。

640?wx_fmt=png

  • Comparison with Distillation


本文提出的DML模型与蒸馏模型密切相关,因此作者对比了这两个模型的效果。如下表所示,设置了三组网络,分别是:独立网络net1,net2;蒸馏模型net1为教师网络,net2为学生网络;DML模型,net1和net2相互学习。从实验结果分析,意料之中,传统的蒸馏方法从一个强大的预训练的教师网络指导学生网络的确实提升了性能。但结果同样表明,预训练的强大教师网络不是必要条件,与蒸馏模型相比,在DML中一起训练的两个网络也获得了明显的提升。

640?wx_fmt=png

  • DML的有效性


上述实验部分证明了DML的有效性,我们再从理论上讨论一下DML为什么能够提升以及通过哪些方法进行提升。

1)更鲁棒的最小值

与传统的优化方法相比,DML不是帮助我们找到一个更好的或者更深层次的训练损失最小值,而是帮助我们找到一个更广泛或者更可靠的最小值,它能更好地概括测试数据,更加健壮。作者利用Market-1501数据集和MobileNet主干网络做了一个小实验来证明DML能够找到更鲁棒的最小值。

作者比较了DML模型和独立模型在添加高斯噪声前后训练的损失变化。从图(a)可以看出两个模型的极小值的深度是相同的,但是在加入高斯噪声后,独立模型的训练损失增加较多,而DML模型的训练损失较少。这表明DML模型找到了一个更广泛,健壮的最小值,进而提供更好的泛化性能。

640?wx_fmt=png

2)怎样找到更好的最小值?

那么DML是怎样找到这个广泛健壮的最小值的呢?DML会要求每个网络匹配其同伴网络的概率估计,如果给定网络预测为零,而其对等网络预测为非零,则该网络将受到严重惩罚。总体上,DML是指,当每个网络独立地将一个关注点放在一个小的次概率集合上时,DML中的所有网络都倾向于聚合它们对次级概率的预测。也就是说所有的网络把重心放在次概率上,并且把更多重心放在更明显的次概率上。因此,DML是通过对“合理的”次概率预测的相互概率匹配来寻找更宽泛的最小值。

结论
本文提出了一种简单且普适的方法DML来提高深度神经网络的性能,方法是将几个网络一起训练,相互蒸馏。用这种方法,可以获得紧凑的网络。实验证明,DML相比传统的蒸馏模型更好更健壮。此外,DML也能提高大型网络的性能,并且以这种方式训练的网络队列可以作为一个集成来进一步提高性能。

论文链接:
https://arxiv.org/abs/1706.00384
代码链接
YingZhangDUT/Deep-Mutual-Learninggithub.com

(*本文为 AI科技大本营编译文章,请微信联系 1092722531


精彩推荐


2019 中国大数据技术大会(BDTC)再度来袭!豪华主席阵容及百位技术专家齐聚,15 场精选专题技术和行业论坛,超强干货+技术剖析+行业实践立体解读,深入解析热门技术在行业中的实践落地。

即日起,限量 5 折票开售,数量有限,扫码购买,先到先得!

640?wx_fmt=png

推荐阅读

640?wx_fmt=png

你点的每个“在看”,我都认真当成了AI

相关文章:

mac apache 配置

mac系统自带apache这无疑给广大的开发朋友提供了便利,接下来是针对其中的一些说明 一、自带apache相关命令 1. sudo apachectl start 启动服务,需要权限,就是你计算机的password 2. sudo apachectl stop 终止服务 ####3. sudo apachectl rest…

jQuery学习---------认识事件处理

3种事件模型:原始事件模型DOM事件模型IE事件模型原始事件模型(0级事件模型)1、事件处理程序被定义为函数实例,然后绑定到DOM元素事件对象上,实现事件的注册。例子:var btn document.getElementsByTagName(…

C++中的虚函数表介绍

在C语言中,当我们使用基类的引用或指针调用一个虚成员函数时会执行动态绑定。因为我们直到运行时才能知道到底调用了哪个版本的虚函数,所以所有虚函数都必须有定义。通常情况下,如果我们不使用某个函数,则无须为该函数提供定义。但是我们必须…

AI如何赋能金融行业?百度、图灵深视等同台分享技术实践

近日,由BTCMEX举办的金融技术创新研讨会在北京举办。BTCMEX投资人李笑来,AI技术公司TuringPass、百度、美国Apache基金会项目Pulsar、区块链安全公司SlowMist等相关专家参加了此次会议,共同探讨了金融技术在创新方面的现状。 图灵深视副总裁许…

【Win32 API学习]打开可执行文件

在MFC中打开其他可执行文件常用到的方法有:WinExec、ShellExecute、CreatProcess。 1.WinExec WinExec 主要运行EXE文件,用法简单,只有两个参数,前一个指定命令路径,后一个指定窗口显示方式: UINT WinExec(…

支付宝接口使用文档说明 支付宝异步通知

支付宝接口使用文档说明 支付宝异步通知(notify_url)与return_url. 现支付宝的通知有两类。 A服务器通知,对应的参数为notify_url,支付宝通知使用POST方式 B页面跳转通知,对应的参数为return_url,支付宝通知使用GET方式 &#xff…

完全隐藏Master Page Site Actions菜单只有管理员才可以看见

1. 在Master Page Head 增加下面的Style <style type"text/css"> .ms-cui-tt{visibility:hidden;} </style> 2. 增加SPSecurityTrimmedControl <SharePoint:SPRibbonPeripheralContent runat"server" Location"TabRowLeft&qu…

深度学习中的随机梯度下降(SGD)简介

随机梯度下降(Stochastic Gradient Descent, SGD)是梯度下降算法的一个扩展。机器学习中反复出现的一个问题是好的泛化需要大的训练集&#xff0c;但大的训练集的计算代价也更大。机器学习算法中的代价函数通常可以分解成每个样本的代价函数的总和。随着训练集规模增长为数十亿…

推荐系统中的前沿技术研究与落地:深度学习、AutoML与强化学习 | AI ProCon 2019...

整理 | 夕颜出品 | AI科技大本营&#xff08;ID:rgznai100&#xff09;个性化推荐算法滥觞于互联网的急速发展&#xff0c;随着国内外互联网公司&#xff0c;如 Netflix 在电影领域&#xff0c;亚马逊、淘宝、京东等在电商领域&#xff0c;今日头条在内容领域的采用和推动&…

运维日志管理系统

因公司数据安全和分析的需要&#xff0c;故调研了一下 GlusterFS lagstash elasticsearch kibana 3 redis 整合在一起的日志管理应用&#xff1a;安装&#xff0c;配置过程&#xff0c;使用情况等续一&#xff0c;glusterfs分布式文件系统部署&#xff1a;说明&#xf…

NLP学习思维导图,非常的全面和清晰

作者 | Tae Hwan Jung & Kyung Hee编译 | ronghuaiyang【导读】Github上有人整理了NLP的学习路线图&#xff08;思维导图&#xff09;&#xff0c;非常的全面和清晰&#xff0c;分享给大家。先奉上GitHub地址&#xff1a;https://github.com/graykode/nlp-roadmapnlp-roadm…

Go在windows10 64位上安装过程

1. 从 https://golang.org/dl/ 下载最新的发布版本go1.10即go1.10.windows-amd64.msi; 2. 双击go1.10.windows-amd64.msi ,使用默认选项&#xff0c;默认会安装到C:\Go目录下&#xff1b; 3. 将C:\Go\bin目录添加到系统环境变量中(默认已自动添加)&#xff0c;此目录下有go.exe…

Windows SharePoint Services 3.0 应用程序模板

微软发布的一些WSS模板&#xff0c;看了一下&#xff0c;跟以前看到的模板好像不同模板分两类&#xff0c;一类是站点管理模板&#xff0c;一类是服务器管理模板站点管理模板&#xff1a;董事会、业务绩效报告、政府机构案例管理、课堂管理、临床试验启动和管理、竞争性分析站点…

HAProxy+Keepalived高可用负载均衡配置

一、系统环境&#xff1a;系统版本&#xff1a;CentOS5.5 x86_64master_ip:172.20.27.40backup_ip:172.20.27.50 vip:172.20.27.200web_1: 172.20.27.90web_2:172.20.27.100二、haproxy安装&#xff1a;1.首先172.20.27.40安装上安装&#xff1a;1.1安装 tar zxvf haproxy-1.3.…

Go在Ubuntu 14.04 64位上的安装过程

1. 从 https://golang.org/dl/ 或 https://studygolang.com/dl 下载最新的发布版本go1.10即go1.10.linux-amd64.tar.gz&#xff1b; 2. 将下载的tar包解压缩到/usr/local目录下&#xff0c;执行以下命令&#xff0c;结果如下&#xff1a; $ sudo tar -C /usr/local -xzf go1.…

毕业就拿阿里offer,你和他比差在哪?

我在大学的时候&#xff0c;真的遇到一个神人&#xff0c;叫他小马吧。超前学习。1024&#xff0c;是程序员的节日&#xff0c;恰逢CSDN的20周年&#xff0c;我们准备为你做件大事&#xff01;我们与AI博士唐宇迪、畅销书作家、北大硕士阿甘等4位老师&#xff0c;共同为大家带来…

04号团队-团队任务5:项目总结会

1.团队信息 团队序号&#xff1a;04 开发项目&#xff1a;北软毕设管理系统 整理人&#xff1a;丛云聪 学号&#xff1a;2017035107185 在团队中的职务&#xff1a;项目经理兼产品经理 2.代码仓库地址 主仓库&#xff1a;https://gitee.com/The_Old_Cousin/StuInfoManage…

微软职位内部推荐-Sr SDE for Win Apps Ecosystem

微软近期Open的职位:Job posting title: Senior Software Design EngineerLocation: China, BeijingLevel: 63Division: Operations System Group EngineeringGroup OverviewOSG is delivering flagship products in Microsoft. China is a second largest economy in the worl…

C# Winform 启动和停止进程

启动和停止进程 一、启动进程 方法1&#xff1a; &#xff08;1&#xff09; 创建一个Process组件的实例&#xff0c;例如&#xff1a; Process myProcess new Process(); &#xff08;2&#xff09; 设置其对应的StartInfo属性&#xff0c;指定要运行的应用程序名…

在Windows/Ubuntu上使用Visual Studio Code作为Go语言编辑器操作步骤

下面以在Windows10上操作为例&#xff0c;在Ubuntu上操作步骤与windows一致&#xff1a; 1. 从 https://code.visualstudio.com/ 下载windows上的最新发布版本1.21.1&#xff0c;即VSCodeSetup-x64-1.21.1.exe&#xff1b; 2. 以管理员身份运行VSCodeSetup-x64-1.21.1.exe&…

实战:基于tensorflow 的中文语音识别模型 | CSDN博文精选

作者 | Pelhans来源 | CSDN博客目前网上关于tensorflow 的中文语音识别实现较少&#xff0c;而且结构功能较为简单。而百度在PaddlePaddle上的 Deepspeech2 实现功能却很强大&#xff0c;因此就做了一次大自然的搬运工把框架转为tensorflow….简介百度开源的基于PaddlePaddle的…

js获取Html元素的实际宽度高度

第一种情况就是宽高都写在样式表里&#xff0c;就比如#div1{width:120px;}。这中情况通过#div1.style.width拿不到宽度&#xff0c;而通过#div1.offsetWidth才可以获取到宽度。第二种情况就是宽和高是写在行内中&#xff0c;比如style"width:120px;"&#xff0c;这中…

新框架ES-MAML:基于进化策略、简易的元学习方法

作者 | Xingyou Song、Wenbo Gao、Yuxiang Yang、Krzysztof Choromanski、Aldo Pacchiano、Yunhao Tang译者 | TroyChang编辑 | Jane出品 | AI科技大本营&#xff08;ID&#xff1a;rgznai100&#xff09;【导读】现有的MAML算法都是基于策略梯度的&#xff0c;在试图利用随机策…

Tesseract-OCR 3.04简单使用举例(读入图像输出识别结果)

下面code是对Tesseract-OCR 3.04版本进行简单使用的举例&#xff1a;包括两段&#xff0c;一个是读入带有中文字符的图像&#xff0c;一个是读入仅有英文字符的图像&#xff1a; #include "funset.hpp"#include <iostream> #include <string> #include &…

坑爹的微软官方文档:SQL无人值守安装

我在部署项目的时候&#xff0c;需要用批处理无人值守安装SQLserver,.Net等组件。 于是查了微软官方文档&#xff0c;其中一项内容如下&#xff1a; http://msdn.microsoft.com/zh-cn/library/ms144259.aspx SQL Server 安装程序控件 /IACCEPTSQLSERVERLICEN…

各种 django 静态文件的配置总结【待续】

2019独角兽企业重金招聘Python工程师标准>>> 最近在学习django框架的使用&#xff0c;想引用静态css文件&#xff0c;怎么都引用不到&#xff0c;从网搜了好多&#xff0c;大多因为版本问题&#xff0c;和我现在的使用的dango1.1配置不同&#xff0c;根据资料和公司…

实战:人脸识别的Arcface实现 | CSDN博文精选

来源 | CSDN博客本文将简单讲述arcface从训练到部署的整个过程&#xff0c;主要包括前期的数据筛选和准备&#xff0c;模型训练以及模型部署。此文参考的arcface的代码地址&#xff1a;https://github.com/ronghuaiyang/arcface-pytorch数据集准备1. 首先准备需要训练的人脸数据…

Windows7/10上快速搭建Tesseract-OCR开发环境操作步骤

之前在https://blog.csdn.net/fengbingchun/article/details/51628957 中描述过如何在Windows上搭建Tesseract-OCR开发环境&#xff0c;那时除了需要clone https://github.com/fengbingchun/OCR_Test 工程外&#xff0c;还需要依赖 https://github.com/fengbingchun/Liblept_T…

C#基础系列:实现自己的ORM(反射以及Attribute在ORM中的应用)

反射以及Attribute在ORM中的应用 一、 反射什么是反射&#xff1f;简单点吧&#xff0c;反射就是在运行时动态获取对象信息的方法&#xff0c;比如运行时知道对象有哪些属性&#xff0c;方法&#xff0c;委托等等等等。反射有什么用呢&#xff1f;反射不但让你在运行是获取对象…

Network | sk_buff

sk_buff结构可能是linux网络代码中最重要的数据结构&#xff0c;它表示接收或发送数据包的包头信息。它在中定义&#xff0c;并包含很多成员变量供网络代码中的各子系统使用。 这个结构被不同的网络层&#xff08;MAC或者其他二层链路协议&#xff0c;三层的IP&#xff0c;四…