当前位置: 首页 > 编程日记 > 正文

PyTorch中nn.Module类简介

torch.nn.Module类是所有神经网络模块(modules)的基类,它的实现在torch/nn/modules/module.py中。你的模型也应该继承这个类,主要重载__init__、forward和extra_repr函数。Modules还可以包含其它Modules,从而可以将它们嵌套在树结构中。

只要在自己的类中定义了forward函数,backward函数就会利用Autograd被自动实现。只要实例化一个对象并传入对应的参数就可以自动调用forward函数。因为此时会调用对象的__call__方法,而nn.Module类中的__call__方法会调用forward函数。

nn.Module类中函数介绍:

__init__:初始化内部module状态。

register_buffer:向module添加buffer,不作为模型参数,可作为module状态的一部分。默认情况下,buffer是持久(persistent)的,将与参数一起保存。buffer是否persistent的区别在于这个buffer是否被放入self.state_dict()中被保存下来。

register_parameter:向module添加参数。

add_module:添加一个submodule(children)到当前module中。

apply:将fn递归应用于每个submodule(children),典型用途为初始化模型参数。

cuda:将所有模型参数和buffers转移到GPU上。

xpu:将所有模型参数和buffers转移到XPU上。

cpu:将所有模型参数和buffers转移到CPU上。

type:将所有参数和buffers转换为所需的类型。

float:将所有浮点参数和buffers转换为float32数据类型。

double:将所有浮点参数和buffers转换为double数据类型。

half:将所有浮点参数和buffers转换为float16数据类型。

bfloat16:将所有浮点参数和buffers转换为bfloat16数据类型。

to:将参数和buffers转换为指定的数据类型或转换到指定的设备上。

register_backward_hook:在module中注册一个反向钩子。不推荐使用。

register_full_backward_hook:在module中注册一个反向钩子。每次计算梯度时都会调用此钩子。使用此钩子时不允许就地(in place)修改输入或输出,否则会触发error。

register_forward_pre_hook:在module中注册前向pre-hook。每次调用forward之前都会调用此钩子。

register_forward_hook:在module中注册一个前向钩子。每次forward计算输出后都会调用此钩子。

state_dict:返回包含了module的整个状态的字典。其中keys是对应的参数和buffer名称。

load_state_dict:将参数和buffers从state_dict复制到module及其后代(descendants)中。

parameters:返回module的参数的迭代器。

named_parameters:返回module的参数的迭代器,产生(yield)参数的名称以及参数本身。不会返回重复的parameter。

buffers:返回module的buffers的迭代器。

named_buffers:返回module的buffers的迭代器,产生(yield)buffer的名称以及buffer本身。不会返回重复的buffer。

children:返回直接子module的迭代器。

named_children:返回直接子module的迭代器,产生(yield)子module的名称以及子module本身。不会返回重复的children。

modules:返回网络中所有modules的迭代器。

named_modules:返回网络中所有modules的迭代器,产生(yield)module的名称以及module本身。不会返回重复的module。

train:将module设置为训练模式。这仅对某些module起作用。module.py实现中会修改self.training并通过self.children()来调整所有submodule的状态。

eval:将module设置为评估模式。这仅对某些module起作用。module.py实现中直接调用train(False)。

requires_grad_:更改autograd是否应记录对此module中参数的操作。此方法就地(in place)设置参数的requires_grad属性。

zero_grad:将所有模型参数的梯度设置为零。

share_memory:

extra_repr:设置module的额外表示。你应该在自己的modules中重新实现此方法。

测试代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F # nn.functional.py中存放激活函数等的实现@torch.no_grad()
def init_weights(m):print("xxxx:", m)if type(m) == nn.Linear:m.weight.fill_(1.0)print("yyyy:", m.weight)class Model(nn.Module):def __init__(self):# 在实现自己的__init__函数时,为了正确初始化自定义的神经网络模块,一定要先调用super().__init__super(Model, self).__init__()self.conv1 = nn.Conv2d(1, 20, 5) # submodule(child module)self.conv2 = nn.Conv2d(20, 20, 5)self.add_module("conv3", nn.Conv2d(10, 40, 5)) # 添加一个submodule到当前module,等价于self.conv3 = nn.Conv2d(10, 40, 5)self.register_buffer("buffer", torch.randn([2,3])) # 给module添加一个presistent(持久的) bufferself.param1 = nn.Parameter(torch.rand([1])) # module参数的tensorself.register_parameter("param2", nn.Parameter(torch.rand([1]))) # 向module添加参数# nn.Sequential: 顺序容器,module将按照它们在构造函数中传递的顺序添加,它允许将整个容器视为单个moduleself.feature = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))self.feature.apply(init_weights) # 将fn递归应用于每个submodule,典型用途为初始化模型参数self.feature.to(torch.double) # 将参数数据类型转换为doublecpu = torch.device("cpu")self.feature.to(cpu) # 将参数数据转换到cpu设备上def forward(self, x):x = F.relu(self.conv1(x))return F.relu(self.conv2(x))model = Model()
print("## Model:", model)model.cpu() # 将所有模型参数和buffers移动到CPU上
model.float() # 将所有浮点参数和buffers转换为float数据类型
model.zero_grad() # 将所有模型参数的梯度设置为零# state_dict:返回一个字典,保存着module的所有状态,参数和persistent buffers都会包含在字典中,字典的key就是参数和buffer的names
print("## state_dict:", model.state_dict().keys())for name, parameters in model.named_parameters(): # 返回module的参数(weight and bias)的迭代器,产生(yield)参数的名称以及参数本身print(f"## named_parameters: name: {name}; parameters size: {parameters.size()}")for name, buffers in model.named_buffers(): # 返回module的buffers的迭代器,产生(yield)buffer的名称以及buffer本身print(f"## named_buffers: name: {name}; buffers size: {buffers.size()}")# 注:children和modules中重复的module只被返回一次
for children in model.children(): # 返回当前module的child module(submodule)的迭代器print("## children:", children)for name, children in model.named_children(): # 返回直接submodule的迭代器,产生(yield) submodule的名称以及submodule本身print(f"## named_children: name: {name}; children: {children}")for modules in model.modules(): # 返回当前模型所有module的迭代器,注意与children的区别print("## modules:", modules)for name, modules in model.named_modules(): # 返回网络中所有modules的迭代器,产生(yield)module的名称以及module本身,注意与named_children的区别print(f"## named_modules: name: {name}; module: {modules}")model.train() # 将module设置为训练模式
model.eval() # 将module设置为评估模式print("test finish")

GitHub:https://github.com/fengbingchun/PyTorch_Test

相关文章:

什么是三层交换机、网关、DNS、子网掩码、MAC地址

一、什么是vlan? 二、单臂路由与三层交换机 三、什么是网关 一、什么是网关 二、如何来理解网关 三、网关的ip地址 四、网关是如何实现通信? 五、什么是默认网关? 四、什么是DNS 五、MAC地址 六、子网掩码 很多朋友多次问到什么是网关、dns、子网掩码&…

20行代码发一篇NeurIPS:梯度共享已经不安全了

整理 | 夕颜,Jane出品 | AI科技大本营(ID:rgznai100)【导读】12 月 8 日-14 日,NeurIPS 2019 在加拿大温哥华举行,和往常一样,今年大会吸引了数万名专家参会,并展示了计算机领域的最新进展。其中…

关于页面打印window.print()的样式问题

当我们打印网页的时候。有时候会发现。打印出来的。跟网页上看到的样式的差别有点大。这其中可能有的问题是。样式问题。 当调用打印(window.print())方法时。打印机会在网页的样式中查找 media print{}的样式,并适应到要打印的网页中。 所以 如果要打印的页面符合看…

Python3中参数*args和**kwargs介绍

在Python中,我们可以使用两种特殊符号将可变数量的参数传递给函数:*args和**kwargs。你可以使用任何单词代替args和kwargs,但通常做法是使用args和kwargs。 *args允许函数接受任意数量的位置参数(positional arguments)。 **kwargs收集所有未…

4大主流CPU处理器技术架构,不知道就out了!

作者 | 王艺威责编 | 阿秃RISC(精简指令集计算机)是一种执行较少类型计算机指令的微处理器,起源于80年代的MIPS主机(即RISC机),RISC机中采用的微处理器统称RISC处理器。这样一来,它能够以更快的…

grunt-connect-proxy解决开发时跨域问题

最近的项目中前后端是完全分离开发的,前端用grunt管理项目。这样就会导致一个问题:开发时前端调用后台的接口时因为不在一个服务器,所以会出现跨域问题。但是也不能用JSONP或CROS方式实现真正的跨域,因为项目发布时其实是在同一个…

混合推荐系统就是多个推荐系统“大杂烩”吗?

作者丨gongyouliu编辑丨zandy【导读】在本篇文章中,我们会介绍混合推荐系统(Hybrid Recommender Systems),就是利用多种推荐算法配合起来做推荐,期望避免单个推荐算法存在的问题,最终获得比单个算法更好的推荐效果。本篇文章我们从…

Python3中collections.OrderedDict介绍

Python3中的collections模块实现了特定目标的容器,以提供Python标准内建容器dict、list、set和tuple的替代选择,包括namedtuple、deque、ChainMap、Counter、OrderedDict、defaultdict、UserDict、UserList、UserString。这里介绍下OrderedDict&#xff…

汗!雅虎中国个人空间

今天发现雅虎中国有了个人空间,偷偷试了下,让人失望到极点,几乎没有什么特点,和MSN很相似,空间相册放着好好的Flickr不用,偏偏弄了个很垃圾的相册,还有整合能力也不行。都不知道del.icio.us和Fl…

关于v$process与v$session中process的理解

v$session有个process字段,V$PROCESS有个SPID字段,这两个字段是不是一个意思呢?是不是都代表会话的操作系统进程呢?官方文档上的解释:SPID VARCHAR2(12) Operating system process identifierPROCESS VARCHAR2…

Python3中lambda表达式介绍

Python3中的lambda表达式或lambda函数是匿名函数(anonymous function),意味着该函数没有名称。def关键字用于在Python3中创建一个普通函数,类似地,lambda关键字用于在Python3中创建匿名函数。 Python3 lambda函数语法: lambda pa…

6大理由,告诉你为什么这个大会你不能错过! | 文末有福利

作者 | Carol出品 | 区块链大本营(blockchain_camp)* 文末可参与活动赢赠票!如果说有一个什么领域,能让中科院、华为、腾讯、京东、360、微众银行的大咖汇聚在一起,那一定是——区块链。悄咪咪地给大家剧透一下&#x…

魔与道的反复较量 反垃圾邮件技术

反垃圾邮件武器库不同的反垃圾邮件产品采用的技术有所不同,但总体来说,不外乎以下几种技术,其中,针对垃圾邮件的核心技术有贝叶斯智能分析、垃圾邮件评分、垃圾邮件指纹识别。转载于:https://blog.51cto.com/aonlin/17074

在Centos 7下编译openwrt+njit-client

首先要有一个centos7 step1:更新系统的源: yum install update 但是发现官方的源好像被墙了,于是自己又去换源,找163的源换。具体的操作最后的链接。 可是换完源之后发现163的源只支持到centos6、、、、、、但是就泪崩了。于是又把源换了回来…

Python3中内置函数callable介绍

Python3中的内置函数callable接受一个对象参数,如果此对象参数看起来可调用,则callable函数返回True,否则返回False。如果返回True,则调用仍有可能失败;但如果返回False,则调用对象将永远不会成功。 类是可…

户外广告新创意

近来,各大城市纷纷加大了对户外广告的监管力度,部分城市甚至停止审批户外广告牌。这让户外广告运营者和广告发布商甚为头疼。 长期以来,户外广告牌扮演着截然相反的“双重角色”,在户外广告运营者和广告发布商眼中,“寸…

百度重新定义「智能屏」,瞄准10后

加入「公开课」交流群,获取更多学习资料、课程及热招岗位等信息记者 | 阿司匹林作为中国智能音箱主力推手中的一员,百度从 2017 年已经开始布局。根据数据机构Strategy Analytics发布智能音箱市场报告,2019年第三季度,百度旗下人工…

jQuery最简单的表单提交方式

第一步:绑定事件 常用的与ajax相关的事件参考如下: 1、$(selector).click(function) 2、$(selector).change(function) 3、$(selector).keyup(function) 4、$(selector).submit(function) 提交表单前&#…

Python3中typing模块介绍

typing.py的源码在:https://github.com/python/cpython/blob/main/Lib/typing.py。此模块为类型提示(Type Hints)提供运行时支持(This module provides runtime support for type hints)。从python 3.5版本开始将Typing作为标准库引入。 python3中增加了Function An…

显示所有文件和文件夹无论如何 无法被设置

问题:XP系统选显示所有文件和文件夹确定后没有任何反应再次打开文件夹选项里面仍是不显示隐藏的文件和文件夹 答案:在记事本粘贴下面文字,另存为所有文件, .reg 格式。成功的话图标变为绿色碎方块。在双击它。 Windows Registry E…

工作5年后才明白的道理:不起眼的技能中,藏着你的未来

编程圈儿一直都流传着一个调侃的段子:一流程序员靠数学二流靠算法三流靠逻辑四流靠SDK五流靠Google和StackOverFlow六流靠百度和自己琢磨低端的看高端的就是黑魔法!从过来人的角度看,这不仅仅是个段子,而是目前程序员的真实写照。…

Transform-style和Perspective属性

在《CSS3 Transform——transform-origin》一文中主要介绍了CSS3 Transform属性中的transform-origin属性的使用,其实在transform属性中,transform-origin属性仅是其中之一,要彻底理解transform属性,这是不够的,必须的…

Python3中__call__方法介绍

如果Python3类中有__call__方法,那么此类实例的行为类似于函数并且可以像函数一样被调用。当实例作为函数被调用时,如果定义了此方法,则x(arg1, arg2, …)是x.__call__(arg1, arg2, …)的简写。 为了将一个类实例当作函数调用,我们…

切尔西携手YouTube 英超第一家共享视频球队诞生

英格兰超级足球联赛冠军球队切尔西日前表示,已经与互联网视频服务网站YouTube签订了合作协议,未来将通过YouTube发布每日新闻和视频内容,从而也成为英超首支在线视频服务的球队. 据路透社报道,根据协议的内容,切尔西将建立一个YouTube旗下的品牌网站,其中将发布每日更新内容,当…

商汤联手华科:提出文字检测模型GNNets,新颖模块可解决几何分布难题

加入「公开课」交流群,获取更多学习资料、课程及热招岗位等信息编辑 | Jane出品 | AI科技大本营(ID:rgznai100)【导读】今年的ICCV,商汤科技及联合实验室共有57篇论文入选ICCV 2019(包含11篇Oral&#xff0…

(链表)反转链表Reverse List

逆转链表是简单而又简单的链表问题,其问题的方法之一可以设置三个指针,一个指向当前结点,一个指向前驱结点,一个指向后继指针 代码如下: class Solution { public:ListNode* ReverseList(ListNode* pHead) { // if(pHeadNULL || pHead->nextNULL) // return pH…

很长时间没有来了

好长时间没有来自己的博客了,更新的速度实在是太慢了,自己已经找了一份新的工作,给自己一个好的环境吧,有时间可以去学习更多的网络知识了.学习万岁!加一下,博友:思念狗的骨头:[url]http://starger.blog.51cto.com/ [/url] 他的文章还是比较不错的!转载于:https://blog.51cto.c…

十年磨一剑,可重构计算架构将引领未来芯片市场

2019 年 6 月,AI 芯片创业公司清微智能首款可重构计算架构 AI 芯片实现量产的消息在业内迅速传开,可重构计算架构芯片再次引发一波讨论的热潮。经历过十多年的技术积累,这枚小小的芯片在全球芯片市场中开启了全新的篇章。时光倒流&#xff0c…

PyTorch中nn.Module类中__call__方法介绍

在PyTorch源码的torch/nn/modules/module.py文件中,有一条__call__语句和一条forward语句,如下: __call__ : Callable[…, Any] _call_impl forward: Callable[…, Any] _forward_unimplemented 在PyTorch中nn.Module类是所有神经网络模块…

压缩和归档及vi的使用

1.cat(more less head tail) /etc/passwd :查看/etc/passwd文件内容2.head -13 /etc/passwd | tail -1 :只查看/etc/passwd文件中第13行3.wc -l /etc/passwd :统计/etc/passwd文件有多少行4.grep -v "^#" /etc/inittab | grep -v &…