本文探讨了一种新的深度学习算法——深度残差膨胀网络(Deep Residual Shrinkage Network),退出了笔者本人的了解。
1.深度残差膨胀网络的相干根底
从名字就可能看出,深度残差膨胀网络是深度残差网络的一种改良办法。其特色是“膨胀”,在这里指的是软阈值化,而软阈值化简直是当初信号降噪算法的必备步骤。
因而,深度残差膨胀网络是一种面向强噪数据的深度学习算法,是信号处理里的经典内容和深度学习、注意力机制的又一种联合。
深度残差膨胀网络的根本模块如下图(a)所示,通过一个小型子网络,学习失去一组阈值,而后进行特色的软阈值化。同时,该模块还退出了恒等门路,以升高模型训练难度。深度残差膨胀网络的整体构造如下图(b)所示,与通常的深度残差网络是一样的。
那么为何要进行膨胀呢?膨胀有什么益处呢?本文尝试从删除冗余特色的灵便度的角度,开展了探讨。
2.膨胀(这里指软阈值化)
不理解软阈值化的同学能够去搜一下Soft Threshlding,在谷歌学术上会搜到这一篇:DL Donoho. De-noising by soft-thresholding[J]. IEEE transactions on information theory, 1995.
De-noising by soft-thresholding这一篇论文,目前的援用次数是12893次。能够看进去,软阈值化是一种经典的办法,尤其在信号降噪方面是十分罕用的。
软阈值函数的式子如下:
其中t是阈值,是一个负数。从公式能够看出,软阈值化将[-t,t]区间内的特色置为0,将大于t的特色减t,将小于-t的特色加t。
如果用图片示意软阈值函数,就如下图所示:
3.膨胀(这里指软阈值化)与ReLU激活函数的比照
软阈值化在深度残差膨胀网络中是作为非线性映射,而当初深度学习最罕用的非线性映射是ReLU激活函数。所以上面进行了两者的比照。
3.1 独特长处
咱们首先剖析一下,膨胀(这里指软阈值化)和ReLU激活函数的独特长处。
首先,软阈值化和ReLU都能够将局部区间的特色置为0,相当于删除局部特色/信息。(可了解为,后面的层将冗余特色转换到某个取值区间,而后用软阈值化或ReLU进行删除)
其次,软阈值化和ReLU的梯度都要么为0,要么为1,都有利于梯度的反向流传。
3.2 膨胀(这里指软阈值化)与ReLU的初步比照
相较于ReLU,软阈值化可能更加灵便地设置“待删除(置为0)”的特色取值区间。
咱们首先独立地看ReLU,以下图为例。ReLU将低于0的特色,全副删除(置为0);大于0的特色,全副保留(放弃不变)。
软阈值函数呢?它将某个区间,也就是[-阈值,阈值]这一区间内的特色删除(置为零);将这个区间之外的局部,包含大于阈值和小于-阈值的局部,保留下来(尽管朝向0进行了膨胀)。下图展现了阈值t=10的状况:
在深度残差膨胀网络中,阈值是能够通过注意力机制主动设置的。也就是说,[-阈值,阈值]的区间,是能够依据样本本身状况、主动调整的。
3.3 膨胀(这里指软阈值化)与ReLU的深层比照
如果咱们把ReLU和之前(卷积层或者批标准化外面的)偏置b,放在一起看呢?那么ReLU可能删除的特色取值空间,是能够变动的。比如说,将偏置b和ReLU作为一个整体的话,函数模式就变成了max(x+b,0)或者ReLU(x+b)。当偏置b为负数的时候,特色x会沿y轴向上平移,而后再将负特色置为0。例如,当b=20的时候,如下图所示:
或者当偏置b为正数的时候,特色x会沿y轴向下平移,而后再将负特色置为0。例如,当b=-20的时候,如下图所示:
接下来,咱们来探讨软阈值函数。将偏置b和软阈值化作为一个整体的话,函数模式就变成了sign(x+b)•max(abs(x+b)-t,0)。当偏置b为负数的时候,首先特色x会沿y轴向上平移,而后再将零左近的特色置为0。例如,当偏置b=20、阈值t=10的时候,如下图所示:
当偏置b为负时,特色x会沿y轴向下平移,而后再将零左近的特色置为0。例如,当偏置b=-20、阈值t=10的时候,如下图所示:
在深度残差膨胀网络中,因为偏置b和阈值t都是能够训练失去的参数,所以当偏置b和阈值t取值适合的时候,软阈值化是能够实现与ReLU雷同的性能的。也就是,在现有的这些特色的[最小值,最大值]的范畴内(不思考无穷的状况,个别咱们采集的数据不会有无穷),将低于某个值的特色全置为0,或者将高于某个值的特色全置为0。例如,在下图的数据中,如果咱们将偏置b设置为20,将阈值t也设置为20,就将所有小于0的特色全副置为0了。因为没有小于-40的特色,所以“偏置+软阈值化”就相当于实现了ReLU的性能(将低于0的特色置为0)。
当然,因为[-阈值,阈值]区间和偏置b都是可调的,也能够是这样(b=40,t=20)(是不是和“偏置+ReLU”很类似):
然而,反过来的话,不论“偏置+ReLU”怎么组合,都无奈实现下图中软阈值函数能够实现的性能。也就是,“偏置+ReLU”无奈将某个区间内特色的置为0,并且同时保留大于上界和小于下界的特色。
从这个角度看的话,当和前一层的偏置放在一起看的时候,软阈值化比ReLU可能更加灵便地设置“待删除特色的取值区间”。
4.注意力机制的加持
更重要地,深度残差膨胀网络采纳了注意力机制(相似于Squeeze-and-Excitation Network)主动设置阈值,防止了人工设置阈值的麻烦。(人工设置阈值始终是一个大麻烦,而深度残差膨胀网络用注意力机制解决了这个大麻烦)。
在注意力机制中,深度残差膨胀网络采纳了非凡的网络结构,保障了阈值不仅为负数,而且不会太大。因为如果阈值过大的话,就可能呈现下图的状况,也就是所有特色都被置为0了。深度残差膨胀网络的阈值,其实是(特色图的绝对值的平均值)×(0到1之间的系数),很好地防止了阈值太大的状况。
同时,深度残差膨胀网络的阈值,是在注意力机制下,依据每个样本的状况,独自设置的。也就是,每个样本,都有本人的一组独特的阈值。因而,深度残差膨胀网络实用于各个样本中噪声含量不同的状况。
5.深度残差膨胀网络只实用于强噪声的数据吗?
咱们在应用深度残差膨胀网络的时候,仿佛不须要思考数据中是否真的含有很多噪声。换言之,深度残差膨胀网络应该能够用于弱噪声的数据。
这是因为,深度残差膨胀网络中的阈值,是依据样本本身的状况,通过一个小型子网络主动取得的。如果样本所含噪声很少,那么阈值能够被主动设置得很低(靠近于0),从而“软阈值化”就进化成了“间接相等”。在这种状况下,软阈值化,就相当于不存在了。
6.恒等连贯升高了训练难度
相较于一般的残差网络,深度残差膨胀网络的构造较为简单,所以恒等门路是有必要存在的。
7. 论文网址
M. Zhao, S. Zhong, X. Fu, et al., Deep residual shrinkage networks for fault diagnosis, IEEE Transactions on Industrial Informatics, DOI: 10.1109/TII.2019.2943898
https://ieeexplore.ieee.org/document/8850096
https://github.com/zhao62/Deep-Residual-Shrinkage-Networks
8. Keras示例代码
#!/usr/bin/env python3# -*- coding: utf-8 -*-"""Created on Sat Dec 28 23:24:05 2019Implemented using TensorFlow 1.0.1 and Keras 2.2.1 M. Zhao, S. Zhong, X. Fu, et al., Deep Residual Shrinkage Networks for Fault Diagnosis, IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898@author: super_9527"""from __future__ import print_functionimport kerasimport numpy as npfrom keras.datasets import mnistfrom keras.layers import Dense, Conv2D, BatchNormalization, Activationfrom keras.layers import AveragePooling2D, Input, GlobalAveragePooling2Dfrom keras.optimizers import Adamfrom keras.regularizers import l2from keras import backend as Kfrom keras.models import Modelfrom keras.layers.core import LambdaK.set_learning_phase(1)# Input image dimensionsimg_rows, img_cols = 28, 28# The data, split between train and test sets(x_train, y_train), (x_test, y_test) = mnist.load_data()if K.image_data_format() == 'channels_first': x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) input_shape = (1, img_rows, img_cols)else: x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) input_shape = (img_rows, img_cols, 1)# Noised datax_train = x_train.astype('float32') / 255. + 0.5*np.random.random([x_train.shape[0], img_rows, img_cols, 1])x_test = x_test.astype('float32') / 255. + 0.5*np.random.random([x_test.shape[0], img_rows, img_cols, 1])print('x_train shape:', x_train.shape)print(x_train.shape[0], 'train samples')print(x_test.shape[0], 'test samples')# convert class vectors to binary class matricesy_train = keras.utils.to_categorical(y_train, 10)y_test = keras.utils.to_categorical(y_test, 10)def abs_backend(inputs): return K.abs(inputs)def expand_dim_backend(inputs): return K.expand_dims(K.expand_dims(inputs,1),1)def sign_backend(inputs): return K.sign(inputs)def pad_backend(inputs, in_channels, out_channels): pad_dim = (out_channels - in_channels)//2 inputs = K.expand_dims(inputs,-1) inputs = K.spatial_3d_padding(inputs, ((0,0),(0,0),(pad_dim,pad_dim)), 'channels_last') return K.squeeze(inputs, -1)# Residual Shrinakge Blockdef residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False, downsample_strides=2): residual = incoming in_channels = incoming.get_shape().as_list()[-1] for i in range(nb_blocks): identity = residual if not downsample: downsample_strides = 1 residual = BatchNormalization()(residual) residual = Activation('relu')(residual) residual = Conv2D(out_channels, 3, strides=(downsample_strides, downsample_strides), padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(residual) residual = BatchNormalization()(residual) residual = Activation('relu')(residual) residual = Conv2D(out_channels, 3, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(residual) # Calculate global means residual_abs = Lambda(abs_backend)(residual) abs_mean = GlobalAveragePooling2D()(residual_abs) # Calculate scaling coefficients scales = Dense(out_channels, activation=None, kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(abs_mean) scales = BatchNormalization()(scales) scales = Activation('relu')(scales) scales = Dense(out_channels, activation='sigmoid', kernel_regularizer=l2(1e-4))(scales) scales = Lambda(expand_dim_backend)(scales) # Calculate thresholds thres = keras.layers.multiply([abs_mean, scales]) # Soft thresholding sub = keras.layers.subtract([residual_abs, thres]) zeros = keras.layers.subtract([sub, sub]) n_sub = keras.layers.maximum([sub, zeros]) residual = keras.layers.multiply([Lambda(sign_backend)(residual), n_sub]) # Downsampling using the pooL-size of (1, 1) if downsample_strides > 1: identity = AveragePooling2D(pool_size=(1,1), strides=(2,2))(identity) # Zero_padding to match channels if in_channels != out_channels: identity = Lambda(pad_backend, arguments={'in_channels':in_channels,'out_channels':out_channels})(identity) residual = keras.layers.add([residual, identity]) return residual# define and train a modelinputs = Input(shape=input_shape)net = Conv2D(8, 3, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(inputs)net = residual_shrinkage_block(net, 1, 8, downsample=True)net = BatchNormalization()(net)net = Activation('relu')(net)net = GlobalAveragePooling2D()(net)outputs = Dense(10, activation='softmax', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(net)model = Model(inputs=inputs, outputs=outputs)model.compile(loss='categorical_crossentropy', optimizer=Adam(), metrics=['accuracy'])model.fit(x_train, y_train, batch_size=100, epochs=5, verbose=1, validation_data=(x_test, y_test))# get resultsK.set_learning_phase(0)DRSN_train_score = model.evaluate(x_train, y_train, batch_size=100, verbose=0)print('Train loss:', DRSN_train_score[0])print('Train accuracy:', DRSN_train_score[1])DRSN_test_score = model.evaluate(x_test, y_test, batch_size=100, verbose=0)print('Test loss:', DRSN_test_score[0])print('Test accuracy:', DRSN_test_score[1])
9. TFLearn程序
#!/usr/bin/env python3# -*- coding: utf-8 -*-"""Created on Mon Dec 23 21:23:09 2019Implemented using TensorFlow 1.0 and TFLearn 0.3.2 M. Zhao, S. Zhong, X. Fu, B. Tang, M. Pecht, Deep Residual Shrinkage Networks for Fault Diagnosis, IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898 @author: super_9527""" from __future__ import division, print_function, absolute_import import tflearnimport numpy as npimport tensorflow as tffrom tflearn.layers.conv import conv_2d # Data loadingfrom tflearn.datasets import cifar10(X, Y), (testX, testY) = cifar10.load_data() # Add noiseX = X + np.random.random((50000, 32, 32, 3))*0.1testX = testX + np.random.random((10000, 32, 32, 3))*0.1 # Transform labels to one-hot formatY = tflearn.data_utils.to_categorical(Y,10)testY = tflearn.data_utils.to_categorical(testY,10) def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False, downsample_strides=2, activation='relu', batch_norm=True, bias=True, weights_init='variance_scaling', bias_init='zeros', regularizer='L2', weight_decay=0.0001, trainable=True, restore=True, reuse=False, scope=None, name="ResidualBlock"): # residual shrinkage blocks with channel-wise thresholds residual = incoming in_channels = incoming.get_shape().as_list()[-1] # Variable Scope fix for older TF try: vscope = tf.variable_scope(scope, default_name=name, values=[incoming], reuse=reuse) except Exception: vscope = tf.variable_op_scope([incoming], scope, name, reuse=reuse) with vscope as scope: name = scope.name #TODO for i in range(nb_blocks): identity = residual if not downsample: downsample_strides = 1 if batch_norm: residual = tflearn.batch_normalization(residual) residual = tflearn.activation(residual, activation) residual = conv_2d(residual, out_channels, 3, downsample_strides, 'same', 'linear', bias, weights_init, bias_init, regularizer, weight_decay, trainable, restore) if batch_norm: residual = tflearn.batch_normalization(residual) residual = tflearn.activation(residual, activation) residual = conv_2d(residual, out_channels, 3, 1, 'same', 'linear', bias, weights_init, bias_init, regularizer, weight_decay, trainable, restore) # get thresholds and apply thresholding abs_mean = tf.reduce_mean(tf.reduce_mean(tf.abs(residual),axis=2,keep_dims=True),axis=1,keep_dims=True) scales = tflearn.fully_connected(abs_mean, out_channels//4, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling') scales = tflearn.batch_normalization(scales) scales = tflearn.activation(scales, 'relu') scales = tflearn.fully_connected(scales, out_channels, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling') scales = tf.expand_dims(tf.expand_dims(scales,axis=1),axis=1) thres = tf.multiply(abs_mean,tflearn.activations.sigmoid(scales)) # soft thresholding residual = tf.multiply(tf.sign(residual), tf.maximum(tf.abs(residual)-thres,0)) # Downsampling if downsample_strides > 1: identity = tflearn.avg_pool_2d(identity, 1, downsample_strides) # Projection to new dimension if in_channels != out_channels: if (out_channels - in_channels) % 2 == 0: ch = (out_channels - in_channels)//2 identity = tf.pad(identity, [[0, 0], [0, 0], [0, 0], [ch, ch]]) else: ch = (out_channels - in_channels)//2 identity = tf.pad(identity, [[0, 0], [0, 0], [0, 0], [ch, ch+1]]) in_channels = out_channels residual = residual + identity return residual # Real-time data preprocessingimg_prep = tflearn.ImagePreprocessing()img_prep.add_featurewise_zero_center(per_channel=True) # Real-time data augmentationimg_aug = tflearn.ImageAugmentation()img_aug.add_random_flip_leftright()img_aug.add_random_crop([32, 32], padding=4) # Build a Deep Residual Shrinkage Network with 3 blocksnet = tflearn.input_data(shape=[None, 32, 32, 3], data_preprocessing=img_prep, data_augmentation=img_aug)net = tflearn.conv_2d(net, 16, 3, regularizer='L2', weight_decay=0.0001)net = residual_shrinkage_block(net, 1, 16)net = residual_shrinkage_block(net, 1, 32, downsample=True)net = residual_shrinkage_block(net, 1, 32, downsample=True)net = tflearn.batch_normalization(net)net = tflearn.activation(net, 'relu')net = tflearn.global_avg_pool(net)# Regressionnet = tflearn.fully_connected(net, 10, activation='softmax')mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=20000, staircase=True)net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy')# Trainingmodel = tflearn.DNN(net, checkpoint_path='model_cifar10', max_checkpoints=10, tensorboard_verbose=0, clip_gradients=0.) model.fit(X, Y, n_epoch=100, snapshot_epoch=False, snapshot_step=500, show_metric=True, batch_size=100, shuffle=True, run_id='model_cifar10') training_acc = model.evaluate(X, Y)[0]validation_acc = model.evaluate(testX, testY)[0]