当前位置: 首页 > 编程日记 > 正文

einsum,一个函数走天下

640?wx_fmt=jpeg

作者 | 永远在你身后

转载自知乎

【导读】einsum 全称 Einstein summation convention(爱因斯坦求和约定),又称为爱因斯坦标记法,是爱因斯坦 1916 年提出的一种标记约定,本文主要介绍了einsum 的应用。

简单的说,应用 einsum 就是省去求和式中的求和符号,例如下面的公式:

640?wx_fmt=png

以 einsum 的写法就是:

640?wx_fmt=png

后者将 640?wx_fmt=png 符号给省去了,显得更加简洁;再比如:

640?wx_fmt=png

640?wx_fmt=png

上面两个栗子换成 einsum 的写法就变成:

640?wx_fmt=png

在实现一些算法时,数学表达式已经求出来了,需要将之转换为代码实现,简单的一些还好,有时碰到例如矩阵转置、矩阵乘法、求迹、张量乘法、数组求和等等,若是以分别以 transopse、sum、trace、tensordot 等函数实现的话,不但复杂,还容易出错。

现在,这些问题你统统可以一个函数搞定,没错,就是 einsum,einsum 函数就是根据上面的标记法实现的一种函数,可以根据给定的表达式进行运算,可以替代但不限于以下函数:

矩阵求迹:trace求矩阵对角线:diag张量(沿轴)求和:sum张量转置:transopose矩阵乘法:dot张量乘法:tensordot向量内积:inner外积:outer

该函数在 numpy、tensorflow、pytorch 上都有实现,用法基本一样,定义如下:

equation 是字符串的表达式,operands 是操作数,是一个元组参数,并不是只能有两个,所以只要是能够通过 einsum 标记法表示的乘法求和公式,都可以用一个 einsum 解决,下面以 numpy 举几个栗子:

# 沿轴计算张量元素之和:	
c = a.sum(axis=0)

上面的以 sum 函数的实现代码,设 640?wx_fmt=png为三维张量,上面代码用公式来表达的话就是:

640?wx_fmt=png

换成 einsum 标记法:

640?wx_fmt=png

然后根据此式使用 einsum 函数实现等价功能:

c = np.einsum('ijk->jk', a)	
# 作用与 c = a.sum(axis=0) 一样

更进一步的,如果 640?wx_fmt=png 不止是三维,可以将下标 640?wx_fmt=png 换成省略号,以表示剩下的所有维度:

这种写法 pytorch 与 tensorflow 同样支持,如果不是很理解的话,可以查看其对应的公式:

640?wx_fmt=png

# 矩阵乘法	
c = np.dot(a, b)

矩阵乘法的公式为:

640?wx_fmt=png

然后是 einsum 对应的实现:

最后再举一个张量乘法栗子:

# 张量乘法	
c = np.tensordot(a, b, ([0, 1], [0, 1]))

如果 640?wx_fmt=png 是三维的,对应的公式为:

640?wx_fmt=png

对应的 einsum 实现:

下面以 numpy 做一下测试,对比 einsum 与各种函数的速度,这里使用 python 内建的 timeit 模块进行时间测试,先测试(四维)两张量相乘然后求所有元素之和,对应的公式为:

640?wx_fmt=png

然后是测试代码:

from timeit import Timer	
import numpy as np	# 定义两个全局变量	
a = np.random.rand(64, 128, 128, 64)	
b = np.random.rand(64, 128, 128, 64)	# 定义使用einsum与sum的函数	
def einsum():	temp = np.einsum('ijkl,ijkl->', a, b)	def npsum():	temp = (a * b).sum()	# 打印运行时间	
print("einsum cost:", Timer("einsum()", "from __main__ import einsum").timeit(20))	
print("npsum cost:", Timer("npsum()", "from __main__ import npsum").timeit(20))

上面 Timer 是 timeit 模块内的一个类

Timer(stmt, setup).timeit(number)	# stmt: 要测试的语句	# setup: 传入stmt的运行环境,比如stmt中要导入的模块等。	# 可以写一行语句,也可以写多行语句,写多行语句时要用分号;隔开语句	# number: 执行次数

将两个函数各执行 20 遍,最后的结果为,单位为秒:

einsum cost: 1.5560735	
npsum cost: 8.0874927

可以看到,einsum 比 sum 快了几乎一个量级,接下来测试单个张量求和:

将上面的代码改一下:

def einsum():	temp = np.einsum('ijkl->', a)	def npsum():	temp = a.sum()

相应的运行时间为:

einsum cost: 3.2716003	
npsum cost: 6.7865246

还是 einsum 更快,所以哪怕是单个张量求和,numpy 上也可以用 einsum 替代,同样,求均值(mean)、方差(var)、标准差(std)也是一样。

接下来测试 einsum 与 dot 函数,首先列一下矩阵乘法的公式以以及 einsum表达式:

640?wx_fmt=svg

640?wx_fmt=png

然后是测试代码:

a = np.random.rand(2024, 2024)	
b = np.random.rand(2024, 2024)	# einsum与dot比较	
def einsum():	res = np.einsum('ik,kj->ij', a, b)	def dot():	res = np.dot(a, b)	print("einsum cost:", Timer("einsum()", "from __main__ import einsum").timeit(20))	
print("dot cost:", Timer("dot()", "from __main__ import dot").timeit(20))	# einsum cost: 80.2403851	
# dot cost: 2.0842243

这就很尴尬了,比 dot 慢了 40 倍(并且差距随着矩阵规模的平方增加),这还怎么打天下?不过在 numpy 的实现里,einsum 是可以进行优化的,去掉不必要的中间结果,减少不必要的转置、变形等等,可以提升很大的性能,将 einsum 的实现改一下:

def einsum():	res = np.einsum('ik,kj->ij', a, b, optimize=True)

加了一个参数 optimize=True,官方文档上该参数是可选参数,接受4个值:

optimize 默认为 False,如果设为 True,这默认选择‘greedy(贪心)’方式,再看看速度:

einsum cost: 2.0330937	
dot cost: 1.9866218

可以看到,通过优化,虽然还是稍慢一些,但是 einsum 的速度与 dot 达到了一个量级;不过 numpy 官方手册上有个 einsum_path,说是可以进一步提升速度,但是我在自己电脑上(i7-9750H)测试效果并不稳定,这里简单的介绍一下该函数的用法为:

path = np.einsum_path('ik,kj->ij', a, b)[0]	
np.einsum('ik,kj->ij', a, b, optimize=path)

einsum_path 返回一个 einsum 可使用的优化路径列表,一般使用第一个优化路径;另外,optimize 及 einsum_path 函数只有 numpy 实现了, tensorflow 和 pytorch 上至少现在没有。

最后,再测试 einsum 与另一个常用的函数 tensordot,首先定义两个四维张量的及 tensordot 函数:

a = np.random.rand(128, 128, 64, 64)	
b = np.random.rand(128, 128, 64, 64)	def tensordot():	res = np.tensordot(a, b, ([0, 1], [0, 1]))

该实现对应的公式为:

640?wx_fmt=png

所以 einsum 函数的实现为:

def einsum():	res = np.einsum('ijkl,ijmn->klmn', a, b, optimize=True)

tensordot 也是链接到 BLAS 实现的函数,所以不加 optimize 肯定比不了,最后结果为:

print("einsum cost:", Timer("einsum()", "from __main__ import einsum").timeit(1))	
print("tensordot cost:", Timer("tensordot()", "from __main__ import tensordot").timeit(1))	# einsum cost: 4.2361331	
# tensordot cost: 4.2580409

测试了 10 多次,基本上速度一样,einsum 表现好一点的;不过说是一个函数打天下,肯定是做不到的,还有一些数组的分割、合并、指数、对数等功能没法实现,需要使用别的函数,其他的基本都可以用 einsum 来实现,简单而又高效。

经过进一步测试发现,优化反而出现速度降低的情况,例如:

def einsum():	temp = einsum('...->', a, optimize=True)	def test():	temp = a.sum()

上面两中对数组求和的方法,当a是一维向量时,或者 a 是多维但是规模很小是,优化的 einsum 反而更慢,但是去掉 optimize 参数后表现比内置的 sum函数稍好,我认为优化是有一个固定的成本。

还有一个坑需要注意的是,有些情况的省略号不加 optimize 会报错,就拿上面的栗子而言:

np.einsum('...->', a, optimize=True)   # 正常运行	
np.einsum('...->', a)   # 报错

很无奈,试了很多次,不加 optimize 就是会报错,但是并不是所有的省略号写法都需要加 optimize ,例如:

640?wx_fmt=png

640?wx_fmt=png

使用省略号实现上面两个公式并不需要加 optimize ,能够正常运行

np.einsum('i...->...', a)   # 正常	
np.einsum('...,...->...', a, b)   # 正常

但是如果碰到下面的公式:

640?wx_fmt=png

上式表示将 a 除第一个维度之外,剩下的维度全部累加,这种实现就必须要加 optimize。

再举一个栗子:

c = (a * b).sum()	
# 如果不知道a, b的维数,使用einsum实现上面的功能也必须要加optimize	
c = einsum('...,...->', a, b, optimize=True)

总结一下,在计算量很小时,优化因为有一定的成本,所以速度会慢一些;但是,既然计算量小,慢一点又怎样呢,而且使用优化之后,可以更加肆意的使用省略号写表达式,变量的维数也不用考虑了,所以建议无脑使用优化。

原文链接:

https://zhuanlan.zhihu.com/p/71639781

(*本文为AI科技大本营转载文章,转载请联系作者)

福利时刻

距离大会参与通道关闭还有 1 天,扫描下方二维码或点击阅读原文,马上参与!(学生票特享 598 元,团购票每人立减优惠,倒计时 1 天!)

640?wx_fmt=jpeg

推荐阅读

  • 从垃圾分类到千行百业,如何打响AI“落地战”?

  • 2亿日活,日均千万级视频上传,快手推荐系统如何应对技术挑战

  • 在图数据上做机器学习,应该从哪个点切入?

  • Docker容器化部署Python应用

  • AI 假冒老板骗取 24.3 万美元

  • 编程吸金榜:你排第几?网友神回应了!

  • 吴子宁:手握 280 多项专利的斯坦福技术先锋 | 人物志

  • 阿里云 CDN 业务基于边缘容器的云原生转型实践


640?wx_fmt=png

你点的每个“在看”,我都认真当成了喜欢

相关文章:

常用排序算法的C++实现

排序是将一组”无序”的记录序列调整为”有序”的记录序列。假定在待排序的记录序列中,存在多个具有相同的关键字的记录,若经过排序,这些记录的相对次序保持不变,即在原序列中,rirj,且ri在rj之前&#xff0…

4.65FTP服务4.66测试登录FTP

2019独角兽企业重金招聘Python工程师标准>>> FTP服务 测试登录FTP 4.65FTP服务 文件传输协议(FTP),可以上传和下载文件。比如我们可以把Windows上的文件shan上传到Linux,也可以把Linux上的文件下载到Windows上。 Cent…

JavaScript的应用

DOM, BOM, XMLHttpRequest, Framework, Tool (Functionality) Performance (Caching, Combine, Minify, JSLint) ---------------- 人工做不了,交给程序去做,这样可以流程化。 Maintainability (Pattern) http://www.jmarshall.com/easy/http/ http://dj…

miniz库简介及使用

miniz:Google开源库,它是单一的C源文件,紧缩/膨胀压缩库,使用zlib兼容API,ZIP归档读写,PNG写方式。关于miniz的更详细介绍可以参考:https://code.google.com/archive/p/miniz/miniz.c is a loss…

iOS之runtime详解api(三)

第一篇我们讲了关于Class和Category的api,第二篇讲了关于Method的api,这一篇来讲关于Ivar和Property。 4.objc_ivar or Ivar 首先,我们还是先找到能打印出Ivar信息的函数: const char * _Nullable ivar_getName(Ivar _Nonnull v) …

亚马逊首席科学家李沐「实训营」国内独家直播,马上报名 !

开学了,别人家的学校都开始人工智能专业的学习之旅了,你呢?近年来,国内外顶尖科技企业的 AI 人才抢夺战愈演愈烈。华为开出200万年薪吸引 AI 人才,今年又有 35 所高校新增人工智能本科专业,众多新生即将开展…

人脸检测库libfacedetection介绍

libfacedetection是于仕琪老师放到GitHub上的二进制库,没有源码,它的License是MIT,可以商用。目前只提供了windows 32和64位的release动态库,主页为https://github.com/ShiqiYu/libfacedetection,采用的算法好像是Mult…

倒计时1天 | 2019 AI ProCon报名通道即将关闭(附参会指南)

2019年9月5-7日,面向AI技术人的年度盛会—— 2019 AI开发者大会 AI ProCon,震撼来袭!2018 年由 CSDN 成功举办 AI 开发者大会一年之后,全球 AI 市场正发生着巨大的变化。顶尖科技企业和创新力量不断地进行着技术的更迭和应用的推…

法院判决:优步无罪,无人车安全员可能面临过失杀人控诉

据路透社报道,负责优步无人车在亚利桑那州致人死亡事件调查的律师事务所发布公开信宣布,优步在事故中“不承担刑事责任”,但是当时在车上的安全员Rafaela Vasquez要接受进一步调查,可能面临车辆过失杀人罪指控。2018年3月&#xf…

09 Storage Structure and Relationships

目标:存储结构:Segments分类:Extents介绍:Blocks介绍:转载于:https://blog.51cto.com/eread/1333894

边界框的回归策略搞不懂?算法太多分不清?看这篇就够了

作者 | fivetrees来源 | https://zhuanlan.zhihu.com/p/76477248本文已由作者授权,未经允许,不得二次转载【导读】目标检测包括目标分类和目标定位 2 个任务,目标定位一般是用一个矩形的边界框来框出物体所在的位置,关于边界框的回…

人脸识别引擎SeetaFaceEngine简介及在windows7 vs2013下的编译

SeetaFaceEngine是开源的C人脸识别引擎,无需第三方库,它是由中科院计算所山世光老师团队研发。它的License是BSD-2.SeetaFaceEngine库包括三个模块:人脸检测(detection)、面部特征点定位(alignment)、人脸特征提取与比对(identification)。人…

当移动数据分析需求遇到Quick BI

我叫洞幺,是一名大型婚恋网站“我在这等你”的资深老员工,虽然在公司五六年,还在一线搬砖。“我在这等你”成立15年,目前积累注册用户高达2亿多,在我们网站成功牵手的用户达2千多万。目前我们的公司在CEO的英名带领下&…

为什么选择数据分析师这个职业?

我为什么选择做数据分析师? 我大学专业是物流管理,学习内容偏向于管理学和经济学,但其实最感兴趣的还是心理学,即人在各种刺激下反应的机制以及原理。做数据分析师,某种意义上是对群体行为的研究和量化,两者…

人脸识别引擎SeetaFaceEngine中Detection模块使用的测试代码

人脸识别引擎SeetaFaceEngine中Detection模块用于人脸检测&#xff0c;以下是测试代码&#xff1a;int test_detection() {std::vector<std::string> images{ "1.jpg", "2.jpg", "3.jpg", "4.jpeg", "5.jpeg", "…

基于Pygame写的翻译方法

发布时间&#xff1a;2018-11-01技术&#xff1a;pygameeasygui概述 实现一个翻译功能&#xff0c;中英文的互相转换。并可以播放翻译后的内容。 翻译接口调用的是百度翻译的api接口。详细 代码下载&#xff1a;http://www.demodashi.com/demo/14326.html 一、需求分析 使用pyg…

冠军奖3万元!CSDN×易观算法大赛开赛啦

伴随着5G、物联网与大数据形成的后互联网格局的逐步形成&#xff0c;日益多样化的用户触点、庞杂的行为数据和沉重的业务体量也给我们的数据资产管理带来了不容忽视的挑战。为了建立更加精准的数据挖掘形式和更加智能的机器学习算法&#xff0c;对不断生成的用户行为事件和各类…

快速把web项目部署到weblogic上

weblogic简介 BEA WebLogic是用于开发、集成、部署和管理大型分布式Web应用、网络应用和数据库应 用的Java应用服务器。将Java的动态功能和Java Enterprise标准的安全性引入大型网络应用的开发、集成、部署和管理之中。 BEA WebLogic Server拥有处理关键Web应用系统问题所需的性…

使GDAL库支持中文路径或中文文件名的处理方法

之前生成的gdal 2.1.1动态库&#xff0c;在通过命令行执行时&#xff0c;遇到有中文路径或中文图像名时&#xff0c;GDALOpen函数不能正确的被调用&#xff0c;如下图&#xff1a;解决方法&#xff1a;1. 在所有使用GDALAllRegister();语句后面加上一句CPLSetConfigOption…

创新工场论文入选NeurIPS 2019,研发最强“AI蒙汗药”

9月4日&#xff0c;被誉为机器学习和神经网络领域的顶级会议之一的 NeurIPS 2019 揭晓收录论文名单&#xff0c;创新工场人工智能工程院的论文《Learning to Confuse: Generating Training Time Adversarial Data with Auto-Encoder》被接收在列。这篇论文围绕现阶段人工智能系…

Flutter环境搭建(Windows)

SDK获取 去官方网站下载最新的安装包 &#xff0c;或者在Github中的Flutter项目去 下载 。 将下载的安装包解压 注意&#xff1a;不要将Flutter安装到高权限路径&#xff0c;例如 C:\Program Files\ 配置环境变量&#xff0c;在Path中添加flutter\bin的全路径(如&#xff1a;D…

Android在eoe分享一篇推荐开发组件或者框架的文章

http://www.eoeandroid.com/thread-311194-1-1.html y4078275315 主题 62 帖子 352 e币实习版主 积分314发消息电梯直达楼主 回复 发表于 2013-11-7 09:58:45 | 只看该作者 |只看大图 34本帖最后由 y407827531 于 2013-11-28 15:07 编辑感谢版主推荐&#xff0c;本贴会持续更新…

如何打造高质量的机器学习数据集?这份超详指南不可错过

作者 | 周岩&#xff0c;夕小瑶&#xff0c;霍华德&#xff0c;留德华叫兽转载自知乎博主『运筹OR帷幄』导读&#xff1a;随着计算机行业的发展&#xff0c;人工智能和数据科学近几年成为了学术和工业界关注的热点。特别是这些年人工智能的发展日新月异&#xff0c;每天都有新的…

人脸识别引擎SeetaFaceEngine中Alignment模块使用的测试代码

人脸识别引擎SeetaFaceEngine中Alignment模块用于检测人脸关键点&#xff0c;包括5个点&#xff0c;两个眼的中心、鼻尖、两个嘴角&#xff0c;以下是测试代码&#xff1a;int test_alignment() {std::vector<std::string> images{ "1.jpg", "2.jpg"…

微软宣布 Win10 设备数突破8亿,距离10亿还远吗?

开发四年只会写业务代码&#xff0c;分布式高并发都不会还做程序员&#xff1f; >>> 微软高管 Yusuf Mehdi 昨天在推特发布了一条推文&#xff0c;宣布运行 Windows 10 的设备数已突破 8 亿&#xff0c;比半年前增加了 1 亿。 根据之前的报道&#xff0c;两个月前 W…

领导者必须学会做的十件事情

在这个世界上你可以逃避很多事情并且仍然能够获得成功&#xff0c;但是有些事情你根本就无法走捷径。商业领导者们并不存在于真空之中。成功总是与竞争有关。你可能拥有伟大的产品、服务、概念、战略、团队等等&#xff0c;如果它不能从某种程度上超越竞争对手&#xff0c;这对…

人脸识别引擎SeetaFaceEngine中Identification模块使用的测试代码

人脸识别引擎SeetaFaceEngine中Identification模块用于比较两幅人脸图像的相似度&#xff0c;以下是测试代码&#xff1a;int test_recognize() {const std::string path_images{ "E:/GitCode/Face_Test/testdata/recognization/" };seeta::FaceDetection detector(&…

李沐亲授加州大学伯克利分校深度学习课程移师中国,现场资料新鲜出炉

2019 年 9 月 5 日&#xff0c;AI ProCon 2019 在北京长城饭店正式拉开帷幕。大会的第一天&#xff0c;以亚马逊首席科学家李沐面对面亲自授课完美开启&#xff01;“大神”&#xff0c;是很多人对李沐的印象。除了是亚马逊首席科学家李&#xff0c;李沐还拥有多重身份&#xf…

对Python课的看法

学习Python已经有两周的时间了&#xff0c;我是计算机专业的学生&#xff0c;我抱着可以多了解一种语言的想法报了Python的选修课&#xff0c;从第一次听肖老师的课开始&#xff0c;我便感受到一种好久没有感受到的课堂氛围&#xff0c;感觉十分舒服&#xff0c;不再是那种高中…