真香!Vision Transformer 快速实现 Mnist 识别
作者 | 李秋键
出品 | AI科技大本营(ID:rgznai100)
引言:基于深度学习的方法在计算机视觉领域中最典型的应用就是卷积神经网络CNN。CNN中的数据表示方式是分层的,高层特征表示依赖于底层特征,由浅入深抽象地提取高级特征。CNN的核心是卷积核,具有平移不变性和局部敏感性等特点,可以捕捉局部的空间信息。
在过去的10年间,CNN存在很大的优势,在计算机视觉领域被人们寄予厚望,引领了一个时代。但是卷积这种操作缺乏对图像本身的全局理解,无法建模特征之间的依赖关系,从而不能充分地利用上下文信息。此外,卷积的权重是固定的,并不能动态地适应输入的变化。因此,研究人员尝试将自然语言处理领域中的Transformer模型迁移到计算机视觉任务。
Vision Transformer也因此诞生,一种完全基于自注意力机制的图像分类方法。
相比CNN,Transformer的自注意力机制不受局部相互作用的限制,既能挖掘长距离的依赖关系又能并行计算,可以根据不同的任务目标学习最合适的归纳偏置,在诸多视觉任务中取得了良好的效果。
故今天我们将实现Pytorch搭建transformer模型实现Mnist手写字体识别,效果如下:
Transformer基本介绍
Transformer在计算机视觉领域能够迅速发展的原因:
(1)学习长距离依赖能力强。CNN是通过不断地堆叠卷积层来实现对图像从局部信息到全局信息的提取,这种计算机制显然会导致模型臃肿,计算量大幅增加,带来梯度消失问题,甚至使整个网络无法训练收敛。而Transformer自带的长依赖特性,利用注意力机制来捕获全局上下文信息,抽取更强有力的特征。
(2)多模态融合能力强。CNN使用卷积核来获取图像信息,但不擅长融合其他模态的信息(如声音、文字、时间等)。而Transformer的输入不需要保持二维图像,通常可以直接对像素进行操作得到初始嵌入向量,其他模态的信息转换为向量即可直接在输入端进行融合。
(3)模型更具可解释性。在Transformer的多头注意力结构中,每个头都应用独立的自注意力机制,这使得模型可以针对不同的任务在不同的表示子空间里学习相关的信息。
1.1 Transformer基本结构
(1)编码器-解码器
Transformer采用编码器-解码器架构,由分别堆叠了6层的编码器和解码器组成,是一种避免循环的模型结构。
编码器每个层结构包含两个子层,多头注意力层和前馈连接层。解码器有三个子层结构,mask多头注意力层,多头注意力层,前馈连接层。每个子层后面都加上残差连接和正则化层,结构如下图:
位置编码记录了序列数据之间顺序的相关性,相比较RNN顺序输入,Transformer方法可以直接将数据并行输入,并存储数据之间的位置关系,大大提高了计算速度,减少了存储空间。
(2)自注意力及多头注意力
注意力机制现在已成为神经网络领域的一个重要概念。其快速发展的原因主要有三个。首先,它是解决多任务较为先进的算法,其次被广泛用于提高神经网络的可解释性,第三有助于克服RNN中的一些挑战,如随着输入长度的增加导致性能下降,以及输入顺序不合理导致的计算效率低下。而自注意力机制是注意力机制的改进,其减少了网络对外部信息的依赖,更擅长捕捉数据或特征内部的相关性。
Transformer架构引入自注意力机制,避免在神经网络中使用递归,完全依赖自注意力机制来绘制输入与输出之间的全局依赖。通过使用缩放点积注意力(scaled dot-product attention),相比一般的注意力,缩放点积注意力使用点积进行相似度计算,在实际中会更快更节省空间。在计算时,需要将输入通过线性变换得到矩阵Q(查询)、K(键值)、V(值)。
(3)位置特征编码模块
使用0到9表示分割后的小图像位置编号,并且每个位置设置一个可训练的随机变量,通过梯度下降法获得位置向量。包括以及模块代码可见。
1.2 Vision Transformer基本结构
为了将图像转化成Transformer结构可以处理的序列数据,Vision Transformer引入了图像块(patch)的概念。首先将二维图像做分块处理,每个图像块展平成一维向量,接着对每个向量进行线性投影变换,同时引入位置编码,加入序列的位置信息。此外在输入的序列数据之前添加了一个分类标志位,更好地表示全局信息。ViT模型通常在大型数据集上预训练,针对较小的下游任务进行微调。在ImageNet数据集上,VIT以88.55%的准确率超越了EfficientNet模型,成功打破了基于卷积主导的网络在分类任务上面的垄断,比传统的CNN网络更具效率和可扩展性。
模型搭建
为了从代码层面理解模型,下面用pytorch简单搭建手写字体识别模型。
这里程序的设计分为以下几个步骤,分别为模块构建、模型搭建以及训练等几个步骤。
2.1 模块构建
这里使用到的模块包括:残差模块,放在每个前馈网络和注意力之后;layernorm归一化,放在多头注意力层和激活函数层,用绝对位置编码的BERT,layernorm用来自身通道归一化;FeedForward放置多头注意力后,因为在于多头注意力使用的矩阵乘法为线性变换,后面跟上由全连接网络构成的FeedForward增加非线性结构;多头注意力层,多个自注意力连起来。使用qkv计算。
代码如下:
#残差模块,放在每个前馈网络和注意力之后
class Residual(nn.Module):def __init__(self, fn):super().__init__()self.fn = fndef forward(self, x, **kwargs):return self.fn(x, **kwargs) + x
#layernorm归一化,放在多头注意力层和激活函数层。用绝对位置编码的BERT,layernorm用来自身通道归一化
class PreNorm(nn.Module):def __init__(self, dim, fn):super().__init__()self.norm = nn.LayerNorm(dim)self.fn = fndef forward(self, x, **kwargs):return self.fn(self.norm(x), **kwargs)
#放置多头注意力后,因为在于多头注意力使用的矩阵乘法为线性变换,后面跟上由全连接网络构成的FeedForward增加非线性结构
class FeedForward(nn.Module):def __init__(self, dim, hidden_dim):super().__init__()self.net = nn.Sequential(nn.Linear(dim, hidden_dim),nn.GELU(),nn.Linear(hidden_dim, dim))def forward(self, x):return self.net(x)
#多头注意力层,多个自注意力连起来。使用qkv计算
class Attention(nn.Module):def __init__(self, dim, heads=8):super().__init__()self.heads = headsself.scale = dim ** -0.5self.to_qkv = nn.Linear(dim, dim * 3, bias=False)self.to_out = nn.Linear(dim, dim)def forward(self, x, mask = None):b, n, _, h = *x.shape, self.headsqkv = self.to_qkv(x)q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv=3, h=h)dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scaleif mask is not None:mask = F.pad(mask.flatten(1), (1, 0), value = True)assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'mask = mask[:, None, :] * mask[:, :, None]dots.masked_fill_(~mask, float('-inf'))del maskattn = dots.softmax(dim=-1)out = torch.einsum('bhij,bhjd->bhid', attn, v)out = rearrange(out, 'b h n d -> b n (h d)')out = self.to_out(out)return out
2.2 模型搭建
构建原始Transformer代码,然后构建VIT将图像切割成一个个图像块,组成序列化的数据输入Transformer执行图像分类任务。
代码如下:
class ViT(nn.Module):def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels=3):super().__init__()assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'num_patches = (image_size // patch_size) ** 2patch_dim = channels * patch_size ** 2self.patch_size = patch_sizeself.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))self.patch_to_embedding = nn.Linear(patch_dim, dim)self.cls_token = nn.Parameter(torch.randn(1, 1, dim))self.transformer = Transformer(dim, depth, heads, mlp_dim)self.to_cls_token = nn.Identity()self.mlp_head = nn.Sequential(nn.Linear(dim, mlp_dim),nn.GELU(),nn.Linear(mlp_dim, num_classes))def forward(self, img, mask=None):p = self.patch_sizex = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)x = self.patch_to_embedding(x)cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)x = torch.cat((cls_tokens, x), dim=1)x += self.pos_embeddingx = self.transformer(x, mask)x = self.to_cls_token(x[:, 0])return self.mlp_head(x)
2.3 模型训练
patch大小为 7x7(对于 28x28 图像,这意味着每个图像 4 x 4 = 16 个patch)、10 个可能的目标类别(0 到 9)和 1 个颜色通道(因为图像是灰度)。
在网络参数方面,使用了 64 个单元的维度,6 个 Transformer 块的深度,8 个 Transformer 头,MLP 使用 128 维度。
代码如下:
model = ViT(image_size=28, patch_size=7, num_classes=10, channels=1,dim=64, depth=6, heads=8, mlp_dim=128)
optimizer = optim.Adam(model.parameters(), lr=0.003)
train_loss_history, test_loss_history = [], []
for epoch in range(1, N_EPOCHS + 1):
print('Epoch:', epoch)train_epoch(model, optimizer, train_loader, train_loss_history)evaluate(model, test_loader, test_loss_history)
print('Execution time:', '{:5.2f}'.format(time.time() - start_time), 'seconds')
完整代码:
链接:
https://pan.baidu.com/s/1myFLjiTwgQe8z9WYVONntA
提取码:sbjm
李秋键,CSDN博客专家,CSDN达人课作者。硕士在读于中国矿业大学,开发有taptap竞赛获奖等。
往
期
回
顾
技术
Pandas&SQL语法归纳总结
资讯
Nginx宣布在俄罗斯禁止贡献
资讯
2022人工智能开启未来新密码
技术
一行Python代码能干嘛?来看!
分享
点收藏
点点赞
点在看
相关文章:

(二十一)数组的初始化
class Demo3 {public static void main(String[] args) {//数组的初始化int[] a new int[] {12,13,14,15};int[] b {12,13,14,15};//数组的便利for(int i 0;i<4;i) {System.out.println(a[i]);}for(int i 0;i<b.length;i) {System.out.println(b[i]);}} }转载于:http…

深入探讨PHP中的内存管理问题
一、 内存在PHP中,填充一个字符串变量相当简单,这只需要一个语句"<?php $str hello world ; ?>"即可,并且该字符串能够被自由地修改、拷贝和移动。而在C语言中,尽管你能够编写例如"char …

介绍一个效率爆表的数据采集框架
作者 | 俊欣来源丨关于数据分析与可视化今天我们来聊一下如何用协程来进行数据的抓取,协程又称为是微线程,也被称为是用户级线程,在单线程的情况下完成多任务,多个任务按照一定顺序交替执行。那么aiohttp模块在Python中作为异步的…

最多显示6行并且最多显示三条文本
为什么80%的码农都做不了架构师?>>> private void setCommentContent(ViewHolder vh, String feedId, int commentNum, ArrayList<CommentItem> comment_lists){if(commentNum < 0 || comment_lists null || comment_lists.isEmpty()){for(in…

【刷算法】LeetCode- 两数之和
题目描述 给定一个整数数组和一个目标值,找出数组中和为目标值的两个数。 你可以假设每个输入只对应一种答案,且同样的元素不能被重复利用。 示例: 给定 nums [2, 7, 11, 15], target 9因为 nums[0] nums[1] 2 7 9 所以返回 [0, 1] 复制代码分析 第…

栈区和堆区内存分配区别
一直以来总是对这个问题的认识比较朦胧,我相信很多朋友也是这样的,总是听到内存一会在栈上分配,一会又在堆上分配,那么它们之间到底是怎么的区别呢?为了说明这个问题,我们先来看一下内存内部的组织情况&…
高精度进制转换
高精度进制转换: 对于普通的不是非常大的数的相互转换,我们一般採用不断模取余的方法,比如:将10进制数m转换成n进制数,则对m模n取余就可以。可是,假设是一个有几百、几千或者很多其它位的大数呢?…

远程办公,你希望在家工作几天?
受疫情影响,员工的工作方式不得不发生改变。在过去短短的几个月内,远程办公从偶然一次变成了常态化。随着疫情的反复,远程办公再次成为了许多企业的选择。3月份携程正式启动了“32”混合办公模式,即每周有1-2天,员工可…

python爬虫日志(9)爬取代理
2019独角兽企业重金招聘Python工程师标准>>> 话不多说,直接上代码,很简单,很容易看懂 import requests from bs4 import BeautifulSoup import randomdef get_ip_list():print("正在获取代理列表...")ip_url http://ww…

使php支持mbstring库以及使用
1.执行yum install php-mbstring2. 修改php.ini (这一步非常重要, 部分lxadmin版本无法自动修改)echo ‘extensionmbstring.so’ >>/etc/php.ini #更具php安装目录而定3. 重启web service如果是apache: service httpd restart方法二:php 5.36安装目录…
仿余额宝数字跳动效果 TextCounter
1、TextCounter 效果 2、TextCounter 说明 每次打开余额宝第一件事情就去看看有多少钱,最炫的就是看着钱在跳动相当的舒服,今天放出这个效果。 温馨提示:支持的Android版本最低的是Android 4.0.0 IceCreamSandwich ( API等级14 &a…

年仅 16 岁的黑客少年,竟是搅乱 IT 巨头的幕后主使?
整理 | 郑丽媛出品 | CSDN近来,黑客组织 Lapsus$ 活跃在各大科技网站:窃取英伟达近 1TB 的数据、泄露三星近 190GB 的机密数据、公布微软 Bing 和 Cortana 源码…不同于大部分黑客组织,Lapsus$ 没有刻意隐藏自己,反而行事非常高调…

使用硬盘,安装双系统,Win7+CentOS
我用那个U盘装了很多次都不行,都是说找不到文件。最后就找了一篇博客看如何安装双系统,最后发现原来可以用硬盘安装的。经过5个多小时终于完成了。^-^。 1.首先是分区,可以使用Window7自带的磁盘管理程序进行分区。(PS 我是用Cent…

Linux 文件系统剖析
Linux 文件系统剖析 按照分层结构讨论 Linux 文件系统 M. Tim Jones, 顾问工程师, Emulex Corp. 简介: 在文件系统方面,Linux 可以算得上操作系统中的 “瑞士军刀”。Linux 支持许多种文件系统,从日志型 文件系统到集群文件系统和加密文件系统…

Docker构建Nginx+Tomcat动静分离架构
随着主流Nginx WEB服务器的发展,现在基于Nginx的WEB服务器已广泛应用于各大互联网企业。今天我们来使用docker构建我们的LinuxNginxTomcat动静分离服务器。1) 启动docker镜像查看当前系统存在的镜像,我这里为CentOS6.6,大家可以参考我第一…

硬核!Python 四种变量的代码对象和反汇编分析
作者 | 大奎整理 | 阳哥来源丨Python数据之道在Python基础的学习过程中,对变量和参数的理解有助于我们从更基础层面了解Python语言的运行。在这个过程中,还是有不少冷门和细节的地方需要进一步熟悉。今天我们来分享Python四种变量的代码对象和反汇编分析…

Python--数据存储:pickle模块的使用讲解
在机器学习中,我们常常需要把训练好的模型存储起来,这样在进行决策时直接将模型读出,而不需要重新训练模型,这样就大大节约了时间。Python提供的pickle模块就很好地解决了这个问题,它可以序列化对象并保存到磁盘中&…

Linux虚拟内存和物理内存精华【美】
原文地址: 《Playing with Virtual Memory》 http://www.snailinaturtleneck.com/blog/2011/08/30/playing-with-virtual-memory/ 扩展阅读: 《Understanding Memory》 http://www.ualberta.ca/CNS/RESEARCH/LinuxClusters/mem.html 《Understanding Vir…

留不住客户?该从你的系统上找找原因了
本篇文章暨 CSDN《中国 101 计划》系列数字化转型场景之一。 《中国 101 计划——探索企业数字化发展新生态》为 CSDN 联合《新程序员》、GitCode.net 开源代码仓共同策划推出的系列活动,寻访一百零一个数字化转型场景,聚合呈现并开通评选通道࿰…

系统配置文件备份比较
客户的系统出各种问题,这次出了问题整整一天都没找出原因,都红脸了,最后发现是系统配置文件被改掉了,简直不能忍,所以写了这个脚本,放到定时任务里面,每天备份比较配置文件import difflib impor…

RPC是什么?为什么要学习RPC?
随着近几年分布式、微服务架构的火热,RPC在开发工作中使用的越来越多,也变的越来越重要。 今天我们来看RPC是什么,为什么要了解RPC,通过学习RPC我们能掌握什么内容? 什么是「RPC」 RPC 全称 Remote Procedure Call, wikipedia的部…

Lua学习笔记6:C++和Lua的相互调用
曾经一直用C写代码。话说近期刚换工作。项目组中的是cocos2dx-lua,各种被虐的非常慘啊有木有。新建cocos2dx-lua项目。打开class能够发现,事实上就是C项目啦,只是为什么仅仅有一类Appdelegate类呢?哈哈,我相信聪明的你一定猜到了&…

Redis消息通知系统的实现
Redis消息通知系统的实现Posted on 2012-02-29by 老王 http://huoding.com/2012/02/29/146最近忙着用Redis实现一个消息通知系统,今天大概总结了一下技术细节,其中演示代码如果没有特殊说明,使用的都是PhpRedis扩展来实现的。内存比如要推送一…

用 Python 实现答题卡识别!
作者 | 棒子胡豆来源丨CSDN博客答题卡素材图片:思路读入图片,做一些预处理工作。进行轮廓检测,然后找到该图片最大的轮廓,就是答题卡部分。进行透视变换,以去除除答题卡外的多余部分,并且可以对答题卡进行校…
Confluence 6 计划任务
管理员控制台能够允许你对 Confluence 运行的计划任务进行计划的调整,这些计划任务将会按照你的调整按时执行。可以按照计划执行的任务如下: Confluence 站点备份存储优化任务,清理 Confluence 的临时目录中的文件和缓存索引优化任务…

PHP共享内存段
在asp.net和java中都有共享内存,php除了可以使用Memcached等方式变通以外其实php也是支持共享内存的!需要安装扩展shmop 找到php安装源文件目录# cd /usr/local/php-5.4.0/ext/shmop # /usr/local/php/bin/phpize # ./configure --with-php-config/usr/l…

马尔科夫随机场的基本概念
1、随机过程: 描写叙述某个空间上粒子的随机运动过程的一种方法。它是一连串随机事件动态关系的定量描写叙述。随机过程与其他数学分支,如微分方程、复变函数等有密切联系。是自然科学、project科学及社会科学等领域研究随机现象的重要工具。 2、马尔科夫…

从事了两年 AI 研究,我学到了什么?
作者 | Tom Silver译者 | 弯月出品 | CSDN我从事人工智能研究的工作已经有两年了,有朋友问我都学到了什么,所以我想借本文分享一些迄今为止积累的经验教训。我将在本文中分享一些常见的经历,还会讨论相对具体的人工智能行业技巧。希望对大家能…

Windows server 2008普通用户不能远程登录问题
1、查登录权限 如果文件服务器没有为用户授权,那么用户自然就不能远程登录服务器系统了,为此笔者决定先仔细检查一下文件服务器系统是否为自己使用的登录账号,授予了远程登录权限。在进行这种检查时,笔者先是在文件服务器本地以系…

面向小白的最全 Python 可视化教程,超全的!
作者 | 俊欣来源丨关于数据分析与可视化今天小编总结归纳了若干个常用的可视化图表,并且通过调用plotly、matplotlib、altair、bokeh和seaborn等模块来分别绘制这些常用的可视化图表,最后无论是绘制可视化的代码,还是会指出来的结果都会通过调…