文章目录
- DRSN 原理
-
- 残差网络
- 自注意力网络
- 软阈值化
- 代码实现
DRSN 原理
DRSN 由三部分组成:残差网络、自注意力网络和软阈值化。
残差网络
残差网络(或称深度残差网络、深度残差学习,英文ResNet)属于一种卷积神经网络。相较于普通的卷积神经网络,残差网络采用了跨层恒等连接,以减轻卷积神经网络的训练难度。其具体说明可以参考文章:Tensorflow2.0之自定义ResNet。
自注意力网络
在 DRSN 中,SE 模块被选用作为自注意力机制中的主要部分。它可以通过一个小型的子网络,自动学习得到一组权重,对特征图的各个通道进行加权。其含义在于,某些特征通道是较为重要的,而另一些特征通道是信息冗余的;那么,我们就可以通过这种方式增强有用特征通道、削弱冗余特征通道。其具体说明可以参考文章:SENet 网络结构的原理与 Tensorflow2.0 实现。
软阈值化
软阈值化是许多信号降噪方法的核心步骤。它的用处是将绝对值低于某个阈值的特征置为零,将其他的特征也朝着零进行调整,也就是收缩。在这里,阈值是一个需要预先设置的参数,其取值大小对于降噪的结果有着直接的影响。
软阈值化的公式和导数为:
软阈值化的输入与输出之间的关系如下图所示。
从图中可以看出,软阈值化是一种非线性变换,有着与ReLU激活函数非常相似的性质:梯度要么是0,要么是1。因此,软阈值化也能够作为神经网络的激活函数。事实上,一些神经网络已经将软阈值函数作为激活函数进行了使用。
代码实现
总体来说,DRSN 的网络主体为:
根据该结构编写代码如下:
import tensorflow as tf
import numpy as npdef residual_shrinkage_block(inputs, out_channels, downsample_strides=1):in_channels = inputs.shape[-1]residual = tf.keras.layers.BatchNormalization()(inputs)residual = tf.keras.layers.Activation('relu')(residual)residual = tf.keras.layers.Conv2D(out_channels, 3, strides=(downsample_strides, downsample_strides), padding='same')(residual)residual = tf.keras.layers.BatchNormalization()(residual)residual = tf.keras.layers.Activation('relu')(residual)residual = tf.keras.layers.Conv2D(out_channels, 3, padding='same')(residual)residual_abs = tf.abs(residual)abs_mean = tf.keras.layers.GlobalAveragePooling2D()(residual_abs)scales = tf.keras.layers.Dense(out_channels, activation=None)(abs_mean)scales = tf.keras.layers.BatchNormalization()(scales)scales = tf.keras.layers.Activation('relu')(scales)scales = tf.keras.layers.Dense(out_channels, activation='sigmoid')(scales)thres = tf.keras.layers.multiply([abs_mean, scales])sub = tf.keras.layers.subtract([residual_abs, thres])zeros = tf.keras.layers.subtract([sub, sub])n_sub = tf.keras.layers.maximum([sub, zeros])residual = tf.keras.layers.multiply([tf.sign(residual), n_sub])out_channels = residual.shape[-1]if in_channels != out_channels:identity = tf.keras.layers.Conv2D(out_channels, 1, strides=(downsample_strides, downsample_strides), padding='same')(inputs)residual = tf.keras.layers.add([residual, identity])return residualinputs = np.zeros((1, 224, 224, 3), np.float32)
residual_shrinkage_block(inputs, 8).shape
TensorShape([1, 224, 224, 8])