深度解析MegEngine亚线性显存优化技术
基于梯度检查点的亚线性显存优化方法[1]由于较高的计算/显存性价比受到关注。MegEngine经过工程扩展和优化,发展出一套行之有效的加强版亚线性显存优化技术,既可在计算存储资源受限的条件下,轻松训练更深的模型,又可使用更大batch size,进一步提升模型性能,稳定batchwise算子。使用MegEngine训练ResNet18/ResNet50,显存占用分别最高降低23%/40%;在更大的Bert模型上,降幅更是高达75%,而额外的计算开销几乎不变。该技术已在MegEngine开源,欢迎大家上手使用:https://github.com/MegEngine。
作者 | 旷视研究院
深度神经网络训练是一件复杂的事情,它体现为模型的时间复杂度和空间复杂度,分别对应着计算和内存;而训练时内存占用问题是漂浮在深度学习社区上空的一块乌云,如何拨云见日,最大降低神经网络训练的内存占用,是一个绕不开的课题。
GPU显卡等硬件为深度学习提供了必需的算力,但硬件自身有限的存储,限制了可训练模型的尺寸,尤其是大型深度网络,由此诞生出一系列相关技术,比如亚线性显存优化、梯度累加、混合精度训练、分布式训练,进行GPU显存优化。
其中,亚线性显存优化方法[1]由于较高的计算/显存性价比备受关注;旷视基于此,经过工程扩展和优化,发展出加强版的MegEngine亚线性显存优化技术,轻松把大模型甚至超大模型装进显存,也可以毫无压力使用大batch训练模型。
这里将围绕着深度学习框架MegEngine亚线性显存优化技术的工程实现和实验数据,从技术背景、原理、使用、展望等多个方面进行首次深入解读。
背景
在深度学习领域中,随着训练数据的增加,需要相应增加模型的尺寸和复杂度,进行模型「扩容」;而ResNet [2] 等技术的出现在算法层面扫清了训练深度模型的障碍。不断增加的数据和持续创新的算法给深度学习框架带来了新挑战,能否在模型训练时有效利用有限的计算存储资源,尤其是减少GPU显存占用,是评估深度学习框架性能的重要指标。
在计算存储资源一定的情况下,深度学习框架有几种降低显存占用的常用方法,其示例如下:
通过合适的梯度定义,让算子的梯度计算不再依赖于前向计算作为输入,从而in-place地完成算子的前向计算,比如Sigmoid、Relu等;
在生命周期没有重叠的算子之间共享显存;
通过额外的计算减少显存占用,比如利用梯度检查点重新计算中间结果的亚线性显存优化方法[1];
通过额外的数据传输减少显存占用,比如把暂时不用的数据从GPU交换到CPU,需要时再从CPU交换回来。
上述显存优化技术在MegEngine中皆有不同程度的实现,这里重点讨论基于梯度检查点的亚线性显存优化技术。
原理
一个神经网络模型所占用的显存空间大体分为两个方面:1)模型本身的参数,2)模型训练临时占用的空间,包括参数的梯度、特征图等。其中最大占比是 2)中以特征图形式存在的中间结果,比如,从示例[1]可知,根据实现的不同,从70%到90%以上的显存用来存储特征图。
这里的训练过程又可分为前向计算,反向计算和优化三个方面,其中前向计算的中间结果最占显存,还有反向计算的梯度。第 1)方面模型自身的参数内存占用最小。
MegEngine加强版亚线性显存优化技术借鉴了[1]的方法,尤其适用于计算存储资源受限的情况,比如一张英伟达2080Ti,只有11G的显存;而更贵的Tesla V100,最大显存也只有32G。
图1:亚线性显存优化原理,其中 (b) 保存了Relu结果,实际中Relu结果可用in-place计算
图 1(a) 给出了卷积神经网络的基本单元,它由Conv-BN-Relu组成。可以看到,反向计算梯度的过程依赖于前向计算获取的中间结果,一个网络需要保存的中间结果与其大小成正比,即显存复杂度为O(n)。
本质上,亚线性显存优化方法是以时间换空间,以计算换显存,如图 1(b) 所示,它的算法原理如下:
选取神经网络中k个检查点,从而把网络分成k个block,需要注意的是,初始输入也作为一个检查点;前向计算过程中只保存检查点处的中间结果;
反向计算梯度的过程中,首先从相应检查点出发,重新计算单个block需要的中间结果,然后计算block内部各个block的梯度;不同block的中间结果计算共享显存。
这种方法有着明显的优点,即大幅降低了模型的空间复杂度,同时缺点是增加了额外的计算:
显存占用从O(n)变成O(n/k)+ O(k),O(n/k)代表计算单个节点需要的显存,O(k)代表k个检查点需要的显存, 取k=sqrt(n),O(n/k)+ O(k)~O(sqrt(n)),可以看到显存占用从线性变成了亚线性;
因为在反向梯度的计算过程中需要从检查点恢复中间结果,整体需要额外执行一次前向计算。
工程
在[1]的基础上,MegEngine结合自身实践,做了工程扩展和优化,把亚线性显存优化方法扩展至任意的计算图,并结合其它常见的显存优化方法,发展出一套行之有效的加强版亚线性显存优化技术。
亚线性优化方法采用简单的网格搜索(grid search)选择检查点,MegEngine在此基础上增加遗传算法,采用边界移动、块合并、块分裂等策略,实现更细粒度的优化,进一步降低了显存占用。
如图2所示,采用型号为2080Ti的GPU训练ResNet50,分别借助基准、亚线性、亚线性+遗传算法三种显存优化策略,对比了可使用的最大batch size。仅使用亚线性优化,batch size从133增至211,是基准的1.6x;而使用亚线性+遗传算法联合优化,batch size进一步增至262,较基准提升2x。
图2:三种显存优化方法优化batch size的对比:ResNet50
通过选定同一模型、给定batch size,可以更好地观察遗传算法优化显存占用的情况。如图3所示,随着迭代次数的增加,遗传算法逐渐收敛显存占用,并在第5次迭代之后达到一个较稳定的状态。
图3:遗传算法收敛示意图
此外,MegEngine亚线性优化技术通过工程改良,不再局限于简单的链状结构和同质计算节点, 可用于任意的计算图,计算节点也可异质,从而拓展了技术的适用场景;并可配合上述显存优化方法,进一步降低模型的显存占用。
实验
MegEngine基于亚线性显存技术开展了相关实验,这里固定batch size=64,在ResNet18和ResNet50两个模型上,考察模型训练时的显存占用和计算时间。
如图4所示,相较于基准实现,使用MegEngine亚线性显存技术训练ResNet18时,显存占用降低32%, 计算时间增加24%;在较大的ReNet50上,显存占用降低40%,计算时间增加25%。同时经过理论分析可知,模型越大,亚线性显存优化的效果越明显,额外的计算时间则几乎不变。
图4:MegEngine亚线性优化技术实验显存/时间对比:ReNet18/ReNet50
在更大模型Bert上实验数据表明,借助MegEngine亚线性显存技术,显存占用最高降低75%,而计算时间仅增加23%,这与理论分析相一致。有兴趣的同学可前往MegEngine ModeHub试手更多模型实验:https://megengine.org.cn/model-hub/。
使用
MegEngine官网提供了亚线性显存优化技术的使用文档。当你的GPU显存有限,苦于无法训练较深、较大的神经网络模型,或者无法使用大batch进一步提升深度神经网络的性能,抑或想要使batchwise算子更加稳定,那么,MegEngine亚线性显存优化技术正是你需要的解决方案。
上手MegEngine亚线性优化技术非常便捷,无需手动设定梯度检查点,通过几个简单的参数,轻松控制遗传算法的搜索策略。具体使用时,在MegEngine静态图接口中调用SublinearMemoryConfig设置trace的参数sublinear_memory_config,即可打开亚线性显存优化:
from megengine.jit import trace, SublinearMemoryConfig
config = SublinearMemoryConfig()
@trace(symbolic=True, sublinear_memory_config=config)def train_func(data, label, *, net, optimizer): ...
MegEngine在编译计算图和训练模型时,虽有少量的额外时间开销,但会显著缓解显存不足问题。下面以ResNet50为例,说明MegEngine可有效突破显存瓶颈,训练batch size从100最高增至200:
import osfrom multiprocessing import Processdef train_resnet_demo(batch_size, enable_sublinear, genetic_nr_iter=0):import megengine as mgeimport megengine.functional as Fimport megengine.hub as hubimport megengine.optimizer as optimfrom megengine.jit import trace, SublinearMemoryConfigimport numpy as npprint("Run with batch_size={}, enable_sublinear={}, genetic_nr_iter={}".format( batch_size, enable_sublinear, genetic_nr_iter ) )# 使用GPU运行这个例子assert mge.is_cuda_available(), "Please run with GPU"try:# 我们从 megengine hub 中加载一个 resnet50 模型。 resnet = hub.load("megengine/models", "resnet50")optimizer = optim.SGD(resnet.parameters(), lr=0.1,)config = Noneif enable_sublinear: config = SublinearMemoryConfig(genetic_nr_iter=genetic_nr_iter)@trace(symbolic=True, sublinear_memory_config=config)def train_func(data, label, *, net, optimizer): pred = net(data) loss = F.cross_entropy_with_softmax(pred, label) optimizer.backward(loss)resnet.train()for i in range(10): batch_data = np.random.randn(batch_size, 3, 224, 224).astype(np.float32) batch_label = np.random.randint(1000, size=(batch_size,)).astype(np.int32) optimizer.zero_grad() train_func(batch_data, batch_label, net=resnet, optimizer=optimizer) optimizer.step()except: print("Failed") returnprint("Sucess")
# 以下示例结果在2080Ti GPU运行得到,显存容量为 11 GB# 不使用亚线性内存优化,允许的batch_size最大为 100 左右p = Process(target=train_resnet_demo, args=(100, False))p.start()p.join()# 报错显存不足p = Process(target=train_resnet_demo, args=(200, False))p.start()p.join()# 使用亚线性内存优化,允许的batch_size最大为 200 左右p = Process(target=train_resnet_demo, args=(200, True, 20))p.star
展望
如上所述,MegEngine的亚线性显存优化技术通过额外做一次前向计算,即可达到O(sqrt(n))的空间复杂度。如果允许做更多次的前向计算,对整个网络递归地调用亚线性显存算法,有望在时间复杂度为O(n log n)的情况下,达到 O(log n)的空间复杂度。
更进一步,MegEngine还将探索亚线性显存优化技术与数据并行/模型并行、混合精度训练的组合使用问题,以期获得更佳的集成效果。最后,在RNN以及GNN、Transformer等其他类型网络上的使用问题,也是MegEngine未来的一个探索方向。
参考文献
Chen, T., Xu, B., Zhang, C., & Guestrin, C. (2016). Training deep nets with sublinear memory cost. arXiv preprint arXiv:1604.06174.
He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 770-778).
推荐阅读
饿了么交易系统 5 年演化史
360金融首席科学家张家兴:别指望AI Lab做成中台
干货 | 时间序列预测类问题下的建模方案探索实践
写了Bug,误执行 rm -fr /*,我删删删删库了,要跑路吗?| 原力计划
中国 App 出海“变形记”
从货币历史,看可编程货币的升级
你点的每个“在看”,我都认真当成了AI
相关文章:

2016-04-28
2019独角兽企业重金招聘Python工程师标准>>> 1.提交form表单之前的函数(校验不错):onsubmit"return A();".2.解析XML的方式:2.1.DOM是用与平台和语言无关的方式表示XML文档的官方W3C标准,基于"树"(DocumentBuilderFactory).2.2.SAX的优点类似于…
Spring源码分析【8】-MyBatis注解方法不能重载
代码如下: 这是不可以的,会报错: 2016-08-18 11:36:00,267 [main] ERROR [org.mybatis.spring.mapper.MapperFactoryBean] - Error while adding the mapper interface com.unix21.mapper.UserMapper to configuration.java.lang.IllegalArgu…
不知道这 7 大 OpenCV 函数怎么向计算机视觉专家进阶?
作者 | Lazar Gugleta译者 | Arvin,责编 | 夕颜头图 | CSDN付费下载自视觉中国出品 | CSDN(ID:CSDNnews)计算机视觉和计算机图形学现在非常流行,因为它们与人工智能息息相关,它们主要的共同点是使用同一个OpenCV库&…

MySQL5.5复制新特性
MySQL5.5复制新特性一.MySQL5.5复制改进MySQL5.5版本对MySQL Replication进行了多项的改良,以提供数据的完整性,性能和应用灵活性更高水平。1.Semisynchronous Replication:主从之间的等待机制2.Slave fsync tuning:调整slave fsync包括sync-…

GitLab 8.7发布
日前,GitLab 8.7版发布。该版本中,添加了新功能和优化,并小幅提升了性能。\\8.7版本发布于8.6版本整整30天之后,跟上了每月22日次版本的进度。最新的版本增加了在单个问题上设置到期日期的支持以及以用户所在时区而不是UTC来显示所…
Java飞行记录器 JRockit Flight Recorder JFR诊断JVM的历史性能和操作
需要展开子树,复制堆栈跟踪,就可以查看到代码调用链,看到自己的业务代码,从而定位到最耗时的代码位置:

vi/vim: 使用taglist插件
本节所用命令的帮助入口: :help helptags :help taglist.txt 上篇文章介绍了在vim中如何使用tag文件,本文主要介绍如何使用taglist插件(plugin)。 想必用过Source Insight的人都记得这样一个功能:SI能够把当前文件中的宏、全局变量、函数等t…
学会这些Python美图技巧,就等女朋友夸我了
来源 | ZackSock(ID: ZackSock)Python中有许多用于图像处理的库,像是Pillow,或者是OpenCV。而很多时候感觉学完了这些图像处理模块没有什么用,其实只是你不知道怎么用罢了。今天就给大家带了一些美图技巧,让…

Linux下的softlink和hardlink(转)
Linux中包括两种链接:硬链接(hard link)和软链接(soft link),软链接又称为符号链接(symbolic link)创建命令:ln -s destfile/directory softlink #建立软连接 ln destfile hardlink #建立硬连接in…

ubuntu安装之后的最初几天一路杂记
我就随便写了啊,没那么正式,想到什么就写什么。 由于大四的毕业设计要做一个牵扯到linux的项目,最近不得不再次玩起了ubuntu,其实前一次(大二的时候吧)就已经在电脑上安装过一个ubuntu了,只不过…

百万级访问量网站的技术准备工作[转帖]
当今从纯网站技术上来说,因为开源模式的发展,现在建一个小网站已经很简单也很便宜,所以很多人都把创业方向定位在互联网应用。这些人里大多数不是 很懂技术,或者不是那么精通,而网站开发维护方面的知识又很分散&#x…
智能驾驶L2的黄金时代,打磨地图是关键
作者 | 自动驾驶从业者,中寰卫星黄亮出品 | AI科技大本营(ID:rgznai100)智能驾驶L2,以我们通俗的定义是,以高级辅助驾驶的产品为主的各种巡航产品,包括定速巡航,自适应巡航ACC,预见性…

css中的垂直居中方法
单行文字 (外行高度固定) line-height 行高, 将line-height值与外部标签盒子的高度值设置成一致就可以了。 height:3em; line-height:3em; 多行文字 图文结合(图和单行文字) 图文结合(图和多行文字…
U盘挂载,gedit,vi,文本模式中文乱码等等问题
U盘或硬盘挂载 首先,我们要查看一下磁盘的分区信息sudo fdisk -l (注意注意,是小写的L,不是1,也不是i) 这里可以看到我的硬盘情况,前面几个是win7系统下的C,D ,E ,F 盘。我现在是在图书馆,没…
一次对语音技术的彻底批判
作者 | Alexander Veysov译者 | 孙薇,编辑 | 夕颜出品 | AI科技大本营(ID:rgznai100)ImageNet的出现带来计算机视觉领域的突破发展,掀起了一股预训练之风,这就是所谓的ImageNet时刻。但与计算机视觉同样重要…

Windows下编译Chrome V8
主要还是参考google的官方文档: How to Download and Build V8 Building on Windows 同时也参考了一些其它的中文博客: 脚本引擎小pk:SpiderMonkey vs V8 Windows 下编译V8引擎-with visual sudio 2010 将google V8 编译成 dll v8学习笔记 步…

mysql子查询
一句话就是子查询的结果作为外部查询的比较条件 所谓子查询是指一个查询语句嵌套在另一个查询语句的内部的查询,也就是select里面还有select。 在select语句中先计算子查询,子查询的结果作为外层另一个查询的过滤条件。 子查询中常用的操作符有ÿ…

Ubuntu查看系统位数及版本
怎么查看本机cup是几位的呢?命令: more /proc/cpuinfo 该命令列出了很多cup信息 找到clflush size ,其值就是cup位数 我的是clflush size: 64 那怎么查看你所装的ubuntu系统是几位的呢?命令: uname -ar Linux wen-lapt…
百度翻译Q1 DAU增长40%,疫情期学生在线学习率猛增
5月11日,百度翻译公布最新的DAU(日活跃用户数量)相关数据,2020年Q1较上一个季度环比增长10%,较去年Q1同比增长40%。 此外,百度翻译还在一个季度内,将翻译的语种扩充了近7倍,目前百度…

Oracle 10g配置RMAN RECOVERY CATALOG
Oracle的RMAN配置信息默认存放在target数据库的控制文件中,当然也可以配置一个recovery catalog服务器来存储这些信息,下面是控制文件和恢复的特性比较,一般来说维护10台以下的oracle数据库备份,可以不需要配置恢复目录. Control …

android Spinner 例子
为什么80%的码农都做不了架构师?>>> 一、主xml:activity_main.xml <?xml version"1.0" encoding"utf-8"?> <LinearLayout xmlns:android"http://schemas.android.com/apk/res/android"android:layout_width&q…
ubuntu下vim的配置
写在前面,我写本文的目的不在于教大家怎么来配置VIM,因为我是新手,我也是参考了各位前辈的方法,在此只是记录一下过程,当然我个人觉得更重要的是心得体会。其实大家可能也发觉,国内的抄袭转载现象很严重&am…
赠书 | 从阿里到Facebook,一线大厂这样做深度学习推荐系统
本文内容节选自《深度学习推荐系统》一书。由美国Roku推荐系统架构负责人、前Hulu高级研究员王喆精心编著,书中包含了这场革命中一系列的主流技术要点:深度学习推荐模型、Embedding技术、推荐系统工程实现、模型评估体系、业界前沿实践…………深度学习在…

使用 CAS 在 Tomcat 中实现单点登录
CAS 介绍 CAS 是 Yale 大学发起的一个开源项目,旨在为 Web 应用系统提供一种可靠的单点登录方法,CAS 在 2004 年 12 月正式成为 JA-SIG 的一个项目。CAS 具有以下特点: 开源的企业级单点登录解决方案。CAS Server 为需要独立部署的 Web 应用。…
Windows SDK 7.1 (包含directshow)安装配置
最近一直在做毕业设计的事情,需要利用directshow进行视频开发,但是现在单独的directshow包已经没有了,从directx9.0c开始directshow和directx分开发布,现在的directshow已经集成到windows SDK当中了。 但是说实话,由于…
20行Python代码实现视频字符化
来源 | ZackSock(ID:ZackSock)我们经常在B站上看到一些字符鬼畜视频,主要就是将一个视频转换成字符的样子展现出来。看起来是非常高端,但是实际实现起来确实非常简单,我们只需要接触opencv模块,就能很快的实…

隔年的衣服发黄处理方法
1.用菠菜水,将菠菜煮水五分钟,然后用菠菜水除旧衣服黄渍特灵 2.用淘米水泡洗就可以了 3.用温盐水泡上20分钟再洗 4.如果是白颜色衣服的话,你不妨在洗衣服的时候放一点蓝色墨水或者用漂白 转载于:https://blog.51cto.com/wanghu2009/519490

linux监控(陆续补充)
一 定时任务for user in $(cat /etc/passwd | cut -f1 -d:);do crontab -l -u $user;done是否有用户执行了隐藏定时任务? 是否有某个任务正在备份二 网络sysctl -a | grep xx 查看网络内核参数信息ss -s 显示所有存在的连接cat /proc/interrupts 查看中断请求是否…

自绘按钮的实现
如果你希望能够在自己的程序中表现出新意,那么你一定不会仅仅满足于MFC提供那些标准控件。这时,我们就必须自己另外多做些工作了。就改变控件外观这一点来说,主要是利用控件的自绘功能(Owner Draw)实现的。本篇将和各位…

24/4毕业设计小记
折腾了很久了,关于我的毕业设计,一直就没有时间来写博客,今天感冒了,趁着思路不太好的时候就写一篇博客吧!写什么好呢,就写基于vlc sdk的播放器开发吧! 我的项目是关于windows和linux两个平台的…