from keras.callbacks import TensorBoard, ModelCheckpointfrom keras.utils import multi_gpu_model # 导入keras多卡函数class ParallelModelCheckpoints(ModelCheckpoint): # 在保存模型时,由于存在两个模型,所以需要指定model,\# 继承ModelCheckpoint,重写init()def __init__(self, model, # 需要保存的模型filepath='./log/epoch-{epoch:02d}_loss-{loss:.4f}_acc-{val_acc:.4f}_lr-{lr:.5f}.h5',monitor='val_acc',verbose=1,save_best_only=True,save_weights_only=False,mode='auto',period=1):self.single_model = modelsuper(ParallelModelCheckpoints, self).__init__(filepath, monitor, verbose,save_best_only, save_weights_only, mode, period)def set_model(self, model):super(ParallelModelCheckpoints, self).set_model(self.single_model)# 首先在cpu上创建原来的模型with tf.device('/cpu:0'):model = MobileNet(...)# 创建多卡模型parallel_model = multi_gpu_model(model, gpus=4) # 其中 4 是gpu的数量parallel_model.load_weights(h5_path, by_name=True) # 继续训练的时候导入参数是用的parallel_model模型,而不是modelparallel_model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])model_checkpoint = ParallelModelCheckpoints(model) # 设置需要保存h5的模型print("Start training the model") # 然后就可以训练了training_history = parallel_model.fit_generator(train_generator,steps_per_epoch=step_size_train,validation_data=validation_generator,validation_steps=step_size_valid,epochs=epoch_list[-1],verbose=1,callbacks=[TensorBoard(log_dir='./tb'), model_checkpoint, stepDecayLR])print("Model training finished")
详细解决方案
keras--多GPU训练
热度:129 发布时间:2023-10-27 03:00:34.0
相关解决方案
- Keras-如何将学习到的Embedding()层用于输入和输出?
- keras GRU不会采用我的简单二维数组
- Keras - 如何在 CPU 上运行加载的模型
- Keras - MS-SSIM 作为损失函数 计算 SSIM 的 c*s 值的函数: 获得高斯核的函数 主损失函数
- Keras 自编码器AutoEncoder(五)
- keras.datasets.fashion_mnist下载数据集时TimeoutError: [WinError 10060]解决方法
- MNIST图像分类 - Keras
- 'ValueError: Tensor Tensor is not an element of this graph'(Keras, backend:TensorFlow)
- TensorFlow/Keras InternalError: Dst tensor is not initialized.
- keras实现FCN代码问题记录-Keras implementation of FCN for Semantic Segmentation
- 【机器学习15】keras-yolo4
- 吴恩达Deep Learning编程作业 Course4- 卷积神经网络-第二周作业:Keras tutorial-the Happy House
- 【深度学习】 Keras 实现Minst数据集上经典网络结构(DeepDense、LeNet、AlexNet、ZFNet)分类
- keras--多GPU训练
- Keras--动态调整学习率
- tensorflow(Keras)模型导入opencv 之版本问题dnn module Unknown layer type Shape in conv2d_transpose/Shape in
- tf.keras 模型的保存与加载(六)
- tf.keras 网络搭建相关基础函数(三)
- Tensorflow2.0 中 tf.keras.layers.Conv2D 里的初始化方法 'glorot_uniform' 到底是个啥?
- Tensorflow2.0之tf.keras.applacations迁移学习
- 在Faster R-CNN VOC2012的基础上训练widerface错误解决(keras)
- 报错:module 'tensorflow.python.keras.backend' has no attribute 'get_graph'
- Python报错:module 'tensorflow.python.keras.backend' has no attribute 'get_graph'
- 【Keras Mnist】手写数字识别数据集
- tensorflow2中tensorflow.keras.layers.Activation函数使用方法
- keras mnist图像分类任务(服饰)
- 深度学习TensorFlow---保存和加载 Keras 模型
- python基础知识及数据分析工具安装及简单使用(Numpy/Scipy/Matplotlib/Pandas/StatsModels/Scikit-Learn/Keras/Gensim))
- keras 函数式API
- 树莓派3b+构建Docker深度学习环境(ubuntu16.04+tensorflow+keras+opencv)