一、网络结构
该篇论文的DiCNN1 网络结构图如下:
二、代码实现
2.1、基于TensorFlow1.14.0实现核心代码
########## DiCNN1Net structures ################
def DiCNN1Net(lms, pan, num_spectral = 8, num_res = 4, num_fm = 32, reuse=False):weight_decay = 1e-5with tf.variable_scope('net'): if reuse:tf.get_variable_scope().reuse_variables()ms_1 = tf.concat([lms,pan],axis=3)rs = ly.conv2d(ms_1, num_outputs = num_fm, kernel_size = 3, stride = 1,weights_regularizer = ly.l2_regularizer(weight_decay), weights_initializer = ly.variance_scaling_initializer(),activation_fn = tf.nn.relu)rs = ly.conv2d(rs, num_outputs = num_fm, kernel_size = 3, stride = 1,weights_regularizer = ly.l2_regularizer(weight_decay),weights_initializer = ly.variance_scaling_initializer(),activation_fn = tf.nn.relu)rs = ly.conv2d(rs, num_outputs = num_spectral, kernel_size = 3, stride = 1,weights_regularizer = ly.l2_regularizer(weight_decay),weights_initializer = ly.variance_scaling_initializer(),activation_fn = None)rs = tf.add(rs,lms)return rs
2.2、基于pytorch1.7.1实现核心代码
########## DiCNN1Net structures ################
class DiCNN1Net(nn.Module):def __init__(self):super(PanNet, self).__init__()channel = 32spectral_num = 8self.conv1 = nn.Conv2d(in_channels=spectral_num + 1, out_channels=channel, kernel_size=3, stride=1, padding=1,bias=True) #输入9 输出32self.conv2 = nn.Conv2d(in_channels=channel, out_channels=channel, kernel_size=3, stride=1, padding=1,bias=True) #输入32 输出32self.conv3 = nn.Conv2d(in_channels=channel, out_channels=spectral_num, kernel_size=3, stride=1, padding=1,bias=True) #输入32 输出8self.relu = nn.ReLU(inplace=True)# init_weights(self.conv1, self.conv2, self.conv3) # state initialization, important!def forward(self, x, y): # x= lms; y = pan 该函数才是网络的流程!input = torch.cat([x, y], 1) # Bsx9x64x64 channel在第1个位置rs1 = self.relu(self.conv1(input)) # Bsx32x64x64rs2 = self.relu(self.conv2(rs1)) # Bsx32x64x64rs3 = self.conv3(rs2) # Bsx8x64x64output = torch.add(x, rs3) # Bsx8x64x64return output
卷积层初始化(init_weights),可以先不使用,torch会默认完成对应的初始化(默认是kaiming init),这里PanNet使用是因为 需要如此调试,一般默认的就够用,不用增加麻烦。当然在实际中,也有可能更改初始化,在引进此处的代码即可。
2.3、MATLAB图像显示代码
训练好模型后,测试输出的是一个.mat格式的文件,该文件是一个256x256x8的图像数据,在日常生活中我们一般使用的图像数据是RGB三个通道的,仿照程序已有的代码,在八个通道中抽出三个通道作为RGB…
def vis_ms(data):_,b,g,_,r,_,_,_ = tf.split(data,8,axis = 3)vis = tf.concat([r,g,b],axis = 3)return vis
由已有代码可知,R通道是该数据的第5个通道,G通道是该数据的第3个通道,B通道是该数据的第2个通道。有了RGB三通道的图像数据,我们就很方便的可以显示图像了,matlab代码如下所示:
clc;
clear;
close all;load('output.mat');
img_r = output(:,:,5);
img_g = output(:,:,3);
img_b = output(:,:,2);
img(:,:,1) = img_r;
img(:,:,2) = img_g;
img(:,:,3) = img_b;
imshow(img)subplot(1,4,1),imshow(img);title(['RGB图']);
subplot(1,4,2),imshow(img_r );title(['R图']);
subplot(1,4,3),imshow(img_g );title(['G图']);
subplot(1,4,4),imshow(img_b );title(['B图']);
参考
- MATLAB数字图像处理提取颜色分量