当前位置: 首页 > 编程日记 > 正文

ICCV 2019 | 无需数据集的Student Networks

640?wx_fmt=png

译者 | 李杰
出品 | AI科技大本营(ID:rgznai100)

640?wx_fmt=png

本文是华为诺亚方舟实验室联合北京大学和悉尼大学在ICCV2019的工作。

摘要

在计算机视觉任务中,为了将预训练的深度神经网络模型应用到各种移动设备上,学习一个轻便的网络越来越重要。当我们可以直接访问训练数据集时,现有的深度神经网络压缩和加速方法对于训练紧凑的深度模型是非常有效的,但是现实情况却是,有了隐私保护,法规政策等,数据集的回去越来越困难,为此,本文提出了一种利用生成对抗网络(GANs)训练高效深度神经网络的新框架DAFL(Data-Free Learning)。

该框架无需训练数据集,具体来说,将蒸馏模型中预先训练好的教师网络充当GAN中的判别器的角色,生成器的任务是生成让判别器响应最大的样本,然后利用生成的数据和教师网络,训练出模型尺寸较小、计算复杂度较低的高效网络。实验表明,在CIFAR-10和CIFAR-100中,利用零数据训练的DAFL框架分别达到了92.22%和74.47%的正确率,同时在CelebA中,也获得了80.56%的正确率,证明了DAFL的有效性。

引言

  • CNNs的冗余性

CNNs在计算机视觉中有广泛的应用,但是好的性能表现通常依赖于很深,很宽的网络,这种设计模式对训练提出了高要求,有大量的参数数据需要进行处理,这种情况下,想将网络应用到诸如自动驾驶,边缘计算等内容上,几乎是不可能的。虽然这些预训练好的神经网络有许多参数,但是研究表明,在给定的神经网络中,丢弃85%以上的权值并不会明显损害神经网络的性能,这说明这些神经网络存在显著的冗余。

  • 数据的难获得性

针对CNNs的冗余,有很多压缩和加速算法,但是这些算法都是基于丰富的训练数据的支撑,容易忽视的一个点是,由于隐私和传输限制,训练数据集在现实应用中通常是未知的,例如,用户不想让自己的照片泄露给别人,而且一些训练数据集太大,无法快速上传到云端。

  • Data-Free Learning

本文聚焦无训练数据的情况,提出了一种不需要原始训练数据集的深度神经网络压缩新框架DAFL。具体来说,将给定的大型的教师网络作为生成对抗网络中的判别器部分,在对抗生成过程中,通过从网络中提取信息,建立一个生成网络来代替原来的训练集,从而为学习性能可接受的小网络提供参考。

无数据驱动的学生网络学习

现有的轻量级网络学习方法可分为两类:数据驱动方法(Data-Driven Network)和无数据(Data-Free Network)方法。数据驱动的代表性算法有两个解决思路,一是消除减少原模型中的冗余共结构和参数;二是蒸馏模型,利用教师网络训练学生网络,但是这两种方法都依赖大量训练数据的支撑,一旦数据没有,很难达到预期效果。这时候,就显示出无数据方法的重要性了。本文提出一种新的无数据框架,通过在师生学习范式中嵌入一个生成网络来压缩深度神经网络。

  • 教师网络和学生网络的知识蒸馏

知识蒸馏模型将强大网络中的知识信息传输到小网络中,期望得到一个体积小,正确率高的网络。定义蒸馏模型中的教师网络NTN_{T} 和学生网络 Ns{,对学生网络的优化可以通过下列损失函数表示:

640?wx_fmt=png

其中,Hcrosscross} 表示交叉熵函数,ys{S和 yT分别表示学生网络和教师网络的输出,可以通过下式得到:

640?wx_fmt=png

640?wx_fmt=png

  • 利用GANs生成训练样本

为了在没有原始数据集的情况下学习一个轻量级网络,本文利用GAN生成训练样本。生成对抗网络(GANs)被广泛应用在样本生成领域。GANs包括两部分,一个生成器网络G和一个判别器网络D。G的作用是生成数据,D的作用是判别数据是真实图片数据还是生成图片数据。详细说,GANs的流程可以这样表示:输入一个噪声,z,送入生成器网络G,得到生成的数据x,对于任意一个GANs来说,损失函数可以用下式表示:

640?wx_fmt=png

在对抗过程中,生成器G根据生成器产生的训练误差持续进化,对G的优化可以看做是下列问题的优化:

640?wx_fmt=png

其中,D*是固定最优的判别器。

通过描述GANs的训练流程,我们可以发现生成器很适合用来生成训练数据,但是根据上列GANs的优化函数,判别器的训练需要真实图片,这显然与零样本不相符。

近来的工作已经证明,判别器D可以样本中学习表示的层次,这说明判别器机制有在其它计算机视觉任务中的泛化能力。本文提出将一个深度神经网络直接作为判别器D,这样,生成器G就可以直接进行优化而不用和D同时进行训练,即在训练G时,原网络D的参数是固定的。此外,原始的GANs中,判别器的目的是判定图片是真实还是生成的,但是神经网络是进行分类的,如果用一个训练好的神经网络作为判别器,输出是图像类别,而不是图像的真实性。这样一来,原始GANs中的损失函数也不再适用,需要设计几个新的损失函数。

在图像分类任务中,深度神经网络在训练阶段利用交叉熵作为损失函数,监督分类网络生成与真实标签相同的结果,在多分类任务中,网络输出要接近一个one-hot向量,只有一个类别的概率为1,其余均为0。

640?wx_fmt=png

one-hot示例

具体过程为:

给定生成器G,教师网络 NT,随机噪声{ z1,z2,……,zn},生成器生成的图像{ x1,x2,……,xn},将这些生成图像送到教师网络中,得到输出结果{ y1,y2,……,yn},然后通过下式计算预测类别标签:

640?wx_fmt=png

我们定义三个子损失来得到最终的损失函数。

  • 交叉熵损失Lcross

如果G生成的图像与教师网络的训练数据分布相同,那么它们的输出也应该与训练数据具有相似的输出。基于此,引入了one-hot loss,激励教师网络对生成器生成图像的输出接近one-hot向量,这样倒逼生成器生成与原始训练集相似的图片。

640?wx_fmt=png

其中, Hcross是交叉熵函数, yi是教师网络预测输出,ti是真实标签。通过one-hot loss,我们期望生成的图像能够以较高的概率被教师网络划分为一个特定的类别。

  • 特征激活损失La

除了教师网络的预测类标签外,卷积层提取的中间特征也是输入图像的重要特征,不同卷积层对应了不同的语义。由于教师网络中的卷积过滤器已经被训练来提取训练数据中的固有模式,因此如果输入的图像是真实的,而不是一些随机的向量,特征图往往会收到更高的激活值。因此,定义教师网络最后一层的提取特征 xi为 fi,定义一个激活损失函数。

640?wx_fmt=png

其中,||· ||1是l1范数。加负号的原因是想让尽可能多的f被激活。

  • 信息熵损失Linfo 

为了简化深度神经网络的训练过程,每一类训练实例的数量通常是平衡的,以MNIST为例,有60000图片,被分为10类,每一类6000张。1948年,香农提出了“信息熵”的概念,解决了对信息的量化度量问题。通常,一个信源发送出什么符号是不确定的,衡量它可以根据其出现的概率来度量。概率大,出现机会多,不确定性小;反之不确定性就大。本文提出利用信息熵损失来衡量生成图像的类平衡性。

具体来说,给定一个概率向量{p1,p2,......,pk},信息熵会用来衡量信息的混乱程度。p的信息熵计算公式为:

640?wx_fmt=png

H(p)的值表示p携带的信息量,当所有的变量取值为1/k时,取得最大值。这里涉及到信息熵最大值问题,感兴趣的同学可以查阅详细信息,最终证明,当所有概率取值相同且为1/k时,信息熵最大。

我们把信息熵的概念迁移到图像生成中来。给定一组教师网络的输出向量,{ y1,y2,……,yn},对于每一类图像的生成概率我们可以表示为:

640?wx_fmt=png

基于此,我们定义信息熵损失函数如下:

640?wx_fmt=png

当损失取得最小值的时候,向量中的每一项都相等且等于 1/k,这说明,生成器G可以以大致相同的概率生成每个类别的图像。因此,最小化生成图像的信息熵可以得到一组类别均衡的图像。

  • 总损失函数

基于上述三个子损失函数,我们可以得到最终的损失函数:

640?wx_fmt=png

其中,α和 β是平衡三个任务的超参数,通过最小化上述函数,最优生成器G可以生成与之前用于训练教师网络的训练数据分布相似的图像。本文提出的方法可以直接模拟训练数据的分布,更灵活、高效地生成新图像。

  • 优化算法

如下图所示,DAFL算法可以分为两个步骤。

640?wx_fmt=png

首先,将一个预训练好的固定教师网络充当生成对抗网络中的判别器,利用上文中的总损失函数 Ltotal,优化了一个生成器G来生成与原始训练图像分布相似的教师网络图像;其次,利用知识蒸馏的方法直接将知识从教师网络转移到学生网络。对学生网络的优化方法,是利用LKD损失函数:

640?wx_fmt=png

整体流程如下图所示,通过从给定的教师网络中提取有用信息,训练生成器逼近原始训练集中的图像。然后,利用生成的图像和教师网络有效地学习可移植的学生网络。

640?wx_fmt=png

论文实验

  • 在MNIST数据集上的实验

首先在MNIST数据集上进行实验,MNIST数据集由10个类别(从0到9)的28×28像素图像组成,整个数据集包括6万张训练图像和1万张测试图像。为了选择超参数,从训练图像中选取10,000张图像作为验证集。然后,在全部60,000张图像上训练模型,以获得最终的网络。为了便于比较,设置了两种结构,一种是基于卷积(用LeNet-5作为教师网络,用比教师网络参通道数少一半的LeNet-5-HALF作为学生网络),另一种是基于全连接层(教师网络是有两个隐层,每个隐层1200个节点的HiltonNet,学生网络是有两个隐层,每个隐层800节点的HiltonNet)。我们发现,学生网络的参数量明显少于教师网络。在实验中,设置损失函数中的超参数 α=0.1,β=5,训练200轮。

实验结果如下图所示,在没有原始数据的情况下,DAFL的学生网络分别取得了98.2%和97.91%的正确率,这些数据与利用原始数据训练的教师网络性能相近,但是极大简化了结构和参数量,证明了DAFL的有效性。

640?wx_fmt=png

  • 在CIFAR数据集上的实验

在CIFAR上同样有很好的性能表现,不再过多赘述。

640?wx_fmt=png

  • 结果可视化

(1)生成图可视化对比

在研究了DAFL方法的有效性之后,进一步对MNIST数据集进行了可视化实验。如下图所示,(a)图是MNIST数据集上的图像,(b)是生成器生成图像。通过损失函数深度挖掘教师网络的信息,生成了与训练图像较为相似的图像,这表明生成器能够以某种方式学习数据分布。

640?wx_fmt=png

640?wx_fmt=png

(2)卷积核可视化

如下图所示,可视化了LeNet-5教师网络和学生网络的过滤器。虽然学生网络是在没有真实数据的情况下训练的,但是通过所提出的方法学习的学生网络的卷积核仍然与教师网络的卷积核相似。可视化实验进一步证明,该生成器可以生成与原始图像模式相似的图像,并且利用生成的样本,学生网络可以从教师网络获取有价值的知识。

结论

传统的方法需要原始的训练数据集来微调压缩后的深度神经网络,使其具有可接受的精度。然而,由于一些隐私和传输限制,给定深度网络的训练集和详细的体系结构信息通常是不可用的。在本文中,我们提出了一个新的框架来训练一个生成器来逼近原始数据集而不需要训练数据。然后通过知识蒸馏方案有效地学习可移植网络。在分类数据集上的实验表明,该方法能够在不需要任何训练数据的情况下学习可移植的深度神经网络并取得很好的性能。

Paper link:

https://arxiv.org/abs/1904.01186

Code link:

https://github.com/huawei-noah/DAFL

(*本文为AI科技大本营投稿文章,转载请微信联系 1092722531


精彩推荐



12月6-8日,深圳!2019嵌入式智能国际大会,集聚500+位主流AIoT中坚力量,100+位海内外特邀技术领袖!9场技术论坛布道,更有最新芯片和模组等新品展示!点击链接或扫码,输入本群专属购票优惠码CSDNQRSH,即可享受6.6折早鸟优惠,比原价节省1000元,学生票仅售399元

640?wx_fmt=jpeg

推荐阅读

相关文章:

oc中特殊字符的判断方法

-(BOOL)isSpacesExists { // NSString *_string [NSString stringWithFormat:"123 456"]; NSRange _range [self rangeOfString:" "]; if (_range.location ! NSNotFound) { //有空格 return YES; }else { //没有空格 return NO; } } -(BOOL)i…

理解 Delphi 的类(十) - 深入方法[23] - 重载

为什么80%的码农都做不了架构师?>>> {下面的函数重名, 但参数不一样, 此类情况必须加 overload 指示字;调用时, 会根据参数的类型和个数来决定调用哪一个;这就是重载. }function MyFun(s: string): string; overload; beginResult : 参数是一个字符串: …

玩转ios友盟远程推送,16年5月图文防坑版

最近有个程序员妹子在做远程推送的时候遇到了困难,求助本帅。尽管本帅也是多彩的绘图工具,从没做过远程推送,但是本着互相帮助,共同进步的原则,本帅还是掩饰了自己的彩笔身份,耗时三天(休息时间…

提高C++性能的编程技术笔记:临时对象+测试代码

类型不匹配:一般情况是指当需要X类型的对象时提供的却是其它类型的对象。编译器需要以某种方式将提供的类型转换成要求的X类型。这一过程可能会产生临时对象。 按值传递:创建和销毁临时对象的代价是比较高的。倘若可以,我们应该按指针或者引…

北美欧洲顶级大咖齐聚,在这里读懂 AIoT 未来!

2019 嵌入式智能国际大会即将来袭!购票官网:https://dwz.cn/z1jHouwE随着海量移动设备的时代到来,以传统数据中心运行的人工智能计算正在受到前所未有的挑战。在这一背景下,聚焦于在远离数据中心的互联网边缘进行人工智能运算的「…

c# 关闭软件 进程 杀死进程

c# 关闭软件 进程 杀死进程 foreach (System.Diagnostics.Process p in System.Diagnostics.Process.GetProcessesByName("Server")){p.Kill();} 转载于:https://www.cnblogs.com/lxctboy/p/3999053.html

提高C++性能的编程技术笔记:单线程内存池+测试代码

频繁地分配和回收内存会严重地降低程序的性能。性能降低的原因在于默认的内存管理是通用的。应用程序可能会以某种特定的方式使用内存,并且为不需要的功能付出性能上的代价。通过开发专用的内存管理器可以解决这个问题。对专用内存管理器的设计可以从多个角度考虑。…

【Swift】 GETPOST请求 网络缓存的简单处理

GET & POST 的对比 源码: https://github.com/SpongeBob-GitHub/Get-Post.git 1. URL - GET 所有的参数都包含在 URL 中 1. 如果需要添加参数,脚本后面使用 ? 2. 参数格式:值对 参数名值 3. 如果有多个参数,使用 & 连接 …

深度CTR预估模型的演化之路2019最新进展

作者 | 锅逗逗来源 | 深度传送门(ID: deep_deliver)导读:本文主要介绍深度CTR经典预估模型的演化之路以及在2019工业界的最新进展。介绍在计算广告和推荐系统中,点击率(Click Through Rate,以下简称CTR&…

2015大型互联网公司校招都开始了,薪资你准备好了嘛?

2015年的校招早就开始了,你还不知道吧?2015年最难就业季来了,你还没准备好嘛?现在就开始吧,已经很多大型互联网公司祭出毕业生底薪了看谷歌、看百度、看腾讯、看阿里巴巴再看传统软件公司:看微软、看联想、…

提高C++性能的编程技术笔记:多线程内存池+测试代码

为了使多个线程并发地分配和释放内存,必须在分配器方法中添加互斥锁。 全局内存管理器(通过new()和delete()实现)是通用的,因此它的开销也非常大。 因为单线程内存管理器要比多线程内存管理器快的多,所以如果要分配的大多数内存块限于单线程…

iOS中几种定时器

一、NSTimer 1. 创建方法 NSTimer *timer [NSTimer scheduledTimerWithTimeInterval:1.0 target:self selector:selector(action:) userInfo:nil repeats:NO];TimerInterval : 执行之前等待的时间。比如设置成1.0,就代表1秒后执行方法target : 需要执行方法的对象…

手把手教你使用Flask轻松部署机器学习模型(附代码链接) | CSDN博文精选

作者 | Abhinav Sagar翻译 | 申利彬校对 | 吴金笛来源 | 数据派THU(ID:DatapiTHU)本文旨在让您把训练好的机器学习模型通过Flask API 投入到生产环境 。当数据科学或者机器学习工程师使用Scikit-learn、Tensorflow、Keras 、PyTorch等框架部署…

JQuery遮罩层

2019独角兽企业重金招聘Python工程师标准>>> css样式&#xff1a;<style type"text/css"> .mask { position: absolute; top: 0px; filter: alpha(opacity60); background-color: #777; z-index: 1002; left: 0px; …

代码覆盖测试工具Kcov简介及使用

Kcov是一个代码覆盖测试工具&#xff0c;最初基于Bcov&#xff0c;它可在FreeBSD、Linux、OSX系统中使用&#xff0c;支持的语言包括编译语言(compiled languages)、Python和Bash。与Bcov一样&#xff0c;Kcov对编译的程序使用DWARF调试信息&#xff0c;以便无需特殊编译器开关…

Google148亿元收购Fitbit,抢占苹果、三星可穿戴设备市场地盘

编译 | 夕颜出品 | AI科技大本营&#xff08;ID:rgznai100&#xff09;11 月 1 日&#xff0c;Google 母公司 Alphabet 和 可穿戴设备公司 Fitbit 同时发布新闻&#xff0c;宣布已经达成了收购后者的最终协议。Google LLC 以每股 7.35 美元的价格收购 Fitbit&#xff0c;总价值…

ios关于用xib创建的cell 自动返回cell的高度问题!

1 设置tableView的属性 self.tableView.rowHeight UITableViewAutomaticDimension; self.tableView.estimatedRowHeight 44.0; // 设置为一个接近“平均”行高的值 2 cell要约束好&#xff0c;要能够让cell知道自己的高度根据哪个控件计算就可以&#xff08;不明白看下图&…

西门子PLC学习笔记二-(工作记录)

今天师傅给讲了讲做自己主动化控制的总体的思路&#xff0c;特进行一下记录&#xff0c;做个备忘。 1.需求分析 本次的项目是对楼宇循环供水的控制&#xff0c;整个项目须要完毕压力、压差、温度等的获取及显示、同一时候完毕电机的控制。 2.设计 使用西门子的Step7工具进行梯形…

Swift 3.0 预告:将 Objc 库转换成更符合 Swift 语法风格的形式

转自&#xff1a;swiftcafe Swift 3.0 更新越来越临近&#xff0c;这次更新会给我们带来很多实用的内容&#xff0c;比如对 Objc 库的迁移&#xff0c;会更符合 Swift 的语法风格。用过之前版本的 Swift&#xff0c;我们会发现很多 Objc 库的方法名称其实还是以 Objc 的风格来命…

非对称加密算法RSA公钥私钥的模数和指数提取方法

生成非对称加密算法RSA公钥、私钥的方法&#xff1a; 1. 通过OpenSSL库生成&#xff0c;可参考 https://github.com/fengbingchun/OpenSSL_Test/blob/master/demo/OpenSSL_Test/funset.cpp 中的Generate_RSA_Key函数&#xff1b; 2. 在Linux下通过命令生成&#xff0c;执行…

数据库“新解”,看这里,get!

自从第一台通用计算机诞生至今&#xff0c;围绕计算机系统硬件的创新迭代就一直“在路上”&#xff0c;伴随着硬件能力的不断提升&#xff0c;软件更新自然不可缺少。通常来说在传统的计算机软件工程领域&#xff0c;操作系统、编译器与数据库被并称为最具难度的“三剑客”系统…

win 64位系统安装带有c编写的python模块出现ValueError: [u'path']解决

2019独角兽企业重金招聘Python工程师标准>>> 关于win 64位机器安装Scrapy的问题&#xff1a;http://steamforge.net/wiki/index.php/How_to_Install_Scrapy_in_64-bit_Windows_7 在安装Scrapy是要安装一系列的依赖模块&#xff0c; 出现问题&#xff1a; 1、error: …

探索 Swift 中的 MVC-N 模式

作者&#xff1a;Marcus Zarra&#xff08;twitter&#xff1a;mzarra&#xff09; Marcus 将会为大家介绍一种设计模式&#xff0c;他曾经在那些需要从互联网进行大量频繁数据请求的 iOS 应用当中使用此设计模式。这个设计采用了著名的 MVC (Model View Controller) 模式&…

MXNet中依赖库介绍及简单使用

MXNet是一种开源的深度学习框架&#xff0c;核心代码是由C实现&#xff0c;在编译源码的过程中&#xff0c;它需要依赖其它几种开源库&#xff0c;这里对MXNet依赖的开源库进行简单的说明&#xff1a; 1. OpenBLAS&#xff1a;全称为Open Basic Linear Algebra Subprograms&am…

Python十大装腔语法

作者 | 许向武 责编 | 郭芮 来源 | CSDN 博客Python 是一种代表简单思想的语言&#xff0c;其语法相对简单&#xff0c;很容易上手。不过&#xff0c;如果就此小视 Python 语法的精妙和深邃&#xff0c;那就大错特错了。本文精心筛选了最能展现 Python 语法之精妙的十个知识点&…

MATLAB——scatter的简单应用

scatter可用于描绘散点图。 1.scatter(X,Y) X和Y是数据向量&#xff0c;以X中数据为横坐标&#xff0c;以Y中数据位纵坐标描绘散点图&#xff0c;点的形状默认使用圈。 样例&#xff1a; X [1:10]; Y X rand(size(X)); scatter(X, Y) 得到&#xff1a; 2.scatter(...,fill…

Windows10上使用VS2017编译MXNet源码操作步骤(C++)

MXNet是一种开源的深度学习框架&#xff0c;核心代码是由C实现。MXNet官网推荐使用VS2015或VS2017编译&#xff0c;因为源码中使用了一些C14的特性&#xff0c;VS2013是不支持的。这里通过VS2017编译&#xff0c;步骤如下&#xff1a; 1. 编译OpenCV&#xff0c;版本为3.4.2&a…

StoryBoard 视图切换和传值

一 于StoryBoard相关的类、方法和属性 1 UIStoryboard // 根据StoryBoard名字获取StoryBoard (UIStoryboard *)storyboardWithName:(NSString *)name bundle:(nullable NSBundle *)storyboardBundleOrNil;// 获取指定StoryBoard的第一个视图控制器- (nullable __kindof UIViewC…

率清华团队研发“天机芯”登《Nature》封面,他说类脑计算是发展人工通用智能的基石...

整理 | AI科技大本营&#xff08;ID:rgznai100&#xff09;8 月&#xff0c;清华大学教授、类脑计算研究中心主任施路平率队研发的关于“天机芯”的论文登上《Nature》封面&#xff0c;这实现了中国在芯片和人工智能两大领域登上该杂志论文零的突破&#xff0c;引发国内外业界一…

IntelliJ IDEA 12详细开发教程(四) 搭建Android应用开发环境与Android项目创建

今天我要给大家讲的是使用Intellij Idea开发Android应用开发。自我感觉使用Idea来进行Android开发要比在Eclipse下开发简单很多。&#xff08;一&#xff09;打开网站&#xff1a;http://developer.android.com/sdk/index.html。从网站上下载SDK下载需要的Android版本&#xff…