当前位置: 首页 > 编程日记 > 正文

知否?知否?一文看懂深度文本分类之DPCNN原理与代码

【导读】ACL2017年中,腾讯AI-lab提出了Deep Pyramid Convolutional Neural Networks for Text Categorization(DPCNN)。论文中提出了一种基于word-level级别的网络-DPCNN,由于上一篇文章介绍的TextCNN 不能通过卷积获得文本的长距离依赖关系,而论文中DPCNN通过不断加深网络,可以抽取长距离的文本依赖关系。实验证明在不增加太多计算成本的情况下,增加网络深度就可以获得最佳的准确率。‍


作者 | 何从庆

本文经授权转载自 AI算法之心


DPCNN结构

究竟是多么牛逼的网络呢?我们下面来窥探一下模型的芳容。

640


DPCNN结构细节

模型是如何通过加深网络来捕捉文本的长距离依赖关系的呢?下面我们来一一道来。为了更加简单的解释DPCNN,这里我先不解释是什么是Region embedding,我们先把它当作word embedding。


等长卷积

首先交代一下卷积的的一个基本概念。一般常用的卷积有以下三类:

假设输入的序列长度为n,卷积核大小为m,步长(stride)为s,输入序列两端各填补p个零(zero padding),那么该卷积层的输出序列为(n-m+2p)/s+1。

(1) 窄卷积(narrow convolution): 步长s=1,两端不补零,即p=0,卷积后输出长度为n-m+1。

(2) 宽卷积(wide onvolution) :步长s=1,两端补零p=m-1,卷积后输出长度 n+m-1。

(3) 等长卷积(equal-width convolution): 步长s=1,两端补零p=(m-1)/2,卷积后输出长度为n。如下图所示,左右两端同时补零p=1,s=3。


池化

那么DPCNN是如何捕捉长距离依赖的呢?这里我直接引用文章的小标题——Downsampling with the number of feature maps fixed。

作者选择了适当的两层等长卷积来提高词位embedding的表示的丰富性。然后接下来就开始 Downsampling (池化)。再每一个卷积块(两层的等长卷积)后,使用一个size=3和stride=2进行maxpooling进行池化。序列的长度就被压缩成了原来的一半。其能够感知到的文本片段就比之前长了一倍

例如之前是只能感知3个词位长度的信息,经过1/2池化层后就能感知6个词位长度的信息啦,这时把1/2池化层和size=3的卷积层组合起来如图所示。

640


固定feature maps(filters)的数量

为什么要固定feature maps的数量呢?许多模型每当执行池化操作时,增加feature maps的数量,导致总计算复杂度是深度的函数。与此相反,作者对feature map的数量进行了修正,他们实验发现增加feature map的数量只会大大增加计算时间,而没有提高精度。

另外,夕小瑶小姐姐在知乎也详细的解释了为什么要固定feature maps的数量。有兴趣的可以去知乎搜一搜,讲的非常透彻。

固定了feature map的数量,每当使用一个size=3stride=2进行maxpooling进行池化时,每个卷积层的计算时间减半(数据大小减半),从而形成一个金字塔。

640

这就是论文题目所谓的 Pyramid

好啦,看似问题都解决了,目标成功达成。剩下的我们就只需要重复的进行等长卷积+等长卷积+使用一个size=3和stride=2进行maxpooling进行池化就可以啦,DPCNN就可以捕捉文本的长距离依赖啦!


Shortcut connections with pre-activation

但是!如果问题真的这么简单的话,深度学习就一下子少了超级多的难点了。

(1) 初始化CNN的时,往往各层权重都初始化为很小的值,这导致了最开始的网络中,后续几乎每层的输入都是接近0,这时的网络输出没有意义;

(2) 小权重阻碍了梯度的传播,使得网络的初始训练阶段往往要迭代好久才能启动;

(3) 就算网络启动完成,由于深度网络中仿射矩阵(每两层间的连接边)近似连乘,训练过程中网络也非常容易发生梯度爆炸或弥散问题。

当然,上述这几点问题本质就是梯度弥散问题。那么如何解决深度CNN网络的梯度弥散问题呢?当然是膜一下何恺明大神,然后把ResNet的精华拿来用啦! ResNet中提出的shortcut-connection/ skip-connection/ residual-connection(残差连接)就是一种非常简单、合理、有效的解决方案。

类似地,为了使深度网络的训练成为可能,作者为了恒等映射,所以使用加法进行shortcut connections,即z+f(z),其中 f 用的是两层的等长卷积。这样就可以极大的缓解了梯度消失问题。

另外,作者也使用了 pre-activation,这个最初在何凯明的“Identity Mappings in Deep Residual Networks上提及,有兴趣的大家可以看看这个的原理。直观上,这种“线性”简化了深度网络的训练,类似于LSTM中constant error carousels的作用。而且实验证明  pre-activation优于post-activation。

整体来说,巧妙的结构设计,使得这个模型不需要为了维度匹配问题而担忧。


Region embedding

同时DPCNN的底层貌似保持了跟TextCNN一样的结构,这里作者将TextCNN的包含多尺寸卷积滤波器的卷积层的卷积结果称之为Region embedding,意思就是对一个文本区域/片段(比如3gram)进行一组卷积操作后生成的embedding。

另外,作者为了进一步提高性能,还使用了tv-embedding (two-views embedding)进一步提高DPCNN的accuracy

上述介绍了DPCNN的整体架构,可见DPCNN的架构之精美。本文是在原始论文以及知乎上的一篇文章的基础上进行整理。本文可能也会有很多错误,如果有错误,欢迎大家指出来!建议大家为了更好的理解DPCNN ,看一下原始论文和参考里面的知乎。


用Keras实现DPCNN网络

这里参考了一下kaggle的代码,模型一共用了七层,模型的参数与论文不太相同。这里滤波器通道个数为64(论文中为256),具体的参数可以参考下面的代码,部分我写了注释。

def CNN(x):
   block = Conv1D(filter_nr, kernel_size=filter_size, padding=same, activation=linear,
               kernel_regularizer=conv_kern_reg, bias_regularizer=conv_bias_reg)(x)
   block = BatchNormalization()(block)
   block = PReLU()(block)
   block = Conv1D(filter_nr, kernel_size=filter_size, padding=same, activation=linear,
               kernel_regularizer=conv_kern_reg, bias_regularizer=conv_bias_reg)(block)
   block = BatchNormalization()(block)
   block = PReLU()(block)
   return block

def DPCNN():
   filter_nr = 64 #滤波器通道个数
   filter_size = 3 #卷积核
   max_pool_size = 3 #池化层的pooling_size
   max_pool_strides = 2 #池化层的步长
   dense_nr = 256 #全连接层
   spatial_dropout = 0.2
   dense_dropout = 0.5
   train_embed = False
   conv_kern_reg = regularizers.l2(0.00001)
   conv_bias_reg = regularizers.l2(0.00001)

   comment = Input(shape=(maxlen,))
   emb_comment = Embedding(max_features, embed_size, weights=[embedding_matrix], trainable=train_embed)(comment)
   emb_comment = SpatialDropout1D(spatial_dropout)(emb_comment)

   #region embedding层
   resize_emb = Conv1D(filter_nr, kernel_size=1, padding=same, activation=linear,
               kernel_regularizer=conv_kern_reg, bias_regularizer=conv_bias_reg)(emb_comment)
   resize_emb = PReLU()(resize_emb)
   #第一层
   block1 = CNN(emb_comment)
   block1_output = add([block1, resize_emb])
   block1_output = MaxPooling1D(pool_size=max_pool_size, strides=max_pool_strides)(block1_output)
   #第二层
   block2 = CNN(block1_output)
   block2_output = add([block2, block1_output])
   block2_output = MaxPooling1D(pool_size=max_pool_size, strides=max_pool_strides)(block2_output)
   #第三层
   block3 = CNN(block2_output)
   block3_output = add([block3, block2_output])
   block3_output = MaxPooling1D(pool_size=max_pool_size, strides=max_pool_strides)(block3_output)  
   #第四层
   block4 = CNN(block3_output)
   block4_output = add([block4, block3_output])
   block4_output = MaxPooling1D(pool_size=max_pool_size, strides=max_pool_strides)(block4_output)
   #第五层
   block5 = CNN(block4_output)
   block5_output = add([block5, block4_output])
   block5_output = MaxPooling1D(pool_size=max_pool_size, strides=max_pool_strides)(block5_output)
   #第六层
   block6 = CNN(block5_output)
   block6_output = add([block6, block5_output])
   block6_output = MaxPooling1D(pool_size=max_pool_size, strides=max_pool_strides)(block6_output)
   #第七层
   block7 = CNN(block6_output)
   block7_output = add([block7, block6_output])
   output = GlobalMaxPooling1D()(block7_output)
   #全连接层
   output = Dense(dense_nr, activation=linear)(output)
   output = BatchNormalization()(output)
   output = PReLU()(output)
   output = Dropout(dense_dropout)(output)
   output = Dense(6, activation=sigmoid)(output)

   model = Model(comment, output)
   model.summary()
   model.compile(loss=binary_crossentropy,
               optimizer=optimizers.Adam(),
               metrics=[accuracy])
   return model



DPCNN实战

上面我们用keras实现了我们的DPCNN网络,这里我们借助kaggle的有毒评论文本分类竞赛来实战下我们的DPCNN网络。

具体地代码,大家可以去我的GitHub上面找到源码: 

https://github.com/hecongqing/TextClassification/blob/master/DPCNN.ipynb

参考:

https://ai.tencent.com/ailab/media/publications/ACL3-Brady.pdf

https://zhuanlan.zhihu.com/p/35457093

https://www.kaggle.com/michaelsnell/conv1d-dpcnn-in-keras


AI算法之心是一个介绍python、pyspark、机器学习、自然语言处理、深度学习、算法竞赛的平台。


(本文为 AI科技大本营转载文章,转载请微信联系原作者。)

征稿

640?wx_fmt=png


推荐阅读

  • PDF翻译神器,再也不担心读不懂英文Paper了

  • Facebook增强版LASER开源:零样本迁移学习,支持93种语言

  • 啥是佩奇排名算法

  • 网络爬虫的法律边界

  • Caicloud 开源 Nirvana:让 API 从对框架的依赖中涅槃重生

  • 程序员有话说 | 那个拒绝加班的程序员后来怎么样了

  • 告别摩拜

  • 6大改进:盘点以太坊的2018冒险之旅

  • 不难!月薪 50K大牛,悉心整理程序员必备技能!


640?wx_fmt=png

相关文章:

linux驱动:设备-总线-驱动(以TI+DM8127中GPIO为例)

一:说明:这次学习设备-总线-驱动是以TIDM8127的GPIO为例 1、GPIO资源注册到omap_hwmod链表中 2、初始化GPIO 3、将GPIO注册到plarform层 4、将GPIO注册到device层 二、流程图 1、GPIO资源注册到omap_hwmod链表中 2、初始化GPIO 3、将GPIO注册到pla…

生活总是在推着你一步一步往前走

上早班的时候,无意间看到了关于高考这个字眼。对于我的高考已经过去五年了,但回想起来记忆依旧是那么深刻。记得五年前的那个日子,阳光明媚,空气中到处都是一股夏天的气息,我妈和我哥早早的从家里搭车到县城&#xff0…

急!!!求从字符串中提取形如: div([MC0010000000006],此若干个字符或数字,0) 的正则表达式...

如题, 形如: div([MC0010000000006],此处有若干个字符或数字, 此处只有一个字符) 静坐等待.

C# 如何创建Excel多级分组

在Excel中如果能够将具有多级明细的数据进行分组显示,可以清晰地展示数据表格的整体结构,使整个文档具有一定层次感。根据需要设置显示或者隐藏分类数据下的详细信息,在便于数据查看、管理的同时也使文档更具美观性。那么,在C#中如…

苹果裁员逾200人,拿无人驾驶“开刀”

整理 | 琥珀出品 | AI科技大本营1 月 14日,据美国媒体 CNBC 援引知情人士消息报道称,本周,苹果泰坦项目(Project Titan)的 200 多名员工遭到解雇。据悉,泰坦项目是苹果未公开的自动驾驶汽车项目。一名苹果发…

linux驱动:i2c驱动(一)

I2C系统框架:I2C核心层、I2C总线驱动、I2C设备驱动 -------------------------------------------------------------------------------- 【I2C核心层】 代码在driver/i2c/i2c-core.c中 【I2C总线驱动】也叫I2C适配器驱动 1、每个适配器视为一个字符设备文件 …

关于SQLServer2005的学习笔记——XML的处理

在 SQLServer2005 中对 XML 的处理功能显然增强了很多,提供了 query(),value(),exist(),modify(),nodes() 等函数。关于 xml ,难以理解的不是 SQLServer 提供的函数,而是对 xml 本身的理解,看似很简单的文件格式,处理起…

2019最新实战!给程序员的7节深度学习必修课,最好还会Python!

整理 | 琥珀出品 | AI科技大本营从 2017 年开始,fast.ai 创始人、数据科学家 Jeremy Howard 以每年一迭代的方式更新“针对编程者的深度学习课程”(Practical Deep Learning For Coders)。这场免费的课程可以教大家如何搭建最前沿的模型、了解…

linux驱动:i2c驱动(二)

3、驱动源码分析 IPNC_RDK_V3.8.0.1/Source/ti_tools/ipnc_psp_arago/kernel/sound/soc/codecs/tlv320aic3x.c 3.1 注册模块 module_init(aic3x_modinit); 3.2 在初始化函数中添加i2c驱动 static int __init aic3x_modinit(void) { intret 0; #if defined(CONFIG_I2C) ||…

01 使用AFN3 0上传图片时间慢的问题

##iOS中修改图片的大小:修改分辨率和裁剪 ###第一步:裁剪图片 // 裁剪// 要裁剪的图片区域,按照原图的像素大小来,超过原图大小的边自动适配CGSize size CGSizeMake(1000, 1000);UIImage *img [self imageWithImageSimple:image scaledToS…

配置telnet

配置telnet<?xml:namespace prefix o ns "urn:schemas-microsoft-com:office:office" />允许root账号能够登录telnet&#xff0c;但是拒绝某一台主机登录且只允许在9&#xff1a;00-14&#xff1a;00 14&#xff1a;00-18&#xff1a;00能够访问&#xff0…

04 pod setup 慢的问题

解决方式一: 可以直接从别人的电脑中拷贝解决方式二转载于:https://juejin.im/post/5a3c5a985188257d391d3a39

linux驱动:i2c驱动(三)流程图之注册设备

一、设备注册过程 1、将i2c设备信息保存到i2c_board_info结构体中&#xff1b; 2、在注册i2c_board_info时&#xff08;i2c_register_board_info&#xff09;将它加入一个全局列表__i2c_board_list中&#xff0c; 3、在注册I2c adapter适配器驱动后&#xff0c;再从全局列表…

AI找Bug,一键快速预测

作者 | Jane出品 | AI科技大本营&#xff08;ID:rgznai100&#xff09;在程序开发中&#xff0c;程序员每天都要和 Bug 打交道&#xff0c;对新手程序员而言&#xff0c;debug 是一件非常让人头疼的事情。好不容易写完一段代码&#xff0c;一运行&#xff0c;全是红色&#xff…

专业研究HP procurve网络、阿姆瑞特和系统集成的论坛

一个专业研究HP procurve网络、阿姆瑞特防火墙和系统集成的论坛http://www.vlan2.com确实不错。转载于:https://blog.51cto.com/showrouter/284235

到底是什么特征影响着CNN的性能?

作者 | 刘畅 编辑 | Jane出品 | AI科技大本营&#xff08;ID:rgznai100&#xff09;开门见山。最近阅读了一篇论文&#xff0c;加上看了一些之前的工作。记录一下&#xff0c;CNN 到底学到了什么东西&#xff0c;或者换句话讲。到底是什么样的特征在影响着CNN 的性能&#xff1…

Java数据结构与算法(八)-二叉树

一、为什么要使用树 有序数组插入、删除数据慢。链表查找数据慢树可以解决这两个问题二、相关术语 树的结点&#xff1a;包含一个数据元素及若干指向子树的分支&#xff1b;孩子结点&#xff1a;结点的子树的根称为该结点的孩子&#xff1b;双亲结点&#xff1a;B 结点是A 结点…

linux驱动:i2c驱动(四)流程图之注册驱动

二、i2c设备的驱动部分 1、i2c驱动i2c_driver 2、通过i2c_add_driver注册 2、注册过程中 比较i2c_device_id数组中各成员的id与i2c_client中的名字&#xff0c;找到设备 3、执行i2c_driver驱动中的probe

Expression Blend实例中文教程(2) - 界面快速入门

上一篇主要介绍Expression系列产品&#xff0c;另外概述了Blend的强大功能&#xff0c;本篇将用Blend 3创建一个新Silverlight项目&#xff0c;通过创建的过程&#xff0c;对Blend进行快速入门学习。 在开始使用Blend前&#xff0c;首先需要进行Silverlight的开发环境搭建&…

Lua基本语法-书写规范以及自带常用函数

Lua基本语法-书写规范和常用函数本文提供全流程&#xff0c;中文翻译。Chinar坚持将简单的生活方式&#xff0c;带给世人&#xff01;&#xff08;拥有更好的阅读体验 —— 高分辨率用户请根据需求调整网页缩放比例&#xff09; 1String Operation —— 字符串操作2Table ——…

linux驱动:音频驱动(一)ALSA

一、【基础知识】 1、J2 《--HPR_OUTHPL_OUT 《-- U13&#xff08;TLV320AIC3104IRHBR&#xff09;的HPROUTHPLOUT 2、驱动源码 IPNC_RDK_V3.8.0.1/Source/ti_tools/ipnc_psp_arago/kernel/sound/soc/codecs/tlv320aic3x.c 3、依赖于I2C驱动 4、声卡驱动框架&#xff1a;…

秘籍 | 机器学习数据集网址大全

作者 | Will Badr译者 | Linstancy整理 | Jane出品 | AI科技大本营&#xff08;ID:rgznai100&#xff09;要找到一定特定的数据集可以解决各种机器学习问题&#xff0c;是一件很难的事情。越来越多企业或研究机构将自己的数据集公开&#xff0c;已经成为全球的趋势&#xff0c;…

为asa防火墙配置ssh登陆

由于最近事情超多&#xff0c;单位下发某些令人恶心的制度&#xff0c;今天突然说北京分公司和总公司之间要做***的连接&#xff0c;虽然俺是个CCNP&#xff0c;但是对于***来说接触的少之又少&#xff0c;并且工作繁忙&#xff0c;每天头大&#xff0c;北京分公司的安全ie同事…

70.nodejs操作mongodb

转自&#xff1a;https://www.cnblogs.com/whoamme/p/3467374.html 首先安装nodejs mongodb npm install mongodb var mongodb require(mongodb); var server new mongodb.Server(localhost, 27017, {auto_reconnect:true}); var db new mongodb.Db(mydb, server, {saf…

明晚8点公开课 | 用AI给旧时光上色!详解GAN在黑白照片上色中的应用

在改革开放40周年之际&#xff0c;百度联合新华社推出了一个刷屏级的H5应用——用AI技术为黑白老照片上色&#xff0c;浓浓的怀旧风勾起了心底快被遗忘的时光。想了解如何给老照片上色&#xff1f;本次公开课中&#xff0c;我们邀请到了百度高级研发工程师李超&#xff0c;他的…

linux驱动:音频驱动(二)ASoc

五、【ASoC声卡驱动框架】 1、ASoC将嵌入式设备的音频系统从软件层面划分为3个组件 1.1 codec驱动&#xff1a;音频编解码器驱动&#xff0c;与平台无关&#xff0c;实现音频控制项添加、音频接口实现、DAPM&#xff08;动态音频电源管理&#xff09;、音频编解码器的IO功能 …

把32位的SharePoint服务器场迁移到64位, 应该怎么做?

总体步骤如下: 1. 迁移已经存在了的数据库服务器到新的数据库服务器. 先迁移这一层的目的是避免可能发生的一些由64位系统对32位系统执行查询或写入操作所引起的性能问题. 2. 迁移WFE服务器到64位环境下. 准备工作: 1. 重新编译已经存在的32位的应用程序和自定义的程序集(web p…

testem方便的web tdd 测试框架使用

备注&#xff1a;单元测试&#xff0c;对于日常的开发是比较重要的&#xff0c;testem 简化了我们的代码编写&#xff0c;以及运行。主要特性&#xff1a;a. 支持的测试框架有&#xff1a;jasmine quint mocha buster.js &#xff0c;同时也包含一些其他的适配器&#xff0c;支…

程序员老在改Bug,就不能一次改好吗?

作者丨伍杏玲来源 | 程序人生&#xff08;ID&#xff1a;coder_life&#xff09;程序员的日常三件事&#xff1a;写Bug、改Bug、背锅。连程序员都自我调侃道&#xff0c;为什么每天都在加班&#xff1f;因为我的眼里常含Bug。但是真的有这么多Bug要改吗&#xff1f;就不能一次改…