当前位置: 首页 > 编程日记 > 正文

各种 AI 数据增强方法,都在这儿了

来源 | 算法进阶

责编 | 寇雪芹

头图 | 下载于视觉中国

数据、算法、算力是人工智能发展的三要素。数据决定了Ai模型学习的上限,数据规模越大、质量越高,模型就能够拥有更好的泛化能力。

然而在实际工程中,经常有数据量太少(相对模型而言)、样本不均衡、很难覆盖全部的场景等问题,解决这类问题的一个有效途径是通过数据增强(Data Augmentation),使模型学习获得较好的泛化性能。


数据增强介绍

数据增强(Data Augmentation)是在不实质性的增加数据的情况下,从原始数据加工出更多的表示,提高原数据的数量及质量,以接近于更多数据量产生的价值。其原理是,通过对原始数据融入先验知识,加工出更多数据的表示,有助于模型判别数据中统计噪声,加强本体特征的学习,减少模型过拟合,提升泛化能力。

如经典的机器学习例子--哈士奇误分类为狼:通过可解释性方法,可发现错误分类是由于图像上的雪造成的。通常狗对比狼的图像里面雪地背景比较少,分类器学会使用雪作为一个特征来将图像分类为狼还是狗,而忽略了动物本体的特征。此时,可以通过数据增强的方法,增加变换后的数据(如背景换色、加入噪声等方式)来训练模型,帮助模型学习到本体的特征,提高泛化能力。

需要关注的是,数据增强样本也有可能是引入片面噪声,导致过拟合。此时需要考虑的是调整数据增强方法,或者通过算法(可借鉴Pu-Learning思路)选择增强数据的最佳子集,以提高模型的泛化能力。

常用数据增强方法可分为:基于样本变换的数据增强及基于深度学习的数据增强。

基于样本变换的数据增强

样本变换数据增强即采用预设的数据变换规则进行已有数据的扩增,包含单样本数据增强和多样本数据增强。

2.1 单样本增强

单(图像)样本增强主要有几何操作、颜色变换、随机擦除、添加噪声等方法,可参见imgaug开源库。

2.2 多样本数据增强方法

多样本增强是通过先验知识组合及转换多个样本,主要有Smote、SamplePairing、Mixup等方法在特征空间内构造已知样本的邻域值。

  • Smote

Smote(Synthetic Minority Over-sampling Technique)方法较常用于样本均衡学习,核心思想是从训练集随机同类的两近邻样本合成一个新的样本,其方法可以分为三步:

1、 对于各样本X_i,计算与同类样本的欧式距离,确定其同类的K个(如图3个)近邻样本;

2、从该样本k近邻中随机选择一个样本如近邻X_ik,生成新的样本:

Xsmote_ik =  Xi  +  rand(0,1) ∗ ∣X_i − X_ik∣

3、重复2步骤迭代N次,可以合成N个新的样本。

# SMOTE
from imblearn.over_sampling import SMOTEprint("Before OverSampling, counts of label\n{}".format(y_train.value_counts()))
smote = SMOTE()
x_train_res, y_train_res = smote.fit_resample(x_train, y_train)
print("After OverSampling, counts of label\n{}".format(y_train_res.value_counts()))
  • SamplePairing

SamplePairing算法的核心思想是从训练集随机抽取的两幅图像叠加合成一个新的样本(像素取平均值),使用第一幅图像的label作为合成图像的正确label。

  • Mixup

Mixup算法的核心思想是按一定的比例随机混合两个训练样本及其标签,这种混合方式不仅能够增加样本的多样性,且能够使决策边界更加平滑,增强了难例样本的识别,模型的鲁棒性得到提升。其方法可以分为两步:

1、从原始训练数据中随机选取的两个样本(xi, yi) and (xj, yj)。其中y(原始label)用one-hot 编码。

2、对两个样本按比例组合,形成新的样本和带权重的标签

x˜ = λxi + (1 − λ)xj
y˜ = λyi + (1 − λ)yj

最终的loss为各标签上分别计算cross-entropy loss,加权求和。其中 λ ∈ [0, 1], λ是mixup的超参数,控制两个样本插值的强度。

# Mixup
def mixup_batch(x, y, step, batch_size, alpha=0.2):"""get batch data:param x: training data:param y: one-hot label:param step: step:param batch_size: batch size:param alpha: hyper-parameter α, default as 0.2:return:  x y """candidates_data, candidates_label = x, yoffset = (step * batch_size) % (candidates_data.shape[0] - batch_size)# get batch datatrain_features_batch = candidates_data[offset:(offset + batch_size)]train_labels_batch = candidates_label[offset:(offset + batch_size)]if alpha == 0:return train_features_batch, train_labels_batchif alpha > 0:weight = np.random.beta(alpha, alpha, batch_size)x_weight = weight.reshape(batch_size, 1)y_weight = weight.reshape(batch_size, 1)index = np.random.permutation(batch_size)x1, x2 = train_features_batch, train_features_batch[index]x = x1 * x_weight + x2 * (1 - x_weight)y1, y2 = train_labels_batch, train_labels_batch[index]y = y1 * y_weight + y2 * (1 - y_weight)return x, y基于深度学习的数据增强

3.1 特征空间的数据增强

不同于传统在输入空间变换的数据增强方法,神经网络可将输入样本映射为网络层的低维向量(表征学习),从而直接在学习的特征空间进行组合变换等进行数据增强,如MoEx方法等。

3.2 基于生成模型的数据增强

生成模型如变分自编码网络(Variational Auto-Encoding network, VAE)和生成对抗网络(Generative Adversarial Network, GAN),其生成样本的方法也可以用于数据增强。这种基于网络合成的方法相比于传统的数据增强技术虽然过程更加复杂, 但是生成的样本更加多样。

  • 变分自编码器VAE

    变分自编码器(Variational Autoencoder,VAE)其基本思路是:将真实样本通过编码器网络变换成一个理想的数据分布,然后把数据分布再传递给解码器网络,构造出生成样本,模型训练学习的过程是使生成样本与真实样本足够接近。

# VAE模型
class VAE(keras.Model):...def train_step(self, data):with tf.GradientTape() as tape:z_mean, z_log_var, z = self.encoder(data)reconstruction = self.decoder(z)reconstruction_loss = tf.reduce_mean(tf.reduce_sum(keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)))kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))total_loss = reconstruction_loss + kl_lossgrads = tape.gradient(total_loss, self.trainable_weights)self.optimizer.apply_gradients(zip(grads, self.trainable_weights))self.total_loss_tracker.update_state(total_loss)self.reconstruction_loss_tracker.update_state(reconstruction_loss)self.kl_loss_tracker.update_state(kl_loss)return {"loss": self.total_loss_tracker.result(),"reconstruction_loss": self.reconstruction_loss_tracker.result(),"kl_loss": self.kl_loss_tracker.result(),}
  • 生成对抗网络GAN

    生成对抗网络-GAN(Generative Adversarial Network) 由生成网络(Generator, G)和判别网络(Discriminator, D)两部分组成, 生成网络构成一个映射函数G: Z→X(输入噪声z, 输出生成的图像数据x), 判别网络判别输入是来自真实数据还是生成网络生成的数据。

# DCGAN模型class GAN(keras.Model):...def train_step(self, real_images):batch_size = tf.shape(real_images)[0]random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))# G: Z→X(输入噪声z, 输出生成的图像数据x)generated_images = self.generator(random_latent_vectors)# 合并生成及真实的样本并赋判定的标签combined_images = tf.concat([generated_images, real_images], axis=0)labels = tf.concat([tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0)# 标签加入随机噪声labels += 0.05 * tf.random.uniform(tf.shape(labels))# 训练判定网络with tf.GradientTape() as tape:predictions = self.discriminator(combined_images)d_loss = self.loss_fn(labels, predictions)grads = tape.gradient(d_loss, self.discriminator.trainable_weights)self.d_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_weights))random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))# 赋生成网络样本的标签(都赋为真实样本)misleading_labels = tf.zeros((batch_size, 1))# 训练生成网络with tf.GradientTape() as tape:predictions = self.discriminator(self.generator(random_latent_vectors))g_loss = self.loss_fn(misleading_labels, predictions)grads = tape.gradient(g_loss, self.generator.trainable_weights)self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))# 更新损失self.d_loss_metric.update_state(d_loss)self.g_loss_metric.update_state(g_loss)return {"d_loss": self.d_loss_metric.result(),"g_loss": self.g_loss_metric.result(),}


3.3 基于神经风格迁移的数据增强

神经风格迁移(Neural Style Transfer)可以在保留原始内容的同时,将一个图像的样式转移到另一个图像上。除了实现类似色彩空间照明转换,还可以生成不同的纹理和艺术风格。

神经风格迁移是通过优化三类的损失来实现的:

style_loss:使生成的图像接近样式参考图像的局部纹理;

content_loss:使生成的图像的内容表示接近于基本图像的表示;

total_variation_loss:是一个正则化损失,它使生成的图像保持局部一致。

# 样式损失
def style_loss(style, combination):S = gram_matrix(style)C = gram_matrix(combination)channels = 3size = img_nrows * img_ncolsreturn tf.reduce_sum(tf.square(S - C)) / (4.0 * (channels ** 2) * (size ** 2))# 内容损失
def content_loss(base, combination):return tf.reduce_sum(tf.square(combination - base))# 正则损失
def total_variation_loss(x):a = tf.square(x[:, : img_nrows - 1, : img_ncols - 1, :] - x[:, 1:, : img_ncols - 1, :])b = tf.square(x[:, : img_nrows - 1, : img_ncols - 1, :] - x[:, : img_nrows - 1, 1:, :])return tf.reduce_sum(tf.pow(a + b, 1.25))

3.4 基于元学习的数据增强

深度学习研究中的元学习(Meta learning)通常是指使用神经网络优化神经网络,元学习的数据增强有神经增强(Neural augmentation)等方法。

  • 神经增强

神经增强(Neural augmentation)是通过神经网络组的学习以获得较优的数据增强并改善分类效果的一种方法。其方法步骤如下:

1、获取与target图像同一类别的一对随机图像,前置的增强网络通过CNN将它们映射为合成图像,合成图像与target图像对比计算损失;

2、将合成图像与target图像神经风格转换后输入到分类网络中,并输出该图像分类损失;

3、将增强与分类的loss加权平均后,反向传播以更新分类网络及增强网络权重。使得其输出图像的同类内差距减小且分类准确。

直播间地址:

https://live.csdn.net/room/csdnnews/B3423dYF

更多精彩推荐
☞Python 玩出花儿,把罗小黑养在自己桌面☞315 曝光人脸识别摄像头,进店瞬间偷走你的“脸”,自动分析心情☞玩转3D全息图像!AI即刻生成☞在 5G 速度上,iPhone 12 只是个弟弟

相关文章:

ORACLE11g 前期安装环境配置

Linux系统可以拿来直接用的脚本哦#!/bin/bashservice iptables stop &> /dev/nulliptables -F service iptables save &> /dev/nullsed -i s/enforcing/disabled/ /etc/selinux/configsetenforce 0sed /tmpfs/d /etc/fstab &> /dev/nullecho tmpfs …

linux mysql 卸载,安装,測试全过程

Mysql卸载yum remove mysql mysql-server mysql-libs compat-mysql51rm -rf /var/lib/mysqlrm /etc/my.cnf查看是否还有mysql软件:rpm -qa|grep mysql有的话继续删除Mysql安装1>若本地没有安装包 能够考虑使用yum命令进行下载# yum -y install mysql-server# yum…

C#中获取程序当前路径的集中方法

string str1 Process.GetCurrentProcess().MainModule.FileName;//可获得当前执行的exe的文件名。 string str2Environment.CurrentDirectory;//获取和设置当前目录(即该进程从中启动的目录)的完全限定路径。//备注 按照定义,如果该进程在本…

如何开启远程(win7win8)

如何开启远程连接点击我的电脑-属性-高级系统设置-远程-选中“允许远程连接到此计算机”-应用-确定。在局域网内,拥有固定IP的话,就很容易远程处理事情了。若经过此步骤还不能远程的话,则需要查看系统是否开启了远程服务。“我的电脑”--管理…

微软推出“ Group Transcribe”应用,多人多语言会议实时高准确度文字转录并翻译

近期,微软针对面对面对话和会议推出了免费实时语音到文字转录和翻译应用程序——Group Transcribe。一方面,Group Transcribe可以通过手机把会议的语音内容实时转录为文本,供与会者阅读和浏览。 另一方面,在实时交流过程中&#x…

STM32单片机外部中断配置讲解

2019独角兽企业重金招聘Python工程师标准>>> 单片机外部中断简介 所谓外部中断,就是通过外部信号所引起的中断,如单片机引脚上的电平变化(高电平、低电平)、边沿变化(上升沿、下降沿)等。51单片机有5个中断源,其中有两个是外部中断…

Android语音信号波形显示

简单地介绍了AudioRecord和AudioTrack的使用,这次就结合SurfaceView实现一个Android版的手机模拟信号示波器(PS:以前也讲过J2ME版的手机示波器)。最近物联网炒得很火,作为手机软件开发者,如何在不修改手机硬件电路的前提下实现与第…

科研费4年翻3倍,全球科研队伍突破8000人,滴滴致力打造出行领域核心技术

日前,十三届全国人大四次会议表决通过了《国民经济和社会发展第十四个五年规划和2035年远景目标纲要》(下称《规划》)。《规划》强调要坚持创新在我国现代化建设全局中的核心地位,把科技自立自强作为国家发展的战略支撑。 《规划…

c++ 继承访问控制初步

访问控制方式这里有篇很好的文章,其实内容也是总结cprimer上的内容 现在就按照这篇的文章举例进行学习. 思路 不同继承方式的影响主要体现在: 1、派生类成员对基类成员的访问控制。 2、派生类对象对基类成员的访问控制 三种继承方式 公有继承(public) 所有public和p…

Excel在.Net 环境下Web方式下驻留内存问题的解决

这段时间在VS 2003 的WebForm 方式下对Excel 进行操作,遇到一个最为头疼的问题就是对Excel操作完毕后Excel不能够正常关闭,系统退出后,Excel总是驻留在内存中。但是这段代码放到WinForm的程序中又没有问题。在网上进行了查找也没有找到有效可…

2.8 FSM之Moore和Mealy part3

来看看我们的Mealy机的设计吧~~。Mealy机的想法起源于:这里我们有输入,并且根据相应的输入我们的字符识别机能做出相应的应答也就是输出。所以我们为何不把输入和输出同时表达出来呢?这样我们就能把输出和抽象的状态分离出来。好处第一就是我…

​对标GPT-3、AlphaFold,智源研究院发布超大规模智能模型系统“悟道1.0”

出品 | AI科技大本营(ID:rgznai100)3月20日,北京智源人工智能研究院发布我国首个超大规模智能模型系统“悟道1.0”。“悟道1.0”由智源研究院学术副院长、清华大学唐杰教授领衔,带领来自北京大学、清华大学、中国人民大学、中国科…

TCP Cluster for mqtt 技术实施方案

最前沿的网络技术,为你的网站带来国际化的用户体验和易用性,这一切只有Witmart.com能做到。

两台SQL Server数据同步解决方案

复制的概念复制是将一组数据从一个数据源拷贝到多个数据源的技术,是将一份数据发布到多个存储站点上的有效方式。使用复制技术,用户可以将一份数据发布到多台服务器上,从而使不同的服务器用户都可以在权限的许可的范围内共享这份数据。复制技…

一个用微软官方的OpenXml读写Excel 目前网上不太普及的方法。

新版本的xlsx是使用新的存储格式,貌似是处理过的XML。 传统的excel处理方法,我真的感觉像屎。用Oldeb不方便,用com组件要实际调用excel打开关闭,很容易出现死。 对于OpenXML我网上搜了一下,很多人没有介绍。所以我就这…

分析6千万条GitHub帖子,发现你的工作状态与表情符号强相关

作者 | 凌霄出品 | AI科技大本营(ID:rgznai100)新冠疫情使得远程办公的人数大幅度增加,然而,当越来越多的人远程工作时,人们的情绪和心理健康状态也难以通过日常面对面的交流来观察,雇主们也就无法获得员工…

软件定义网络 对我们有多重要?

软件定义网络(简称SDN)属于网络流量控制的下一个步骤。Tech Pro Research发布的调查报告正是以此为中心,旨在为我们展示企业如何使用SDN方案。 过去几年以来,以更为高效方式管理环境的需求正快速普及,这也使得网络领域的更高灵活性与控制手段…

SQL Server数据库六种数据移动方法

1. 通过工具DTS的设计器进行导入或导出DTS的设计器功能强大,支持多任务,也是可视化界面,容易操作,但知道的人一般不多,如果只是进行SQL Server数据库中部分表的移动,用这种方法最好,当然&#x…

[企业化NET]Window Server 2008 R2[3]-SVN 服务端 和 客户端 基本使用

1. 服务器基本安装即问题解决记录 √ 2. SVN环境搭建和客户端使用 2.1 服务端 和 客户端 安装 √ 2.2 项目建立与基本使用 √ 2.3 基本冲突解决,并版,tags 3. 数据库安装 4. 邮件服务器搭建 5. JIRA环境搭建和使用 6. CC.NET项目持续发布工具…

又一个Jupyter神器,操作Excel自动生成Python代码

来源 | Python数据科学(ID: PyDataScience)不得不说,Jupyter对于表的处理真的是越来越方便了,很多库可以直接实现可视化操作,无需写代码。但是这还不够,最近看到一个神器叫Mito,它真的是做到了无…

CIR:2020年全球数据中心应用AOC市场达$42亿

未来十年,QSFP和CXP将占有源光缆销售收入的大部分。到2020年,QSFP和QSFP28销售收入将分别达到7.27亿美元和7.41亿美元。 根据CIR(CommunicationsIndustryResearchers)的最新报告(《2015有源光缆市场:数据中心和高性能计算市场》),…

Visual C#创建资源文件

资源文件顾名思义就是存放资源的文件。资源文件在程序设计中有着自身独特的优势,他独立于源程序,这样资源文件就可以被多个程序使用。同时在程序设计的时候,有时出于安全或者其他方面因素的考虑,把重要东西存放在资源文件中&#…

给IIS添加CA证书以支持https

一、在IIS中生成Certificate Signing Request (CSR) 个人理解:生成CSR就是生成“私钥/公钥对”之后从中提取出公钥。 1. 打开IIS Manager,在根节点中选择Server Certificates(服务器证书),点击右侧的Create Certificat…

MathWorks的AI之路:面向工业场景,打通开发到部署的全链路

作者 | 阿司匹林 AI正在快速发展,并在更多的领域落地。对于MATLAB和Simulink的开发商MathWorks来说,把握AI的机会,显得尤为重要。 不少人对MATLAB等的印象依然停留在学校期间学习的高级线性代数解题器的阶段。然而,MATLAB在几年前…

《Android应用开发攻略》——1.3 从命令行创建 “Hello, World”应用程序

1.3 从命令行创建 “Hello, World”应用程序 Ian Darwin1.3.1 问题你想在不使用Eclipse ADT插件的情况下创建新的Android项目。1.3.2 解决方案使用Android开发工具包(Android Development Kit,ADK)中的android工具,利用creat proj…

将Excel文件数据库导入SQL Server

将Excel文件数据库导入SQL Server的三种方案//方案一: 通过OleDB方式获取Excel文件的数据,然后通过DataSet中转到SQL Server openFileDialog new OpenFileDialog();openFileDialog.Filter "Excel files(*.xls)|*.xls"; if(openFileDialog.…

Android----PopupWindow

Android的对话框有两种:PopupWindow和AlertDialog。它们的不同点在于:  AlertDialog的位置固定,而PopupWindow的位置可以随意  AlertDialog是非阻塞线程的,而PopupWindow是阻塞线程的 PopupWindow的位置按照有无偏移分&#x…

GitLab 在中国成立公司极狐,GitHub 还会远吗?

作者 | 宋慧 责编 | 苏宓出品 | CSDN(ID:CSDNnews)开源的种子已在中国落地开花。今天,中国的开源圈再次迎来一大盛事:全球第二大开源代码托管和项目管理平台 GitLab与红杉宽带等基金正式宣布成立中国合资公司极狐信…

消除危害 让BYOD策略更安全的几个秘诀

自带设备办公(BYOD)已经不是什么新鲜的事情,在近些年,随着移动设备的发展,员工利用自带设备办公已经成为一件非常平常的事情。 但是由于出于安全问题的考虑,一些企业禁止员工通过自带设备连接到公司网络中进行办公。他们不允许个人…