当前位置: 代码迷 >> 综合 >> tf2.0使用官方resnet50 finetune
  详细解决方案

tf2.0使用官方resnet50 finetune

热度:49   发布时间:2023-11-21 13:17:48.0

前言

没咋用tf了,入门学的tf,一堆sess整的最后不想看了…解除了pytorch还是torch好用,最近“机缘巧合”接了个tf的活,记录一下。

finetune官方模型

核心比较简单:

r50_finetune = tf.keras.models.Sequential()
model = tf.keras.applications.resnet50.ResNet50(input_shape = (img_size,img_size,3), weights="imagenet", include_top = False, pooling = 'avg')
r50_finetune.add(model)
r50_finetune.add(tf.keras.layers.Dense(num_classes, activation = 'softmax'))r50_finetune.compile(optimizer = "sgd",loss = 'categorical_crossentropy',metrics=['accuracy'])r50_finetune.summary()
  • 第一行用Sequential构建tf模型,方便增删model
  • 第二行调用tf库封装的resnet50,第一个参数设置输入的大小,一般是(224,224,3),第二个参数是否加载imagenet的pretrain,当然选是,选了是就要把resnet50的fc层删掉,于是第三个参数把top给删了,最后一个参数删了头出来是没有池化的feature map加一个平均池化把拉成一维。
  • 第三行,把刚定义的r50加入我们的模型
  • 第四行,平均池化后加一个fc层对应我们的分类类别数。
  • 后面的就是构建优化器和损失函数了
  • 最后一行打印模型

总结

怎么说呢,tf的高层api封装的还挺全,虽然简洁但是感觉还是不太习惯这样的写法。