用 Java 训练深度学习模型,原来这么简单
作者 | DJL-Keerthan&Lanking
来源 | HelloGitHub
头图 | CSDN下载自东方IC
前言
很长时间以来,Java 都是一个很受企业欢迎的编程语言。得益于丰富的生态以及完善维护的包和框架,Java 拥有着庞大的开发者社区。尽管深度学习应用的不断演进和落地,提供给 Java 开发者的框架和库却十分短缺。现今主要流行的深度学习模型都是用 Python 编译和训练的。对于 Java 开发者而言,如果要进军深度学习界,就需要重新学习并接受一门新的编程语言同时还要学习深度学习的复杂知识。这使得大部分 Java 开发者学习和转型深度学习开发变得困难重重。
为了减少 Java 开发者学习深度学习的成本,AWS 构建了 Deep Java Library (DJL),一个为 Java 开发者定制的开源深度学习框架。它为 Java 开发者对接主流深度学习框架提供了一个桥梁。
在这篇文章中,我们会尝试用 DJL 构建一个深度学习模型并用它训练 MNIST 手写数字识别任务。
什么是深度学习?
在我们正式开始之前,我们先来了解一下机器学习和深度学习的基本概念。
机器学习是一个通过利用统计学知识,将数据输入到计算机中进行训练并完成特定目标任务的过程。这种归纳学习的方法可以让计算机学习一些特征并进行一系列复杂的任务,比如识别照片中的物体。由于需要写复杂的逻辑以及测量标准,这些任务在传统计算科学领域中很难实现。
深度学习是机器学习的一个分支,主要侧重于对于人工神经网络的开发。人工神经网络是通过研究人脑如何学习和实现目标的过程中归纳而得出一套计算逻辑。它通过模拟部分人脑神经间信息传递的过程,从而实现各类复杂的任务。深度学习中的“深度”来源于我们会在人工神经网络中编织构建出许多层(layer)从而进一步对数据信息进行更深层的传导。深度学习技术应用范围十分广泛,现在被用来做目标检测、动作识别、机器翻译、语意分析等各类现实应用中。
训练 MNIST 手写数字识别
3.1 项目配置
你可以用如下的 gradle 配置来引入依赖项。在这个案例中,我们用 DJL 的 api 包 (核心 DJL 组件) 和 basicdataset 包 (DJL 数据集) 来构建神经网络和数据集。这个案例中我们使用了 MXNet 作为深度学习引擎,所以我们会引入 mxnet-engine 和 mxnet-native-auto 两个包。这个案例也可以运行在 PyTorch 引擎下,只需要替换成对应的软件包即可。
plugins {id 'java'}repositories { jcenter()}dependencies {implementation platform("ai.djl:bom:0.8.0")implementation "ai.djl:api"implementation "ai.djl:basicdataset"// MXNetruntimeOnly "ai.djl.mxnet:mxnet-engine"runtimeOnly "ai.djl.mxnet:mxnet-native-auto"}
3.2 NDArray 和 NDManager
NDArray 是 DJL 存储数据结构和数学运算的基本结构。一个 NDArray 表达了一个定长的多维数组。NDArray 的使用方法类似于 Python 中的 numpy.ndarray。
NDManager 是 NDArray 的老板。它负责管理 NDArray 的产生和回收过程,这样可以帮助我们更好的对 Java 内存进行优化。每一个 NDArray 都会是由一个 NDManager 创造出来,同时它们会在 NDManager 关闭时一同关闭。NDManager 和 NDArray 都是由 Java 的 AutoClosable 构建,这样可以确保在运行结束时及时进行回收。
Model
在 DJL 中,训练和推理都是从 Model class 开始构建的。我们在这里主要讲训练过程中的构建方法。下面我们为 Model 创建一个新的目标。因为 Model 也是继承了 AutoClosable 结构体,我们会用一个 try block 实现:
try (Model model = Model.newInstance()) { ...// 主体训练代码 ...}
准备数据
MNIST(Modified National Institute of Standards and Technology)数据库包含大量手写数字的图,通常被用来训练图像处理系统。DJL 已经将 MNIST 的数据集收录到了 basicdataset 数据集里,每个 MNIST 的图的大小是 28 x 28。如果你有自己的数据集,你也可以通过 DJL 数据集导入教程来导入数据集到你的训练任务中。
数据集导入教程:
http://docs.djl.ai/docs/development/how_to_use_dataset.html#how-to-create-your-own-dataset
int batchSize = 32; // 批大小Mnist trainingDataset = Mnist.builder() .optUsage(Usage.TRAIN) // 训练集 .setSampling(batchSize, true) .build();Mnist validationDataset = Mnist.builder() .optUsage(Usage.TEST) // 验证集 .setSampling(batchSize, true) .build();
这段代码分别制作出了训练和验证集。同时我们也随机排列了数据集从而更好的训练。除了这些配置以外,你也可以添加对于图片的进一步处理,比如设置图片大小,对图片进行归一化等处理。
制作 model(建立 Block)
当你的数据集准备就绪后,我们就可以构建神经网络了。在 DJL 中,神经网络是由 Block(代码块)构成的。一个 Block 是一个具备多种神经网络特性的结构。它们可以代表 一个操作, 神经网络的一部分,甚至是一个完整的神经网络。然后 Block 可以顺序执行或者并行。同时 Block 本身也可以带参数和子 Block。这种嵌套结构可以帮助我们构造一个复杂但又不失维护性的神经网络。在训练过程中,每个 Block 中附带的参数会被实时更新,同时也包括它们的各个子 Block。这种递归更新的过程可以确保整个神经网络得到充分训练。
当我们构建这些 Block 的过程中,最简单的方式就是将它们一个一个的嵌套起来。直接使用准备好 DJL 的 Block 种类,我们就可以快速制作出各类神经网络。
根据几种基本的神经网络工作模式,我们提供了几种 Block 的变体。SequentialBlock 是为了应对顺序执行每一个子 Block 构造而成的。它会将前一个子 Block 的输出作为下一个 Block 的输入 继续执行到底。与之对应的,是 ParallelBlock 它用于将一个输入并行输入到每一个子 Block 中,同时将输出结果根据特定的合并方程合并起来。最后我们说一下 LambdaBlock,它是帮助用户进行快速操作的一个 Block,其中并不具备任何参数,所以也没有任何部分在训练过程中更新。
我们来尝试创建一个基本的 多层感知机(MLP)神经网络吧。多层感知机是一个简单的前向型神经网络,它只包含了几个全连接层 (LinearBlock)。那么构建这个网络,我们可以直接使用 SequentialBlock。
int input = 28 * 28; // 输入层大小int output = 10; // 输出层大小int[] hidden = new int[] {128, 64}; // 隐藏层大小SequentialBlock sequentialBlock = new SequentialBlock();sequentialBlock.add(Blocks.batchFlattenBlock(input));for (int hiddenSize : hidden) {// 全连接层 sequentialBlock.add(Linear.builder().setUnits(hiddenSize).build());// 激活函数 sequentialBlock.add(activation);}sequentialBlock.add(Linear.builder().setUnits(output).build());
当然 DJL 也提供了直接就可以拿来用的 MLP Block :
Block block = new Mlp( Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH, Mnist.NUM_CLASSES,new int[] {128, 64});
训练
当我们准备好数据集和神经网络之后,就可以开始训练模型了。在深度学习中,一般会由下面几步来完成一个训练过程:
初始化:我们会对每一个 Block 的参数进行初始化,初始化每个参数的函数都是由 设定的 Initializer 决定的。
前向传播:这一步将输入数据在神经网络中逐层传递,然后产生输出数据。
计算损失:我们会根据特定的损失函数 Loss 来计算输出和标记结果的偏差。
反向传播:在这一步中,你可以利用损失反向求导算出每一个参数的梯度。
更新权重:我们会根据选择的优化器(Optimizer)更新每一个在 Block 上参数的值。
DJL 利用了 Trainer 结构体精简了整个过程。开发者只需要创建 Trainer 并指定对应的 Initializer、Loss 和 Optimizer 即可。这些参数都是由 TrainingConfig 设定的。下面我们来看一下具体的参数设置:
TrainingListener:这个是对训练过程设定的监听器。它可以实时反馈每个阶段的训练结果。这些结果可以用于记录训练过程或者帮助 debug 神经网络训练过程中的问题。用户也可以定制自己的 TrainingListener 来对训练过程进行监听。
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) .addEvaluator(new Accuracy()) .addTrainingListeners(TrainingListener.Defaults.logging());try (Trainer trainer = model.newTrainer(config)){// 训练代码}
当训练器产生后,我们可以定义输入的 Shape。之后就可以调用 fit 函数来进行训练。fit 函数会对输入数据,训练多个 epoch 是并最终将结果存储在本地目录下。
/* * MNIST 包含 28x28 灰度图片并导入成 28 * 28 NDArray。 * 第一个维度是批大小, 在这里我们设置批大小为 1 用于初始化。 */Shape inputShape = new Shape(1, Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH);int numEpoch = 5;String outputDir = "/build/model";
// 用输入初始化 trainertrainer.initialize(inputShape);
TrainingUtils.fit(trainer, numEpoch, trainingSet, validateSet, outputDir, "mlp");
这就是训练过程的全部流程了!用 DJL 训练是不是还是很轻松的?之后看一下输出每一步的训练结果。如果你用了我们默认的监听器,那么输出是类似于下图:
[INFO ] - Downloading libmxnet.dylib ...[INFO ] - Training on: cpu().[INFO ] - Load MXNet Engine Version 1.7.0 in 0.131 ms.Training: 100% |████████████████████████████████████████| Accuracy: 0.93, SoftmaxCrossEntropyLoss: 0.24, speed: 1235.20 items/secValidating: 100% |████████████████████████████████████████|[INFO ] - Epoch 1 finished.[INFO ] - Train: Accuracy: 0.93, SoftmaxCrossEntropyLoss: 0.24[INFO ] - Validate: Accuracy: 0.95, SoftmaxCrossEntropyLoss: 0.14Training: 100% |████████████████████████████████████████| Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.10, speed: 2851.06 items/secValidating: 100% |████████████████████████████████████████|[INFO ] - Epoch 2 finished.NG [1m 41s][INFO ] - Train: Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.10[INFO ] - Validate: Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.09[INFO ] - train P50: 12.756 ms, P90: 21.044 ms[INFO ] - forward P50: 0.375 ms, P90: 0.607 ms[INFO ] - training-metrics P50: 0.021 ms, P90: 0.034 ms[INFO ] - backward P50: 0.608 ms, P90: 0.973 ms[INFO ] - step P50: 0.543 ms, P90: 0.869 ms[INFO ] - epoch P50: 35.989 s, P90: 35.989 s
当训练结果完成后,我们可以用刚才的模型进行推理来识别手写数字。如果刚才的内容哪里有不是很清楚的,可以参照下面两个链接直接尝试训练。
手写数据集训练:
https://docs.djl.ai/examples/docs/train_mnist_mlp.html
手写数据集推理:
https://docs.djl.ai/jupyter/tutorial/03_image_classification_with_your_model.html
最后
在这个文章中,我们介绍了深度学习的基本概念,同时还有如何优雅的利用 DJL 构建深度学习模型并进行训练。DJL 也提供了更加多样的数据集和神经网络。
Deep Java Library(DJL)是一个基于 Java 的深度学习框架,同时支持训练以及推理。DJL 博取众长,构建在多个深度学习框架之上 (TenserFlow、PyTorch、MXNet 等) 也同时具备多个框架的优良特性。你可以轻松使用 DJL 来进行训练然后部署你的模型。
它同时拥有着强大的模型库支持:只需一行便可以轻松读取各种预训练的模型。现在 DJL 的模型库同时支持高达 70 多个来自 GluonCV、 HuggingFace、TorchHub 以及 Keras 的模型。
项目地址:https://github.com/awslabs/djl/
更多精彩推荐
九问国产操作系统,九大掌门人首次同台激辩
一文读懂机器学习“数据中毒”
性能超越图神经网络,将标签传递和简单模型结合实现SOTA
小霸王被申请破产重整;虎牙员工自曝被HR抬出公司;Office 2010被微软终止服务|极客头条
Unity “出圈”:游戏引擎的技术革新和跨界商机
相关文章:

重装操作系统的20条原则(转载)
系统是否需重装,三条法则帮你忙: 如果系统出现以下三种情况之一,应该是你考虑重装系统的时候了: 1)系统运行效率变得低下,垃圾文件充斥硬盘且散乱分布又不便于集中清理和自动清理; 2)系统频繁出错&…

RKLayout
2019独角兽企业重金招聘Python工程师标准>>> RKLayout 是 iOS 上一个简单的布局管理器 转载:http://www.adobex.com/ios/source/details/00000978.htm 转载于:https://my.oschina.net/u/868244/blog/107107

【网络流24题】最小路径覆盖问题
【题目】1738: 最小路径覆盖问题 【题解】网络流 关于输出路径,因为即使有反向弧经过左侧点也一定会改变左侧点的去向,若没连向右侧就会被更新到0,所以不用在意。 mark记录有入度的右侧点,然后从没入度的右侧点开始把整条路径输出…

Windows的端口列表(转载)
按端口号可分为3大类: (1)公认端口(Well Known Ports):从0到1023,它们紧密绑定(binding)于一些服务。通常这些端口的通讯明确表明了某种服务的协议。例如:80端…
有了图分析,可解释的AI还远吗?
GraphAI 更多新可能随着深度学习、机器学习等人工智能技术的逐级深入,企业对挖掘大数据的关联性去探索“隐藏”在背后的商业价值提出了更高的要求。尤其是,新一代人工智能技术正从“感知智能”迈向“认知智能”,让机器实现“理解、推理、决策…

智能公交GPS
蓝斯通信推出LZ8713B*-X车载终端。这款设备是按交通部JT/T794-2011《道路运输车辆卫星定位系统车载终端技术要求 》和JT/T808-2011《道路运输车辆卫星定位系统车载终端通讯协议及数据格式》技术标准进行设计的。 产品集GPS(可选配BD2双模模块)…
数据化管理在餐饮业中的应用
一、为什么要重视数据化运营和管理? “从经营到管理,管理方向需要数据灯塔” 餐饮市场和社会各业具有相似之处,也有很明确的本质不同。 1、首先,餐饮市场不像电信、石油市场是垄断性的,餐饮市场充分透明,符…
1 元秒杀 1000+ 册爆款电子书,错过再等一年!
wow代码人们让钱包瑟瑟发抖的双十一已经来啦与此同时码不停蹄地向你奔赴而来的还有 CSDN 为你准备的???? 1 元秒杀 ????价值 3.5 万元的爆款电子书限时特惠,仅需 1 元你,准备好了吗仅限 500 人速领????????????错过悔10年系列好书一

关于XP进程问题(转载)
Smss.exe 会话管理子系统,它负责启动用户会话。这个进程是通过系统进程来初始化的,包括对已经正在运行的Winlogon,Win32(Csrss.exe)线程和设定的系统变量作出反映。在它启动这些进程后,它等待Winlogon或者C…

NoticeView
2019独角兽企业重金招聘Python工程师标准>>> NoticeView 是 iOS 的消息提醒组件,类似 TweetBot 的提醒。 转载:http://www.adobex.com/ios/source/details/00001068.htm 转载于:https://my.oschina.net/u/868244/blog/107344

elasticsearch 分片恢复经历了哪些步骤?
why 服务重启,或者与集群断网重连时,需要和集群当前的主分片的数据保持一致。 how 上图中,RecoverTarget 代表加入集群前想要同步数据的分片,RecoverSource代表当前集群中的正常分片。 同步过程本质上来说,就是通过拷贝…

Java 事件适配器 Adapter
事件适配器Adapters 在上一篇文章中: http://www.cnblogs.com/mengdd/archive/2013/02/06/2908241.html 第二个例子中,可以看到要实现相应的事件监听器接口,就必须实现其中的所有方法。 有的接口中包含多个方法(多个事件处理器&am…
Facebook面经全披露,我是怎么拿到机器学习工程师offer的?
作者 | Rahul Agarwal翻译 | Katie,责编 | 晋兆雨出品 | AI科技大本营头图 | 付费下载于视觉中国去年八月,我正在接受面试。那时,我已经分别接受Google India和Amazon India的机器学习和数据科学职位面试。然后我的上级建议我申请Facebook伦敦…

内存性能参数详解(转载)
内存性能参数详解 先说说最有效提高你机器内存性能的几个参数:CL,TRP,TRCD CAS Latency “列地址选通脉冲潜伏期” BIOS中可能的其他描述为:tCL、CAS Latency Time、CAS Timing Delay,这个值一般是1.5~3之间࿰…

一些关于Hibernate延迟加载的误区
最近面试别人,正好出的笔试题中有道关于Hibernate延迟加载的问题,聊天过程中发现很多人对Hibernate的延迟加载有些理解误区,写 些东东在这里,希望对大家有所帮助。 首先是第一个误区:延迟加载只能作用于关联实体看到这…

Java单元测试与Jutil详解(一) 简介
1.什么是单元测试 单元测试(unit testing),是指对软件中的最小可测试单元进行检查和验证。对于单元测试中单元的含义,Java里单元指一个类。总的来说,单元就是人为规定的最小的被测功能模块。单元测试是在软件开发过程中…
反转!BAT编程吸金榜来了,AI程序员刷爆了......
从2017年开始,人工智能便波澜不断,无论是从BAT高调布局AI,还是从年薪80万招聘AI应届生,炽手可热形容AI工程师一点都不过分。百度推出“少帅计划”,针对30岁以下的深度学习科学家,开出100万以上年薪!阿里巴巴…

Windows启动文件
Windows启动文件 Files Used in the Windows 2000 Boot Process FileLocationBoot stageNtldr System partition root (C:/ )Preboot and bootBoot.iniSystem partition rootBootBootsect.dosSystem partition rootBoot (optional)Ntdetect.com System partition rootBootNtboo…

Sublime Text 3 个人使用总结
待更新 Sublime Text 3\Packages\FileHeader\template\header转载于:https://www.cnblogs.com/yourstars/p/6739965.html

破解出cmos密码(转载)
----CMOS (Award)密码简介与破解0--3法---- 计算机启动时,由存放在主板ROM中的bios将cmos数据调入内存中,以实现控制系统。 其中,Award主板上的一小块RAM用于存放CMOS数据,地址为00-7F的共128个字节中。 当中的字节 1c和1d存放的就…
NLP实战:利用Python理解、分析和生成文本 | 赠书
导读:本文内容参考自《自然语言处理实战:利用Python理解、分析和生成文本》一书,由Hobson Lane等人所著。本书是介绍自然语言处理(NLP)和深度学习的实战书。NLP已成为深度学习的核心应用领域,而深度学习是N…

Servlet入门 代码
1. 第一个Servlet程序 package com.allanlxf.serv.basic; import javax.servlet.*; import java.io.*; public class TimeServlet implements Servlet {private ServletConfig config;public TimeServlet(){System.out.println("TimeServlet()");}public void init(S…

统计学习方法:朴素贝叶斯
作者:桂。 时间:2017-04-20 18:31:37 链接:http://www.cnblogs.com/xingshansi/p/6740308.html 前言 本文为《统计学习方法》第四章:朴素贝叶斯(naive bayes),主要是借助先验知识统计估计&…

Windows自动启动程序的十大藏身之所(转载)
Windows自动启动程序的十大藏身之所 Windows启动时通常会有一大堆程序自动启动。不要以为管好了“开始→程序→启动”菜单就万事大吉,实际上,在Windows XP/2K中,让Windows自动启动程序的办法很多,下文告诉你最重要的两个文件夹和八…
警惕!银行风控模型或将“摇身一变”,成为风险缔造者
作者 | 祝世虎来源 | 现代金融风险管理头图 | CSDN下载自视觉中国2011年,美联储发布了《模型风险管理监督指南(SR11-7)》(《SRLetter 11-7: Supervisory Guidance on Model Risk Management》),该指南逐步成…

Spring注解注入
spring注入方式-----注解注入(1)操作:首先在要注入的类前面加上:Component(与后面三个是等价的)Repository(持久层),Service业务层,Controller和控制层应为不能自动识别某个类是否是持久层,业务…

zip 的压缩原理与实现
http://www.blueidea.com/bbs/newsdetail.asp?id1819267&page2&posts&Daysprune5&lp1无损数据压缩是一件奇妙的事情,想一想,一串任意的数据能够根据一定的规则转换成只有原来 1/2 - 1/5 长度的数据,并且能够按照相应的规则还…
上海交大发布 MedMNIST 医学图像分析数据集 新基准
来源 | HyperAI超神经责编 | 晋兆雨头图 | 付费下载于视觉中国内容概要:医学图像分析是一个非常复杂的跨学科领域,近日上海交通大学发布了 MedMNIST 数据集,有望促进医学图像分析的发展。关键词:医学图像分析 公开数据集令人头秃…

VS 2010中对WPF4有哪些多点触摸支持?
随着多点触摸输入和操作处理支持的引进, WPF 4提供了一个极棒的方式,可在Windows 7中使你的客户端应用大放光彩,新的特性包括:UIElement上的多点触摸操作、惯性(漫游(Pan)、缩放(Zoo…

业务组件架构的思考
在iOS开发中,我们接触比较多的是MVC架构,下面我们先来分析一下MVC架构。 1.MVC MVC是一种软件架构模式,在1978年由Trygve Reenskaug提出,它把软件系统分为三个基本部分:模型(Model)、视图&#…