XGBoost缺失值引发的问题及其深度分析 | CSDN博文精选
作者 | 兆军(美团配送事业部算法平台团队技术专家)
来源 | 美团技术团队
(*点击阅读原文,查看美团技术团队更多文章)
背景
XGBoost模型作为机器学习中的一大“杀器”,被广泛应用于数据科学竞赛和工业领域,XGBoost官方也提供了可运行于各种平台和环境的对应代码,如适用于Spark分布式训练的XGBoost on Spark。然而,在XGBoost on Spark的官方实现中,却存在一个因XGBoost缺失值和Spark稀疏表示机制而带来的不稳定问题。
事情起源于美团内部某机器学习平台使用方同学的反馈,在该平台上训练出的XGBoost模型,使用同一个模型、同一份测试数据,在本地调用(Java引擎)与平台(Spark引擎)计算的结果不一致。但是该同学在本地运行两种引擎(Python引擎和Java引擎)进行测试,两者的执行结果是一致的。因此质疑平台的XGBoost预测结果会不会有问题?
该平台对XGBoost模型进行过多次定向优化,在XGBoost模型测试时,并没有出现过本地调用(Java引擎)与平台(Spark引擎)计算结果不一致的情形。而且平台上运行的版本,和该同学本地使用的版本,都来源于Dmlc的官方版本,JNI底层调用的应该是同一份代码,理论上,结果应该是完全一致的,但实际中却不同。
从该同学给出的测试代码上,并没有发现什么问题:
//测试结果中的一行,41列
double[] input = new double[]{1, 2, 5, 0, 0, 6.666666666666667, 31.14, 29.28, 0, 1.303333, 2.8555, 2.37, 701, 463, 3.989, 3.85, 14400.5, 15.79, 11.45, 0.915, 7.05, 5.5, 0.023333, 0.0365, 0.0275, 0.123333, 0.4645, 0.12, 15.082, 14.48, 0, 31.8425, 29.1, 7.7325, 3, 5.88, 1.08, 0, 0, 0, 32];
//转化为float[]
float[] testInput = new float[input.length];
for(int i = 0, total = input.length; i < total; i++){testInput[i] = new Double(input[i]).floatValue();
}
//加载模型
Booster booster = XGBoost.loadModel("${model}");
//转为DMatrix,一行,41列
DMatrix testMat = new DMatrix(testInput, 1, 41);
//调用模型
float[][] predicts = booster.predict(testMat);
上述代码在本地执行的结果是333.67892,而平台上执行的结果却是328.1694030761719。
两次结果怎么会不一样,问题出现在哪里呢?
执行结果不一致问题排查历程
如何排查?首先想到排查方向就是,两种处理方式中输入的字段类型会不会不一致。如果两种输入中字段类型不一致,或者小数精度不同,那结果出现不同就是可解释的了。仔细分析模型的输入,注意到数组中有一个6.666666666666667,是不是它的原因?
一个个Debug仔细比对两侧的输入数据及其字段类型,完全一致。
这就排除了两种方式处理时,字段类型和精度不一致的问题。
第二个排查思路是,XGBoost on Spark按照模型的功能,提供了XGBoostClassifier和XGBoostRegressor两个上层API,这两个上层API在JNI的基础上,加入了很多超参数,封装了很多上层能力。会不会是在这两种封装过程中,新加入的某些超参数对输入结果有着特殊的处理,从而导致结果不一致?
与反馈此问题的同学沟通后得知,其Python代码中设置的超参数与平台设置的完全一致。仔细检查XGBoostClassifier和XGBoostRegressor的源代码,两者对输出结果并没有做任何特殊处理。
再次排除了XGBoost on Spark超参数封装问题。
再一次检查模型的输入,这次的排查思路是,检查一下模型的输入中有没有特殊的数值,比方说,NaN、-1、0等。果然,输入数组中有好几个0出现,会不会是因为缺失值处理的问题?
快速找到两个引擎的源码,发现两者对缺失值的处理真的不一致!
XGBoost4j中缺失值的处理
XGBoost4j缺失值的处理过程发生在构造DMatrix过程中,默认将0.0f设置为缺失值:
/*** create DMatrix from dense matrix** @param data data values* @param nrow number of rows* @param ncol number of columns* @throws XGBoostError native error*/public DMatrix(float[] data, int nrow, int ncol) throws XGBoostError {long[] out = new long[1];//0.0f作为missing的值XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, 0.0f, out));handle = out[0];}
XGBoost on Spark中缺失值的处理
而XGBoost on Spark将NaN作为默认的缺失值。
scala/*** @return A tuple of the booster and the metrics used to build training summary*/@throws(classOf[XGBoostError])def trainDistributed(trainingDataIn: RDD[XGBLabeledPoint],params: Map[String, Any],round: Int,nWorkers: Int,obj: ObjectiveTrait = null,eval: EvalTrait = null,useExternalMemory: Boolean = false,//NaN作为missing的值missing: Float = Float.NaN,hasGroup: Boolean = false): (Booster, Map[String, Array[Float]]) = {//...}
也就是说,本地Java调用构造DMatrix时,如果不设置缺失值,默认值0被当作缺失值进行处理。而在XGBoost on Spark中,默认NaN会被为缺失值。原来Java引擎和XGBoost on Spark引擎默认的缺失值并不一样。而平台和该同学调用时,都没有设置缺失值,造成两个引擎执行结果不一致的原因,就是因为缺失值不一致!
修改测试代码,在Java引擎代码上设置缺失值为NaN,执行结果为328.1694,与平台计算结果完全一致。
//测试结果中的一行,41列double[] input = new double[]{1, 2, 5, 0, 0, 6.666666666666667, 31.14, 29.28, 0, 1.303333, 2.8555, 2.37, 701, 463, 3.989, 3.85, 14400.5, 15.79, 11.45, 0.915, 7.05, 5.5, 0.023333, 0.0365, 0.0275, 0.123333, 0.4645, 0.12, 15.082, 14.48, 0, 31.8425, 29.1, 7.7325, 3, 5.88, 1.08, 0, 0, 0, 32];float[] testInput = new float[input.length];for(int i = 0, total = input.length; i < total; i++){testInput[i] = new Double(input[i]).floatValue();}Booster booster = XGBoost.loadModel("${model}");//一行,41列DMatrix testMat = new DMatrix(testInput, 1, 41, Float.NaN);float[][] predicts = booster.predict(testMat);
XGBoost on Spark源码中缺失值引入的不稳定问题
然而,事情并没有这么简单。
Spark ML中还有隐藏的缺失值处理逻辑:SparseVector,即稀疏向量。
SparseVector和DenseVector都用于表示一个向量,两者之间仅仅是存储结构的不同。
其中,DenseVector就是普通的Vector存储,按序存储Vector中的每一个值。
而SparseVector是稀疏的表示,用于向量中0值非常多场景下数据的存储。
SparseVector的存储方式是:仅仅记录所有非0值,忽略掉所有0值。具体来说,用一个数组记录所有非0值的位置,另一个数组记录上述位置所对应的数值。有了上述两个数组,再加上当前向量的总长度,即可将原始的数组还原回来。
因此,对于0值非常多的一组数据,SparseVector能大幅节省存储空间。
SparseVector存储示例见下图:
如上图所示,SparseVector中不保存数组中值为0的部分,仅仅记录非0值。因此对于值为0的位置其实不占用存储空间。下述代码是Spark ML中VectorAssembler的实现代码,从代码中可见,如果数值是0,在SparseVector中是不进行记录的。
scalaprivate[feature] def assemble(vv: Any*): Vector = {val indices = ArrayBuilder.make[Int]val values = ArrayBuilder.make[Double]var cur = 0vv.foreach {case v: Double =>//0不进行保存if (v != 0.0) {indices += curvalues += v}cur += 1case vec: Vector =>vec.foreachActive { case (i, v) =>//0不进行保存if (v != 0.0) {indices += cur + ivalues += v}}cur += vec.sizecase null =>throw new SparkException("Values to assemble cannot be null.")case o =>throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.")}Vectors.sparse(cur, indices.result(), values.result()).compressed}
不占用存储空间的值,也是某种意义上的一种缺失值。SparseVector作为Spark ML中的数组的保存格式,被所有的算法组件使用,包括XGBoost on Spark。而事实上XGBoost on Spark也的确将Sparse Vector中的0值直接当作缺失值进行处理:
scalaval instances: RDD[XGBLabeledPoint] = dataset.select(col($(featuresCol)),col($(labelCol)).cast(FloatType),baseMargin.cast(FloatType),weight.cast(FloatType)).rdd.map { case Row(features: Vector, label: Float, baseMargin: Float, weight: Float) =>val (indices, values) = features match {//SparseVector格式,仅仅将非0的值放入XGBoost计算case v: SparseVector => (v.indices, v.values.map(_.toFloat))case v: DenseVector => (null, v.values.map(_.toFloat))}XGBLabeledPoint(label, indices, values, baseMargin = baseMargin, weight = weight)}
XGBoost on Spark将SparseVector中的0值作为缺失值为什么会引入不稳定的问题呢?
重点来了,Spark ML中对Vector类型的存储是有优化的,它会自动根据Vector数组中的内容选择是存储为SparseVector,还是DenseVector。也就是说,一个Vector类型的字段,在Spark保存时,同一列会有两种保存格式:SparseVector和DenseVector。而且对于一份数据中的某一列,两种格式是同时存在的,有些行是Sparse表示,有些行是Dense表示。选择使用哪种格式表示通过下述代码计算得到:
scala/*** Returns a vector in either dense or sparse format, whichever uses less storage.*/@Since("2.0.0")def compressed: Vector = {val nnz = numNonzeros// A dense vector needs 8 * size + 8 bytes, while a sparse vector needs 12 * nnz + 20 bytes.if (1.5 * (nnz + 1.0) < size) {toSparse} else {toDense}}
在XGBoost on Spark场景下,默认将Float.NaN作为缺失值。如果数据集中的某一行存储结构是DenseVector,实际执行时,该行的缺失值是Float.NaN。而如果数据集中的某一行存储结构是SparseVector,由于XGBoost on Spark仅仅使用了SparseVector中的非0值,也就导致该行数据的缺失值是Float.NaN和0。
也就是说,如果数据集中某一行数据适合存储为DenseVector,则XGBoost处理时,该行的缺失值为Float.NaN。而如果该行数据适合存储为SparseVector,则XGBoost处理时,该行的缺失值为Float.NaN和0。
即,数据集中一部分数据会以Float.NaN和0作为缺失值,另一部分数据会以Float.NaN作为缺失值! 也就是说在XGBoost on Spark中,0值会因为底层数据存储结构的不同,同时会有两种含义,而底层的存储结构是完全由数据集决定的。
因为线上Serving时,只能设置一个缺失值,因此被选为SparseVector格式的测试集,可能会导致线上Serving时,计算结果与期望结果不符。
问题解决
查了一下XGBoost on Spark的最新源码,依然没解决这个问题。
赶紧把这个问题反馈给XGBoost on Spark, 同时修改了我们自己的XGBoost on Spark代码。
scalaval instances: RDD[XGBLabeledPoint] = dataset.select(col($(featuresCol)),col($(labelCol)).cast(FloatType),baseMargin.cast(FloatType),weight.cast(FloatType)).rdd.map { case Row(features: Vector, label: Float, baseMargin: Float, weight: Float) =>//这里需要对原来代码的返回格式进行修改val values = features match {//SparseVector的数据,先转成Densecase v: SparseVector => v.toArray.map(_.toFloat)case v: DenseVector => v.values.map(_.toFloat)}XGBLabeledPoint(label, null, values, baseMargin = baseMargin, weight = weight)}
scala/*** Converts a [[Vector]] to a data point with a dummy label.** This is needed for constructing a [[ml.dmlc.xgboost4j.scala.DMatrix]]* for prediction.*/def asXGB: XGBLabeledPoint = v match {case v: DenseVector =>XGBLabeledPoint(0.0f, null, v.values.map(_.toFloat))case v: SparseVector =>//SparseVector的数据,先转成DenseXGBLabeledPoint(0.0f, null, v.toArray.map(_.toFloat))}
问题得到解决,而且用新代码训练出来的模型,评价指标还会有些许提升,也算是意外之喜。
希望本文对遇到XGBoost缺失值问题的同学能够有所帮助,也欢迎大家一起交流讨论。
技术的道路一个人走着极为艰难?
一身的本领得不施展?
优质的文章得不到曝光?
别担心,
即刻起,CSDN 将为你带来创新创造创变展现的大舞台,
扫描下方二维码,欢迎加入 CSDN 「原力计划」!
(*本文为AI科技大本营转载文章,转载请联系原作者)
◆
精彩公开课
◆
推荐阅读
滴滴开源在2019:十大重点项目盘点,DoKit客户端研发助手首破1万Star
你的 App 在 iOS 13 上被卡死了吗
12306 回应软件崩了;微信发布新版本,朋友圈可“斗图”;Ant Design 3.26.4 发布 | 极客头条
2019 最烂密码排行榜大曝光!网友:已中招!
如何用Redis实现微博关注关系?
扎心了!互联网公司福利缩水指南
“对不起,我们只招有出色背景的技术人员!”
2019中国区块链开发者大会圆满落幕!10大烧脑核心技术演讲干货全送上!
你点的每个“在看”,我都认真当成了AI
相关文章:

什么是CPI指数和GDP
即消费者物价指数(Consumer Price Index),英文缩写为CPI,是反映与居民生活有关的产品及劳务价格统计出来的物价变动指标,通常作为观察通货膨胀水平的重要指标。如果消费者物价指数升幅过大,表明通胀已经成为经济不稳定因素&#x…

The Ultimate Guide To iPhone Resolutions
2019独角兽企业重金招聘Python工程师标准>>> ios 屏幕尺寸 像素 等说明 转载于:https://my.oschina.net/starmier/blog/467271
DllMain中不当操作导致死锁问题的分析--进程对DllMain函数的调用规律的研究和分析
不知道大家是否思考过一个过程:系统试图运行我们写的程序,它是怎么知道程序起始位置的?很多同学想到,我们在编写程序时有个函数,类似Main这样的名字。是的!这就是系统给我们提供的控制程序最开始的地方&…
力挺Python!同是程序员,为啥同事年前就实现了财务自由?
人红是非多,最近Python就遇到了这样的问题。与技术社区上一片「形势大好」对比鲜明的是,国内技术圈却一直存在对Python,「力挺」和「吃瓜」两派阵营,针锋相对,那么,Python到底有没有用,真相究竟…

C# 判断远程文件是否存在
#region 判断远程文件是否存在/// <summary>/// 判断远程文件是否存在/// </summary>/// <param name"fileUrl"></param>/// <returns></returns>public static bool RemoteFileExists(string fileUrl){HttpWebRequest re null…

DllMain中不当操作导致死锁问题的分析--导致DllMain中死锁的关键隐藏因子
有了前面两节的基础,我们现在切入正题:研究下DllMain为什么会因为不当操作导致死锁的问题。首先我们看一段比较经典的“DllMain中死锁”代码。(转载请指明出于breaksoftware的csdn博客) //主线程中 HMODULE h LoadLibraryA(strD…
性能超FPN!北大、阿里等提多层特征金字塔网络
作者 | Qijie Zhao等编译 | 李杰出品 | AI科技大本营(ID:rgznai100)特征金字塔网络具有处理不同物体尺度变化的能力,因此被广泛应用到one-stage目标检测网络(如DSSD,RetinaNet,RefineDet)和two-…

什么是WIFI
WIFI全称Wireless Fidelity,又称802.11b标准,它的最大优点就是传输速度较高,可以达到11Mbps,另外它的有效距离也很长,同时也与已有的各种802.11DSSS设备兼容。 WIFI是由AP(Access Point)和无线网卡组成的无线网络。…
Android入门——电话拨号器和4种点击事件
关于HelloWorld为,电话拨号程序还AndroidA入门demo,从这个样例我们要理清楚做安卓项目的思路。大体分为三步: 1.理解需求,理清思路 2.设计UI 3.代码实现 电话拨号器 1. 理解需求: *一个文本框——用来接收电话号码 *一个button——用来触发事…

DllMain中不当操作导致死锁问题的分析--导致DllMain中死锁的关键隐藏因子2
本文介绍使用Windbg去验证《DllMain中不当操作导致死锁问题的分析--导致DllMain中死锁的关键隐藏因子》中的结论,调试对象是文中刚开始那个例子。(转载请指明出于breaksoftware的csdn博客) 1 g 让程序运行起来 2 ctrlbreak 中断程序 3 ~ 查看…
从入门到深入:移动平台模型裁剪与优化的技术探索与工程实践
可以看到,通过机器学习技术,软件或服务的功能和体验得到了质的提升。比如,我们甚至可以通过启发式引擎智能地预测并调节云计算分布式系统的节点压力,以此改善服务的弹性和稳定性,这是多么美妙。而对移动平台来说&#…

我在不炎熱也不抑鬱的秋天,依然不抽煙
写过几次电影的观后感,挺过瘾.最近看到my little airport的那张新唱片,再也没有办法保持沉默了 为什么人家的唱片名都起的和小说一样,难得是为了证明听歌的人们都不喜欢动笔吗? 于是,我建了个类别,叫 我也会听歌.很明显,这里面会塞一些和歌相关的东西 这是第一篇

ubuntu安装redis的方法以及PHP安装redis扩展、CI框架sess使用redis的方法
为什么80%的码农都做不了架构师?>>> 再一次被网上那些教程误导后决定自己写一个。真心被那些奇怪的教程误导了好几次,之前研究其它东西的时候也是。蛋疼啊。 安装redis 直接用apt-get命令即可 sudo apt-get install redis-server 安装的时候…

浅谈数据库设计技巧
说到数据库,我认为不能不先谈数据结构。1996年,在我初入大学学习计算机编程时,当时的老师就告诉我们说:计算机程序=数据结构+算法。尽管现在的程序开发已由面向过程为主逐步过渡到面向对象为主,…
避免神经网络过拟合的5种技术(附链接) | CSDN博文精选
作者 | Abhinav Sagar翻译 | 陈超校对 | 王琦来源 | 数据派THU(ID:DatapiTHU)(*点击阅读原文,查看作者更多精彩文章)本文介绍了5种在训练神经网络中避免过拟合的技术。 最近一年我一直致力于深度学习领域。这段时间里,我使用过很多神经网络&a…

DllMain中不当操作导致死锁问题的分析--加载卸载DLL与DllMain死锁的关系
前几篇文章一直没有在源码级证明:DllMain在收到DLL_PROCESS_ATTACH和DLL_PROCESS_DETACH时会进入临界区。这个论证非常重要,因为它是使其他线程不能进入临界区从而导致死锁的关键。我构造了在DLL被映射到进程地址空间的场景,请看死锁时加载DL…

LinearLayout增加divider分割线
2019独角兽企业重金招聘Python工程师标准>>> 在android3.0及后面的版本在LinearLayout里增加了个分割线 android:divider"drawable/shape"<!--分割线图片--> android:showDividers"middle|beginning|end" <!--分割线位置--> 分割线…

JAVA游戏编程之二----j2me MIDlet 手机游戏入门开发--贪吃蛇
作者:雷神 QQ:38929568 QQ群:28048051JAVA游戏编程(满) 28047782(将满) 与前一款扫雷比较,这个游戏多了一个 类,用来显示动画,也是蛇要吃的物品类, 也有了代码…

DllMain中不当操作导致死锁问题的分析——线程中调用GetModuleFileName、GetModuleHandle等导致死锁
之前的几篇文章已经讲解了在DllMain中创建并等待线程导致的死锁的原因。是否还记得,我们分析了半天汇编才知道在线程中的死锁位置。如果对于缺乏调试经验的同学来说,可能发现这个位置有点麻烦。那么本文就介绍几个例子,它们会在线程明显的位置…
如何从菜鸡变成收割机,大厂面试的算法,你懂了吗?
是什么?让大厂面试显得逼格很高,是算法和数据结构吗?是的!!!Google工程师曾总结过,大厂之所以爱考察算法和数据结构是因为:算法能力能够准确辨别一个程序员的技术功底是否扎实&#…

Ejabberd源码解析前奏--配置
一、基本配置 配置文件将在你第一次启动ejabberd时加载,从该文件中获得的内容将被解析并存储到内部的ejabberd数据库中,以后的配置将从数据库加载,并且任何配置文件里的命令都会被添加到数据库里。 需要注意的是:ejabberd从不编辑…

DllMain中不当操作导致死锁问题的分析——DllMain中要谨慎写代码(完结篇)
之前几篇文章主要介绍和分析了为什么会在DllMain做出一些不当操作导致死锁的原因。本文将总结以前文章的结论,并介绍些DllMain中还有哪些操作会导致死锁等问题。(转载请指明出于breaksoftware的csdn博客) DllMain的相关特性 首先列出…
滴滴叶杰平:年运送乘客百亿次,AI如何“服务”出行领域?| BDTC 2019
出品 | AI科技大本营(ID:rgznai100)“如果把北京一天滴滴的轨迹数据放在一起,要覆盖北京所有道路差不多四百次,数据非常大、非常完整。”超5.5亿用户,年运送乘客100亿人次,除了中国地区,滴滴也在…

分析部署无线局域网的关键要素
在部署无线局域网时需要考虑的关键问题包括:确定单个接入点的RF覆盖,保证足够的支持所有用户的容量,以及考虑RF信号损耗因素。 单个AP的覆盖 网络设计师必须通过研究AP的服务范围来决定单个AP的覆盖。数据速率是一种距离函数ÿ…

Delphi调用java开发的WebService,传入参数出错
http://www.cnblogs.com/zhangzhifeng/p/3397053.html 调用没有参数的服务正常,当调用有参数的服务出现以下错误java.util.concurrent.ExecutionException: java.lang.NullPointerException 另外加了RIO.HTTPWebNode.UseUTF8InHeader : True;InvRegistry.RegisterInvokeOptions…
B站收藏6.1w+!这门课拯救你薄弱的计算机基础
作者 | Rocky0429来源 | Python空间大家好,我是 Rocky0429,一个对计算机基础一无所知的蒟蒻...作为一个所谓的计算机科班出身的人来说,特别难为情的是自己的计算机基础很差,比如计算机网络当年一度差点挂掉,多亏当时…

一种不会导致资源泄露的“终止”线程的方法
在项目工程中,我们可能会使用第三方开发的模块。该模块提供一个接口用于完成非常复杂和耗时的工作。我们一般不会将该API放在UI线程中执行,而是启动一个线程,用工作线程去执行这个耗时的操作。(转载请指明出于breaksoftware的csdn…

TCP/IP详解学习笔记(9)-TCP协议概述
终于看到了TCP协议,这是TCP/IP详解里面最重要也是最精彩的部分,要花大力气来读。前面的TFTP和BOOTP都是一些简单的协议,就不写笔记了,写起来也没啥东西。TCP和UDP处在同一层---运输层,但是TCP和UDP最不同的地方是&…

在windows程序中嵌入Lua脚本引擎--使用VS IDE编译Luajit脚本引擎
前些天听到一个需求:某业务方需要我们帮忙清理用户电脑上的一些废弃文件。同事完成这个逻辑的方案便是在我们程序中加入了一个很“独立”的业务逻辑:检索和删除某个程序产生的废弃文件。试想,该“独立”的逻辑之后会如何?被删掉&a…
优酷智能档在大型直播场景下的技术实践
作者 | 阿里文娱高级技术专家 肖文良 本文为阿里文娱高级技术专家肖文良在【阿里文娱2019双11猫晚技术沙龙】中的演讲,主要内容为如何通过优酷智能档,降低用户卡顿尤其是双11直播场景下,提升用户观看体验。具体包括智能档的落地挑战、算法架…