当前位置: 首页 > 编程日记 > 正文

CV06-Xception笔记

目录

一、为啥是Xception

二、Xception结构

2.1 Xception结构基本描述

2.2 实现细节

2.3 DeepLabV3+改进

三、记录pytorch采坑relu激活函数inplace=True


Xception笔记,记录一些自己认为重要的要点,以免日后遗忘。

复现Xception论文、DeepLabV+改进的Xception,代码地址https://github.com/Ascetics/Pytorch-Xception

一、为啥是Xception

Xception脱胎于Inception,Inception的思想是将卷积分成cross-channel conv和spatial conv,更准确的说是先用1x1卷积得到几个不同channel(小于输入channel)的结果,再在这些结果上分别用3x3、5x5 conv,也就是论文Figure 1描述的那样。Inception的这种算法背后,本质上是将cross-channel conv和spatial conv解耦。

考虑将Inception简化:去掉平均池化层,只用3x3 conv(2个3x3 conv相当于1个5x5 conv)。就测到了论文Figure 2描述的这种结构。

 

在Figure 2的基础上,用1个channel很大的1x1 conv 将输入映射到一个channel很大的输出上。再将这个输出“切成几段”,“切成几段”分别做3x3 spatial conv,就得到了论文中Figure 3的结构。作者在此提出一个问题,这样将cross-channel conv和spatial conv完全解耦分开合理吗?完全解耦分开,可以这样做吗?
基于Figure 3提出的假设,做一个极端的Inception模型。还是先用1个channel很大的1x1 conv 将输入映射到一个channel很大的输出上,然后“切成几段”变成“切片”,每个channel切一片。对每个channel做3x3卷积。这样极端的设计就接近于深度可分离卷积depthwise separable convolution。

为什么是“接近”,而不是“就是”呢?因为和depthwise separable convolution的操作顺序、操作内容不一样。

  1. 顺序上,depthwise separable convolution,用3x3 conv进行spatial conv,用1x1 conv进行cross-channel conv;极端版本Inception先用1x1 conv再用3x3 conv;
  2. 内容上,depthwise separable convolution,spatial conv和cross-channel conv之间没有非线性(ReLU激活函数);极端版本Inception,卷积之间有非线性(ReLU激活函数);

作者认为第一个区别是不重要的,特别是因为这些操作要在堆叠(深度学习)的环境中使用。第二个区别重要,作者研究了一下,结论见论文Figure 10。本文后面会解释。

要看懂Xception,需要了解VGG、Inception、Depthwise Separable Convlution和ResNet,都会用到。

二、Xception结构

2.1 Xception结构基本描述

卷积神经网络特征提取中的卷积都可以完全解耦,变成深度可分离卷积(Xception也就是Extreme Inception的意思)。接收了这一设定,Xception结构被解释为论文Figure 5的样子。

Xception的特征提取基础由36个conv layer构成。这36个conv layer被组织成14个module,除了第一个和最后一个module,其余的module都带有residual connection(残差,参看何凯明大神的ResNet)。简言之,Xception结构就是连续使用depthwise separable convolution layer和residual connection。

2.2 实现细节

如Figure 5 描述所述。

输入先经过Entry flow,不重复;再经过Middle flow,Middle flow重复8次;最后经过Exit flow,不重复。

所有的Conv 和 Separable Conv后面都加BN层,但是Figure 5没有画出来。

所有的Separable Conv都用depth=1,也就是每个depth-wise都是“切片”的。

注意, depthwise separable convolution在spatial conv和cross-channel conv之间不要加ReLU激活函数,任何激活函数都不要加。论文Figure 10展示了,这里不加激活函数效果最好,加ReLU、ELU都不好。

还有一些是论文中没有明说的细节。

Residual Connection在1x1卷积后面也加上BN。Residual Connection加上以后,不要着急做激活函数,仔细看图,激活函数ReLU是属于下一个Block的。这就导致了代码实现上采坑,下一节详细记录一下。也算长个记性。

2.3 DeepLabV3+改进

这一部分在下一篇博客,学DeepLabV3+中再记录。

三、记录pytorch采坑relu激活函数inplace=True

上面2.2写了一个细节,如果严格按照论文的示意图来实现Xception,那么每个block第一个操作不是SeparableConv,而是ReLU(红色框)。如果仅仅是第一个操作是ReLU也没有关系,但是旁边还有个Residual Connection(蓝色框)。自古红蓝出CP,于是坑来了,在反向传播的时候,报了个错,明显是因为inplace导致了某个对象被modify了,反向传播求梯度报错。(此时我所有代码的ReLU用的都是inplace=True,省内存嘛)

以Entry Flow的Block为例,先来欣赏一下错误的代码。

class _PoolEntryBlock(nn.Module):def __init__(self, in_channels, out_channels, relu1=True):"""Entry Flow的3个下采样module按论文所说,每个Conv和Separable Conv都需要跟BN论文Figure 5中,第1个Separable Conv前面没有ReLU,需要判断一下,论文Figure 5中,每个module的Separable Conv的out_channels一样,MaxPool做下采样:param in_channels: 输入channels:param out_channels: 输出channels:param relu1: 判断有没有第一个ReLU,默认是有的"""super(_PoolEntryBlock, self).__init__()self.project = ResidualConnection(in_channels, out_channels, stride=2)self.relu1 = Noneif relu1:self.relu1 = nn.ReLU(inplace=True)  self.sepconv1 = SeparableConv2d(in_channels, out_channels,kernel_size=3, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.relu2 = nn.ReLU(inplace=True)self.sepconv2 = SeparableConv2d(out_channels, out_channels,kernel_size=3, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)passdef forward(self, x):identity = self.project(x)  # residual connection 准备if self.relu1:  # 第1个Separable Conv前面没有ReLU,需要判断一下x = self.relu1(x)x = self.sepconv1(x)  # 第2个Separable Convx = self.bn1(x)x = self.relu2(x)x = self.sepconv2(x)  # 第2个Separable Convx = self.bn2(x)x = self.maxpool(x)  # 下采样2倍x = x + identity  # residual connection 相加return xpass

采坑时的做法,先算出Residual Connection,再做relu、SeparableConv。Residual Connection时,已经进行过一次卷积操作,此时要求输入x本身不能发生改变,不能再被modify。后面的ReLU(inplace=True)恰恰就modify了x。所以反向传播时报错。

改变执行的先后顺序呢?也不行。如果先ReLU(inplace=True),那么x也被modify了,再做Residual Connection时输入就不是block输入的那个x了。

解决的办法,改为ReLU(inplace=False),或者Residual Connection的输入改为x.clone(),总之不能省内存……正确的代码已经push到github上了,地址详见文章开头。

为此,我写了一个简化的模型:

  1. class Wrong就是采坑的错误实现;
  2. class RightOne就是改为ReLU(inplace=False);
  3. class RightTwo就是Residual Connection的输入改为x.clone();

一杯茶,一包烟,一个bug改一天……

import torch
import torch.nn as nnclass Wrong(nn.Module):def __init__(self):super(Wrong, self).__init__()self.convs = nn.Sequential(nn.ReLU(inplace=True),nn.Conv2d(3, 3, 3, padding=1))self.residual = nn.Conv2d(3, 3, 3, padding=1)passdef forward(self, x):r = self.residual(x)  # 卷积之后,x就不能modify了h = self.convs(x)  # relu就modify了x,反向传播时候会报错h = h + rreturn hpassclass RightOne(nn.Module):def __init__(self):super(RightOne, self).__init__()self.convs = nn.Sequential(nn.ReLU(inplace=False),  # 改法1,别省内存了nn.Conv2d(3, 3, 3, padding=1))self.residual = nn.Conv2d(3, 3, 3, padding=1)passdef forward(self, x):r = self.residual(x)h = self.convs(x)h = h + rreturn hpassclass RightTwo(nn.Module):def __init__(self):super(RightTwo, self).__init__()self.convs = nn.Sequential(nn.ReLU(inplace=True),nn.Conv2d(3, 3, 3, padding=1))self.residual = nn.Conv2d(3, 3, 3, padding=1)passdef forward(self, x):r = self.residual(x.clone())  # 改法2,clone还是消耗内存的h = self.convs(x)h = h + rreturn hpassif __name__ == '__main__':in_data = torch.randint(-2, 2, (1, 3, 2, 2), dtype=torch.float)in_label = torch.randint(0, 3, (1, 2, 2))print(in_data.shape)func = nn.CrossEntropyLoss()t = RightTwo()in_data = in_data.cuda()in_label = in_label.cuda()t.cuda()out_data = t(in_data)print(out_data.shape)loss = func(out_data, in_label)loss.backward()

相关文章:

C++排序算法实现(更新中)

比较排序法:如冒泡排序、简单选择排序、合并排序、快速排序。其最优的时间复杂度为O(nlogn)。 其他排序法:如桶排序、基数排序等。时间复杂度可以达到O(n)。但试用范围有要求。 桶排序:排序的数组元素跨距不能很大。因为跨距很大的话&#xf…

iOS SwiftUI篇-5 专题NavigationView、NavigationLink

iOS SwiftUI篇-5 专题NavigationView、NavigationLink NavigationView:标题、展示模式、隐藏导航栏、隐藏返回按钮、添加导航栏按钮 NavigationLink:Text文本跳转、Image图片跳转、Button按钮跳转、点击按钮根据业务跳转到不同页面 NavigationView 标题、展示模式 import S…

PHP artisan

Artisan 是 Laravel 提供的 CLI(命令行接口),它提供了非常多实用的命令来帮助我们开发 Laravel 应用。前面我们已使用过 Artisan 命令来生成应用的 App Key 和控制器。在本教程中,我们会用到以下 Artisan 命令,你也可以…

【转载】Pytorch在加载模型参数时指定设备

转载 https://sparkydogx.github.io/2018/09/26/pytorch-state-dict-gpu-to-cpu/ >>> torch.load(tensors.pt) # Load all tensors onto the CPU >>> torch.load(tensors.pt, map_locationtorch.device(cpu)) # Load all tensors onto the CPU, using a fun…

目标检测之Faster-RCNN的pytorch代码详解(数据预处理篇)

首先贴上代码原作者的github:https://github.com/chenyuntc/simple-faster-rcnn-pytorch(非代码作者,博文只解释代码) 今天看完了simple-faster-rcnn-pytorch-master代码的最后一个train.py文件,是时候认真的总结一下了&#xff0…

hp-ux 集群,内存 小记

hp-ux 集群,内存 小记 -----查看hp 集群状态信息 # cmviewcl -v CLUSTER STATUS dbsvr up NODE STATUS STATE db01 up running Cluster_Lock_LVM: VOLUM…

iOS SwiftUI篇-6 专题TabView

iOS SwiftUI篇-6 专题TabView TabView: 图片+文字组成tabItem,选中时改变图片和文字颜色 跳转到二级页面时隐藏tabbar,返回到首页时显示tabbar 首页、我的两个tab,效果图: 图片文字组成tabItem,选中时改变图片和文字颜色 代码: struct MainContentView: View {@State…

三维刚体变化中Rcw,tcw的含义

高翔博士的《视觉SLAM十四讲》中,介绍Tcw指从世界坐标w到c的变换矩阵。但研一学机器人学的时候,讲T12的含义是,坐标系2相对于坐标系1的变换。于是一脸懵逼。昨天想了一晚上,有了一点自己的想法,在这记录一下&#xff0…

CV07-DeepLab v3+笔记

目录 一、Dilated Convolution 膨胀卷积 二、ASPP与Encoder & Decoder 三、深度可分离卷积 3.1 深度可分离卷积原理 3.2 深度可分离卷积减小参数量和计算量 3.3 深度可分离卷积实现细节 四、Xception作为Backbone DeepLab v3笔记,记录一些自己认为重要的…

1116.加减乘除

题目描述:根据输入的运算符对输入的整数进行简单的整数运算。 运算符只会是加、减-、乘*、除/、求余%、阶乘!六个运算符之一。 输出运算的结果,如果出现除数为零,则输出“error”,如果求余运算的第二个运算数为0,也输出…

Flutter专题1-环境搭建

Flutter专题1-环境搭建和创建项目 这里以MaciOS为例,其他平台参考官网https://flutter.dev/docs/get-started/install 1. 系统要求 系统:macOS (64-bit) 硬盘空间:2.8G 工具:Git 2.获取Flutter SDK 2.1下载SDK,从https://flutter.dev/docs/development/tools/s…

ORB_SLAM2源码:ORBmatcher.cc

ORBmatcher.cc中的函数,主要实现(1)路标点和特征点的匹配(2D-3D点对)。(2)特征点和特征点的匹配(2D-2D点对)。SearchByProjection的函数重载看得我一脸懵逼。在这做一下笔…

iOS国际化技巧

参考链接:http://www.cocoachina.com/ios/20151120/14258.html http://www.jianshu.com/p/88c1b65e3ddb http://www.cnblogs.com/levilinxi/p/4296712.html http://www.cocoachina.com/appstore/20160310/15632.html http://www.cocoachina.com/ios/20170214/18681.html转载于:…

CV08-数据预处理与数据增强

复现车道线分割项目(Lane Segmentation赛事说明在这里),学习数据预处理和数据增强。学习分为Model、Data、Training、Inference、Deployment五个阶段,也就是建模、数据、训练、推断、部署这五个阶段。现在进入的是Data阶段。项目的…

ORB_SLAM2程序入口(System.cc)

程序入口 ORB_SLAM2的程序入口为src/System.cc。在CMakeList.txt中可知,ORB_SLAM2的可执行程序为: Examples/Stereo/stereo_kitti.cc等。 add_executable(stereo_kitti Examples/Stereo/stereo_kitti.cc) target_link_libraries(stereo_kitti ${PROJECT…

HDU 6229 Wandering Robots 找规律+离散化

题目链接:Wandering Robots 题解:先讲一下规律,对于每一个格子它可以从多少个地方来有一个值(可以从自己到自己),然后答案就是统计合法格子上的数与所有格子的数的比值 比如说样例的3 0格子上的值就是 3 4 …

app、H5、safari、appstore应用主页评分页之间拉起调用、打开手机某些系统功能、app打开文档

定义打开URL的方法 - (void)openURL:(NSString *)urlStr {NSURL *url [NSURL URLWithString:urlStr];UIApplication *app [UIApplication sharedApplication];if ([app canOpenURL:url]) { #ifdef __IPHONE_10_0[app openURL:url options:[NSDictionary dictionary] complet…

XML学习总结

1、XML结构 2、XmlNodeType值为一个枚举类型: 假设我们对一个XML文件进行遍历,不推断节点是否为Element类型。就会将文本节点遍历出来,出现#test。 3、XmlElement和XmlNode的差别:(摘自CSDN论坛) &#xff…

Linux01-基本操作与Shell

目录 一、环境 二、Linux目录结构及基本操作 2.1 Linux目录结构 2.2 基本操作 三、shell 3.1 shell的意义 3.2 su - 一、环境 2019年搞下RHCE的证书,但是一直没有整理Linux学习的笔记,为了不让到手的知识被遗忘,从今天起整理Linux学习…

ORB_SLAM2中Tracking线程的三种追踪方式

1、参考关键帧追踪模式 bool Tracking::TrackReferenceKeyFrame()对参考关键帧中的路标点进行跟踪。在Tracking线程中,每传入一帧,都会进行位姿优化。 以上一帧的位姿为当前位姿进行优化。 (1)计算当前帧的词袋 mCurrentFra…

nodejs 中间件 反向代理 接口转发

背景 随着后端业务系统的增加,纵向需求不断扩展,一个业务系统已经无法满足需求了,衍生出多个业务系统,对外暴露的ip、端口就可能有多个,此时不方便外部接口调用,有些特殊行业客户出于安全性考虑不发提供多…

oneinstack

https://oneinstack.com/转载于:https://www.cnblogs.com/diyunpeng/p/9740895.html

最近在做托盘时,发现 CnTrayIcon1的OnClick 事件,不能被其它按钮来执行,蛋疼。...

比如: procedure TForm1.Button1Click(Sender: TObject);begin CnTrayIcon1.OnClick ; // 这句就是不能通过!!end; 有过路的高手,指点学生一下。谢谢转载于:https://www.cnblogs.com/hahy8008/p/6783614.html

Linux02-帮助手册

目录 一、man手册 1.1 man的基本使用 1.2 mandb更新文档 二、/usr/share/doc 三、access.redhat.com 门户 一、man手册 1.1 man的基本使用 man就是mannual的缩写,手册的意思。Linux的命令很多,参数选项更多,人脑一般是记不住的&…

ORB_SLAM2中Tracking线程

Tracking线程是ORB_SLAM2的主线程。在System.cc中,使用构造函数进行了初始化,开启了三个线程。 可执行程序—>System构造函数(初始化三个线程)—>处理输入的帧(TrackMonocular)—>调用Tracking线程…

selenium的基础知识点

from selenium import webdriver from scrapy.selector import Selector#模拟登陆 browser webdriver.Chrome(executable_pathChromedriver.exe) #路径是Chromedriver.exe的存放位置,windows下只要配置好这个环境就不需要了browser.get(http://w) #需要登陆的那个网…

iOS 直播专题2-音视频采集

从设备(手机)的摄像头、MIC中采集音频、视频的原始数据ios的音视频采集可以从AVFoundation框架里采集 视频采集 这里我们选取GPUImage来采集视频,因为这个框架集成了很多视频滤镜,例如美颜 采集流程: 摄像头采集视频代码 GPUImageVideoCamera.m // 从前摄像头或后摄像头…

bzoj 4871: [Shoi2017]摧毁“树状图”

4871: [Shoi2017]摧毁“树状图” Time Limit: 25 Sec Memory Limit: 512 MBSubmit: 53 Solved: 9[Submit][Status][Discuss]Description 自从上次神刀手帮助蚯蚓国增添了上千万人口(蚯口?),蚯蚓国发展得越来越繁荣了&#xff01…

Linux03-本地账户和组

目录 一、本地账户/etc/passwd 二、本地组/etc/group 三、切换账户su - 四、增删改本地账户useradd、userdel、usermod 五、账户默认配置文件/etc/login.defs 六、设置密码passwd(5)命令 七、增删改组groupadd、groupdel和groupmod 八、通过sudo以root身份运行命令 九…

ORB_SLAM2单目初始化策略

基本流程 单目初始化程序存储在Initializer.cc中   需要注意,对于双目/RGB-D相机,初始化时,由于可以直接获得相机的深度信息,因此无需求H/F,直接作为关键帧插入就行。   使用RANSACDLT求解H,RANSAC八点…