CV03-双线性差值pytorch实现
一、双线性差值
1.1 公式
在理解双线性差值(Bilinear Interpolation)的含义基础上,参考pytorch差值的官方实现注释,自己实现了一遍。
差值就是利用已知点来估计未知点的值。一维上,可以用两点求出斜率,再根据位置关系来求插入点的值。
同理,在二维平面上也可以用类似的办法来估计插入点的值。如图,已知四点、
、
、
四点的值与坐标值
、
、
、
,求位于
的点
的值。思路是
- 先用w方向一维的线性差值,根据
、
求出点
,根据
、
求出点
;
- 再用h方向一维线性差值,根据
和
求出点
;
那么就有如下公式
具体到图像的双线性差值问题,我们可以理解成将图片进行了放大,但不使图像变成大块的斑点状,而是增大了图像的分辨率,多出来的像素就是双线性差值的结果。图像上周边4点一定是临近的,也就是说
上面的公式简化为
这样我们就面临将目标图像的坐标映射到原图像上求出
的问题。
1.2 坐标变换
对于第一个问题,目标图像的坐标映射到原图像上求出
,有两种思路。
第一种是把像素点看成是1×1大小的方块,像素点位于方块的中心,坐标转换时,HW方向的坐标都要加0.5才能对应起来。pytorch里面叫做torch.nn.functional.interpolate(align_corners=False)。
举例,如图原图像是一个3×3的图像,放大到5×5,每个像素点都是位于方形内的黑色小点。设是原图像的大小,本例是3×3,
是目标图像的大小,本例是5×5。换算公式为
第二种是上下左右相邻的像素点之间连线,像素点都位于交点上,坐标转换时,HW方向的总长度都要减少1才能对应起来g。pytorch里面叫做torch.nn.functional.interpolate(align_corners=True)。
举例,一个3×3的图像放大到5×5,每个像素点都是位于交点的黑色小点。设是原图像的大小,本例是3×3,
是目标图像的大小,本例是5×5。换算时,我们取边的长度,也就是HW方向各减1,也就是从2×2变成4×4。这样就有个结论就是变换以后目标图像四个顶点的像素值一定和原图像四个顶点像素值一样。换算公式为
二、for循环实现双线性差值(naive实现)
是对一张图像的,维度HWC,采用for循环遍历H、W计算差值点的像素值。这个实现too young,too simple,简直naive,效率低但易于理解;这里只实现了第一种坐标变换。
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import osdef bilinear_interpolation_naive(src, dst_size):"""双线性差值的naive实现:param src: 源图像:param dst_size: 目标图像大小H*W:return: 双线性差值后的图像"""(src_h, src_w, src_c) = src.shape # 原图像大小 H*W*C(dst_h, dst_w), dst_c = dst_size, src_c # 目标图像大小H*W*Cif src_h == dst_h and src_w == dst_w: # 如果大小不变,直接返回copyreturn src.copy()scale_h = float(src_h) / dst_h # 计算H方向缩放比scale_w = float(src_w) / dst_w # 计算W方向缩放比dst = np.zeros((dst_h, dst_w, dst_c), dtype=src.dtype) # 目标图像初始化for h_d, row in enumerate(dst): # 遍历目标图像H方向for w_d, col in enumerate(row): # 遍历目标图像所有W方向h = scale_h * (h_d + 0.5) - 0.5 # 将目标图像H坐标映射到源图像上w = scale_w * (w_d + 0.5) - 0.5 # 将目标图像W坐标映射到源图像上h0 = int(np.floor(h)) # 最近4个点坐标h0w0 = int(np.floor(w)) # 最近4个点坐标w0h1 = min(h0 + 1, src_h - 1) # h0 + 1就是h1,但是不能越界w1 = min(w0 + 1, src_w - 1) # w0 + 1就是w1,但是不能越界r0 = (w1 - w) * src[h0, w0, ...] + (w - w0) * src[h0, w1, ...] # 双线性差值R0r1 = (w1 - w) * src[h1, w0, ...] + (w - w0) * src[h1, w1, ...] # 双线性插值R1p = (h1 - h) * r0 + (h - h0) * r1 # 双线性插值Pdst[h_d, w_d, ...] = p.astype(np.uint8) # 插值结果放进目标像素点return dstif __name__ == '__main__':def unit_test():image_file = os.path.join(os.getcwd(), 'test.jpg')image = mpimg.imread(image_file)image_scale = bilinear_interpolation_naive(image, (256, 256))fig, axes = plt.subplots(1, 2, figsize=(8, 10))axes = axes.flatten()axes[0].imshow(image)axes[1].imshow(image_scale)axes[0].axis([0, image.shape[1], image.shape[0], 0])axes[1].axis([0, image_scale.shape[1], image_scale.shape[0], 0])fig.tight_layout()plt.show()passunit_test()
三、用numpy矩阵实现
是对一张图像的,维度HWC;采用numpy矩阵实现,速度快;
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import os
import torchdef bilinear_interpolation(src, dst_size, align_corners=False):"""双线性插值高效实现:param src: 源图像H*W*C:param dst_size: 目标图像大小H*W:return: 双线性插值后的图像"""(src_h, src_w, src_c) = src.shape # 原图像大小 H*W*C(dst_h, dst_w), dst_c = dst_size, src_c # 目标图像大小H*W*Cif src_h == dst_h and src_w == dst_w: # 如果大小不变,直接返回copyreturn src.copy()# 矩阵方式实现h_d = np.arange(dst_h) # 目标图像H方向坐标w_d = np.arange(dst_w) # 目标图像W方向坐标if align_corners:h = float(src_h - 1) / (dst_h - 1) * h_dw = float(src_w - 1) / (dst_w - 1) * w_delse:h = float(src_h) / dst_h * (h_d + 0.5) - 0.5 # 将目标图像H坐标映射到源图像上w = float(src_w) / dst_w * (w_d + 0.5) - 0.5 # 将目标图像W坐标映射到源图像上h = np.clip(h, 0, src_h - 1) # 防止越界,最上一行映射后是负数,置为0w = np.clip(w, 0, src_w - 1) # 防止越界,最左一行映射后是负数,置为0h = np.repeat(h.reshape(dst_h, 1), dst_w, axis=1) # 同一行映射的h值都相等w = np.repeat(w.reshape(dst_w, 1), dst_h, axis=1).T # 同一列映射的w值都相等h0 = np.floor(h).astype(np.int) # 同一行的h0值都相等w0 = np.floor(w).astype(np.int) # 同一列的w0值都相等h0 = np.clip(h0, 0, src_h - 2) # 最下一行上不大于src_h - 2,相当于paddingw0 = np.clip(w0, 0, src_w - 2) # 最右一列左不大于src_w - 2,相当于paddingh1 = np.clip(h0 + 1, 0, src_h - 1) # 同一行的h1值都相等,防止越界w1 = np.clip(w0 + 1, 0, src_w - 1) # 同一列的w1值都相等,防止越界q00 = src[h0, w0] # 取每一个像素对应的q00q01 = src[h0, w1] # 取每一个像素对应的q01q10 = src[h1, w0] # 取每一个像素对应的q10q11 = src[h1, w1] # 取每一个像素对应的q11h = np.repeat(h[..., np.newaxis], dst_c, axis=2) # 图像有通道C,所有的计算都增加通道Cw = np.repeat(w[..., np.newaxis], dst_c, axis=2)h0 = np.repeat(h0[..., np.newaxis], dst_c, axis=2)w0 = np.repeat(w0[..., np.newaxis], dst_c, axis=2)h1 = np.repeat(h1[..., np.newaxis], dst_c, axis=2)w1 = np.repeat(w1[..., np.newaxis], dst_c, axis=2)r0 = (w1 - w) * q00 + (w - w0) * q01 # 双线性插值的r0r1 = (w1 - w) * q10 + (w - w0) * q11 # 双线性差值的r1q = (h1 - h) * r0 + (h - h0) * r1 # 双线性差值的qdst = q.astype(src.dtype) # 图像的数据类型return dstif __name__ == "__main__":def unit_test2():image_file = os.path.join(os.getcwd(), 'test.jpg')image = mpimg.imread(image_file)image_scale = bilinear_interpolation(image, (256, 256))fig, axes = plt.subplots(1, 2, figsize=(8, 10))axes = axes.flatten()axes[0].imshow(image)axes[1].imshow(image_scale)axes[0].axis([0, image.shape[1], image.shape[0], 0])axes[1].axis([0, image_scale.shape[1], image_scale.shape[0], 0])fig.tight_layout()plt.show()passunit_test2()def unit_test3():src = np.array([[1, 2], [3, 4]])print(src)src = src.reshape((2, 2, 1))dst_size = (4, 4)dst = bilinear_interpolation(src, dst_size)dst = dst.reshape(dst_size)print(dst)tsrc = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2)print(tsrc)tdst = F.interpolate(tsrc,size=(4, 4),mode='bilinear')print(tdst)# unit_test3()
四、用torch张量实现
是对tensor的,维度NCHW;和第二段一样,但是采用了张量,可以批量处理。
import torch
import torch.nn.functional as F
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimgdef bilinear_interpolate(src, dst_size, align_corners=False):"""双线性差值:param src: 原图像张量 NCHW:param dst_size: 目标图像spatial大小(H,W):param align_corners: 换算坐标的不同方式:return: 目标图像张量NCHW"""src_n, src_c, src_h, src_w = src.shapedst_n, dst_c, (dst_h, dst_w) = src_n, src_c, dst_sizeif src_h == dst_h and src_w == dst_w:return src.copy()"""将dst的H和W坐标映射到src的H和W坐标"""hd = torch.arange(0, dst_h)wd = torch.arange(0, dst_w)if align_corners:h = float(src_h - 1) / (dst_h - 1) * hdw = float(src_w - 1) / (dst_w - 1) * wdelse:h = float(src_h) / dst_h * (hd + 0.5) - 0.5w = float(src_w) / dst_w * (wd + 0.5) - 0.5h = torch.clamp(h, 0, src_h - 1) # 防止越界,0相当于上边界paddingw = torch.clamp(w, 0, src_w - 1) # 防止越界,0相当于左边界paddingh = h.view(dst_h, 1) # 1维dst_h个,变2维dst_h*1个w = w.view(1, dst_w) # 1维dst_w个,变2维1*dst_w个h = h.repeat(1, dst_w) # H方向重复1次,W方向重复dst_w次w = w.repeat(dst_h, 1) # H方向重复dsth次,W方向重复1次"""求出四点坐标"""h0 = torch.clamp(torch.floor(h), 0, src_h - 2) # -2相当于下边界paddingw0 = torch.clamp(torch.floor(w), 0, src_w - 2) # -2相当于右边界paddingh0 = h0.long() # torch坐标必须是longw0 = w0.long() # torch坐标必须是longh1 = h0 + 1w1 = w0 + 1"""求出四点值"""q00 = src[..., h0, w0]q01 = src[..., h0, w1]q10 = src[..., h1, w0]q11 = src[..., h1, w1]"""公式计算"""r0 = (w1 - w) * q00 + (w - w0) * q01 # 双线性插值的r0r1 = (w1 - w) * q10 + (w - w0) * q11 # 双线性差值的r1dst = (h1 - h) * r0 + (h - h0) * r1 # 双线性差值的qreturn dstif __name__ == '__main__':def unit_test4():# src = torch.randint(0, 100, (1, 3, 3, 3))src = torch.arange(1, 1 + 27).view((1, 3, 3, 3))\.type(torch.float32)print(src)dst = bilinear_interpolate(src,dst_size=(4, 4),align_corners=True)print(dst)pt_dst = F.interpolate(src.float(),size=(4, 4),mode='bilinear',align_corners=True)print(pt_dst)if torch.equal(dst, pt_dst):print('success')image_file = os.path.join(os.getcwd(), 'test.jpg')image = mpimg.imread(image_file)image_in = torch.from_numpy(image.transpose(2, 0, 1))image_in = torch.unsqueeze(image_in, 0)image_out = bilinear_interpolate(image_in, (256, 256))image_out = torch.squeeze(image_out, 0).numpy().astype(int)image_out = image_out.transpose(1, 2, 0)fig, axes = plt.subplots(1, 2, figsize=(8, 10))axes = axes.flatten()axes[0].imshow(image)axes[1].imshow(image_out)axes[0].axis([0, image.shape[1], image.shape[0], 0])axes[1].axis([0, image_out.shape[1], image_out.shape[0], 0])fig.tight_layout()plt.show()unit_test4()
相关文章:

matlab编程实现基于密度的聚类(DBSCAN)
1. DBSCAN聚类的基本原理 详细原理可以参考链接: https://www.cnblogs.com/pinard/p/6208966.html 这是找到的相对很详细的介绍了,此链接基本仍是周志华《机器学习》中的内容,不过这个链接更通俗一点,且算法流程感觉比《机器学习…

EAST 自然场景文本检测
自然场景文本检测是图像处理的核心模块,也是一直想要接触的一个方面。刚好看到国内的旷视今年在CVPR2017的一篇文章:EAST: An Efficient and Accurate Scene Text Detector。而且有开放的代码,学习和测试了下。 题目说的是比较高效࿰…

通过httpmodule获取webapi返回的信息
我写了一个webapi,想在module中获取请求的信息和返回的信息,写进log里,以方便以后查询。request信息很容易能拿到,但是返回信息得费一番周折。不多说,上代码 public class ResponseLoggerModule : IHttpModule {privat…

iOS SwiftUI篇-2 UI控件 Text Button Image List
iOS SwiftUI篇-2 UI控件 Text Button Image List Text 显示文本,相当于UILabel import SwiftUIstruct TextContentView: View {var body: some View {//VStack(垂直排列视图)可以将其内部的多个视图,在垂直方向进行等距排列,VStack最多可以容纳十个子视图,VStack(spacin…

numpy和torch数据操作对比
对numpy和torch数据操作进行对比,避免遗忘。 ndarray和tensor import torch import numpy as npnp_data np.arange(6).reshape((2, 3)) torch_data torch.arange(6) # 张量 tensor2array torch_data.numpy()print(\nnumpy array:\n, np_data,\ntorch tensor\n,…

ZooKeeper学习
一、ZooKeeper 的实现 1.1 ZooKeeper处理单点故障 我们知道可以通过ZooKeeper对分布式系统进行Master选举,来解决分布式系统的单点故障,如图所示。 那么我们继续分析一下,ZooKeeper通过Master选举来帮助分布式系统解决单点故障, 保…

iOS SwiftUI篇-1 项目结构
iOS SwiftUI篇-1 项目结构 介绍Xcode新建的SwiftUI模版项目结构、跟普通Storyboard模版项目的差异、SwiftUI项目的app启动流程、UIScene概念介绍、AppDelegate.swift和Info.plist的差异 1.项目模版 Interface: SwiftUI Life Cycle: UIKit App Delegate Language: Swift Life…

js绑定事件和解绑事件
在js中绑定多个事件用到的是两个方法:attachEvent和addEventListener,但是这两个方法又存在差异性 attachEvent方法 只支持IE678,不兼容其他浏览器addEventListener方法 兼容火狐谷歌,不兼容IE8及以下 addEventListener方法 div.addEventListener(click,fn); div.addEventLi…

基于三维点云数据的主成分分析方法(PCA)的python实现
主成分分析(PCA)获取三维点云的坐标轴方向和点云法向量 # 实现PCA分析和法向量计算,并加载数据集中的文件进行验证import open3d as o3d # import os import numpy as np from scipy.spatial import KDTree# from pyntcloud import PyntClo…

CV02-FCN笔记
目录 一、Convolutionalization 卷积化 二、Upsample 上采样 2.1 Unpool反池化 2.2 Interpolation差值 2.3 Transposed Convolution转置卷积 三、Skip Architecture 3.1 特征融合 3.2 裁剪 FCN原理及实践,记录一些自己认为重要的要点,以免日后遗…

python基础之常用模块
6、TEXT PROCESSING SERVICES :文本处理服务 6.1、re 8、DATA TYPES : 数据类型 8.1、datetime 8.2、collections 8.3、copy 9、 NUMERIC AND MATHEMATICAL MODULES : 数字和数学模块 9.1、random 10、FUNCTIONAL PROGRAMMING MODULES : 函数式编程模块 10.1、iter…

笔记本电脑摄像头实现光流跟踪
看实验室里的师兄在写CSDN,自己也写一个,记录自己的学习进程吧。 研究生从机械转到了毫无基础的SLAM领域。研一半年上课加自学,对SLAM也有一丢丢的了解。最近看光流法时,想到用笔记本电脑的摄像头实现一下,就简单的…

JSON字符串 拼接与解析
常用方式: json字符串拼接(目前使用过两种方式): 1.运用StringBuilder拼接 StringBuilder json new StringBuilder(); json.append("{"); json.append(""uuid":" """ uuid "",&q…

iOS SwiftUI篇-3 排版布局layout
iOS SwiftUI篇-3 排版布局layout swiftUI提供的layout有: ZStack、GeometryReader、HStack、LazyVGrid、LazyHStack、LazyHGrid、LazyVStack、VStack、Spacer、ScrollViewReader等 HStack 水平横向布局容器,子view按顺序水平排列 HStack(alignment: .center, spacing: 10)…

CV04-UNet笔记
目录 一、UNet模型 二、Encoder & Decoder 2.1 Encoder 2.2 Decoder 2.3 classifier 学习U-Net: Convolutional Networks for Biomedical Image Segmentation,记录一些自己认为重要的要点,以免日后遗忘。 代码:https://github.com/…

Scrapy 学习笔记(-)
Scrapy Scrapy 是一个为了爬取网站数据,提取结构性数据而编写的应用框架。 其可以应用在数据挖掘,信息处理或存储历史数据等一系列的程序中。其最初是为了页面抓取 (更确切来说, 网络抓取 )所设计的, 也可以应用在获取API所返回的数据(例如 A…

Ubuntu18.04运行ORB_SLAM2
运行环境:Ubuntu18.04 预先安装的库 需要预先安装一些库,如Eign,Sophus,OpenCV等。笔者在阅读《SLAM十四讲》的时候已经安装,在此不再赘述。 ORB_SLAM2源码的下载与编译 git clone https://github.com/raulmur/ORB…

java中的各种流(老师的有道云笔记)
内存操作流-字节之前的文件操作流是以文件的输入输出为主的,当输出的位置变成了内存,那么就称为内存操作流。此时得使用内存流完成内存的输入和输出操作。如果程序运行过程中要产生一些临时文件,可采用虚拟文件方式实现;直接操作磁…

iOS SwiftUI篇-4 注解@State、@Binding、@ObservedObject、@EnvironmentObject、@Environment
iOS SwiftUI篇-4 注解@State、@Binding、@ObservedObject、@EnvironmentObject、@Environment @State 关联View的状态,当@State修饰的属性改变时,对应的View会跟着刷新,符合MVVM的设计理念 @State var count: Int = 0Section(header: Text("@States")) {Te
CV05-ResNet笔记
目录 一、为什么是ResNet 二、Residual Learning细节 2.1 shortcut计算 2.2 11卷积调整channel维度大小 2.3 ResNet层数 2.4 ResNet里的Basic Block 和 Bottleneck Block 2.5 Global Average Pooling 全局平均池化 2.6 Batch Normalization 学习ResNet,记录…

二叉树的前序,中序,后序的递归、迭代实现
二叉树的前序遍历 递归实现 递归实现没什么好说的。个人感觉将函数功能看成一个整体,不要去想栈中怎么实现的。毕竟自己的脑袋不是电脑,绕着绕着就蒙了。 void preordered_traversal_recursion(TreeNode* root) {if(root NULL) return;container.pus…

DataSet 动态添加列
public DataSet GetNewId(List<string> IdArr){DataSet ds new DataSet();DataTable newtb new DataTable();DataColumn column new DataColumn("cnt", typeof(string));//新增列newtb.Columns.Add(column);for (int i 0; i < IdArr.Count; i){StringBu…

iOS专题1-蓝牙扫描、连接、读写
iOS专题1-蓝牙扫描、连接、读写 概念 外围设备 可以被其他蓝牙设备连接的外部蓝牙设备,不断广播自身的蓝牙名及其数据,如小米手环、共享单车、蓝牙体重秤 中央设备 可以搜索并连接周边的外围设备,并与之进行数据读写通讯,如手机 日常生活中常见的场景是手机app通过蓝…

CV06-Xception笔记
目录 一、为啥是Xception 二、Xception结构 2.1 Xception结构基本描述 2.2 实现细节 2.3 DeepLabV3改进 三、记录pytorch采坑relu激活函数inplaceTrue Xception笔记,记录一些自己认为重要的要点,以免日后遗忘。 复现Xception论文、DeepLabV改进的…

C++排序算法实现(更新中)
比较排序法:如冒泡排序、简单选择排序、合并排序、快速排序。其最优的时间复杂度为O(nlogn)。 其他排序法:如桶排序、基数排序等。时间复杂度可以达到O(n)。但试用范围有要求。 桶排序:排序的数组元素跨距不能很大。因为跨距很大的话…

iOS SwiftUI篇-5 专题NavigationView、NavigationLink
iOS SwiftUI篇-5 专题NavigationView、NavigationLink NavigationView:标题、展示模式、隐藏导航栏、隐藏返回按钮、添加导航栏按钮 NavigationLink:Text文本跳转、Image图片跳转、Button按钮跳转、点击按钮根据业务跳转到不同页面 NavigationView 标题、展示模式 import S…

PHP artisan
Artisan 是 Laravel 提供的 CLI(命令行接口),它提供了非常多实用的命令来帮助我们开发 Laravel 应用。前面我们已使用过 Artisan 命令来生成应用的 App Key 和控制器。在本教程中,我们会用到以下 Artisan 命令,你也可以…

【转载】Pytorch在加载模型参数时指定设备
转载 https://sparkydogx.github.io/2018/09/26/pytorch-state-dict-gpu-to-cpu/ >>> torch.load(tensors.pt) # Load all tensors onto the CPU >>> torch.load(tensors.pt, map_locationtorch.device(cpu)) # Load all tensors onto the CPU, using a fun…

目标检测之Faster-RCNN的pytorch代码详解(数据预处理篇)
首先贴上代码原作者的github:https://github.com/chenyuntc/simple-faster-rcnn-pytorch(非代码作者,博文只解释代码) 今天看完了simple-faster-rcnn-pytorch-master代码的最后一个train.py文件,是时候认真的总结一下了࿰…

hp-ux 集群,内存 小记
hp-ux 集群,内存 小记 -----查看hp 集群状态信息 # cmviewcl -v CLUSTER STATUS dbsvr up NODE STATUS STATE db01 up running Cluster_Lock_LVM: VOLUM…