Python多阶段框架实现虚拟试衣间,超逼真!
作者 | 李秋键
责编 | 晋兆雨
头图 | CSDN下载自视觉中国
任意姿态下的虚拟试衣因其巨大的应用潜力而引起了人们的广泛关注。然而,现有的方法在将新颖的服装和姿势贴合到一个人身上的同时,很难保留服装纹理和面部特征(面孔、毛发)中的细节。故在论文《Downto the Last Detail: Virtual Try-on with Detail Carving》中提出了一种新的多阶段合成框架,可以很好地保留图像显著区域的丰富细节。
具体地说,就是提出了一个多阶段的框架,将生成分解为空间对齐,然后由粗到细生成。为了更好地保留显著区域的细节,如服装和面部区域,我们提出了一个树块(树扩张融合块)来利用多尺度特征在发生器网络。通过多个阶段的端到端训练,可以联合优化整个框架,最终使得视觉逼真度得到了显著的提高、同时获得了细节更为丰富的结果。在标准数据集上进行的大量实验表明,他们提出的框架实现了最先进的性能,特别是在保存服装纹理和面部识别的视觉细节方面。
故今天我们将在他们代码的基础上,实现虚拟换衣系统。具体流程如下:
实验前的准备
首先我们使用的python版本是3.6.5所用到的模块如下:
opencv是将用来进行图像处理和图片保存读取等操作。
numpy模块用来处理矩阵数据的运算。
pytorch模块是常用的用来搭建模型和训练的深度学习框架,和tensorflow以及Keras等具有相当的地位。
json是为了读取json存储格式的数据。
PIL库可以完成对图像进行批处理、生成图像预览、图像格式转换和图像处理操作,包括图像基本处理、像素处理、颜色处理等。
argparse 是python自带的命令行参数解析包,可以用来方便地读取命令行参数。
网络模型的定义和训练
其中已经训练好的模型地址如下:https://drive.google.com/open?id=1vQo4xNGdYe2uAtur0mDlHY7W2ZR3shWT。其中需要将其中的模型放到"./pretrained_checkpoint"目录下。
对于数据集的存放,分为cloth_image(用来存储衣服图片),cloth_mask(用来分割衣服的mask,可以使用grabcut的方法进行分割保存),image(用来存储人物图片),parse_cihp(用来衣服语义分析的图片结果,可以使用[CIHP_PGN](https://github.com/Engineering-Course/CIHP_PGN)的方法获得)和pose_coco(用来存储提取到的人物姿态特征数据,可以使用openpose进行提取保存为josn数据即可)。
对于模型的训练,我们需要使用到VGG19模型,网络上可以很容易下载到,然后把它放到vgg_model文件夹下。
其中提出的一种基于目标姿态和店内服装图像由粗到细的多阶段图像生成框架,首先是设计了一个解析转换网络来预测目标语义图,该语义图在空间上对齐相应的身体部位,并提供更多关于躯干和四肢形状的结构信息。然后使用一种新的树扩张融合块(tree - block)算法,将空间对齐的布料与粗糙的渲染图像融合在一起,以获得更合理、更体面的结果。其中这个虚拟试穿网络不仅不借助3D信息,可以在任意姿态下将新衣服叠加到人的对应区域上,还保留和增强了显著区域的丰富细节,如布料纹理、面部特征等。同时还使用了空间对齐、多尺度上下文特征聚集和显著的区域增强,以由粗到细的方式各种难题。
(1)其中网络主要使用pix2pix模型,其中的部分代码如下:
class PixelDiscriminator(nn.Module):def __init__(self, input_nc,ndf=64, norm_layer=nn.InstanceNorm2d):super(PixelDiscriminator,self).__init__()if type(norm_layer) ==functools.partial:use_bias =norm_layer.func == nn.InstanceNorm2delse:use_bias = norm_layer ==nn.InstanceNorm2dself.net = nn.Sequential(nn.Conv2d(input_nc, ndf,kernel_size=1, stride=1, padding=0),nn.LeakyReLU(0.2, True),nn.Conv2d(ndf, ndf * 2,kernel_size=1, stride=1, padding=0, bias=use_bias),norm_layer(ndf * 2),nn.LeakyReLU(0.2, True),nn.Conv2d(ndf * 2, 1,kernel_size=1, stride=1, padding=0, bias=use_bias),nn.Sigmoid())def forward(self, input):return self.net(input)class PatchDiscriminator(nn.Module):def __init__(self, input_nc,ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d):super(PatchDiscriminator,self).__init__()if type(norm_layer) ==functools.partial: # no need to use biasas BatchNorm2d has affine parametersuse_bias =norm_layer.func == nn.InstanceNorm2delse:use_bias = norm_layer ==nn.InstanceNorm2dkw = 4padw = 1sequence =[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),nn.LeakyReLU(0.2, True)]nf_mult = 1nf_mult_prev = 1# channel upfor n in range(1,n_layers): # gradually increase thenumber of filtersnf_mult_prev = nf_mult #1,2,4,8nf_mult = min(2 ** n, 8)sequence += [nn.Conv2d(ndf *nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw,bias=use_bias),norm_layer(ndf *nf_mult),nn.LeakyReLU(0.2,True)]# channel downnf_mult_prev = nf_multnf_mult = min(2 ** n_layers,8)sequence += [nn.Conv2d(ndf *nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw,bias=use_bias),norm_layer(ndf *nf_mult),nn.LeakyReLU(0.2, True)]# channel = 1 (bct, 1, x, x)sequence += [nn.Conv2d(ndf *nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction mapsequence += [nn.Sigmoid()]self.model =nn.Sequential(*sequence)
(2)生成器部分代码:
class GenerationModel(BaseModel):def name(self):return 'Generation model:pix2pix | pix2pixHD'def __init__(self, opt):self.t0 = time()BaseModel.__init__(self,opt)self.train_mode =opt.train_mode# resume of networksresume_gmm = opt.resume_gmmresume_G_parse =opt.resume_G_parseresume_D_parse =opt.resume_D_parseresume_G_appearance =opt.resume_G_appresume_D_appearance =opt.resume_D_appresume_G_face = opt.resume_G_faceresume_D_face =opt.resume_D_face# define networkself.gmm_model =torch.nn.DataParallel(GMM(opt)).cuda()self.generator_parsing =Define_G(opt.input_nc_G_parsing, opt.output_nc_parsing, opt.ndf, opt.netG_parsing,opt.norm,not opt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids)self.discriminator_parsing =Define_D(opt.input_nc_D_parsing, opt.ndf, opt.netD_parsing, opt.n_layers_D,opt.norm, opt.init_type, opt.init_gain, opt.gpu_ids)self.generator_appearance =Define_G(opt.input_nc_G_app, opt.output_nc_app, opt.ndf, opt.netG_app,opt.norm,not opt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids,with_tanh=False)self.discriminator_appearance = Define_D(opt.input_nc_D_app, opt.ndf,opt.netD_app, opt.n_layers_D,opt.norm, opt.init_type, opt.init_gain, opt.gpu_ids)self.generator_face =Define_G(opt.input_nc_D_face, opt.output_nc_face, opt.ndf, opt.netG_face,opt.norm,notopt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids)self.discriminator_face =Define_D(opt.input_nc_D_face, opt.ndf, opt.netD_face, opt.n_layers_D,opt.norm, opt.init_type, opt.init_gain, opt.gpu_ids)if opt.train_mode == 'gmm':setattr(self,'generator', self.gmm_model)else:setattr(self,'generator', getattr(self, 'generator_' + self.train_mode))setattr(self, 'discriminator',getattr(self, 'discriminator_' + self.train_mode))# load networksself.networks_name = ['gmm','parsing', 'parsing', 'appearance', 'appearance', 'face', 'face']self.networks_model =[self.gmm_model, self.generator_parsing, self.discriminator_parsing,self.generator_appearance, self.discriminator_appearance,self.generator_face, self.discriminator_face]self.networks =dict(zip(self.networks_name, self.networks_model))self.resume_path =[resume_gmm, resume_G_parse, resume_D_parse, resume_G_appearance,resume_D_appearance, resume_G_face, resume_D_face]for network, resume inzip(self.networks_model, self.resume_path):if network != [] andresume != '':assert(osp.exists(resume), 'the resume not exits')print('loading...')self.load_network(network, resume, ifprint=False)# define optimizerself.optimizer_gmm =torch.optim.Adam(self.gmm_model.parameters(), lr=opt.lr, betas=(0.5, 0.999))self.optimizer_parsing_G =torch.optim.Adam(self.generator_parsing.parameters(), lr=opt.lr,betas=[opt.beta1, 0.999])self.optimizer_parsing_D =torch.optim.Adam(self.discriminator_parsing.parameters(), lr=opt.lr,betas=[opt.beta1, 0.999])self.optimizer_appearance_G= torch.optim.Adam(self.generator_appearance.parameters(), lr=opt.lr,betas=[opt.beta1, 0.999])self.optimizer_appearance_D= torch.optim.Adam(self.discriminator_appearance.parameters(), lr=opt.lr,betas=[opt.beta1, 0.999])self.optimizer_face_G =torch.optim.Adam(self.generator_face.parameters(), lr=opt.lr, betas=[opt.beta1,0.999])self.optimizer_face_D =torch.optim.Adam(self.discriminator_face.parameters(), lr=opt.lr,betas=[opt.beta1, 0.999])if opt.train_mode == 'gmm':self.optimizer_G =self.optimizer_gmmelif opt.joint_all:self.optimizer_G =[self.optimizer_parsing_G, self.optimizer_appearance_G, self.optimizer_face_G]setattr(self,'optimizer_D', getattr(self, 'optimizer_' + self.train_mode + '_D'))else:setattr(self,'optimizer_G', getattr(self, 'optimizer_' + self.train_mode + '_G'))setattr(self,'optimizer_D', getattr(self, 'optimizer_' + self.train_mode + '_D'))self.t1 = time()
模型的使用
在模型训练完成之后,通过命令“python demo.py --batch_size_v 80--num_workers 4 --forward_save_path 'demo/forward'”生成图片。
(1)分别定义读取模型函数和模型调用处理图片函数
def load_model(model, path):checkpoint = torch.load(path)try:model.load_state_dict(checkpoint)except:model.load_state_dict(checkpoint.state_dict())model = model.cuda()model.eval()print(20*'=')for param in model.parameters():param.requires_grad = Falsedef forward(opt, paths, gpu_ids, refine_path):cudnn.enabled = Truecudnn.benchmark = Trueopt.output_nc = 3gmm = GMM(opt)gmm =torch.nn.DataParallel(gmm).cuda()# 'batch'generator_parsing =Define_G(opt.input_nc_G_parsing, opt.output_nc_parsing, opt.ndf,opt.netG_parsing, opt.norm,notopt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids)generator_app_cpvton =Define_G(opt.input_nc_G_app, opt.output_nc_app, opt.ndf, opt.netG_app,opt.norm,notopt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids, with_tanh=False)generator_face =Define_G(opt.input_nc_D_face, opt.output_nc_face, opt.ndf, opt.netG_face,opt.norm,notopt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids)models = [gmm,generator_parsing, generator_app_cpvton, generator_face]for model, path in zip(models,paths):load_model(model, path) print('==>loaded model')augment = {}if '0.4' in torch.__version__:augment['3'] =transforms.Compose([# transforms.Resize(256),transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]) # change to [C, H, W]augment['1'] = augment['3']else:augment['3'] =transforms.Compose([#transforms.Resize(256),transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))]) # change to [C, H, W]augment['1'] =transforms.Compose([# transforms.Resize(256),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))]) # change to [C, H, W]val_dataset = DemoDataset(opt,augment=augment)val_dataloader = DataLoader(val_dataset,shuffle=False,drop_last=False,num_workers=opt.num_workers,batch_size = opt.batch_size_v,pin_memory=True)with torch.no_grad():for i, result inenumerate(val_dataloader):'warped cloth'warped_cloth =warped_image(gmm, result)if opt.warp_cloth:warped_cloth_name =result['warped_cloth_name']warped_cloth_path =os.path.join('dataset', 'warped_cloth', warped_cloth_name[0])if notos.path.exists(os.path.split(warped_cloth_path)[0]):os.makedirs(os.path.split(warped_cloth_path)[0])utils.save_image(warped_cloth * 0.5 + 0.5, warped_cloth_path)print('processing_%d'%i)continuesource_parse =result['source_parse'].float().cuda()target_pose_embedding =result['target_pose_embedding'].float().cuda()source_image =result['source_image'].float().cuda()cloth_parse =result['cloth_parse'].cuda()cloth_image =result['cloth_image'].cuda()target_pose_img =result['target_pose_img'].float().cuda()cloth_parse =result['cloth_parse'].float().cuda()source_parse_vis =result['source_parse_vis'].float().cuda()"filter add clothinfomation"real_s =source_parse index = [x for x inlist(range(20)) if x != 5 and x != 6 and x != 7]real_s_ =torch.index_select(real_s, 1, torch.tensor(index).cuda())input_parse =torch.cat((real_s_, target_pose_embedding, cloth_parse), 1).cuda()'P'generate_parse =generator_parsing(input_parse) # tanhgenerate_parse =F.softmax(generate_parse, dim=1)generate_parse_argmax =torch.argmax(generate_parse, dim=1, keepdim=True).float()res = []for index in range(20):res.append(generate_parse_argmax == index)generate_parse_argmax =torch.cat(res, dim=1).float()"A"image_without_cloth =create_part(source_image, source_parse, 'image_without_cloth', False)input_app = torch.cat((image_without_cloth,warped_cloth, generate_parse), 1).cuda()
U^2-Net跨界肖像画,完美复刻人物细节,GitHub标星2.5K+
升级版“绝悟”AI自带“军师”,解禁王者荣耀全英雄池
文本分类六十年
赠书 | 新手指南——如何通过HuggingFace Transformer整合表格数据
由浅入深,解决三道【只出现一次的数】!
相关文章:

百度重置页面自动跳转脚本
大家都知道的原因,百度现在不允许其它搜索引擎直接进入的它旗下的所有站点,在痛苦的被增加了很多点击后写了这个自动跳转的脚本。 原来不只搜索引擎,其它网站的链接也被搞了,nnd,诅咒百度。 使用方法:用xxx…

MYSQL 数据库迁移 ***
1. 导出数据库数据mysqldump -uroot -p webCompile > webCompileOut.sql其中:root 是账户名webCompile 是需要导出的数据库名称webCompileOut.sql 存储导出的数据2. 将导出SecureCRT sz【下载】的数据webCompileOut.sql放到你的目标机器…

exec函数族的使用
调用shell脚本命令:execlp("sh","sh","filename",(char*)0);exec用被执行的程序完全替换调用它的程序的影像。fork创建一个新的进程就产生了一个新的PID,exec启动一个新程序,替换原有的进程,因此这…

全球首个突破200种语言互译的翻译引擎,百度翻译打破世界沟通壁垒
机器翻译作为人工智能关键技术之一,正日益成为企业智能化升级的重要应用场景。12月1日,百度大脑开放日举办了以“机器翻译 沟通全世界”为主题的专场活动。 IDC 中国副总裁兼首席分析师武连峰、百度 AI 技术生态部总经理刘倩、百度人工智能技术委员会主席…

倍福TwinCAT(贝福Beckhoff)基础教程5.1 TwinCAT-2 运行可执行文件
个人认为这条命令做的参数比较混乱,PATHSTR是指可执行文件路径最终文件名,DIRNAME是指可执行文件路径,最后COMNDLINE可有可无,是指带参数运行启动的文件 测试可以正常运行

Linux系统的大小端模式
大端模式所谓的大端模式,是指数据的低位(就是权值较小的后面那几位)保存在内存的高地址中,而数据的高位,保存在内存的低地址中,这样的存储模式有点儿类似于把数据当作字符串顺序处理:地址由小向…
CSDN插件限时内测,新用户抢永久免费去广告特权!
经过程序猿哥哥们和产品小姐姐马不停蹄的疯狂加班,CSDN 官方出品的PC浏览器插件–开发者助手 终于正式上线啦!一键万能操作,新标签页极简个性,让你的浏览器更酷更高效!还有超多实用彩蛋功能等你来解锁!现在…

你必须知道的.net学习总结
着几天在看《你必须知道的.net》,这次看书和以往不同,以前是把自己喜欢的章节看了。但是这次决定把一本书详细的看看。 在第一章第一节中主要讲的是“对象”,我想每一个程序员都对,“对象”有理解。 我来说说书中所说的对象吧。。 我只是把认…

Mybatis 基本配置, 面向接口
< 一 > 主配置文件 <?xml version"1.0" encoding"UTF-8" ?> <!DOCTYPE configuration PUBLIC "-//mybatis.org//DTD Config 3.0//EN" "http://mybatis.org/dtd/mybatis-3-config.dtd"> <configuration><…

据说看完这21个故事的人,30岁前都成了亿万富翁。你是下一个吗?
1.甲去买烟,烟29元,但他没火柴,跟店员说:“顺便送一盒火柴吧。”店员没给。 乙去买烟,烟29元,他也没火柴,跟店员说:“便宜一毛吧。”最后,他用这一毛买一盒火柴。 这是…

助力视障人士,微软等公司捐赠首批AI有声内容
12月2日,微软与周迅AI语音红丹丹公益项目发起人鹿音苑文化传播公司,以及来自微软及各界的150名余志愿者,将创作的首批人工智能有声内容,包括鲁迅、老舍、萧红、朱自清等作家的一系列经典作品、红丹丹文化期刊,正式捐赠…

能和LoadRunner匹敌的VS2010/2012Web负载测试
VS自带的Web负载测试真的很大程度上能和专业的loadrunner媲美(只是Web方面),上个report图吧(如何实现,请往下拉): 看,能探测一堆的计数器(上面红色打叉的是代表超过了基线…

设置编码格式为utf8
response.setCharacterEncoding("UTF-8"); 在Servlet2.3中是不行的,至少要2.4版本才可以,如果低于2.4版本,可以用如下办法: response.setContentType("text/html;charsetUTF-8"); pageEncoding"UTF-8&qu…

Linux 命令集锦
linux下查看监听端口对应的进程 # lsof -i:9000 # lsof -Pnl M -i4# lsof -i | grep 9054 如果退格键变成了:"^h"。 终端连接unix删除退格键,按住CTL键同时按deleteLinux搜索 # find / -name "xxx.conf"查看linux是32位还是64位的命…

mysql (双主,互主)
Master-Master(双主) 1、测试环境 Master/Slave Master1/Slave1 IP 192.168.1.13 192.168.1.10 为了保持干净的环境:两边服务器 rm -rf /var/lib/mysql/* service mysqld re…
熬夜都要看完的 Python 干货!
结合我最近这些年的Python学习、开发经验,发现超90%的人在初学Python时都会遇到下面这些问题:没经验不知道怎么开始学,应用方向太多了根本不知道该怎么选择!各基础入门看似简单,但一到进阶部分就举步维艰,遇…
Linux下二进制文件安装MySQL
MySQL 下载地址:https://dev.mysql.com/downloads/mysql/ 并按如下方式选择来下载安装包。 1. 设置配置文件/etc/my.cnmore /etc/my.cnf [client] port 3306 socket /tmp/mysql.sockdefault-character-setutf8[mysqld] usermysql port 3306 server_id 1 socket/…

解决远程桌面无法连接问题
如果 出现的提示如下:---------------------------中断远程桌面连接---------------------------客户端无法建立跟远程计算机的连接。导致这个错误的可能的原因是:1) 远程计算机上的远程连接可能没有启用。2) 已超出远程计算机上的连接最大数。3) 建立连接时出现了一…
这些算法在印度农村医疗中发挥极大作用,未来还将发挥哪些作用?
作者 | Apoorva Mandavilli译者 | Jhonny,责编 | Carol来源 | Unitimes在结核病猖獗的印度农村等地方,用于扫描肺部X射线的 AI 技术可能有助于消除这种疾病之苦。印度农村马哈拉施特拉邦的 Chinchpada Mission 医院在世界上一些最偏远和最贫困的角落&…

四层和七层交换技术-loadbalance
1 四层交换技术简介 我们知道,二层交换机是根据第二层数据链路层的MAC地址和通过站表选择路由来完成端到端的数据交换的。三层交换机是直接根据第三层网络层IP地址来完成端到端的数据交换的。四 层交换机不仅可以完成端到端交换,还能根据端口主机的应用特…

sql server mvp 發糞塗牆
http://blog.csdn.net/dba_huangzj/article/details/38295753
几行代码完成动态图表绘制 | Python实战
作者 | 小F来源 | 法纳斯特头图 | CSDN下载自视觉中国关于动态条形图,小F以前推荐过「Bar Chart Race」这个库。三行代码就能实现动态条形图的绘制。有些同学在使用的时候,会出现一些错误。一个是加载文件报错,另一个是生成GIF的时候报错。这…

inline-block元素4px空白间隙的解决办法
为什么80%的码农都做不了架构师?>>> http://www.hujuntao.com/archives/inline-block-elements-the-4px-blank-gap-solution.html 转载于:https://my.oschina.net/i33/blog/124448

Ubuntn删除软件
删除dpkg -r 软件清除dpkg -P 软件也可以用sudo apt-get remove 软件 这种方式移除这种方式install的

对象Equals相等性比较的通用实现
对象Equals相等性比较的通用实现 最近编码的过程中,使用了对象本地内存缓存,缓存用了Dictionary<string,object>, ConcurrentDictionary<string,oject>,还可以是MemoryCache(底层基于Hashtable)。使用缓存,肯定要处理数据变化缓存…

Android:ViewPager为页卡内视图组件添加事件
在数据适配器PagerAdapter的初始化方法中添加按钮事件,这里是关键,首先判断当前的页卡编号。必须使用当前的view来获取按钮。 Overridepublic Object instantiateItem(View arg0, int arg1) {if (arg1 < 3) {((ViewPager) arg0).addView(mListViews.g…

解析C语言中的sizeof
一、sizeof的概念 sizeof是C语言的一种单目操作符,如C语言的其他操作符、--等。它并不是函数。sizeof操作符以字节形式给出了其操作数的存储大小。操作数可以是一个表达式或括在括号内的类型名。操作数的存储大小由操作数的类型决定。 二、sizeof的使用方法 1、用于…

开源大咖齐聚2020启智开发者大会,共探深度学习技术未来趋势
2020年12月2日,“OpenI/O 2020启智开发者大会”在北京国家会议中心召开。大会以“启智筑梦 开源先行”为主题,立足于国际国内开源大环境和发展趋势。开源领域顶尖专家学者和企业领军人物共聚一堂,探讨开源开放呈现出的新形势、新格局、新机…

linux中编译C语言程序
1.首先安装gcc编辑器 yum install gcc* -y 2.编写C语言程序 [roottest ~]# vim aa.c #include<stdio.h> int main( ) {int a;printf("请输入一个三位数的整数:");scanf("%d",&a);if(a>100&&a<1000)printf("百位是:…

typedef的四个用途和两大陷阱
typedef的四个用途和两个陷阱 -------------------…