当前位置: 代码迷 >> 综合 >> 【AI论文复现】PanNet: A Deep Network Architecture for Pan-Sharpening(基于TensorFlow1.x)
  详细解决方案

【AI论文复现】PanNet: A Deep Network Architecture for Pan-Sharpening(基于TensorFlow1.x)

热度:30   发布时间:2023-12-13 01:41:03.0

《PanNet: A Deep Network Architecture for Pan-Sharpening》ICCV 2017

前言

深度学习函数包

  • MatconvNet
  • Caffee
  • TensorFlow1.0
  • TensorFlow2.0
  • PyTorch

Tensor -> 张量,即数据。Flow ->流。TensorFlow -> 数据在网络中流动…

用python配置GPU比较简单,一般代码现在CPU上测试好之后再在GPU上跑。

PanNet(卷积神经网络) <-> Pan-Sharpening(遥感图像融合)

pan图 + 低分辨率的多光谱图进行融合 = 高空间分辨率的多光谱图

一、环境配置

软件包安装

pip install tensorflow==1.15.0
pip install opencv-python
pip install scipy
pip install opencv-python
pip install h5py
pip install scipy

TensorFlow 2.0 中所有 tf.contrib 被弃用,所以要注意如果你复现的论文代码包含 tf.contrib,那么就要安装 TensorFlow 2.0 以下的版本!

二、论文精读

2.1、论文及代码获取

论文获取可以到这个网站:https://www.paperswithcode.com/

代码可以到github取,或者从这里下载…

2.2、论文精读

在这里插入图片描述

遥感图像融合(Pan-Sharpening):

卫星会从天上拍一个 全摄图(PAN图 64x64x1),同时还会拍一张 低空间分辨率多光谱图 (LRMS图 16x16x8),两者融合得到一个 高空间分辨率多光谱图 (HRMS图 64x64x8)

在这里插入图片描述

图像直接上采样会让图像变的模糊。

图像注释8个波段,但是实际4个波段,无所谓,示意图嘛,理解就好。

Training Data:PAN图、LRMS图、GT图

卷积神经网络(CNN):

在这里插入图片描述

上图等效于:

在这里插入图片描述

残差网络:

在这里插入图片描述

残差诸如高频等东西。

残差网络提出者:何凯明 CVPR 2016

2.3、代码精读

数据分两种:

  • 训练数据
    • train.mat【类型是4维张量tensor型的】
      • pan:100x64x64x1(100个数据)
      • ms:100x16x16x8(100个数据)
      • gt:100x64x64x8(100个数据)
      • lms(ms上采样4倍):100x64x64x8(100个数据)
    • validation.mat(调参用的)
  • 测试数据
    • pan
    • ms

train_data

在这里插入图片描述

test_data
在这里插入图片描述

batch_size = 32,意思是随机从100里面选出32个,即

  • pan:32x64x64x1(32个数据)
  • ms:32x16x16x8(32个数据)
  • gt:32x64x64x8(32个数据)
  • lms(ms上采样4倍):32x64x64x8(32个数据)

分成batch_size算,效率更高,效果更好

# num_fm = 32即表示卷积的kernel的个数是32个
# stride = 1 即下一个像素的位移是1
rs = ly.conv2d(ms, num_outputs = num_fm, kernel_size = 3, stride = 1, weights_regularizer = ly.l2_regularizer(weight_decay), weights_initializer = ly.variance_scaling_initializer(),activation_fn = tf.nn.relu)   # 32x 64 x 64 x32

含义如下:
在这里插入图片描述
如果stride = 2,隔1个卷,那个最终输出32x32x32
stride=4,,最终输出(64/4)x(64/4)x32

残差网络

for i in range(num_res):   # ResNetrs1 = ly.conv2d(rs, num_outputs = num_fm, kernel_size = 3, stride = 1, weights_regularizer = ly.l2_regularizer(weight_decay), weights_initializer = ly.variance_scaling_initializer(),activation_fn = tf.nn.relu) # 32 x 64 x64 x32 #ResNet的第一个卷积+relu(非线性函数)rs1 = ly.conv2d(rs1, num_outputs = num_fm, kernel_size = 3, stride = 1, weights_regularizer = ly.l2_regularizer(weight_decay), weights_initializer = ly.variance_scaling_initializer(),activation_fn = None) # 32 x 64 x64 x32 #RestNet的第二个卷积,没有relurs = tf.add(rs,rs1)   # ResNet: 32 x 64 x64 x32 #输出要与x相加,对照ResNet图

对比下图理解:
在这里插入图片描述

1个ResNet的block,对应两层卷积;共有 num_res=4 个 block!

PANNET网络一共10个卷积层(1 +(2x4)+ 1)= 10

知道了loss现在用 Adam or SGD 来算θ

测试数据的data
在这里插入图片描述

注意其中的pan,变成了256x256,而ms是64x64x8

test可以测任意大小的数据,不仅限于16x16x8


附录

train.py

#!/usr/bin/env python2
# -*- coding: utf-8 -*-""" # This is a re-implementation of training code of this paper: # J. Yang, X. Fu, Y. Hu, Y. Huang, X. Ding, J. Paisley. "PanNet: A deep network architecture for pan-sharpening", ICCV,2017. # author: Junfeng Yang """import tensorflow as tf
import numpy as np # 科学计算 数组值
import cv2 #opencv
import tensorflow.contrib.layers as ly #TensorFlow1.x 中计算卷积的
import os
import h5py
import scipy.io as sio # 读矩阵存矩矩阵
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # 为0表示调用第一张GPU# get high-frequency (high-pass)
# 高通滤波
def get_edge(data):  rs = np.zeros_like(data)N = data.shape[0]for i in range(N):if len(data.shape)==3:rs[i,:,:] = data[i,:,:] - cv2.boxFilter(data[i,:,:],-1,(5,5)) # 数据 - 低频信息else:rs[i,:,:,:] = data[i,:,:,:] - cv2.boxFilter(data[i,:,:,:],-1,(5,5))return rs# get training patches
def get_batch(train_data,bs): gt = train_data['gt'][...]    ## ground truth N*H*W*Cpan = train_data['pan'][...]  #### Pan image N*H*Wms_lr = train_data['ms'][...] ### low resolution MS imagelms   = train_data['lms'][...]   #### MS image interpolation to Pan scalegt = np.array(gt,dtype = np.float32) / 2047.  ### normalization, WorldView L = 11pan = np.array(pan, dtype = np.float32) /2047.ms_lr = np.array(ms_lr, dtype = np.float32) / 2047.lms  = np.array(lms, dtype = np.float32) /2047.N = gt.shape[0]batch_index = np.random.randint(0,N,size = bs)gt_batch = gt[batch_index,:,:,:]pan_batch = pan[batch_index,:,:]ms_lr_batch = ms_lr[batch_index,:,:,:]lms_batch  = lms[batch_index,:,:,:]pan_hp_batch = get_edge(pan_batch)pan_hp_batch = pan_hp_batch[:,:,:,np.newaxis] # expand to N*H*W*1ms_hp_batch = get_edge(ms_lr_batch)return gt_batch, lms_batch, pan_hp_batch, ms_hp_batchdef vis_ms(data): # 显示数据 数据是8个channel,现实生活中显示RGB,3个channel_,b,g,_,r,_,_,_ = tf.split(data,8,axis = 3)vis = tf.concat([r,g,b],axis = 3)return vis########## PanNet structures ################ # 核心
def PanNet(ms, pan, num_spectral = 8, num_res = 4, num_fm = 32, reuse=False):weight_decay = 1e-5 #做训练时的1个参数,不深究#with tf.device('/gpu:0'):with tf.variable_scope('net'):        if reuse:tf.get_variable_scope().reuse_variables()#ms本来是32x16x16x8,下面该语句中的4表示上采样4倍,即32x64x64x8ms = ly.conv2d_transpose(ms,num_spectral,8,4,activation_fn = None,   # 32 x 64 x64 x8weights_initializer = ly.variance_scaling_initializer(), weights_regularizer = ly.l2_regularizer(weight_decay))ms = tf.concat([ms,pan],axis=3)  # ms + pan: put together (concat) : 32 x 64 x64 x9 axis从0开始数,故此处的3表示第4维度# num_fm = 32即表示卷积的kernel的个数是32个 kernel: 3x3x32# 进 ResNet之前先卷积一下rs = ly.conv2d(ms, num_outputs = num_fm, kernel_size = 3, stride = 1,weights_regularizer = ly.l2_regularizer(weight_decay),weights_initializer = ly.variance_scaling_initializer(),activation_fn = tf.nn.relu)   # 32x 64 x 64 x32for i in range(num_res):   # ResNet# kernel: 3x3x32rs1 = ly.conv2d(rs, num_outputs = num_fm, kernel_size = 3, stride = 1, weights_regularizer = ly.l2_regularizer(weight_decay), weights_initializer = ly.variance_scaling_initializer(),activation_fn = tf.nn.relu) # 32 x 64 x64 x32 #ResNet的第一个卷积+relu(非线性函数)# kernel: 3x3x32rs1 = ly.conv2d(rs1, num_outputs = num_fm, kernel_size = 3, stride = 1, weights_regularizer = ly.l2_regularizer(weight_decay), weights_initializer = ly.variance_scaling_initializer(),activation_fn = None) # 32 x 64 x64 x32 #RestNet的第二个卷积,没有relurs = tf.add(rs,rs1)   # ResNet: 32 x 64 x64 x32 #输出要与x相加,对照ResNet图# kernel: 3x3x8rs = ly.conv2d(rs, num_outputs = num_spectral, kernel_size = 3, stride = 1, weights_regularizer = ly.l2_regularizer(weight_decay), weights_initializer = ly.variance_scaling_initializer(),activation_fn = None)  # 32 x 64 x64 x8return rs################################################################################################################################################################# Main Function: input data from here! (likes sub-funs in matlab before) ######if __name__ =='__main__':tf.reset_default_graph()   train_batch_size = 32 # training batch sizetest_batch_size = 32  # validation batch sizeimage_size = 64      # patch size 64x64x8这里写64即可,100x100x8这里写100即可iterations = 100100 # total number of iterations to use.model_directory = './models' # directory to save trained model to.train_data_name = './training_data/train.mat'  # training datatest_data_name  = './training_data/validation.mat'   # validation datarestore = False  # load model or notmethod = 'Adam'  # training method: Adam or SGD 最小Loss计算时用的策略,直接调用即可...############## loading datatrain_data = sio.loadmat(train_data_name)   # for small data (not v7.3 data)test_data = sio.loadmat(test_data_name)#train_data = h5py.File(train_data_name) # for large data ( v7.3 data)#test_data = h5py.File(test_data_name)############## placeholder for training ########### #placeholder 占位符,数据等后面再塞进来gt = tf.placeholder(dtype = tf.float32,shape = [train_batch_size,image_size,image_size,8]) # 32x64x64x8lms = tf.placeholder(dtype = tf.float32,shape = [train_batch_size,image_size,image_size,8])ms_hp = tf.placeholder(dtype = tf.float32,shape = [train_batch_size,image_size//4,image_size//4,8])#32x16x16x8pan_hp = tf.placeholder(dtype = tf.float32,shape = [train_batch_size,image_size,image_size,1])############# placeholder for testing ##############test_gt = tf.placeholder(dtype = tf.float32,shape = [test_batch_size,image_size,image_size,8])test_lms = tf.placeholder(dtype = tf.float32,shape = [test_batch_size,image_size,image_size,8])test_ms_hp = tf.placeholder(dtype = tf.float32,shape = [test_batch_size,image_size//4,image_size//4,8])test_pan_hp = tf.placeholder(dtype = tf.float32,shape = [test_batch_size,image_size,image_size,1])######## network architecture (call: PanNet constructed before!) ######################mrs = PanNet(ms_hp,pan_hp)    # call pannetmrs = tf.add(mrs,lms)        # 32 x64 x64 x8test_rs = PanNet(test_ms_hp,test_pan_hp,reuse = True)test_rs = test_rs + test_lms  # same as: test_rs = tf.add(test_rs,test_lms) ######## loss function ################mse = tf.reduce_mean(tf.square(mrs - gt))  # compute cost : loss 2范数test_mse = tf.reduce_mean(tf.square(test_rs - test_gt))##### Loss summary (for observation) ################ 为了显示用的,直接copy即可,注意参数,所以数据归一化到0和1 之间了mse_loss_sum = tf.summary.scalar("mse_loss",mse)test_mse_sum = tf.summary.scalar("test_loss",test_mse)lms_sum = tf.summary.image("lms",tf.clip_by_value(vis_ms(lms),0,1))mrs_sum = tf.summary.image("rs",tf.clip_by_value(vis_ms(mrs),0,1))label_sum = tf.summary.image("label",tf.clip_by_value(vis_ms(gt),0,1))all_sum = tf.summary.merge([mse_loss_sum,mrs_sum,label_sum,lms_sum])############ optimizer: Adam or SGD ################## 知道了loss现在用 Adam or SGD 来算θ,copy即可t_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope = 'net')    if method == 'Adam':g_optim = tf.train.AdamOptimizer(0.001, beta1 = 0.9) \.minimize(mse, var_list=t_vars)else:global_steps = tf.Variable(0,trainable = False)lr = tf.train.exponential_decay(0.1,global_steps,decay_steps = 50000, decay_rate = 0.1)clip_value = 0.1/lroptim = tf.train.MomentumOptimizer(lr,0.9)gradient, var   = zip(*optim.compute_gradients(mse,var_list = t_vars))gradient, _ = tf.clip_by_global_norm(gradient,clip_value)g_optim = optim.apply_gradients(zip(gradient,var),global_step = global_steps)##### GPU setting copy即可config = tf.ConfigProto()config.gpu_options.allow_growth = Truesess = tf.Session(config=config)###########################################################################
###########################################################################
#### Run the above (take real data into the Net, for training) ############ Session来注入数据,让网络全部流动起来init = tf.global_variables_initializer()  # initialization: must done!saver = tf.train.Saver()with tf.Session() as sess:  sess.run(init)if restore:print ('Loading Model...')ckpt = tf.train.get_checkpoint_state(model_directory)saver.restore(sess,ckpt.model_checkpoint_path)#### read training data #####gt1 = train_data['gt'][...]  ## ground truth N*H*W*Cpan1 = train_data['pan'][...]  #### Pan image N*H*Wms_lr1 = train_data['ms'][...]  ### low resolution MS imagelms1 = train_data['lms'][...]  #### MS image interpolation to Pan scalegt1 = np.array(gt1, dtype=np.float32) / 2047.  ### [0, 1] normalization, WorldView L = 11pan1 = np.array(pan1, dtype=np.float32) / 2047.ms_lr1 = np.array(ms_lr1, dtype=np.float32) / 2047.lms1 = np.array(lms1, dtype=np.float32) / 2047.N = gt1.shape[0]#### read validation data #####gt2 = test_data['gt'][...]  ## ground truth N*H*W*Cpan2 = test_data['pan'][...]  #### Pan image N*H*Wms_lr2 = test_data['ms'][...]  ### low resolution MS imagelms2 = test_data['lms'][...]  #### MS image interpolation -to Pan scalegt2 = np.array(gt2, dtype=np.float32) / 2047.  ### normalization, WorldView L = 11pan2 = np.array(pan2, dtype=np.float32) / 2047.ms_lr2 = np.array(ms_lr2, dtype=np.float32) / 2047.lms2 = np.array(lms2, dtype=np.float32) / 2047.N2 = gt2.shape[0]mse_train = [] # mse误差,一会画误差图用的mse_valid = []for i in range(iterations): # 进入训练阶段####################################################################### training phase! ###########################bs = train_batch_sizebatch_index = np.random.randint(0, N, size=bs)  # N = 100; choose bs = 32 100各种随机选32个train_gt = gt1[batch_index, :, :, :]pan_batch = pan1[batch_index, :, :]ms_lr_batch = ms_lr1[batch_index, :, :, :]train_lms = lms1[batch_index, :, :, :]pan_hp_batch = get_edge(pan_batch)   # 32x 64 x 64 高通滤波train_pan_hp = pan_hp_batch[:, :, :, np.newaxis]  # expand to N*H*W*1: 32 x64 x64 x1 扩展成4维train_ms_hp = get_edge(ms_lr_batch) # 32 x16 x16 x8#train_gt, train_lms, train_pan_hp, train_ms_hp = get_batch(train_data, bs = train_batch_size)# 数据在网络中跑起来,数据赋给占位符 第一个参数占位符 第二个参数是load的数据_,mse_loss,merged = sess.run([g_optim,mse,all_sum],feed_dict = {
    gt: train_gt, lms: train_lms,ms_hp: train_ms_hp, pan_hp: train_pan_hp})mse_train.append(mse_loss)   # record the mse of trainning 没训练1步,存一下误差if i % 100 == 0: # 每100步打印一下loss,按道理来说loss需要一直下降print ("Iter: " + str(i) + " MSE: " + str(mse_loss))   # print, e.g.,: Iter: 0 MSE: 0.18406609if i % 5000 == 0 and i != 0: # 每5000步,存一下model,即卷积核的参数 .ckpt 格式if not os.path.exists(model_directory):os.makedirs(model_directory)saver.save(sess,model_directory+'/model-'+str(i)+'.ckpt')print ("Save Model")####################################################################### validation phase! ###########################bs_test = test_batch_sizebatch_index2 = np.random.randint(0, N, size=bs_test)test_gt_batch = gt2[batch_index2, :, :, :]test_lms_batch = lms2[batch_index2, :, :, :]ms_lr_batch = ms_lr2[batch_index2, :, :, :]test_ms_hp_batch = get_edge(ms_lr_batch)pan_batch = pan2[batch_index2, :, :]pan_hp_batch = get_edge(pan_batch)test_pan_hp_batch = pan_hp_batch[:, :, :, np.newaxis]  # expand to N*H*W*1'''if i%1000 == 0 and i!=0: # after 1000 iteration, re-set: get_batchtest_gt_batch, test_lms_batch, test_pan_hp_batch, test_ms_hp_batch = get_batch(test_data, bs = test_batch_size)'''test_mse_loss,merged = sess.run([test_mse,test_mse_sum],feed_dict = {
    test_gt : test_gt_batch, test_lms : test_lms_batch,test_ms_hp : test_ms_hp_batch, test_pan_hp : test_pan_hp_batch})mse_valid.append(test_mse_loss)  # record the mse of trainningif i % 1000 == 0 and i != 0: # 每1000步打印一下误差print("Iter: " + str(i) + " Valid MSE: " + str(test_mse_loss))  # print, e.g.,: Iter: 0 MSE: 0.18406609## finally write the mse info ##file = open('train_mse.txt','w')  # write the training error into train_mse.txt 误差存起来file.write(str(mse_train))file.close()file = open('valid_mse.txt','w')  # write the valid error into valid_mse.txtfile.write(str(mse_valid))file.close()

test.py

#!/usr/bin/env python2
# -*- coding: utf-8 -*-""" # This is a re-implementation of training code of this paper: # J. Yang, X. Fu, Y. Hu, Y. Huang, X. Ding, J. Paisley. "PanNet: A deep network architecture for pan-sharpening", ICCV,2017. # author: Junfeng Yang"""
import tensorflow as tf
import tensorflow.contrib.layers as ly
import numpy as np
import scipy.io as sio
import cv2
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["TF_CPP_MIN_LOG_LEVEL"]='3' # 只显示 warning 和 Errordef PanNet(ms, pan, num_spectral = 8, num_res = 4, num_fm = 32, reuse=False):weight_decay = 1e-5with tf.variable_scope('net'):        if reuse:tf.get_variable_scope().reuse_variables()ms = ly.conv2d_transpose(ms,num_spectral,8,4,activation_fn = None, weights_initializer = ly.variance_scaling_initializer(), biases_initializer = None,weights_regularizer = ly.l2_regularizer(weight_decay))ms = tf.concat([ms,pan],axis=3)rs = ly.conv2d(ms, num_outputs = num_fm, kernel_size = 3, stride = 1, weights_regularizer = ly.l2_regularizer(weight_decay), weights_initializer = ly.variance_scaling_initializer(),activation_fn = tf.nn.relu)for i in range(num_res):rs1 = ly.conv2d(rs, num_outputs = num_fm, kernel_size = 3, stride = 1, weights_regularizer = ly.l2_regularizer(weight_decay), weights_initializer = ly.variance_scaling_initializer(),activation_fn = tf.nn.relu)rs1 = ly.conv2d(rs1, num_outputs = num_fm, kernel_size = 3, stride = 1, weights_regularizer = ly.l2_regularizer(weight_decay), weights_initializer = ly.variance_scaling_initializer(),activation_fn = None)rs = tf.add(rs,rs1)rs = ly.conv2d(rs, num_outputs = num_spectral, kernel_size = 3, stride = 1, weights_regularizer = ly.l2_regularizer(weight_decay), weights_initializer = ly.variance_scaling_initializer(),activation_fn = None)return rsdef get_edge(data): # get high-frequencyrs = np.zeros_like(data)if len(rs.shape) ==3:for i in range(data.shape[2]):rs[:,:,i] = data[:,:,i] -cv2.boxFilter(data[:,:,i],-1,(5,5))else:rs = data - cv2.boxFilter(data,-1,(5,5))return rs#################################################################
################# Main fucntion ################################## main之前的函数与train一模一样
if __name__=='__main__':test_data = 'new_data.mat'model_directory = './models/'tf.reset_default_graph() # 默认的东西,先不深究data = sio.loadmat(test_data)ms = data['ms'][...]      # MS image 64x64x8ms = np.array(ms,dtype = np.float32) /2047.lms = data['lms'][...]    # up-sampled LRMS image 256x256x8lms = np.array(lms, dtype = np.float32) /2047.pan  = data['pan'][...]  # PAN image 256x256pan  = np.array(pan,dtype = np.float32) /2047.ms_hp = get_edge(ms)   # high-frequency parts of MS imagems_hp = ms_hp[np.newaxis,:,:,:]  # 1x64x64x8 #补1个维数pan_hp = get_edge(pan) # high-frequency parts of PAN image: 256x256pan_hp = pan_hp[np.newaxis,:,:,np.newaxis]  # 1x256x256x1 #补2个维数h = pan.shape[0] # heightw = pan.shape[1] # widthlms   = lms[np.newaxis,:,:,:]  # 1x256x256x8 #补1个维数##### placeholder for testing#######p_hp = tf.placeholder(shape=[1,h,w,1],dtype=tf.float32)m_hp = tf.placeholder(shape=[1,h/4,w/4,8],dtype=tf.float32)lms_p = tf.placeholder(shape=[1,h,w,8],dtype=tf.float32)rs = PanNet(m_hp,p_hp) # output high-frequency parts 丢ms_hp pan_hp 到网络中,输出高分辨率结果图mrs = tf.add(rs,lms_p) output = tf.clip_by_value(mrs,0,1) # final output 大于1的变成1,小于0变成0################################################################
##################Session Run ##################################init = tf.global_variables_initializer()saver = tf.train.Saver()with tf.Session() as sess:  sess.run(init)# loading model if tf.train.get_checkpoint_state(model_directory):  ckpt = tf.train.latest_checkpoint(model_directory)saver.restore(sess, ckpt)print ("load new model")else:ckpt = tf.train.get_checkpoint_state(model_directory + "pre-trained/")saver.restore(sess,ckpt.model_checkpoint_path) # this model uses 128 feature maps and for debug only print ("load pre-trained model")                            final_output = sess.run(output,feed_dict = {
    p_hp:pan_hp, m_hp:ms_hp, lms_p:lms})# 1x256x256x8sio.savemat('./result/output.mat', {
    'output':final_output[0,:,:,:]}) #256x256x8 存成.mat格式
  相关解决方案