1、API
通过直接使用 tf.keras.Sequential() 函数可以轻松地构建网络,如:
mobile = tf.keras.applications.MobileNetV2(include_top=False, weights='imagenet', input_shape=(224, 224, 3))
mobile.trainable = False
model = tf.keras.Sequential([simplified_mobile,tf.keras.layers.Dropout(0.5),tf.keras.layers.GlobalAveragePooling2D(),tf.keras.layers.Dense(28, activation='softmax')
])
但是,通过 API 定义的方法并不容易自定义复杂的网络。
2、通过函数定义
mobile = tf.keras.applications.MobileNetV2(include_top=False, weights='imagenet', input_shape=(224, 224, 3))
mobile.trainable = Falsedef MobileNetV2 (classes):img_input = tf.keras.layers.Input(shape=(224, 224, 3))x = mobile(img_input)x = tf.keras.layers.Dropout(0.5)(x)x = tf.keras.layers.GlobalAveragePooling2D()(x)x = tf.keras.layers.Dense(classes, activation='softmax')(x)model = tf.keras.Model(img_input, x)return model
使用函数定义网络时要注意以下几点:
- 1、一个网络中往往包含了多个自定义网络函数,如卷积-批归一化-激活函数等,但在最后构建网络的函数的开头必须定义输入该网络的形状,而不是直接在函数名后面定义一个输入。当然,对前面的函数来说是可以直接多定义一个输入的;
- 2、在函数结尾必须有:model = tf.keras.Model(img_input, x)。
3、通过类定义
mobile = tf.keras.applications.MobileNetV2(include_top=False, weights='imagenet', input_shape=(224, 224, 3))
mobile.trainable = Falseclass MobileNetV2(tf.keras.Model):def __init__(self, classes):super().__init__()self.mob = mobileself.dropout = tf.keras.layers.Dropout(0.5)self.gap = tf.keras.layers.GlobalAveragePooling2D()self.dense = tf.keras.layers.Dense(classes, activation='softmax')def call(self, inputs):x = self.mob(inputs)x = self.dropout(x)x = self.gap(x)x = self.dense(x)return x