1. 动机(motivation)
1.针对如何提取到图像合适特征的问题,本文提出了多个分支的卷积分支,每个分支采用不同的感受野,并将图像分解成不同的感受野
2.针对如何为缺失区域寻找相似的patch,本文提出了马尔可夫随机场(ID-MRF)项,
3.针对缺失区域的修复结果有很多可能性的结果,提出了新的置信驱动的重建损失(与空间衰减损失类似),根据缺失区域的空间位置约束生成的内容
2.具体方法
训练的是一种端到端的方式,输入是X破损的图片和掩码M,缺损的区域的填充值为0,M是二进制掩码,0 代表已知的像素,1代表破损区域。
###2.1网络的架构
如上图所示,包含三个子网络。一个生成网络,一个全局和局部鉴别器网络,和一个预训练的VGG网络来计算ID-MRF loss。在测试阶段仅仅只有生成网络被使用。
生成器网络包含三个平行的编码-解码卷积结构的分支来提取输入数据(破损图片和掩码M)的不同水平的特征,一个共享的解码器网络将三个分支提取的特征(这里的特征图的尺寸是和原始图片大小一样大)进行concat组合起来作为输入,将组合的特征进行解码到自然图像的数据空间上去(即进行图像的修复)。如图2所示,三个分支使用不同的感受野进行特征提取。不同的感受野必然会导致最后得到的特征图的尺寸不一样大,那么三个分支的提取到的特征图就不好concat组合,本文是采用双线性插值进行上采样进行扩大特征图的尺寸。
虽然三个分支看上去是相互独立的,但是由于共享解码器,三者之间是互相影响的
2.2 ID-MRF Regularization
这一部分,解决上述语义结构匹配和计算量大的迭代MRF优化问题。计划是只在训练阶段采用mrf的正规化.ID-MRF是在特征空间上对生成区域(修复的区域)的内容和相应真实图片最近邻区域之间不同的优化。由于只在训练中使用它,完整的ground truth图像可以让我们知道高质量的最近邻,并给网络适当的约束。
? 要计算ID-MRF损失,可以简单地使用直接相似度度量(如余弦相似度)来找到生成内容中的补丁的最近邻居。但这一过程往往产生平滑的结构,因为一个平坦的区域容易连接到类似的模式,并迅速减少结构的多样性。我们采用相对距离度量[17,16,22]来建模局部特征与目标特征集之间的关系。它可以恢复如图3(b)所示的细微细节。
具体地,用Yg?Y_g^*Yg??代表对缺失区域的修复结果的内容,Yg?LY_g^{*L}Yg?L?和YLY^LYL分别代表来自预训练模型的第L层的特征。
patch v和s分别来自Yg?LY_g^{*L}Yg?L?和YLY^LYL,定义v与s的相对相似度为:
注意:Y是真实图片
这里的u(.,.)是计算余弦相似度。r∈ps(YL)r\in ps(Y^L)r∈ps(YL)意思是r是属于除了s的YLY^LYL,h 和?\epsilon?是两个正常数。如果v比YLY^LYL中的其他patch更像s, RS(v,s)会变大。
接下来,RS(v,s)归一化为:
最后,根据公式2,Yg?LY_g^{*L}Yg?L?和YLY^LYL之间的ID-MRF损失被定义为:
这里的Z是标准化参数,对于每一个属于YLY^LYL的patch s, v’=argmaxv∈Yg?LRS(v,s)?v’=arg max_{v\in Y_g^{*L} }RS(v,s)^*v’=argmaxv∈Yg?L??RS(v,s)?。
味着v‘相对于 Yg?LY_g^{*L}Yg?L?中的其他patch更加接近patch s。一个极端的例子是 Yg?LY_g^{*L}Yg?L?中的所有pathch都非常接近一个patch s。而其他的patch r
所以Lm(L)值更大。
另一个方面,当Yg?LY_g^{*L}Yg?L?中的patch与YLY^LYL中的候选者非常接近,YLY^LYL中的每一个 patch r在Yg?LY_g^{*L}Yg?L?中有一个唯一的最近邻。那么结果就是RS’(v,r)变大,LM(L)变小。
从这个观点出发,最小化LM(L)鼓励Yg?LY_g^{*L}Yg?L?中的每一个patch V都不同于YLY^LYL中的patch,使得变得多样化。
? 该方法的一个明显优点是提高了Yg?LY_g^{*L}Yg?L?和YLY^LYL特征分布之间的相似性。通过最小化ID-MRF损失,不仅局部神经patch在YLY^LYL中找到对应的候选纹理,而且特征分布更接近,有助于捕获复杂纹理的变化。
? 我们最终的ID-MRF损失是在VGG19的几个特征层上计算的。按照一般实践[5,14],我们使用conv4_2描述图像语义结构。然后利用conv3_2和conv4_2 4将图像纹理描述为:
2.3 Information Fusion
-
空间重建损失
破损区域距离边界近的应该比距离边界远的具有更加多的约束。
-
生成对抗损失
采用更加优化的w-GAN来实现
2.4最终的损失函数
###2.5训练方法
首先仅仅使用重建损失即将λmrf和λadv\lambda_{mrf}和\lambda_{adv}λmrf?和λadv?设置为0进行训练,来稳定后面的对抗训练。
模型G收敛后,我们设置λ mrf = 0.05和λ adv = 0.001进行微调直到收敛。利用Adam优化器[13]对训练过程进行优化,学习率为1e4。设β 1 = 0.5, β 2 = 0.9。批大小为16。
3. GMCNN的pytorch源码详解与实现
3.1训练配置代码,train_options.py
import argparse
import os
import timeclass TrainOptions:def __init__(self):self.parser = argparse.ArgumentParser()self.initialized = Falsedef initialize(self):# experiment specificsself.parser.add_argument('--dataset', type=str, default='Celebhq',help='dataset of the experiment.')#self.parser.add_argument('--data_file', type=str, default='', help='the file storing training image paths')self.parser.add_argument('--data_file', type=str, default='/root/workspace/pyproject/inpainting_gmcnn-master/pytorch/util/celeba_256_train.txt', help='the file storing training image paths')#这个文件里是存放的每张图片的绝对路径self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2')self.parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints', help='models are saved here')# self.parser.add_argument('--load_model_dir', type=str, default='', help='pretrained models are given here')self.parser.add_argument('--load_model_dir', type=str, default='/root/workspace/pyproject/inpainting_gmcnn-master/pytorch/checkpoints/20210509-164655_GMCNN_Celebhq_b8_s256x256_gc32_dc64_randmask-rect_pretrain', help='pretrained models are given here')self.parser.add_argument('--phase', type=str, default='train')# input/output sizes# self.parser.add_argument('--batch_size', type=int, default=16, help='input batch size')self.parser.add_argument('--batch_size', type=int, default=8, help='input batch size')# for setting inputsself.parser.add_argument('--random_crop', type=int, default=1,help='using random crop to process input image when ''the required size is smaller than the given size')self.parser.add_argument('--random_mask', type=int, default=1)self.parser.add_argument('--mask_type', type=str, default='rect')self.parser.add_argument('--pretrain_network', type=int, default=0)#wm,是否是预训练网络,1代表预训练,预训练是仅仅用重建损失训练生成网络,0代表微调网络,加上ID-MRF和生成对抗损失self.parser.add_argument('--lambda_adv', type=float, default=1e-3)self.parser.add_argument('--lambda_rec', type=float, default=1.4)self.parser.add_argument('--lambda_ae', type=float, default=1.2)self.parser.add_argument('--lambda_mrf', type=float, default=0.05)self.parser.add_argument('--lambda_gp', type=float, default=10)self.parser.add_argument('--random_seed', type=bool, default=False)self.parser.add_argument('--padding', type=str, default='SAME')self.parser.add_argument('--D_max_iters', type=int, default=5)#训练时,生成器每训练5次,然后更新一次鉴别器的网络self.parser.add_argument('--lr', type=float, default=1e-5, help='learning rate for training')self.parser.add_argument('--train_spe', type=int, default=1000)self.parser.add_argument('--epochs', type=int, default=40)self.parser.add_argument('--viz_steps', type=int, default=5)self.parser.add_argument('--spectral_norm', type=int, default=1)self.parser.add_argument('--img_shapes', type=str, default='256,256,3',help='given shape parameters: h,w,c or h,w')self.parser.add_argument('--mask_shapes', type=str, default='128,128',help='given mask parameters: h,w')self.parser.add_argument('--max_delta_shapes', type=str, default='32,32')self.parser.add_argument('--margins', type=str, default='0,0')# for generatorself.parser.add_argument('--g_cnum', type=int, default=32,help='# of generator filters in first conv layer')self.parser.add_argument('--d_cnum', type=int, default=64,help='# of discriminator filters in first conv layer')# for id-mrf computationself.parser.add_argument('--vgg19_path', type=str, default='vgg19_weights/imagenet-vgg-verydeep-19.mat')# for instance-wise featuresself.initialized = Truedef parse(self):if not self.initialized:self.initialize()self.opt = self.parser.parse_args()self.opt.dataset_path = self.opt.data_filestr_ids = self.opt.gpu_ids.split(',')self.opt.gpu_ids = []for str_id in str_ids:id = int(str_id)if id >= 0:self.opt.gpu_ids.append(str(id))assert self.opt.random_crop in [0, 1]self.opt.random_crop = True if self.opt.random_crop == 1 else Falseassert self.opt.random_mask in [0, 1]self.opt.random_mask = True if self.opt.random_mask == 1 else Falseassert self.opt.pretrain_network in [0, 1]self.opt.pretrain_network = True if self.opt.pretrain_network == 1 else Falseassert self.opt.spectral_norm in [0, 1]self.opt.spectral_norm = True if self.opt.spectral_norm == 1 else Falseassert self.opt.padding in ['SAME', 'MIRROR']assert self.opt.mask_type in ['rect', 'stroke']str_img_shapes = self.opt.img_shapes.split(',')self.opt.img_shapes = [int(x) for x in str_img_shapes]str_mask_shapes = self.opt.mask_shapes.split(',')self.opt.mask_shapes = [int(x) for x in str_mask_shapes]str_max_delta_shapes = self.opt.max_delta_shapes.split(',')self.opt.max_delta_shapes = [int(x) for x in str_max_delta_shapes]str_margins = self.opt.margins.split(',')self.opt.margins = [int(x) for x in str_margins]# model name and dateself.opt.date_str = time.strftime('%Y%m%d-%H%M%S')self.opt.model_name = 'GMCNN'self.opt.model_folder = self.opt.date_str + '_' + self.opt.model_nameself.opt.model_folder += '_' + self.opt.datasetself.opt.model_folder += '_b' + str(self.opt.batch_size)self.opt.model_folder += '_s' + str(self.opt.img_shapes[0]) + 'x' + str(self.opt.img_shapes[1])self.opt.model_folder += '_gc' + str(self.opt.g_cnum)self.opt.model_folder += '_dc' + str(self.opt.d_cnum)self.opt.model_folder += '_randmask-' + self.opt.mask_type if self.opt.random_mask else ''self.opt.model_folder += '_pretrain' if self.opt.pretrain_network else ''if os.path.isdir(self.opt.checkpoint_dir) is False:os.mkdir(self.opt.checkpoint_dir)self.opt.model_folder = os.path.join(self.opt.checkpoint_dir, self.opt.model_folder)if os.path.isdir(self.opt.model_folder) is False:os.mkdir(self.opt.model_folder)# set gpu idsif len(self.opt.gpu_ids) > 0:os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(self.opt.gpu_ids)args = vars(self.opt)print('------------ Options -------------')for k, v in sorted(args.items()):print('%s: %s' % (str(k), str(v)))print('-------------- End ----------------')return self.opt
3.2训练代码train.py
import os
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.utils as vutils
from tensorboardX import SummaryWriter
from data.data import InpaintingDataset, ToTensor
from model.net import InpaintingModel_GMCNN
from options.train_options import TrainOptions
from util.utils import getLatest
import tqdmconfig = TrainOptions().parse()#wm获取训练的配置信息超参数
print("训练配置信息config:",config)#wmprint('loading data........')
#wm,根据图片的绝对路径,加载数据集
dataset = InpaintingDataset(config.dataset_path, '', transform=transforms.Compose([ToTensor()#图片数据将会被转换成tensor,并且数值都在0-1之间
]))#wm,生成数据集的batch_size迭代器
dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, drop_last=True)
print('data load end.........')print('configuring model..')
ourModel = InpaintingModel_GMCNN(in_channels=4, opt=config)#wm,根据训练配置信息参数,实例化一个GMCNN模型ourModel.print_networks()#打印模型的网络if config.load_model_dir != '':print('Loading pretrained model from {}'.format(config.load_model_dir))ourModel.load_networks(getLatest(os.path.join(config.load_model_dir, '*.pth')))print('Loading done.')
# ourModel = torch.nn.DataParallel(ourModel).cuda()
print('model setting up..')
print('training initializing..')writer = SummaryWriter(log_dir=config.model_folder)#使用tensorboardX实例化一个日志类cnt = 0#用来记录训练了多少个batch_size
#config.epochs=30
for epoch in range(config.epochs):for i, data in enumerate(dataloader):gt = data['gt'].cuda()# normalize to values between -1 and 1,gt = gt / 127.5 - 1data_in = {
'gt': gt}ourModel.setInput(data_in)#wm,将一个batch_size里的图片送入网络ourModel.optimize_parameters()#wm,通过这一个batch_size的数据对网络进行训练优化参数if (i+1) % config.viz_steps == 0: #viz_steps=5ret_loss = ourModel.get_current_losses()#wm,得到当前这个一个batch数据计算到的各种损失值if config.pretrain_network is False:print('[%d, %5d] G_loss: %.4f (rec: %.4f, ae: %.4f, adv: %.4f, mrf: %.4f), D_loss: %.4f'% (epoch + 1, i + 1, ret_loss['G_loss'], ret_loss['G_loss_rec'], ret_loss['G_loss_ae'],ret_loss['G_loss_adv'], ret_loss['G_loss_mrf'], ret_loss['D_loss']))writer.add_scalar('adv_loss', ret_loss['G_loss_adv'], cnt)writer.add_scalar('D_loss', ret_loss['D_loss'], cnt)writer.add_scalar('G_mrf_loss', ret_loss['G_loss_mrf'], cnt)else:print('[%d, %5d] G_loss: %.4f (rec: %.4f, ae: %.4f)'% (epoch + 1, i + 1, ret_loss['G_loss'], ret_loss['G_loss_rec'], ret_loss['G_loss_ae']))#wm,将各种损失的值添加到日志类writer中,cnt是训练了第多少个batch_sizewriter.add_scalar('G_loss', ret_loss['G_loss'], cnt)writer.add_scalar('reconstruction_loss', ret_loss['G_loss_rec'], cnt)writer.add_scalar('autoencoder_loss', ret_loss['G_loss_ae'], cnt)#images中包含了三中类型的图images = ourModel.get_current_visuals_tensor()im_completed = vutils.make_grid(images['completed'], normalize=True, scale_each=True)#修复的图im_input = vutils.make_grid(images['input'], normalize=True, scale_each=True)#输入的带掩码的图im_gt = vutils.make_grid(images['gt'], normalize=True, scale_each=True)#真实的图# wm,将训练过程中产生的图添加到日志类writer中,cnt是训练了第多少个batch_sizewriter.add_image('gt', im_gt, cnt)writer.add_image('input', im_input, cnt)writer.add_image('completed', im_completed, cnt)#wm,每训练1000个batch_size,就保存一次模型if (i+1) % config.train_spe == 0:#wm,train_spe=1000print('saving model ..')ourModel.save_networks(epoch+1)cnt += 1ourModel.save_networks(epoch+1)#保存最后一个epoch的模型writer.export_scalars_to_json(os.path.join(config.model_folder, 'GMCNN_scalars.json'))
writer.close()
3.3搭建GMCNN网络net.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.basemodel import BaseModel
from model.basenet import BaseNet
from model.loss import WGANLoss, IDMRFLoss
from model.layer import init_weights, PureUpsampling, ConfidenceDrivenMaskLayer, SpectralNorm
import numpy as np# generative multi-column convolutional neural net
#1.GMCNN的分支卷积网络,即修复器的网络,用不同的感受野来进行特征提取
class GMCNN(BaseNet):def __init__(self, in_channels, out_channels, cnum=32, act=F.elu, norm=F.instance_norm, using_norm=False):super(GMCNN, self).__init__()self.act = actself.using_norm = using_normif using_norm is True:self.norm = normelse:self.norm = Nonech = cnum# network structureself.EB1 = []#wm,第一个分支self.EB2 = []#wm,第二个分支self.EB3 = []#wm,第三个分支self.decoding_layers = []#一个共享的解码器层self.EB1_pad_rec = []self.EB2_pad_rec = []self.EB3_pad_rec = []self.EB1.append(nn.Conv2d(in_channels, ch, kernel_size=7, stride=1))self.EB1.append(nn.Conv2d(ch, ch * 2, kernel_size=7, stride=2))self.EB1.append(nn.Conv2d(ch * 2, ch * 2, kernel_size=7, stride=1))self.EB1.append(nn.Conv2d(ch * 2, ch * 4, kernel_size=7, stride=2))self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1))self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1))self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1, dilation=2))self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1, dilation=4))self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1, dilation=8))self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1, dilation=16))self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1))self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1))self.EB1.append(PureUpsampling(scale=4))self.EB1_pad_rec = [3, 3, 3, 3, 3, 3, 6, 12, 24, 48, 3, 3, 0]self.EB2.append(nn.Conv2d(in_channels, ch, kernel_size=5, stride=1))self.EB2.append(nn.Conv2d(ch, ch * 2, kernel_size=5, stride=2))self.EB2.append(nn.Conv2d(ch * 2, ch * 2, kernel_size=5, stride=1))self.EB2.append(nn.Conv2d(ch * 2, ch * 4, kernel_size=5, stride=2))self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1))self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1))self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1, dilation=2))self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1, dilation=4))self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1, dilation=8))self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1, dilation=16))self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1))self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1))self.EB2.append(PureUpsampling(scale=2, mode='nearest'))self.EB2.append(nn.Conv2d(ch * 4, ch * 2, kernel_size=5, stride=1))self.EB2.append(nn.Conv2d(ch * 2, ch * 2, kernel_size=5, stride=1))self.EB2.append(PureUpsampling(scale=2))self.EB2_pad_rec = [2, 2, 2, 2, 2, 2, 4, 8, 16, 32, 2, 2, 0, 2, 2, 0]self.EB3.append(nn.Conv2d(in_channels, ch, kernel_size=3, stride=1))self.EB3.append(nn.Conv2d(ch, ch * 2, kernel_size=3, stride=2))self.EB3.append(nn.Conv2d(ch * 2, ch * 2, kernel_size=3, stride=1))self.EB3.append(nn.Conv2d(ch * 2, ch * 4, kernel_size=3, stride=2))self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1))self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1))self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1, dilation=2))self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1, dilation=4))self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1, dilation=8))self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1, dilation=16))self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1))self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1))self.EB3.append(PureUpsampling(scale=2, mode='nearest'))self.EB3.append(nn.Conv2d(ch * 4, ch * 2, kernel_size=3, stride=1))self.EB3.append(nn.Conv2d(ch * 2, ch * 2, kernel_size=3, stride=1))self.EB3.append(PureUpsampling(scale=2, mode='nearest'))self.EB3.append(nn.Conv2d(ch * 2, ch, kernel_size=3, stride=1))self.EB3.append(nn.Conv2d(ch, ch, kernel_size=3, stride=1))self.EB3_pad_rec = [1, 1, 1, 1, 1, 1, 2, 4, 8, 16, 1, 1, 0, 1, 1, 0, 1, 1]self.decoding_layers.append(nn.Conv2d(ch * 7, ch // 2, kernel_size=3, stride=1))self.decoding_layers.append(nn.Conv2d(ch // 2, out_channels, kernel_size=3, stride=1))self.decoding_pad_rec = [1, 1]self.EB1 = nn.ModuleList(self.EB1)#将列表模块连接组合成网络结构self.EB2 = nn.ModuleList(self.EB2)self.EB3 = nn.ModuleList(self.EB3)self.decoding_layers = nn.ModuleList(self.decoding_layers)# padding operationspadlen = 49self.pads = [0] * padlenfor i in range(padlen):self.pads[i] = nn.ReflectionPad2d(i)self.pads = nn.ModuleList(self.pads)def forward(self, x):#将一张图片复制三份,分别送入三个分支x1, x2, x3 = x, x, xfor i, layer in enumerate(self.EB1):pad_idx = self.EB1_pad_rec[i]x1 = layer(self.pads[pad_idx](x1))#对特征图外围进行padding,然后进行卷积操作if self.using_norm:x1 = self.norm(x1)if pad_idx != 0:x1 = self.act(x1)#分支1的特征图结果for i, layer in enumerate(self.EB2):pad_idx = self.EB2_pad_rec[i]x2 = layer(self.pads[pad_idx](x2))if self.using_norm:x2 = self.norm(x2)if pad_idx != 0:x2 = self.act(x2)#分支2的特征图结果for i, layer in enumerate(self.EB3):pad_idx = self.EB3_pad_rec[i]x3 = layer(self.pads[pad_idx](x3))if self.using_norm:x3 = self.norm(x3)if pad_idx != 0:x3 = self.act(x3)#分支3的特征图结果x_d = torch.cat((x1, x2, x3), 1)#wm,将三个分支的结果cat一起#wm,经过编码器x_d = self.act(self.decoding_layers[0](self.pads[self.decoding_pad_rec[0]](x_d)))x_d = self.decoding_layers[1](self.pads[self.decoding_pad_rec[1]](x_d))x_out = torch.clamp(x_d, -1, 1)#wm,将值限制在-1,到1之间return x_out#返回的是一个batch_size的图片数据,数据类型是tensor,值的范围在(-1,1)# return one dimensional output indicating the probability of realness or fakeness
#2.基础鉴别器模块
class Discriminator(BaseNet):def __init__(self, in_channels, cnum=32, fc_channels=8*8*32*4, act=F.elu, norm=None, spectral_norm=True):super(Discriminator, self).__init__()self.act = actself.norm = normself.embedding = Noneself.logit = Nonech = cnumself.layers = []if spectral_norm:self.layers.append(SpectralNorm(nn.Conv2d(in_channels, ch, kernel_size=5, padding=2, stride=2)))self.layers.append(SpectralNorm(nn.Conv2d(ch, ch * 2, kernel_size=5, padding=2, stride=2)))self.layers.append(SpectralNorm(nn.Conv2d(ch * 2, ch * 4, kernel_size=5, padding=2, stride=2)))self.layers.append(SpectralNorm(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, padding=2, stride=2)))self.layers.append(SpectralNorm(nn.Linear(fc_channels, 1)))#返回一个标量,代表对图片的打分,对真实的图片打的高,对修复的图打分低else:self.layers.append(nn.Conv2d(in_channels, ch, kernel_size=5, padding=2, stride=2))self.layers.append(nn.Conv2d(ch, ch * 2, kernel_size=5, padding=2, stride=2))self.layers.append(nn.Conv2d(ch*2, ch*4, kernel_size=5, padding=2, stride=2))self.layers.append(nn.Conv2d(ch*4, ch*4, kernel_size=5, padding=2, stride=2))self.layers.append(nn.Linear(fc_channels, 1))#返回一个标量,代表对图片的打分,对真实的图片打的高,对修复的图打分低self.layers = nn.ModuleList(self.layers)#将列表里面的模块连接组合成网络结构def forward(self, x):for layer in self.layers[:-1]:x = layer(x)if self.norm is not None:x = self.norm(x)x = self.act(x)self.embedding = x.view(x.size(0), -1)#将卷积得到的特征图展成一维向量self.logit = self.layers[-1](self.embedding)return self.logit#返回一个标量,代表对图片的打分,对真实的图片打的高,对修复的图打分低#3综合鉴别器,利用基础鉴别器模块,将全局鉴别器和局部鉴别器组合在一起,区别在于特征图的尺寸不同,即最后一层展成一维向量后长度不同
class GlobalLocalDiscriminator(BaseNet):def __init__(self, in_channels, cnum=32, g_fc_channels=16*16*32*4, l_fc_channels=8*8*32*4, act=F.elu, norm=None,spectral_norm=True):super(GlobalLocalDiscriminator, self).__init__()self.act = actself.norm = normself.global_discriminator = Discriminator(in_channels=in_channels, fc_channels=g_fc_channels, cnum=cnum,act=act, norm=norm, spectral_norm=spectral_norm)self.local_discriminator = Discriminator(in_channels=in_channels, fc_channels=l_fc_channels, cnum=cnum,act=act, norm=norm, spectral_norm=spectral_norm)def forward(self, x_g, x_l):x_global = self.global_discriminator(x_g)x_local = self.local_discriminator(x_l)return x_global, x_local#放回的是全局鉴别器的得分,局部鉴别器的得分from util.utils import generate_mask#4.利用前面的模块,组合成GMCNN的修复模型
class InpaintingModel_GMCNN(BaseModel):def __init__(self, in_channels, act=F.elu, norm=None, opt=None):super(InpaintingModel_GMCNN, self).__init__()self.opt = optself.init(opt)#得到一个计算损失的掩码权重,完好处的像素的掩码处权重较大,缺失区域的掩码权重相对较小,呈高斯形状self.confidence_mask_layer = ConfidenceDrivenMaskLayer()#实例化一个修复器self.netGM = GMCNN(in_channels, out_channels=3, cnum=opt.g_cnum, act=act, norm=norm).cuda() #wm,三个平行网络+一个解码器,并放到cuda上init_weights(self.netGM)#wm,初始化网络self.model_names = ['GM']if self.opt.phase == 'test':returnself.netD = None#wm,将生成器的网络参数,放入Adam优化器中self.optimizer_G = torch.optim.Adam(self.netGM.parameters(), lr=opt.lr, betas=(0.5, 0.9))self.optimizer_D = Noneself.wganloss = Noneself.recloss = nn.L1Loss()self.aeloss = nn.L1Loss()self.mrfloss = Noneself.lambda_adv = opt.lambda_adv#生成对抗损失权重的超参数self.lambda_rec = opt.lambda_rec#重建损失的超参数self.lambda_ae = opt.lambda_aeself.lambda_gp = opt.lambda_gp#w-gan的中超参数self.lambda_mrf = opt.lambda_mrf#mrf损失的权重超参数self.G_loss = Noneself.G_loss_reconstruction = Noneself.G_loss_mrf = Noneself.G_loss_adv, self.G_loss_adv_local = None, Noneself.G_loss_ae = Noneself.D_loss, self.D_loss_local = None, Noneself.GAN_loss = Noneself.gt, self.gt_local = None, Noneself.mask, self.mask_01 = None, Noneself.rect = Noneself.im_in, self.gin = None, Noneself.completed, self.completed_local = None, Noneself.completed_logit, self.completed_local_logit = None, Noneself.gt_logit, self.gt_local_logit = None, Noneself.pred = None#wm,如果不是对模型进行预训练,需要实例化一个鉴别器网络,这里的预训练指的是对模型仅仅用重建损失进行预训练:if self.opt.pretrain_network is False:if self.opt.mask_type == 'rect':self.netD = GlobalLocalDiscriminator(3, cnum=opt.d_cnum, act=act,g_fc_channels=opt.img_shapes[0]//16*opt.img_shapes[1]//16*opt.d_cnum*4,l_fc_channels=opt.mask_shapes[0]//16*opt.mask_shapes[1]//16*opt.d_cnum*4,spectral_norm=self.opt.spectral_norm).cuda()else:self.netD = GlobalLocalDiscriminator(3, cnum=opt.d_cnum, act=act,spectral_norm=self.opt.spectral_norm,g_fc_channels=opt.img_shapes[0]//16*opt.img_shapes[1]//16*opt.d_cnum*4,l_fc_channels=opt.img_shapes[0]//16*opt.img_shapes[1]//16*opt.d_cnum*4).cuda()init_weights(self.netD)#初始化鉴别器self.optimizer_D = torch.optim.Adam(filter(lambda x: x.requires_grad, self.netD.parameters()), lr=opt.lr,betas=(0.5, 0.9))#将鉴别器的网络参数放到Adam优化器中self.wganloss = WGANLoss()#实例化WGAN损失self.mrfloss = IDMRFLoss()#实例化IDMRF损失#初始化各种变量,并获得输入生成器网络的输入图片数据def initVariables(self):self.gt = self.input['gt']#获取一个batch_size的真图mask, rect = generate_mask(self.opt.mask_type, self.opt.img_shapes, self.opt.mask_shapes)#wm,生成掩码,和矩形空洞的位置self.mask_01 = torch.from_numpy(mask).cuda().repeat([self.opt.batch_size, 1, 1, 1])#0代表完好区域,1代表缺失区域,从numpy格式转换成tensorself.mask = self.confidence_mask_layer(self.mask_01)#掩码权重参数,用来计算重建损失时用的if self.opt.mask_type == 'rect':self.rect = [rect[0, 0], rect[0, 1], rect[0, 2], rect[0, 3]]#用来得到局部的真实图self.gt_local = self.gt[:, :, self.rect[0]:self.rect[0] + self.rect[1],self.rect[2]:self.rect[2] + self.rect[3]]else:self.gt_local = self.gtself.im_in = self.gt * (1 - self.mask_01)#只有完好区域为原始的真实值,空洞区域的值为0self.gin = torch.cat((self.im_in, self.mask_01), 1)#这是最开始输入修复网络中的图片数据,4个通道#前向计算生成器,得到生成器的各种损失def forward_G(self):self.G_loss_reconstruction = self.recloss(self.completed * self.mask, self.gt.detach() * self.mask)#计算最终修复的结果和真实图的损失,并用了掩码权重self.G_loss_reconstruction = self.G_loss_reconstruction / torch.mean(self.mask_01)self.G_loss_ae = self.aeloss(self.pred * (1 - self.mask_01), self.gt.detach() * (1 - self.mask_01))#计算原本完好区域和预测出的完好区域的损失self.G_loss_ae = self.G_loss_ae / torch.mean(1 - self.mask_01)self.G_loss = self.lambda_rec * self.G_loss_reconstruction + self.lambda_ae * self.G_loss_ae#给重建损失乘以相关权重系数if self.opt.pretrain_network is False:#如果不是预训练,那么还得计算生成对抗损失和ID-MRF损失# discriminatorself.completed_logit, self.completed_local_logit = self.netD(self.completed, self.completed_local)#获取鉴别器网络对修复的图的全局打分和局部打分self.G_loss_mrf = self.mrfloss((self.completed_local+1)/2.0, (self.gt_local.detach()+1)/2.0)#计算ID-MRF损失self.G_loss = self.G_loss + self.lambda_mrf * self.G_loss_mrf#生成器的损失加上ID-MRF损失self.G_loss_adv = -self.completed_logit.mean()#生成对抗的全局损失self.G_loss_adv_local = -self.completed_local_logit.mean()#生成对抗的局部损失self.G_loss = self.G_loss + self.lambda_adv * (self.G_loss_adv + self.G_loss_adv_local)#总的损失# 前向计算鉴别器,得到鉴别器的各种损失def forward_D(self):self.completed_logit, self.completed_local_logit = self.netD(self.completed.detach(), self.completed_local.detach())#d对修复图片的全局和局部鉴别打分self.gt_logit, self.gt_local_logit = self.netD(self.gt, self.gt_local)#对真实图片全局和局部的鉴别打分# hinge lossself.D_loss_local = nn.ReLU()(1.0 - self.gt_local_logit).mean() + nn.ReLU()(1.0 + self.completed_local_logit).mean()#对局部图片的鉴别器的损失self.D_loss = nn.ReLU()(1.0 - self.gt_logit).mean() + nn.ReLU()(1.0 + self.completed_logit).mean()#对全局图片鉴别器的损失self.D_loss = self.D_loss + self.D_loss_local#反向传播计算生成器的梯度def backward_G(self):self.G_loss.backward()#反向传播计算鉴别器的梯度def backward_D(self):self.D_loss.backward(retain_graph=True)#进行数据流的正向流动def optimize_parameters(self):self.initVariables()self.pred = self.netGM(self.gin)#将破损图片送入修复网络中进行修复,得到预测结果self.completed = self.pred * self.mask_01 + self.gt * (1 - self.mask_01)#将预测得到的图片,完好区域用以前的真值进行替换,那么就得到了最终的修复结果if self.opt.mask_type == 'rect':self.completed_local = self.completed[:, :, self.rect[0]:self.rect[0] + self.rect[1],self.rect[2]:self.rect[2] + self.rect[3]]else:self.completed_local = self.completedif self.opt.pretrain_network is False:#如果不是预训练阶段的仅仅用重建损失训练生成器网络,那么还有生成对抗损失for i in range(self.opt.D_max_iters):self.optimizer_D.zero_grad()#鉴别器网络的梯度清为0self.optimizer_G.zero_grad()#生成器网络的梯度清为0self.forward_D()#正向传播鉴别器self.backward_D()#反向传播self.optimizer_D.step()#更新鉴别器的网络参数self.optimizer_G.zero_grad()#生成器网络的梯度清为0self.forward_G()#生成器正向传播self.backward_G()#生成器反向传播self.optimizer_G.step()#更新生成器的网络参数#返回当前所有的损失,采用字典结构数据进行返回def get_current_losses(self):l = {
'G_loss': self.G_loss.item(), 'G_loss_rec': self.G_loss_reconstruction.item(),'G_loss_ae': self.G_loss_ae.item()}#如果是预训练阶段只有重建损失if self.opt.pretrain_network is False:l.update({
'G_loss_adv': self.G_loss_adv.item(),'G_loss_adv_local': self.G_loss_adv_local.item(),'D_loss': self.D_loss.item(),'G_loss_mrf': self.G_loss_mrf.item()})return l#得到当前的网络输入图片,真实图片,最终修复得到的图片,图片的数据是tensor格式def get_current_visuals(self):return {
'input': self.im_in.cpu().detach().numpy(), 'gt': self.gt.cpu().detach().numpy(),'completed': self.completed.cpu().detach().numpy()}#得到当前的网络输入图片,真实图片,最终修复得到的图片,图片的数据是tensor格式def get_current_visuals_tensor(self):return {
'input': self.im_in.cpu().detach(), 'gt': self.gt.cpu().detach(),'completed': self.completed.cpu().detach()}#对图片进行评估def evaluate(self, im_in, mask):im_in = torch.from_numpy(im_in).type(torch.FloatTensor).cuda() / 127.5 - 1mask = torch.from_numpy(mask).type(torch.FloatTensor).cuda()im_in = im_in * (1-mask)xin = torch.cat((im_in, mask), 1)ret = self.netGM(xin) * mask + im_in * (1-mask)ret = (ret.cpu().detach().numpy() + 1) * 127.5return ret.astype(np.uint8)
3.4一些常用的loss.py,包括有ID-MRF loss
import torch
import torch.nn as nn
import torch.autograd as autograd
import torch.nn.functional as F
from model.layer import VGG19FeatLayer
from functools import reduceclass WGANLoss(nn.Module):def __init__(self):super(WGANLoss, self).__init__()def __call__(self, input, target):d_loss = (input - target).mean()g_loss = -input.mean()return {
'g_loss': g_loss, 'd_loss': d_loss}def gradient_penalty(xin, yout, mask=None):gradients = autograd.grad(yout, xin, create_graph=True,grad_outputs=torch.ones(yout.size()).cuda(), retain_graph=True, only_inputs=True)[0]if mask is not None:gradients = gradients * maskgradients = gradients.view(gradients.size(0), -1)gp = ((gradients.norm(2, dim=1) - 1) ** 2).mean()return gpdef random_interpolate(gt, pred):batch_size = gt.size(0)alpha = torch.rand(batch_size, 1, 1, 1).cuda()# alpha = alpha.expand(gt.size()).cuda()interpolated = gt * alpha + pred * (1 - alpha)return interpolatedclass IDMRFLoss(nn.Module):def __init__(self, featlayer=VGG19FeatLayer):super(IDMRFLoss, self).__init__()self.featlayer = featlayer()self.feat_style_layers = {
'relu3_2': 1.0, 'relu4_2': 1.0}self.feat_content_layers = {
'relu4_2': 1.0}self.bias = 1.0self.nn_stretch_sigma = 0.5self.lambda_style = 1.0self.lambda_content = 1.0def sum_normalize(self, featmaps):reduce_sum = torch.sum(featmaps, dim=1, keepdim=True)return featmaps / reduce_sumdef patch_extraction(self, featmaps):patch_size = 1patch_stride = 1patches_as_depth_vectors = featmaps.unfold(2, patch_size, patch_stride).unfold(3, patch_size, patch_stride)self.patches_OIHW = patches_as_depth_vectors.permute(0, 2, 3, 1, 4, 5)dims = self.patches_OIHW.size()self.patches_OIHW = self.patches_OIHW.view(-1, dims[3], dims[4], dims[5])return self.patches_OIHWdef compute_relative_distances(self, cdist):epsilon = 1e-5div = torch.min(cdist, dim=1, keepdim=True)[0]relative_dist = cdist / (div + epsilon)return relative_distdef exp_norm_relative_dist(self, relative_dist):scaled_dist = relative_distdist_before_norm = torch.exp((self.bias - scaled_dist)/self.nn_stretch_sigma)self.cs_NCHW = self.sum_normalize(dist_before_norm)return self.cs_NCHWdef mrf_loss(self, gen, tar):meanT = torch.mean(tar, 1, keepdim=True)gen_feats, tar_feats = gen - meanT, tar - meanTgen_feats_norm = torch.norm(gen_feats, p=2, dim=1, keepdim=True)tar_feats_norm = torch.norm(tar_feats, p=2, dim=1, keepdim=True)gen_normalized = gen_feats / gen_feats_normtar_normalized = tar_feats / tar_feats_normcosine_dist_l = []BatchSize = tar.size(0)for i in range(BatchSize):tar_feat_i = tar_normalized[i:i+1, :, :, :]gen_feat_i = gen_normalized[i:i+1, :, :, :]patches_OIHW = self.patch_extraction(tar_feat_i)cosine_dist_i = F.conv2d(gen_feat_i, patches_OIHW)cosine_dist_l.append(cosine_dist_i)cosine_dist = torch.cat(cosine_dist_l, dim=0)cosine_dist_zero_2_one = - (cosine_dist - 1) / 2relative_dist = self.compute_relative_distances(cosine_dist_zero_2_one)rela_dist = self.exp_norm_relative_dist(relative_dist)dims_div_mrf = rela_dist.size()k_max_nc = torch.max(rela_dist.view(dims_div_mrf[0], dims_div_mrf[1], -1), dim=2)[0]div_mrf = torch.mean(k_max_nc, dim=1)div_mrf_sum = -torch.log(div_mrf)div_mrf_sum = torch.sum(div_mrf_sum)return div_mrf_sumdef forward(self, gen, tar):gen_vgg_feats = self.featlayer(gen)tar_vgg_feats = self.featlayer(tar)style_loss_list = [self.feat_style_layers[layer] * self.mrf_loss(gen_vgg_feats[layer], tar_vgg_feats[layer]) for layer in self.feat_style_layers]self.style_loss = reduce(lambda x, y: x+y, style_loss_list) * self.lambda_style#reduce函数会对元素进行积累content_loss_list = [self.feat_content_layers[layer] * self.mrf_loss(gen_vgg_feats[layer], tar_vgg_feats[layer]) for layer in self.feat_content_layers]self.content_loss = reduce(lambda x, y: x+y, content_loss_list) * self.lambda_contentreturn self.style_loss + self.content_lossclass StyleLoss(nn.Module):def __init__(self, featlayer=VGG19FeatLayer, style_layers=None):super(StyleLoss, self).__init__()self.featlayer = featlayer()if style_layers is not None:self.feat_style_layers = style_layerselse:self.feat_style_layers = {
'relu2_2': 1.0, 'relu3_2': 1.0, 'relu4_2': 1.0}def gram_matrix(self, x):b, c, h, w = x.size()feats = x.view(b * c, h * w)g = torch.mm(feats, feats.t())return g.div(b * c * h * w)def _l1loss(self, gen, tar):return torch.abs(gen-tar).mean()def forward(self, gen, tar):gen_vgg_feats = self.featlayer(gen)tar_vgg_feats = self.featlayer(tar)style_loss_list = [self.feat_style_layers[layer] * self._l1loss(self.gram_matrix(gen_vgg_feats[layer]), self.gram_matrix(tar_vgg_feats[layer])) forlayer in self.feat_style_layers]style_loss = reduce(lambda x, y: x + y, style_loss_list)return style_lossclass ContentLoss(nn.Module):def __init__(self, featlayer=VGG19FeatLayer, content_layers=None):super(ContentLoss, self).__init__()self.featlayer = featlayer()if content_layers is not None:self.feat_content_layers = content_layerselse:self.feat_content_layers = {
'relu4_2': 1.0}def _l1loss(self, gen, tar):return torch.abs(gen-tar).mean()def forward(self, gen, tar):gen_vgg_feats = self.featlayer(gen)tar_vgg_feats = self.featlayer(tar)content_loss_list = [self.feat_content_layers[layer] * self._l1loss(gen_vgg_feats[layer], tar_vgg_feats[layer]) forlayer in self.feat_content_layers]content_loss = reduce(lambda x, y: x + y, content_loss_list)return content_lossclass TVLoss(nn.Module):def __init__(self):super(TVLoss, self).__init__()def forward(self, x):h_x, w_x = x.size()[2:]h_tv = torch.abs(x[:, :, 1:, :] - x[:, :, :h_x-1, :])w_tv = torch.abs(x[:, :, :, 1:] - x[:, :, :, :w_x-1])loss = torch.sum(h_tv) + torch.sum(w_tv)return loss
4参考文献
4.1原论文
Image Inpainting via Generative Multi-column
Convolutional Neural Networks
4.2源码
https://github.com/shepnerd/inpainting_gmcnn