当前位置: 首页 > 编程日记 > 正文

8比特数值也能训练模型?商汤提训练加速新算法丨CVPR 2020

出品 | AI科技大本营(ID:rgznai100)

在CVPR 2020上,商汤研究院链接与编译团队、高性能计算团队和北航刘祥龙老师团队合作提出了用于加速卷积神经网络训练过程的INT8训练技术。该工作通过将网络的输入、权重和梯度量化到8比特来加速网络的前向传播和反向传播过程,缩短卷积神经网络训练时间。

论文观察到梯度的独特分布给量化训练带来了极大挑战,为了解决梯度量化带来的精度损失和不稳定问题,该论文进行了量化训练收敛稳定性的理论分析并基于此提出了误差敏感的学习率调节和基于方向自适应的梯度截断方法。同时为了保证更高的加速比,该论文还提出使用周期更新、量化卷积融合等技术来减少量化操作带来的时间开销。应用了上述方法之后,INT8训练在图像分类任务和检测任务上都仅仅损失微小的精度,且训练过程相比浮点训练加速了22%。

动机与背景

卷积神经网络被广泛应用在多种计算机视觉任务中并且取得了优异的精度。由于拥有庞大的参数量,训练和部署卷积神经网络需要耗费大量计算资源和漫长的训练时间,如何用更少资源训练卷积神经网络一直是一个学术研究热点,也是工业界关心的话题。

神经网络量化技术是一种使用定点计算代替浮点的加速技术,目前被广泛地应用在神经网络部署中,可以极大地提升部署速度,并降低内存资源占用。现有很多工作均表明将网络前向过程的浮点计算替换成INT8计算,不会带来明显的精度下降[1][2]。

下图展示了现代神经网络加速芯片对于不同精度计算的理论计算峰值对比,可以看到,INT8算力相比于FP32和FP/INT16均能有超过2倍峰值性能提升。

当考虑将神经网络量化技术应用在卷积神经网络训练中时,为了加速卷积的反向梯度传播过程,不得不对梯度进行量化操作。在将浮点的梯度量化到INT8数值范围内之后,训练过程变得极其不稳定,并且收敛到非常差的精度。如何解决量化梯度给训练带来的收敛稳定性问题,是十分重要的问题。与此同时,在提升训练精度的同时,也不应当进入过多额外的计算,否则加速效果将会大打折扣。

一方面是高效的计算峰值保障,一方面是困难重重的算法设计,这是INT8训练技术的机遇与挑战。

何为INT8训练

标准的线性量化操作指的是,将一个浮点张量(tensor)进行线性映射,变换到整数空间中[3]。这个整数空间的大小由于量化比特数来决定,比如常见的8bit量化数,就有256个取值,本文中使用的是对称量化,因此量化数的取值是从-128到127。具体公式如下,其中x是被量化的数据,q是量化后的数据,s是量化系数,clip是截断函数:

在8bit的场景里,截断函数和量化系数的计算公式如下:

为了降低量化带来的误差,一个常见做法是对取整过程进行随机化,使得取整函数从期望上更接近原始的数,具体随机取整的公式如下:

相反的,将8bit量化数变换回浮点的过程称之为反量化。反量化公式如下所示,其中q为量化计算结果,s为量化系数,为反量化后的结果。

上图的上半部分展示了标准的卷积神经网络量化计算前向过程,该过程被广泛应用在INT8部署加速中。在卷积计算之前,量化器会对输入和权重进行量化操作,将浮点数量化到8bit数值上,通过INT8卷积计算核心,即可完成一次INT8前向计算,最终将求和得到的32bit数进行反量化操作回算到浮点数域中,以供给下一层计算使用。

INT8训练的一个核心的加速点在于卷积计算的反向过程,上图展示了INT8训练中卷积计算在反向传播过程中的计算细节。在卷积的反向梯度传播过程,同样的将梯度进行浮点量化操作,不过为了降低量化的误差,针对梯度的量化采用了随机取整操作。通过INT8的反向卷积计算核心,可以得到下一层所需的回传梯度,以及当前层的权重所需的梯度。由于INT8反向卷积输出的是32bit数,与前传类似,需要引入一次反量化操作,将32bit数反算回到浮点数域中。

梯度为何难以量化

为什么对梯度进行量化会给网络训练带来如此大的影响?我们可以观察训练过程中的梯度分布情况来进一步的分析。

通过图(a)中对比梯度和输入、权重的分布,可以发现:梯度分布相比输入和权重分布更加尖锐,同时范围更大。相比于输入和权重,梯度有更多的值集中在0附近,但同时梯度还有许多较大值,让梯度的分布范围变得相当广,这些特征都会导致梯度量化的量化误差比输入和权重更大。

图(b)展示的是layers16随着训练,其梯度从epoch 0到epoch 300的变化情况。从中可以看出,随着训练的进行,梯度分布越变得更加尖锐,同时仍然保持着较广的分布范围,这意味着梯度量化的误差会随着训练的进行变得越来越大。

梯度的分布随网络深度变化情况从图(c)中可以看出。很容易发现,卷积层的深度越浅,梯度分布越尖锐,这也会导致梯度量化的误差更大。

从图(d)中可以看出卷积的结构也会影响梯度分布,对于MobileNetV2来说,conv2为depthwise卷积其相比conv1和conv3具有更加尖锐的分布。

由于卷积神经网络的梯度具有如上四个特点,所以当我们直接在训练中对梯度进行量化时,训练精度非常容易出现突发的崩溃情况。下图展示了在CIFAR-10数据集上进行实验的精度和损失函数变化曲线,以MobileNetv2在CIFAR-10数据集上训练为例,其训练的精度曲线和loss曲线如下图,从图中可以发现INT8训练的loss在训练初期正常下降,但随后迅速上升,对应的精度也不断下降。

是什么影响了收敛稳定性

根据以上的观察和初步启发,我们希望通过理论的分析和推导,对量化训练的收敛稳定性进行建模。根据Adam等相关论文的经验和优化理论中的Regret analysis,不失一般性地定义R(T)为

其中f是损失函数,t是训练轮数,T是训练总轮数,为t轮的权重,是最优权重。

基于以下两个朴素的假设:

通过推导证明可以得到:

其中为t轮的学习率,d为权重的维度,为t轮的量化误差,是t轮的量化后梯度。为了确保网络能够稳定收敛,在T变大时需要能够达到足够小。通过上式可以发现,在T趋于无穷大时,第(1)项可以忽略不计,主要考虑减小第(2)项和第(3)项。我们发现,第(2)项与量化误差正相关,第(3)项与学习率以及量化后的梯度大小有关。

因此我们不难得到两个直观的提升训练收敛稳定性的策略:

  • 通过调节量化函数中的截断减小量化误差

  • 通过适当调低学习率来提高量化训练精度

主要方法

依据以上分析,我们针对量化误差和学习率提出了基于方向自适应的梯度截断和误差敏感的学习率调节两个方法来解决量化训练带来的精度损失问题。同时,为了减少量化操作带来的额外开销,本文还提出了周期更新和量化卷积融合的方法。

1、基于方向自适应的梯度截断:调整截断值,让梯度方向保持正确

为了最小化量化误差,之前有很多研究提出优化截断值的方法,其中就有研究提出通过假设数据分布直接求解最优截断值。但是已有的研究都针对于权重量化的截断值进行优化。就如本文观察所显示,梯度的分布特征与权重区别较大,无法直接使用。本文通过KS检验发现梯度的分布并不符合常见的高斯分布、拉普拉斯分布和学生t分布,因此很难通过假设梯度分布来直接求解最优的截断值。

基于以上的分析,本文采用梯度下降的方法来自适应地学习最优截断值,常见的目标函数有均方误差函数,但是由于梯度的分布特征,均方误差的大小会受到梯度的影响,影响优化过程;同时对于梯度来说,均方误差并不能很好地体现梯度的量化误差对于优化过程的影响,因此本文提出使用能够体现梯度方向的余弦距离来衡量梯度的量化误差,并以余弦距离为目标函数来优化求解最优截断值。余弦距离定义如下:

其中,g是梯度,是量化后的梯度。

2、误差敏感的学习率调节:在错误的方向上尽量少更新

根据上述的理论分析,降低学习率能够有助于模型量化训练的收敛。针对学习率的调整,本文提出误差敏感的学习率调节方法,使用学习率系数对原学习率进行调整,学习率系数与余弦距离负相关,学习率系数定义如下:

其中是超参数,用于控制衰减程度和调节下界。

3、周期更新:降低由于统计而带来的额外计算耗时

由于量化操作需要的统计数据范围和计算截断值等操作十分耗时,为了减少这些操作的时间开销,本文采用周期更新的方式,周期性地统计数据范围和计算截断值。通过周期更新的方法能够有效地提高减少因量化引入的额外时间开销。下表为ResNet50在ImageNet数据集上不同周期的单次训练时间统计表。

4、量化卷积融合:减少访存次数、节省cuda kernel launch次数

通过将量化和反量化操作融合入卷积计算的CUDA核函数里,可以减少一次数据的访存,有效地减少量化和反量化操作的时间开销。

实验结果

图像分类任务:本文在CIFAR10和ImageNet等图像分类数据集进行INT8训练实验。从下表结果中可以看出,在大多数网络结构中均取得了比现有最好方法更优的精度,并且首次在MobileNet、Inception等网络上进行量化训练实验,精度损失也在1.5%以内。

目标检测任务:同时,本文也首次尝试在PASCAL和COCO等目标检测数据集上进行INT8训练实验,精度损失也在2%以内。

已有的少量探究梯度量化的论文[4]均未报告算法在实际训练任务中的真实加速性能,为了最大限度将方法实用化,本文在 GeForce GTX1080TI显卡上编写并优化了用于支持INT8训练的卷积前向和后向计算核心。实测结果表明,使用INT8卷积计算的前向和后向过程相比于浮点计算有明显的加速,其中前向过程平均加速1.63倍,后向过程平均加速1.94倍。如下图所示:

同时,本文在实际训练过程中进行了完整的端到端测试,可以看到,INT8训练可以将ResNet50的一轮训练过程从0.360秒降低到0.293秒,整体训练过程提速了22%。

论文地址:

https://arxiv.org/pdf/1912.12607.pdf 

References

[1].Ruihao Gong, Xianglong Liu, Shenghu Jiang, Tianxiang Li,Peng Hu, Jiazhen Lin, Fengwei Yu, and Junjie Yan. Differen-tiable soft quantization:  Bridging full-precision and low-bitneural networks. In ICCV, October 2019.

[2].Rundong Li, Yan Wang, Feng Liang, Hongwei Qin, Junjie Yan, and Rui Fan. Fully quantized network for object detection. In The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), June 2019.

[3].   Benoit Jacob, Skirmantas Kligys, Bo Chen, Menglong Zhu, Matthew Tang, Andrew Howard, Hartwig Adam, and Dmitry Kalenichenko. Quantization and training of neural networks for efficient integer-arithmetic-only inference. 2018 IEEE Conference on Computer Vision and Pattern Recognition(CVPR), June 2018.

[4].   Yukuan Yang, Shuang Wu, Lei Deng, Tianyi Yan, Yuan Xie, and Guoqi Li. Training high-performance and large-scale deep neural networks with full 8-bit integers, 2019.

欢迎所有开发者扫描下方二维码填写《开发者与AI大调研》,只需2分钟,便可收获价值299元的「AI开发者万人大会」在线直播门票!

推荐阅读

  • GitHub标星2000+,如何用30天啃完TensorFlow2.0?

  • 斩获GitHub 2000+ Star,阿里云开源的Alink机器学习平台如何跑赢双11数据“博弈”?

  • 百年 IBM 终于 All In 人工智能和混合云!

  • 微软为一人收购一公司?破解索尼程序、写黑客小说,看他彪悍的程序人生!

  • 机器学习项目模板:ML项目的6个基本步骤

  • BM、微软、苹果、谷歌、三星……这些区块链中的科技巨头原来已经做了这么多事!

  • 你点的每个“在看”,我都认真当成了AI

相关文章:

×××作,不知写些什么

博客,老是有写的冲动,不过,没什么韧劲坚持,自己感觉文采一般般啦,有时兴起,挥毫泼墨,蜡笔重唱一番,呵呵,自个爽朗了,呵呵 所以,自己坚持&#xff…

centos7 install 安装mysql

CentOS 7的yum源中貌似没有正常安装mysql时的mysql-sever文件,需要去官网上下载 # wget http://dev.mysql.com/get/mysql-community-release-el7-5.noarch.rpm # rpm -ivh mysql-community-release-el7-5.noarch.rpm # yum install mysql-community-server成功安装之…

AI四巨头Google、DeepMind、Microsoft、Uber深度学习框架大比拼

编者按:Google、Uber、DeepMind和Microsoft这四大科技公司是当前将深度学习研究广泛应用于自身业务的典型代表,跻身全球深度学习研究水平最高的科技公司之列。GPipe、Horovod、TF Replicator和DeepSpeed分别是这四家公司开发应用的深度学习框架&#xff…

转《刘润的数字化家庭》

数字家庭也是我的一大梦想,感谢刘润让我的想法更加丰富和具体。。。 转载自刘润的博客,原文地址:http://blog.run2me.com/runliu/archive/2010/06/12/37082.aspx 1 of 22 (大图):用数字化的技术&#xff0c…

自己写的内存池Slabs

看memcached的源码写的&#xff0c;虽然很粗糙&#xff0c;但是基本思想还是有的&#xff0c;自娱自乐&#xff0c;后期不断改进。 #include <stdio.h> #include <stdlib.h> #include <string.h>struct st{void * start;void * end;char ptr[10]; }; struct …

Eclispse Che(2):启动Che服务,进入IDE界面

本文的原文连接是: http://blog.csdn.net/freewebsys/article/details/50888878 未经博主允许不得转载。 博主地址是&#xff1a;http://blog.csdn.net/freewebsys 1&#xff0c;关于Docker 上次使用Che的时候没有成功创建Project。 其实主要问题就是docker的网络问题。 使用…

使用strace和ltrace跟踪程序调用

ltrace能够跟踪进程的库函数调用,它会显现出哪个库函数被调用,而strace则是跟踪程序的每个系统调用.1.系统调用的输出对比程序代码&#xff1a;#include <stdio.h> main(){char str[] "Abcde";printf("\n string %s length %d \n",str,str_length(…

NeHe OpenGL第三十三课:TGA文件

NeHe OpenGL第三十三课&#xff1a;TGA文件 加载压缩和未压缩的TGA文件: 在这一课里&#xff0c;你将学会如何加载压缩和为压缩的TGA文件&#xff0c;由于它使用RLE压缩&#xff0c;所以非常的简单&#xff0c;你能很快地熟悉它的。 我见过很多人在游戏开发论坛或其它地方询问…

阿里自动驾驶新突破!达摩院自研ISP图像处理器大幅提升安全性

阿里巴巴达摩院在自动驾驶领域取得新突破&#xff01;4月8日&#xff0c;据记者了解&#xff0c;达摩院已经自主研发出用于车载摄像头的ISP处理器&#xff0c;保障自动驾驶车辆在夜间拥有更好的“视力”&#xff0c;“看”得更清晰&#xff0c;从而大幅提升自动驾驶安全性, 而背…

3月14号作业

<!DOCTYPE html> <html> <head lang"en"><meta charset"UTF-8"><title></title> </head> <body> <br/> <br/> <img src"51job表单_03.gif"</br> <br/> <br/>…

NeHe OpenGL第三十五课:播放AVI

NeHe OpenGL第三十五课&#xff1a;播放AVI 在OpenGL中播放AVI: 在OpenGL中如何播放AVI呢&#xff1f;利用Windows的API把每一帧作为纹理绑定到OpenGL中&#xff0c;虽然很慢&#xff0c;但它的效果不错。你可以试试。 首先我得说我非常喜欢这一章节.Jonathan de Blok使我产生…

为什么TCP的TIME_WAIT状态要保持2MSL?

TIMEWAIT状态也称为 2MSL等待状态。每个具体TCP实现必须选择一个报文段最大生存时间MSL(Maximum Segment Lifetime)。它是任何报文段被丢弃前在网络内的最长时间。我们知道这个时间是有限的&#xff0c;因为TCP报文段以IP数据报在网络内传输&#xff0c;而IP数据报则有限制其生…

深度 | 一文读懂“情感计算”在零售中的应用发展

作者 | 黄程韦博士、刘刚、包飞博士、杨现博士、孙皓博士、沈艺博士来源 | 苏宁零售技术研究院零售商需要不断通过创新服务来提高顾客的购物体验&#xff0c;而情感计算在该领域具有独特优势。它在零售行业的应用&#xff0c;主要集中在提升购物体验的服务中。在这个科技逐步改…

mysql基于replication实现最简单的M-S主从复制

2019独角兽企业重金招聘Python工程师标准>>> 什么是replication Replication可以实现数据从一台数据库服务器&#xff08;master&#xff09;复制到一到多台数据库服务器。 默认情况下&#xff0c;属于异步复制&#xff0c;因此无需维持长连接。 通过配置&#xff0…

Linux下高并发socket最大连接数所受的各种限制

修改最大打开文件数 # ulimit -n 修改最大进程数 # ulimit -u ------------------------------------------------------ Linux下高并发socket最大连接数所受的各种限制 转自&#xff1a;http://blog.csdn.net/guowake/article/details/6615728 1、修改用户进程可打开…

linux安全问答(1)

一、如何限制对系统资源的过度使用&#xff1f; &#xff08;1&#xff09;、编辑/etc/security/limits.conf文件&#xff0c;在其中加入或改变下面这些内容&#xff1a; * hard core 0 //禁止创建core文件 * hard rss 5000 //表示除root用户之外&#xff0c;其他用户都只能最多…

快速搭建对话机器人,就用这一招!

作者 | Milvus.io 责编 | 胡巍巍问答系统是自然语言处理领域一个很经典的问题&#xff0c;它用于回答人们以自然语言形式提出的问题&#xff0c;有着广泛的应用。其经典应用场景包括&#xff1a;智能语音交互、在线客服、知识获取、情感类聊天等。常见的分类有&#xff1a;生成…

目前流行的源程序版本管理软件和项目管理软件都有哪些?各有什么优缺点?...

目前流行的源程序版本管理软件和项目管理软件&#xff1a;Microsoft TFS&#xff0c;Github&#xff0c;SVN&#xff0c;Coding 各自的优缺点&#xff1a; Microsoft TFS&#xff1a;优点&#xff1a;任务版上能将需求、项目进度一览无余&#xff0c;对于小团队而言&#xff0c…

孙鑫mfc学习笔记第十四课

第十四课网络的相关知识&#xff0c;网络程序的编写&#xff0c;Socket是连接应用程序与网络驱动程序的桥梁&#xff0c;Socket在应用程序中创建&#xff0c;通过bind与驱动程序建立关系。此后&#xff0c;应用程序送给Socket的数据&#xff0c;由Socket交给驱动程序向网络上发…

Linux环境编译安装Mysql以及补装innodb引擎方法

mysql安装 5.6以后可能会收费&#xff0c;所以选择5.1以下从台湾中山大学镜像下载 1.首先要安装C编译环境 # yum install gcc-c 2.下载解压 # wget http://mysql.cdpa.nsysu.edu.tw/Downloads/MySQL-5.1/mysql-5.1.73.tar.gz# tar zxvf mysql-5.1.73.tar.gz# cd mysql-5…

Python 炫技操作:合并字典的七种方法

来源 | Python编程时光&#xff08;ID: Cool-Python&#xff09;Python 语言里有许多&#xff08;而且是越来越多&#xff09;的高级特性&#xff0c;是 Python 发烧友们非常喜欢的。在这些人的眼里&#xff0c;能够写出那些一般开发者看不懂的高级特性&#xff0c;就是高手&am…

shell脚本编程基础(1)及RAID阵列

shell脚本&#xff1a;Linux从底层到上层的系统架构&#xff1a;硬件-->内核-->库(lib)-->shell-->用户。shell既是一种命令语言&#xff0c;也是程序设计语言&#xff08;shell脚本&#xff09;&#xff0c;作为一种命令语言&#xff0c;它提供了用户与内核的交互…

freemarker基本语法及实例

EG.一个对象BOOK 1.输出 ${book.name} 空值判断&#xff1a;${book.name?if_exists }, ${book.name?default(‘xxx’)}//默认值xxx ${ book.name!"xxx"}//默认值xxx 日期格式&#xff1a;${book.date?string(yyyy-MM-dd)} 数字格式&#xff1a;${boo…

前百度主任架构师创业,两年融资千万美元,他说AI新药研发将迎来黄金十年...

「AI技术生态论」 人物访谈栏目是CSDN发起的百万人学AI倡议下的重要组成部分。通过对AI生态专家、创业者、行业KOL的访谈&#xff0c;反映其对于行业的思考、未来趋势的判断、技术的实践&#xff0c;以及成长的经历。2020年&#xff0c;CSDN将对1000人物进行访谈&#xff0c;形…

Linux环境安装卸载JDK以及安装Tomcat和发布Java的web程序

Linux环境&#xff1a;CentOS7.2 一.安装JDK 安装好的CentOS会自带OpenJdk&#xff0c;最好还是先卸载系统自带的JDK&#xff0c;然后自己重新去Oracle网站下载最新的JDK安装。 1.卸载系统自带的JDK 查看java信息 # java -version 查看JDK # rpm -qa | grep java 或者 还…

(转)详解css3弹性盒模型(Flexbox)

今天刚学了css3的弹性盒模型&#xff0c;这是一个可以让你告别浮动、完美实现垂直水平居中的新特性。 Flexbox是布局模块&#xff0c;而不是一个简单的属性&#xff0c;它包含父元素和子元素的属性。 Flexbox布局的主体思想是似的元素可以改变大小以适应可用空间&#xff0c;当…

Java开发环境的搭建以及使用eclipse创建项目

一、Java 开发环境的搭建 这里主要说windows环境下怎么配置Java环境。如果是Linux环境参考本博客另一篇文章即可&#xff1a; Linux环境安装卸载JDK 1.首先安装JDK java的SDK简称JDK。 去官网下载最新的JDK即可&#xff1a; http://www.oracle.com/technetwork/java/javase…

​MMIT冠军方案 | 用于行为识别的时间交错网络,商汤公开视频理解代码库

作者 | 商汤出品 | AI科技大本营&#xff08;ID:rgznai100&#xff09;本文主要介绍三个部分&#xff1a;一个高效的SOTA视频特征提取网络TIN&#xff0c;发表于AAAI2020ICCV19 MMIT多标签视频理解竞赛冠军方案&#xff0c;基于TIN和SlowFast一个基于PyTorch&#xff0c;包含大…

MySQL的主从服务器配置

MySQL的主从服务器配置常见开源数据库有&#xff1a;MySQL&#xff0c;PostgreSQL&#xff0c;SQLite等&#xff0c;商业性质的&#xff1a;Oracle&#xff0c;Sql Server&#xff0c;DB2&#xff0c;Sybase&#xff0c;Infomix其中&#xff0c;Oracle的版本有Oracle 11g,Oracl…

Anaconda中安装Orange3脚本-完整版

2019独角兽企业重金招聘Python工程师标准>>> #Anaconda中安装Orange3脚本&#xff0c;完整版。包括插件的安装&#xff0c;在脚本中一次完成。 sudo apt-get update sudo apt-get -y install git python-pip python-virtualenv python-qt4-dev python3-pyqt4 libqt…