当前位置: 首页 > 编程日记 > 正文

掌握这些步骤,机器学习模型问题药到病除

640?wx_fmt=jpeg

作者 | Cecelia Shao

编译 | ronghuaiyang

来源 | AI公园(ID:AI_Paradise)

【导读】这篇文章提供了切实可行的步骤来识别和修复机器学习模型的训练、泛化和优化问题。

众所周知,调试机器学习代码非常困难。即使对于简单的前馈神经网络也是这样,你经常会在网络体系结构做出一些决定,重初始化和网络优化——所有这些会都导致在你的机器学习代码中出现bug。

正如Chase Roberts在一篇关于“How to unit test machine learning code”的优秀文章中所写的,他遇到的麻烦来自于常见的陷阱:

  1. 代码不会崩溃,不会引发异常,甚至不会变慢。
  2. 训练网络仍在运行,损失仍将下降。
  3. 几个小时后,数值收敛了,但结果很差

那么我们该怎么做呢?

本文将提供一个框架来帮助你调试神经网络:

  1. 从最简单的开始
  2. 确认你的损失
  3. 检查中间输出和连接
  4. 对参数进行诊断
  5. 跟踪你的工作

请随意跳转到特定的部分或通读下面的内容!请注意:我们不包括数据预处理或特定的模型算法选择。对于这些主题,网上有很多很好的资源。


1. 从最简单的开始


一个具有复杂结构和正则化以及学习率调度程序的神经网络将比一个简单的网络更难调试。我们在第一点上有点欺骗性,因为它与调试你已经构建的网络没有什么关系,但是它仍然是一个重要的建议!
从最简单的开始:

  • 首先建立一个更简单的模型
  • 在单个数据点上训练模型


首先,构建一个更简单的模型


首先,构建一个只有一个隐藏层的小型网络,并验证一切正常。然后逐步增加模型的复杂性,同时检查模型结构的每个方面(附加层、参数等),然后再继续。


在单个数据点上训练模型


作为一个快速的完整性检查,你可以使用一两个训练数据点来确认你的模型是否能够过拟合。神经网络应该立即过拟合,训练精度为100%,验证精度与你的模型随机猜测相匹配。如果你的模型不能对这些数据点进行过拟合,那么要么是它太小,要么就是存在bug。
即使你已经验证了模型是有效的,在继续之前也可以尝试训练一个(或几个)epochs。


2. 确认你的损失


你的模型的损失是评估你的模型性能的主要方法,也是模型评估的重要参数,所以你要确保:

  • 损失适合于任务(对于多分类问题使用类别交叉熵损失或使用focal loss来解决类不平衡)
  • 你的损失函数在以正确的尺度进行测量。如果你的网络中使用了不止一种类型的损失,例如MSE、adversarial、L1、feature loss,那么请确保所有损失都按正确的顺序进行了缩放

注意,你最初的损失也很重要。如果模型一开始就随机猜测,检查初始损失是否接近预期损失。在Stanford CS231n coursework中,Andrej Karpathy提出了以下建议:
在随机表现上寻找正确的损失。确保在初始化小参数时得到预期的损失。最好先单独检查数据的loss(将正则化强度设置为零)。例如,对于使用Softmax分类器的CIFAR-10,我们期望初始损失为2.302,因为我们期望每个类的随机概率为0.1(因为有10个类),而Softmax损失是正确类的负对数概率,因此:-ln(0.1) = 2.302。
对于二分类的例子,只需对每个类执行类似的计算。假设数据是20%的0和80%的1。预期的初始损失是- 0.2ln(0.5) - 0.8ln(0.5) = 0.693147。如果你的初始损失比1大得多,这可能表明你的神经网络权重不平衡(即初始化很差)或者你的数据没有标准化。

3. 检查内部的输出和连接

要调试神经网络,通常了解神经网络内部的动态以及各个中间层所起的作用以及这些中间层之间如何连接是很有用的。你可能会遇到以下错误:

  • 梯度更新的表达式不正确
  • 权重更新没有应用
  • 梯度消失或爆炸

如果梯度值为零,这可能意味着优化器中的学习率可能太小,或者你遇到了上面的错误#1,其中包含梯度更新的不正确的表达式。

除了查看梯度更新的绝对值之外,还要确保监视激活的大小、权重的大小和每个层的更新相匹配。例如,参数更新的大小(权重和偏差)应该是1-e3。

有一种现象叫做“死亡的ReLU”或“梯度消失问题”,ReLU神经元在学习了一个表示权重的大的负偏置项后,会输出一个零。这些神经元再也不会在任何数据点上被激活。

你可以使用梯度检查来检查这些错误,通过使用数值方法来近似梯度。如果它接近计算的梯度,则正确地实现了反向传播。

Faizan Shaikh描述了可视化神经网络的三种主要方法:

  • 初步方法 - 向我们展示训练模型整体结构的简单方法。这些方法包括打印出神经网络各层的形状或滤波器以及各层的参数。
  • 基于激活的方法 - 在这些方法中,我们解码单个神经元或一组神经元的激活情况,以直观地了解它们在做什么
  • 基于梯度的方法 - 这些方法倾向于在训练模型时操作由前向和后向传递形成的梯度(包括显著性映射和类激活映射)。

有许多有用的工具可以可视化单个层的激活和连接,比如ConX和Tensorboard。

640?wx_fmt=png

使用ConX生成的动态呈现可视化示例

4. 参数诊断

神经网络有大量的参数相互作用,使得优化变得困难。请注意,这是一个活跃的研究领域,所以下面的建议只是简单的出发点。

  • Batch size - 你希望batch size足够大,能够准确地估计错误梯度,但又足够小,以便小批随机梯度下降(SGD)能够使你的网络归一化。小的batch size将导致学习过程以训练过程中的噪声为代价快速收敛,并可能导致优化困难。论文On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima描述了:
在实践中已经观察到,当使用一个较大的batch size时,模型的质量会下降,这可以通过它的泛化能力来衡量。我们研究了在大批量情况下泛化下降的原因,并给出了支持large-batch方法趋向于收敛于训练和测试函数的sharp的极小值这一观点的数值证据——众所周知,sharp的极小值导致较差的泛化。相比之下,小batch size的方法始终收敛于平坦的最小值,我们的实验支持一个普遍的观点,即这是由于梯度估计中的固有噪声造成的。
  • 学习速率 - 学习率过低将导致收敛速度慢或陷入局部最小值的风险,而学习速率过大导致优化分歧,因为你有跳过损失函数的更深但是更窄部分的风险。考虑将学习率策略也纳入其中,以随着训练的进展降低学习率。CS231n课程有一大部分是关于实现退火学习率的不同技术。
  • 梯度裁剪 - 在反向传播期间的通过最大值或最大范数对梯度进行裁剪。对于处理可能在上面的步骤3中遇到的任何梯度爆炸非常有用。
  • Batch normalization - Batch normalization用于对每一层的输入进行归一化,以解决内部协变量移位问题。如果你同时使用Dropout和Batch Norm,请确保在Dropout上阅读下面的要点。
本文来自Dishank Bansal的”TensorFlow中batch norm的陷阱和训练网络的健康检查“,里面包括了很多使用batch norm的常见错误。
  • 随机梯度下降(SGD) - 有几种使用动量,自适应学习率的SGD,和Nesterov相比并没有训练和泛化性能上的优胜者。一个推荐的起点是Adam或使用Nesterov动量的纯SGD。
  • 正则化 - 正则化对于构建可泛化模型至关重要,因为它增加了模型复杂度或极端参数值的代价。它显著降低了模型的方差,而没有显著增加其偏差。如CS231n课程所述:
通常情况下,损失函数是数据损失和正则化损失的总和(例如L2对权重的惩罚)。需要注意的一个危险是正则化损失可能会超过数据损失,在这种情况下,梯度将主要来自正则化项(它通常有一个简单得多的梯度表达式)。这可能会掩盖数据损失的梯度的不正确实现。
为了检查这个问题,应该关闭正则化并独立检查数据损失的梯度。
  • Dropout - Dropout是另一种正则化你的网络,防止过拟合的技术。在训练过程中,只有保持神经元以一定的概率p(超参数)活动,否则将其设置为零。因此,网络必须在每个训练批中使用不同的参数子集,这减少了特定参数的变化成为主导。
  • 这里需要注意的是:如果您同时使用dropout和批处理规范化(batch norm),那么要注意这些操作的顺序,甚至要同时使用它们。这仍然是一个活跃的研究领域,但你可以看到最新的讨论:
来自Stackoverflow的用户MiloMinderBinder:Dropout是为了完全阻断某些神经元的信息,以确保神经元不相互适应。因此,batch norm必须在dropout之后进行,否则你将通过标准化统计之后的数据传递信息。”
来自arXiv:Understanding the Disharmony between Dropout and Batch Normalization by Variance Shift — 从理论上讲,我们发现,当我们将网络状态从训练状态转移到测试状态时,Dropout会改变特定神经单元的方差。但是BN在测试阶段会保持其统计方差,这是在整个学习过程中积累的。当在BN之前的使用Dropout时,该方差的不一致性(我们将此方案命名为“方差漂移”)导致不稳定的推断数值行为,最终导致更多的错误预测。


5. 跟踪你的网络

你很容易忽视记录实验的重要性,直到你忘记你使用的学习率或分类权重。通过更好的跟踪,你可以轻松地回顾和重现以前的实验,以减少重复的工作(也就是说,遇到相同的错误)。

然而,手工记录信息对于多个实验来说是很困难的。工具如 Comet.ml可以帮助自动跟踪数据集、代码更改、实验历史和生产模型(这包括关于模型的关键信息,如超参数、模型性能指标和环境细节)。

你的神经网络对数据、参数甚至版本中的细微变化都非常敏感,这会导致模型性能的下降。跟踪你的工作是开始标准化你的环境和建模工作流的第一步。

640?wx_fmt=png

快速回顾

我们希望这篇文章为调试神经网络提供了一个坚实的起点。要总结要点,你应该:

  1. 从简单的开始 — 先建立一个更简单的模型,然后通过对几个数据点的训练进行测试
  1. 确认您的损失 — 检查是否使用正确的损失,并检查初始损失
  1. 检查中间输出和连接 — 使用梯度检查和可视化检查看图层是否正确连接,以及梯度是否如预期的那样更新
  1. 诊断参数 — 从SGD到学习率,确定正确的组合(或找出错误的)?
  2. 跟踪您的工作 — 作为基线,跟踪你的实验过程和关键的建模组件

英文原文:
https://towardsdatascience.com/checklist-for-debugging-neural-networks-d8b2a9434f21

(*本文为AI科技大本营转载文章,转载联系作者)


精彩推荐




推荐阅读

640?wx_fmt=png

你点的每个“在看”,我都认真当成了喜欢“

相关文章:

How to list/dump dm thin pool metadata device?

2019独角兽企业重金招聘Python工程师标准>>> See: How to create metadata-snap for thin tools using? I dont think LVM provides any support for metadata snapshots so you will need to drive this process through dmsetup. The kernel interface is descri…

Linux基础(二)--基础的命令ls和date的详细用法

本文中主要介绍了linu系统下一些基础命令的用法,重点介绍了ls和date的用法。1.basename:作用:返回一个字符串参数的基本文件名称。用法:basename PATH例如:basename /usr/share/doc 返回结果为doc2.dirname:作用:返回一…

Caffe中对MNIST执行train操作执行流程解析

之前在 http://blog.csdn.net/fengbingchun/article/details/49849225 中简单介绍过使用Caffe train MNIST的文章,当时只是仿照caffe中的example实现了下,下面说一下执行流程,并精简代码到仅有10余行:1. 先注册所有层&…

华为云垃圾分类AI大赛三强出炉,ModelArts2.0让行业按下AI开发“加速键”

9月20日,华为云人工智能大赛垃圾分类挑战杯决赛在上海世博中心2019华为全联接大会会场顺利举办。经过近两个月赛程的层层筛选,入围决赛阵列的11支战队的高光时刻也如期而至。最终华为云垃圾分类挑战杯三强出炉。本次华为云人工智能大赛垃圾分类挑战杯聚焦…

王坚十年前的坚持,才有了今天世界顶级大数据计算平台MaxCompute...

如果说十年前,王坚创立阿里云让云计算在国内得到了普及,那么王坚带领团队自主研发的大数据计算平台MaxCompute则推动大数据技术向前跨越了一大步。数据是企业的核心资产,但十年前阿里巴巴的算力已经无法满足当时急剧增长数据量的需求。基于Ha…

tomcat简单配置

-----------------------------------------一、前言二、环境三、安装JDK四、安装tomcat五、安装mysql六、安装javacenter七、tomcat后台管理-----------------------------------------一、前言Tomcat是Apache 软件基金会(Apache Software Foundation)的…

使用Caffe进行手写数字识别执行流程解析

之前在 http://blog.csdn.net/fengbingchun/article/details/50987185 中仿照Caffe中的examples实现对手写数字进行识别,这里详细介绍下其执行流程并精简了实现代码,使用Caffe对MNIST数据集进行train的文章可以参考 http://blog.csdn.net/fengbingchun/…

前端也能玩转机器学习?Google Brain 工程师来支招

演讲嘉宾 | 俞玶编辑 | 伍杏玲来源 | CSDN(ID:CSDNnews)导语:9 月 7 日,在CSDN主办的「AI ProCon 2019」上,Google Brain 工程师,TensorFlow.js 项目负责人俞玶发表《TensorFlow.js 遇到小程序》的主题演讲&#xff0c…

mongoDB设置用户名密码的一个要点

2019独角兽企业重金招聘Python工程师标准>>> 增加用户之前, 先选好库 use <库名> #选择admin库后可查看system.users里面的用户数据 db.system.users.find() db.createUser 这个函数填写用户名密码与权限就行了, 在这里设置库的名称没用. 一定要用用use选择好…

基于HTML5的电信网管3D机房监控应用

先上段视频&#xff0c;不是在玩游戏哦&#xff0c;是规规矩矩的电信网管企业应用&#xff0c;嗯&#xff0c;全键盘的漫游3D机房:随着PC端支持HTML5浏览器的普及&#xff0c;加上主流移动终端Android和iOS都已支持HTML5技术&#xff0c;新一代的电信网管应用几乎一致性的首选H…

从原理到实现,详解基于朴素ML思想的协同过滤推荐算法

作者丨gongyouliu编辑丨Zandy来源 | 大数据与人工智能&#xff08;ID: ai-big-data&#xff09;作者在《协同过滤推荐算法》、《矩阵分解推荐算法》这两篇文章中介绍了几种经典的协同过滤推荐算法。我们在本篇文章中会继续介绍三种思路非常简单朴素的协同过滤算法&#xff0c;这…

C++/C++11中引用的使用

引用(reference)是一种复合类型(compound type)。引用为对象起了另外一个名字&#xff0c;引用类型引用(refer to)另外一种类型。通过将声明符写成&d的形式来定义引用类型&#xff0c;其中d是声明的变量名。 一、一般引用&#xff1a;一般在初始化变量时&#xff0c;初始值…

node.js学习5--------------------- 返回html内容给浏览器

/*** http服务器的搭建,相当于php中的Apache或者java中的tomcat服务器*/ // 导包 const httprequire("http"); const fsrequire("fs"); //创建服务器 /*** 参数是一个回调函数,回调函数2个参数,1个是请求参数,一个是返回参数*/ let serverhttp.createServe…

内核分析阅读笔记

内核分析阅读笔记 include/Linux/stddef.h中macro offsetof define,list: #define offsetof(TYPE,MEMBER) ((size_t) &((TYPE *)0)->MEMBER) offsetof macro对于上述示例的展开剂分析:&((struct example_struct *)0)->list表示当结构example_struct正好在地址0上…

杨强教授力荐,快速部署落地深度学习应用的实践手册

香港科技大学计算机科学与工程学系讲座教授、国际人工智能联合会&#xff08;IJCAI&#xff09;理事会主席&#xff08;2017—2019&#xff09;、深圳前海微众银行首席AI 官 杨强为《深度学习模型及应用详解》一书撰序&#xff0c;他提到现在亟需一本介绍深度学习技术实践的图书…

OpenFace库(Tadas Baltrusaitis)中基于HOG进行正脸人脸检测的测试代码

Tadas Baltrusaitis的OpenFace是一个开源的面部行为分析工具&#xff0c;它的源码可以从https://github.com/TadasBaltrusaitis/OpenFace下载。OpenFace主要包括面部关键点检测(facial landmard detection)、头部姿势估计(head pose estimation)、面部动作单元识别(facial acti…

nginx conf 文件配置

打印输出: location / { default_type text/plain; return 502 "$uri"; } $remode_addr 获取访问者的ID$request_method 判断提交方式 GET POST$http_user_agent 获取浏览器软件 if (条件) {} #if之后要有空格 条件3种写法: 1: 来判断相等,用于字符串比较 …

js中 字符串与Unicode 字符值序列的相互转换

一. 字符串转Unicode 字符值序列 var str "abcdef"; var codeArr []; for(var i0;i<str.length;i){codeArr.push(str.charCodeAt(i)); } console.log(codeArr);-->[97, 98, 99, 100, 101, 102] 二.Unicode 字符值序列转字符串 var str String.fromCharCode…

OpenFace库(Tadas Baltrusaitis)中基于Haar Cascade Classifiers进行人脸检测的测试代码

Tadas Baltrusaitis的OpenFace是一个开源的面部行为分析工具&#xff0c;它的源码可以从 https://github.com/TadasBaltrusaitis/OpenFace 下载。OpenFace主要包括面部关键点检测(facial landmard detection)、头部姿势估计(head pose estimation)、面部动作单元识别(facial a…

Uber提出损失变化分配方法LCA,揭秘神经网络“黑盒”

作者 | Janice Lan,Rosanne Liu等译者 | 清儿爸责编 | 夕颜出品 | AI科技大本营&#xff08;ID: rgznai100&#xff09;【导读】神经网络&#xff08;Neural networks&#xff0c;NN&#xff09;在过去十年来硕果累累&#xff0c;推动了整个行业的机器学习进程。然而&#xff0…

范登读书解读《亲密关系》(婚姻、爱情) 笔记

来源&#xff1a;邀请你看《樊登解读《亲密关系》&#xff08;已婚人士必看&#xff09;》&#xff0c;https://url.cn/5HJvLk5?sfuri 人们在童年的时候始终追寻着两种东西&#xff0c;第一种叫做归属感&#xff0c;第二叫做确认自己的重要性、价值感。 如果再童年的时候缺失这…

myeclipse莫名其妙的问题

2019独角兽企业重金招聘Python工程师标准>>> 怎么刷新&#xff0c;clean项目都不管用&#xff0c;结果删除相应工作空间下的那个项目就行。类似路径如D:\workspace\.metadata\.plugins\org.eclipse.core.resources\.projects 转载于:https://my.oschina.net/u/14488…

数据科学家需要知道的5种图算法

作者&#xff1a;Rahul Agarwal编译&#xff1a;ronghuaiyang来源 | AI公园&#xff08;ID:AI_Paradise&#xff09;【导读】因为图分析是数据科学家的未来。作为数据科学家&#xff0c;我们对pandas、SQL或任何其他关系数据库非常熟悉。我们习惯于将用户的属性以列的形式显示在…

在Windows7/10上快速搭建深度学习框架Caffe开发环境

之前在 http://blog.csdn.net/fengbingchun/article/details/50987353 中介绍过在Windows7上搭建Caffe开发环境的操作步骤&#xff0c;那时caffe的项目是和其它依赖项目分开的&#xff0c;每次换新的PC机时再次重新配置搭建还是很不方便&#xff0c;而且caffe的版本较老&#x…

扫码下单支持同桌单人点餐FAQ

一、使用场景 满足较多商户希望同一桌台&#xff0c;各自点各自的菜品的业态场景&#xff08;例如杭味面馆&#xff0c;黄焖鸡米饭店&#xff0c;面馆等大多数轻快餐店&#xff09; 二、配置步骤及注意事项 管理员后台配置--配置管理--店铺配置--扫码点餐tab页 1、开启扫码下单…

使用photoshop 10.0制作符合社保要求的照片

2019独角兽企业重金招聘Python工程师标准>>> 北京市社保新参统人员照片修制方法 修改目标&#xff1a;照片规格:358像素&#xff08;宽&#xff09;&times;441像素&#xff08;高&#xff09;&#xff0c;分辨率350dpi。 颜色模式:24位RGB真彩色。 储存格式&am…

C++11中std::addressof的使用

C11中的std::addressof获得一个对象的实际地址&#xff0c;即使 operator& 操作符已被重载。它常用于原本要使用 operator& 的地方&#xff0c;它接受一个参数&#xff0c;该参数为要获得地址的那个对象的引用。一般&#xff0c;若operator &()也被重载且不一致的话…

一份职位信息的精准推荐之旅,从AI底层架构说起

整理 | 夕颜出品 | AI科技大本营&#xff08;ID:rgznai100&#xff09;【导读】也许&#xff0c;每天早上你的邮箱中又多了一封职位推荐信息&#xff0c;点开一看&#xff0c;你可能发现这些推荐正合你意&#xff0c;于是按照这些信息&#xff0c;你顺利找到一份符合自己期待的…

Vue.js 生命周期

2019独角兽企业重金招聘Python工程师标准>>> 每个 Vue 实例在被创建之前都要经过一系列的初始化过程 vue在生命周期中有这些状态&#xff0c; beforeCreate,created,beforeMount,mounted,beforeUpdate,updated,beforeDestroy,destroyed。Vue在实例化的过程中&#x…

AX2009取销售订单的税额

直接用以下方法即可&#xff1a; Tax::calcTaxAmount(salesLine.TaxGroup, salesLine.TaxItemGroup, systemDateGet(), salesLine.CurrencyCode, salesParmLine.LineAmount, salesTable.taxModuleType()); salesParmLine.LineAmount&#xff1a;这个直接取的是装箱单或者发票…