当前位置: 代码迷 >> 综合 >> 图像补全 -- 使用飞桨复现 SIGGRAPH 2017 论文 Globally and Locally Consistent Image Completion
  详细解决方案

图像补全 -- 使用飞桨复现 SIGGRAPH 2017 论文 Globally and Locally Consistent Image Completion

热度:92   发布时间:2024-02-11 18:56:18.0

图像补全 – 使用飞桨复现 SIGGRAPH 2017 论文 Globally and Locally Consistent Image Completion

本项目代码使用 PaddlePaddle 框架进行实现

应用场景:图像补全(Image completion),目标移除(Object remove)

本文项目代码 GitHub 地址:https://github.com/Eric-Hjx/PaddlePaddle_Image_Completion

本文项目代码 AI Studio 地址:https://aistudio.baidu.com/aistudio/projectdetail/632313

图像补全 – Globally and Locally Consistent Image Completion

  • 文章来源:SIGGRAPH 2017

  • 下载链接:Globally and Locally Consistent Image Completion

  • 应用场景:图像补全(Image completion),目标移除(Object remove)

  • 使用的数据集是CelebA人脸数据集,已下载放置在此项目的数据集中

  • 目标
    进行图像填充,填充任意形状的缺失区域来完成任意分辨率的图像。

在此篇论文中,作者们提出了一中图像补全方法,可以使得图像的缺失部分自动补全,局部和整图保持一致。作者通过全卷积网络,可以补全图片中任何形状的缺失,为了保持补全后的图像与原图的一致性,作者使用全局(整张图片)和局部(缺失补全部分)两种鉴别器来训练。全局鉴别器查看整个图像以评估它是否作为整体是连贯的,而局部鉴别器仅查看以完成区域为中心的小区域来确保所生成的补丁的局部一致性。 然后对图像补全网络训练以欺骗两个内容鉴别器网络,这要求它生成总体以及细节上与真实无法区分的图像。我们证明了我们的方法可以用来完成各种各样的场景。 此外,与PatchMatch等基于补丁的方法相比,我们的方法可以生成图像中未出现的碎片,这使我们能够自然地完成具有熟悉且高度特定的结构(如面部)的对象的图像。

  • 网络构造
    完成网络:完成网络是完全卷积的,用来修复图像。
    全局上下文鉴别器:以完整的图像作为输入,识别场景的全局一致性。
    局部上下文鉴别器:只关注完成区域周围的一个小区域,以判断更详细的外观质量。
    对图像完成网络进行训练,以欺骗两个上下文鉴别器网络,这要求它生成在总体一致性和细节方面与真实图像无法区分的图像。

  • 网络架构:

此网络由一个完成网络和两个辅助上下文鉴别器网络组成,这两个鉴别器网络只用于训练完成网络,在测试过程中不使用。全局鉴别器网络以整个图像为输入,而局部鉴别器网络仅以完成区域周围的一小块区域作为输入。训练两个鉴别器网络以确定图像是真实的还是由完成网络完成的,而生成网络被训练来欺骗两个鉴别器网络,使生成的图像达到真实图像的水平。
image

  • 补全网络结构
    补全网络先利用卷积降低图片的分辨率然后利用去卷积增大图片的分辨率得到修复结果。为了保证生成区域尽量不模糊,文中降低分辨率的操作是使用strided convolution 的方式进行的,而且只用了两次,将图片的size 变为原来的四分之一。同时在中间层还使用了空洞卷积来增大感受野,在尽量获取更大范围内的图像信息的同时不损失额外的信息。

image

  • 内容鉴别器
    这些网络基于卷积神经网络,将图像压缩成小特征向量。 网络的输出通过连接层融合在一起,连接层预测出图像是真实的概率的一个连续值。网络结构如下:
    image

注: 原文作者使用4个K80 GPU,使用的输入图像大小是256256,训练了2个月才训练完成。
本项目为了缩短训练时间,对训练方式做了简化,使用的输入图像大小:128
128,训练方式改为:先训练生成器再将生成器和判别器一起训练。(此方法可以缩短训练时间但训练效果不及原文效果)

代码实现

  • 将图片读取完保存在.npy文件中
# to_npy 
ratio = 0.95
image_size = 128x = []
paths = glob.glob('./img_align_celeba/*.jpg')
for path in paths[:20000]:# 读图img = cv2.imread(path)# resizeimg = cv2.resize(img, (image_size, image_size))# 改通道img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)x.append(img)x = np.array(x, dtype=np.float32)
np.random.shuffle(x)p = int(ratio * len(x))
x_train = x[:p]
x_test = x[p:]
if not os.path.exists('./npy'):os.mkdir('./npy')
np.save('./npy/x_train.npy', x_train)
np.save('./npy/x_test.npy', x_test)
  • 定义加载数据的函数
# 加载数据
def load(dir_='./npy'):x_train = np.load(os.path.join(dir_, 'x_train.npy'))x_test = np.load(os.path.join(dir_, 'x_test.npy'))return x_train, x_test
  • 定义L2——loss
# L2_loss
def L2_loss(yhat, y):loss = np.dot(y-yhat, y-yhat)loss.astype(np.float32)return loss
  • 构造对原图挖洞的mask矩阵
# 原图挖洞,构造mask
def get_points():points = []mask = []for i in range(BATCH_SIZE):# 构造localx1, y1 = np.random.randint(0, IMAGE_SIZE - LOCAL_SIZE + 1, 2)x2, y2 = np.array([x1, y1]) + LOCAL_SIZEpoints.append([x1, y1, x2, y2])# local中挖洞w, h = np.random.randint(HOLE_MIN, HOLE_MAX + 1, 2)p1 = x1 + np.random.randint(0, LOCAL_SIZE - w)q1 = y1 + np.random.randint(0, LOCAL_SIZE - h)p2 = p1 + wq2 = q1 + h# 构造maskm = np.zeros((IMAGE_SIZE, IMAGE_SIZE, 1), dtype=np.float32)m[q1:q2 + 1, p1:p2 + 1] = 1mask.append(m)return np.array(points), np.array(mask)
  • 搭建网络
# 搭建网络
def generator(x):print('x', x.shape)# conv1conv1 = fluid.layers.conv2d(input=x,num_filters=64,filter_size=5,dilation=1,stride=1,padding='SAME',name='generator_conv1',data_format='NHWC')print('conv1', conv1.shape)conv1 = fluid.layers.batch_norm(conv1, momentum=0.99, epsilon=0.001)conv1 = fluid.layers.relu(conv1, name=None)# conv2conv2 = fluid.layers.conv2d(input=conv1,num_filters=128,filter_size=3,dilation=1,stride=2,padding='SAME',name='generator_conv2',data_format='NHWC')print('conv2', conv2.shape)conv2 = fluid.layers.batch_norm(conv2, momentum=0.99, epsilon=0.001)conv2 = fluid.layers.relu(conv2, name=None)# conv3conv3 = fluid.layers.conv2d(input=conv2,num_filters=128,filter_size=3,dilation=1,stride=1,padding='SAME',name='generator_conv3',data_format='NHWC')print('conv3', conv3.shape)conv3 = fluid.layers.batch_norm(conv3, momentum=0.99, epsilon=0.001)conv3 = fluid.layers.relu(conv3, name=None)# conv4conv4 = fluid.layers.conv2d(input=conv3,num_filters=256,filter_size=3,dilation=1,stride=2,padding='SAME',name='generator_conv4',data_format='NHWC')print('conv4', conv4.shape)conv4 = fluid.layers.batch_norm(conv4, momentum=0.99, epsilon=0.001)conv4 = fluid.layers.relu(conv4, name=None)# conv5conv5 = fluid.layers.conv2d(input=conv4,num_filters=256,filter_size=3,dilation=1,stride=1,padding='SAME',name='generator_conv5',data_format='NHWC')print('conv5', conv5.shape)conv5 = fluid.layers.batch_norm(conv5, momentum=0.99, epsilon=0.001)conv5 = fluid.layers.relu(conv5, name=None)# conv6conv6 = fluid.layers.conv2d(input=conv5,num_filters=256,filter_size=3,dilation=1,stride=1,padding='SAME',name='generator_conv6',data_format='NHWC')print('conv6', conv6.shape)conv6 = fluid.layers.batch_norm(conv6, momentum=0.99, epsilon=0.001)conv6 = fluid.layers.relu(conv6, name=None)# 空洞卷积# dilated1dilated1 = fluid.layers.conv2d(input=conv6,num_filters=256,filter_size=3,dilation=2,padding='SAME',name='generator_dilated1',data_format='NHWC')print('dilated1', dilated1.shape)dilated1 = fluid.layers.batch_norm(dilated1, momentum=0.99, epsilon=0.001)dilated1 = fluid.layers.relu(dilated1, name=None)# dilated2dilated2 = fluid.layers.conv2d(input=dilated1,num_filters=256,filter_size=3,dilation=4,padding='SAME',name='generator_dilated2',data_format='NHWC') #stride=1print('dilated2', dilated2.shape)dilated2 = fluid.layers.batch_norm(dilated2, momentum=0.99, epsilon=0.001)dilated2 = fluid.layers.relu(dilated2, name=None)# dilated3dilated3 = fluid.layers.conv2d(input=dilated2,num_filters=256,filter_size=3,dilation=8,padding='SAME',name='generator_dilated3',data_format='NHWC')print('dilated3', dilated3.shape)dilated3 = fluid.layers.batch_norm(dilated3, momentum=0.99, epsilon=0.001)dilated3 = fluid.layers.relu(dilated3, name=None)# dilated4dilated4 = fluid.layers.conv2d(input=dilated3,num_filters=256,filter_size=3,dilation=16,padding='SAME',name='generator_dilated4',data_format='NHWC')print('dilated4', dilated4.shape)dilated4 = fluid.layers.batch_norm(dilated4, momentum=0.99, epsilon=0.001)dilated4 = fluid.layers.relu(dilated4, name=None)# conv7conv7 = fluid.layers.conv2d(input=dilated4,num_filters=256,filter_size=3,dilation=1,name='generator_conv7',data_format='NHWC')print('conv7', conv7.shape)conv7 = fluid.layers.batch_norm(conv7, momentum=0.99, epsilon=0.001)conv7 = fluid.layers.relu(conv7, name=None)# conv8conv8 = fluid.layers.conv2d(input=conv7,num_filters=256,filter_size=3,dilation=1,stride=1,padding='SAME',name='generator_conv8',data_format='NHWC')print('conv8', conv8.shape)conv8 = fluid.layers.batch_norm(conv8, momentum=0.99, epsilon=0.001)conv8 = fluid.layers.relu(conv8, name=None)# deconv1deconv1 = fluid.layers.conv2d_transpose(input=conv8, num_filters=128, output_size=[64,64],stride = 2,name='generator_deconv1',data_format='NHWC')print('deconv1', deconv1.shape)deconv1 = fluid.layers.batch_norm(deconv1, momentum=0.99, epsilon=0.001)deconv1 = fluid.layers.relu(deconv1, name=None)# conv9conv9 = fluid.layers.conv2d(input=deconv1,num_filters=128,filter_size=3,dilation=1,stride=1,padding='SAME',name='generator_conv9',data_format='NHWC')print('conv9', conv9.shape)conv9 = fluid.layers.batch_norm(conv9, momentum=0.99, epsilon=0.001)conv9 = fluid.layers.relu(conv9, name=None)# deconv2deconv2 = fluid.layers.conv2d_transpose(input=conv9, num_filters=64, output_size=[128,128],stride = 2,name='generator_deconv2',data_format='NHWC')print('deconv2', deconv2.shape)deconv2 = fluid.layers.batch_norm(deconv2, momentum=0.99, epsilon=0.001)deconv2 = fluid.layers.relu(deconv2, name=None)# conv10conv10 = fluid.layers.conv2d(input=deconv2,num_filters=32,filter_size=3,dilation=1,stride=1,padding='SAME',name='generator_conv10',data_format='NHWC')print('conv10', conv10.shape)conv10 = fluid.layers.batch_norm(conv10, momentum=0.99, epsilon=0.001)conv10 = fluid.layers.relu(conv10, name=None)# conv11x = fluid.layers.conv2d(input=conv10,num_filters=3,filter_size=3,dilation=1,stride=1,padding='SAME',name='generator_conv11',data_format='NHWC')print('x', x.shape)x = fluid.layers.tanh(x)return xdef discriminator(global_x, local_x):def global_discriminator(x):# conv1conv1 = fluid.layers.conv2d(input=x,num_filters=64,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_global_conv1',data_format='NHWC')print('conv1', conv1.shape)conv1 = fluid.layers.batch_norm(conv1, momentum=0.99, epsilon=0.001)conv1 = fluid.layers.relu(conv1, name=None)# conv2conv2 = fluid.layers.conv2d(input=conv1,num_filters=128,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_global_conv2',data_format='NHWC')print('conv2', conv2.shape)conv2 = fluid.layers.batch_norm(conv2, momentum=0.99, epsilon=0.001)conv2 = fluid.layers.relu(conv2, name=None)# conv3conv3 = fluid.layers.conv2d(input=conv2,num_filters=256,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_global_conv3',data_format='NHWC')print('conv3', conv3.shape)conv3 = fluid.layers.batch_norm(conv3, momentum=0.99, epsilon=0.001)conv3 = fluid.layers.relu(conv3, name=None)# conv4conv4 = fluid.layers.conv2d(input=conv3,num_filters=512,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_global_conv4',data_format='NHWC')print('conv4', conv4.shape)conv4 = fluid.layers.batch_norm(conv4, momentum=0.99, epsilon=0.001)conv4 = fluid.layers.relu(conv4, name=None)# conv5conv5 = fluid.layers.conv2d(input=conv4,num_filters=512,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_global_conv5',data_format='NHWC')print('conv5', conv5.shape)conv5 = fluid.layers.batch_norm(conv5, momentum=0.99, epsilon=0.001)conv5 = fluid.layers.relu(conv5, name=None)# conv6conv6 = fluid.layers.conv2d(input=conv5,num_filters=512,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_global_conv6',data_format='NHWC')print('conv6', conv6.shape)conv6 = fluid.layers.batch_norm(conv6, momentum=0.99, epsilon=0.001)conv6 = fluid.layers.relu(conv6, name=None)# fcx = fluid.layers.fc(input=conv6, size=1024,name='discriminator_global_fc1')return xdef local_discriminator(x):# conv1conv1 = fluid.layers.conv2d(input=x,num_filters=64,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_lobal_conv1',data_format='NHWC')print('conv1', conv1.shape)conv1 = fluid.layers.batch_norm(conv1, momentum=0.99, epsilon=0.001)conv1 = fluid.layers.relu(conv1, name=None)# conv2conv2 = fluid.layers.conv2d(input=conv1,num_filters=128,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_lobal_conv2',data_format='NHWC')print('conv2', conv2.shape)conv2 = fluid.layers.batch_norm(conv2, momentum=0.99, epsilon=0.001)conv2 = fluid.layers.relu(conv2, name=None)# conv3conv3 = fluid.layers.conv2d(input=conv2,num_filters=256,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_lobal_conv3',data_format='NHWC')print('conv3', conv3.shape)conv3 = fluid.layers.batch_norm(conv3, momentum=0.99, epsilon=0.001)conv3 = fluid.layers.relu(conv3, name=None)# conv4conv4 = fluid.layers.conv2d(input=conv3,num_filters=512,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_lobal_conv4',data_format='NHWC')print('conv4', conv4.shape)conv4 = fluid.layers.batch_norm(conv4, momentum=0.99, epsilon=0.001)conv4 = fluid.layers.relu(conv4, name=None)# conv5conv5 = fluid.layers.conv2d(input=conv4,num_filters=512,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_lobal_conv5',data_format='NHWC')print('conv5', conv5.shape)conv5 = fluid.layers.batch_norm(conv5, momentum=0.99, epsilon=0.001)conv5 = fluid.layers.relu(conv5, name=None)# fcx = fluid.layers.fc(input=conv5, size=1024,name='discriminator_lobal_fc1')return xglobal_output = global_discriminator(global_x)local_output = local_discriminator(local_x)print('global_output',global_output.shape)print('local_output',local_output.shape)output = fluid.layers.concat([global_output, local_output], axis=1)output = fluid.layers.fc(output, size=1,name='discriminator_concatenation_fc1')return output
  • 定义损失函数
# 定义损失函数
def calc_g_loss(x, completion):loss = L2_loss(x, completion)return fluid.layers.reduce_mean(loss)def calc_d_loss(real, fake):alpha = 0.1d_loss_real = fluid.layers.reduce_mean(fluid.layers.sigmoid_cross_entropy_with_logits(x=real, label=fluid.layers.ones_like(real)))d_loss_fake = fluid.layers.reduce_mean(fluid.layers.sigmoid_cross_entropy_with_logits(x=fake, label=fluid.layers.zeros_like(fake)))return fluid.layers.elementwise_add(d_loss_real, d_loss_fake) * alpha
  • 定义训练参数
# 参数定义
IMAGE_SIZE=128
LOCAL_SIZE=64
HOLE_MIN=24
HOLE_MAX=48
LEARNING_RATE=1e-3
BATCH_SIZE=64
use_gpu=True
  • 定义Program
d_program = fluid.Program()
dg_program = fluid.Program()# 定义判别器的program
with fluid.program_guard(d_program):# 原始数据x = fluid.layers.data(name='x',shape=[IMAGE_SIZE, IMAGE_SIZE, 3],dtype='float32')# 指定填充 1为洞mask = fluid.layers.data(name='mask',shape=[IMAGE_SIZE, IMAGE_SIZE, 1],dtype='float32')# 全局生成图global_completion = fluid.layers.data(name='global_completion',shape=[IMAGE_SIZE, IMAGE_SIZE, 3],dtype='float32')# 局部生成图local_completion = fluid.layers.data(name='local_completion',shape=[LOCAL_SIZE, LOCAL_SIZE, 3],dtype='float32')# 局部原图local_x = fluid.layers.data(name='local_x',shape=[LOCAL_SIZE, LOCAL_SIZE, 3],dtype='float32')# 真实图fcreal = discriminator(x, local_x)# 生成图fcfake = discriminator(global_completion, local_completion)# 计算生成图片被判别为真实样本的lossd_loss = calc_d_loss(real, fake)# 定义判别生成图片的program
with fluid.program_guard(dg_program):# 原始数据x = fluid.layers.data(name='x',shape=[IMAGE_SIZE, IMAGE_SIZE, 3],dtype='float32')# 指定填充 1为洞mask = fluid.layers.data(name='mask',shape=[IMAGE_SIZE, IMAGE_SIZE, 1],dtype='float32')# 对原始数据挖空洞传入网络input_data = x * (1 - mask)#print('input_data',input_data)imitation = generator(input_data)# 修复完的图只保留空洞的部分和原图拼接completion = imitation * mask + x * (1 - mask)g_program = dg_program.clone()g_program_test = dg_program.clone(for_test=True)# 得到原图和修复图片的lossdg_loss = calc_g_loss(x, completion)print('g_loss_shape:',dg_loss.shape)
  • 设定优化器
opt = fluid.optimizer.Adam(learning_rate=LEARNING_RATE)
opt.minimize(loss=d_loss)
parameters = [p.name for p in g_program.global_block().all_parameters()]
opt.minimize(loss=dg_loss, parameter_list=parameters)
  • 对数据集进行标准化操作
# 数据集标准化
x_train, x_test = load()
#print (x_train.shape)
x_train = np.array([a / 127.5 - 1 for a in x_train])
#print (x_train[0])
x_test = np.array([a / 127.5 - 1 for a in x_test])
  • 初始化
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
# 进行参数初始化
exe.run(fluid.default_startup_program()) 
  • 开始训练
# 生成器优先迭代次数
NUM_TRAIN_TIMES_OF_DG = 100
# 总迭代轮次
epoch = 200step_num = int(len(x_train) / BATCH_SIZE)np.random.shuffle(x_train)for pass_id in range(epoch):# 训练生成器if pass_id <= NUM_TRAIN_TIMES_OF_DG:g_loss_value = 0for i in tqdm.tqdm(range(step_num)):x_batch = x_train[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]points_batch, mask_batch = get_points()# print(x_batch.shape)# print(mask_batch.shape)dg_loss_n = exe.run(dg_program,feed={'x': x_batch, 'mask':mask_batch,},fetch_list=[dg_loss])[0]g_loss_value += dg_loss_nprint('Pass_id:{}, Completion loss: {}'.format(pass_id, g_loss_value))np.random.shuffle(x_test)x_batch = x_test[:BATCH_SIZE]completion_n = exe.run(dg_program, feed={'x': x_batch, 'mask': mask_batch,},fetch_list=[completion])[0][0]# 修复图片sample = np.array((completion_n + 1) * 127.5, dtype=np.uint8)# 原图x_im = np.array((x_batch[0] + 1) * 127.5, dtype=np.uint8)# 挖空洞输入图input_im_data = x_im * (1 - mask_batch[0])input_im = np.array(input_im_data + np.ones_like(x_im) * mask_batch[0] * 255, dtype=np.uint8)output_im = np.concatenate((x_im,input_im,sample),axis=1)#print(output_im.shape)cv2.imwrite('./output/pass_id:{}.jpg'.format(pass_id), cv2.cvtColor(output_im, cv2.COLOR_RGB2BGR))# 保存模型save_pretrain_model_path = 'models/'# 创建保持模型文件目录#os.makedirs(save_pretrain_model_path)fluid.io.save_params(executor=exe, dirname=save_pretrain_model_path, main_program=dg_program)# 生成器判断器一起训练else:g_loss_value = 0d_loss_value = 0for i in tqdm.tqdm(range(step_num)):x_batch = x_train[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]points_batch, mask_batch = get_points()dg_loss_n = exe.run(dg_program,feed={'x': x_batch, 'mask':mask_batch,},fetch_list=[dg_loss])[0]g_loss_value += dg_loss_ncompletion_n = exe.run(dg_program, feed={'x': x_batch, 'mask': mask_batch,},fetch_list=[completion])[0]local_x_batch = []local_completion_batch = []for i in range(BATCH_SIZE):x1, y1, x2, y2 = points_batch[i]local_x_batch.append(x_batch[i][y1:y2, x1:x2, :])local_completion_batch.append(completion_n[i][y1:y2, x1:x2, :])local_x_batch = np.array(local_x_batch)local_completion_batch = np.array(local_completion_batch)d_loss_n  = exe.run(d_program,feed={'x': x_batch, 'mask': mask_batch, 'local_x': local_x_batch, 'global_completion': completion_n, 'local_completion': local_completion_batch},fetch_list=[d_loss])[0]d_loss_value += d_loss_nprint('Pass_id:{}, Completion loss: {}'.format(pass_id, g_loss_value))print('Pass_id:{}, Discriminator loss: {}'.format(pass_id, d_loss_value))np.random.shuffle(x_test)x_batch = x_test[:BATCH_SIZE]completion_n = exe.run(dg_program, feed={'x': x_batch, 'mask': mask_batch,},fetch_list=[completion])[0][0]# 修复图片sample = np.array((completion_n + 1) * 127.5, dtype=np.uint8)# 原图x_im = np.array((x_batch[0] + 1) * 127.5, dtype=np.uint8)# 挖空洞输入图input_im_data = x_im * (1 - mask_batch[0])input_im = np.array(input_im_data + np.ones_like(x_im) * mask_batch[0] * 255, dtype=np.uint8)output_im = np.concatenate((x_im,input_im,sample),axis=1)#print(output_im.shape)cv2.imwrite('./output/pass_id:{}.jpg'.format(pass_id), cv2.cvtColor(output_im, cv2.COLOR_RGB2BGR))# 保存模型save_pretrain_model_path = 'models/'# 创建保持模型文件目录#os.makedirs(save_pretrain_model_path)fluid.io.save_params(executor=exe, dirname=save_pretrain_model_path, main_program = dg_program)
  • 测试代码
IMAGE_SIZE = 128
LOCAL_SIZE = 64
HOLE_MIN = 24
HOLE_MAX = 48
BATCH_SIZE = 64
PRETRAIN_EPOCH = 100test_npy = './npy/x_test.npy'def test():# 原始数据x = fluid.layers.data(name='x',shape=[IMAGE_SIZE, IMAGE_SIZE, 3],dtype='float32')# 指定填充 1为洞mask = fluid.layers.data(name='mask',shape=[IMAGE_SIZE, IMAGE_SIZE, 1],dtype='float32')# 全局生成图global_completion = fluid.layers.data(name='global_completion',shape=[IMAGE_SIZE, IMAGE_SIZE, 3],dtype='float32')# 局部生成图local_completion = fluid.layers.data(name='local_completion',shape=[LOCAL_SIZE, LOCAL_SIZE, 3],dtype='float32')# 局部原图local_x = fluid.layers.data(name='local_x',shape=[LOCAL_SIZE, LOCAL_SIZE, 3],dtype='float32')place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()exe = fluid.Executor(place)# 进行参数初始化exe.run(fluid.default_startup_program())# 获取训练和测试程序test_program = fluid.default_main_program().clone(for_test=True)# 加载模型save_pretrain_model_path = 'models/'fluid.io.load_params(executor=exe, dirname=save_pretrain_model_path, main_program = dg_program)x_test = np.load(test_npy)np.random.shuffle(x_test)x_test = np.array([a / 127.5 - 1 for a in x_test])print (len(x_test))step_num = int(len(x_test) / BATCH_SIZE)cnt = 0for i in tqdm.tqdm(range(step_num)):x_batch = x_test[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]_, mask_batch = get_points()completion_test = exe.run(dg_program, feed={'x': x_batch, 'mask': mask_batch,},fetch_list=[completion])[0]for i in range(BATCH_SIZE):cnt += 1raw = x_batch[i]raw = np.array((raw + 1) * 127.5, dtype=np.uint8)masked = np.array(raw * (1 - mask_batch[i]) + np.ones_like(raw) * mask_batch[i] * 255 , dtype=np.uint8)img = completion_test[i]img = np.array((img + 1) * 127.5, dtype=np.uint8)# print(masked.shape)# print(img.shape)# print(raw.shape)dst = './output_test/{}.jpg'.format("{0:06d}".format(cnt))output_image([['Input', masked], ['Output', img], ['Ground Truth', raw]], dst)def get_points():points = []mask = []for i in range(BATCH_SIZE):x1, y1 = np.random.randint(0, IMAGE_SIZE - LOCAL_SIZE + 1, 2)x2, y2 = np.array([x1, y1]) + LOCAL_SIZEpoints.append([x1, y1, x2, y2])w, h = np.random.randint(HOLE_MIN, HOLE_MAX + 1, 2)p1 = x1 + np.random.randint(0, LOCAL_SIZE - w)q1 = y1 + np.random.randint(0, LOCAL_SIZE - h)p2 = p1 + wq2 = q1 + hm = np.zeros((IMAGE_SIZE, IMAGE_SIZE, 1), dtype=np.float32)m[q1:q2 + 1, p1:p2 + 1] = 1mask.append(m)return np.array(points), np.array(mask)def output_image(images, dst):fig = plt.figure()for i, image in enumerate(images):text, img = imagefig.add_subplot(1, 3, i + 1)plt.imshow(img)plt.tick_params(labelbottom='off')plt.tick_params(labelleft='off')plt.gca().get_xaxis().set_ticks_position('none')plt.gca().get_yaxis().set_ticks_position('none')plt.xlabel(text)plt.savefig(dst)plt.close()if __name__ == '__main__':test()

效果展示

在这里插入图片描述在这里插入图片描述

  相关解决方案