当前位置: 代码迷 >> 综合 >> Tensorflow2.0 定义模型的三种方法
  详细解决方案

Tensorflow2.0 定义模型的三种方法

热度:106   发布时间:2023-10-28 13:24:18.0

1、API

通过直接使用 tf.keras.Sequential() 函数可以轻松地构建网络,如:

mobile = tf.keras.applications.MobileNetV2(include_top=False, weights='imagenet', input_shape=(224, 224, 3)) 
mobile.trainable = False
model = tf.keras.Sequential([simplified_mobile,tf.keras.layers.Dropout(0.5),tf.keras.layers.GlobalAveragePooling2D(),tf.keras.layers.Dense(28, activation='softmax')
])

但是,通过 API 定义的方法并不容易自定义复杂的网络。

2、通过函数定义

mobile = tf.keras.applications.MobileNetV2(include_top=False, weights='imagenet', input_shape=(224, 224, 3)) 
mobile.trainable = Falsedef MobileNetV2 (classes):img_input = tf.keras.layers.Input(shape=(224, 224, 3))x = mobile(img_input)x = tf.keras.layers.Dropout(0.5)(x)x = tf.keras.layers.GlobalAveragePooling2D()(x)x = tf.keras.layers.Dense(classes, activation='softmax')(x)model = tf.keras.Model(img_input, x)return model

使用函数定义网络时要注意以下几点:

  • 1、一个网络中往往包含了多个自定义网络函数,如卷积-批归一化-激活函数等,但在最后构建网络的函数的开头必须定义输入该网络的形状,而不是直接在函数名后面定义一个输入。当然,对前面的函数来说是可以直接多定义一个输入的;
  • 2、在函数结尾必须有:model = tf.keras.Model(img_input, x)。

3、通过类定义

mobile = tf.keras.applications.MobileNetV2(include_top=False, weights='imagenet', input_shape=(224, 224, 3)) 
mobile.trainable = Falseclass MobileNetV2(tf.keras.Model):def __init__(self, classes):super().__init__()self.mob = mobileself.dropout = tf.keras.layers.Dropout(0.5)self.gap = tf.keras.layers.GlobalAveragePooling2D()self.dense = tf.keras.layers.Dense(classes, activation='softmax')def call(self, inputs):x = self.mob(inputs)x = self.dropout(x)x = self.gap(x)x = self.dense(x)return x
  相关解决方案