当前位置: 代码迷 >> 综合 >> pytorch-图像分类算法与小程序开发
  详细解决方案

pytorch-图像分类算法与小程序开发

热度:47   发布时间:2023-10-19 18:57:22.0

先上结果图,免得还没看就跑了

pytorch-图像分类算法与小程序开发

这是小程序显示的界面,上半部分通过调用手机摄像头,将图像传给后端算法,算法将识别结果,放回到这个页面黄色框这个部分。那么这个算法是怎么实现的呢。 

图像分类算法

图像分类常用的卷积神经网络算是人工智能CV篇的入门

常用的算法比如lenet,多用于mnist和cifar10分类分别是28*28,和32*32尺寸的数据集。

然后vgg,resnet,mobilenet,efficientnet等分类算法大多以224*224的图像尺寸制作数据集。

本文就以最经典的resnet50算法为例子,教你如何制作图像分类的数据集,如何使用迁移学习resnet50进行训练模型,以及如何调用模型进行预测,代码都在下面,

多多点赞加关注,更新不间断。

首先是数据集的制作

我通过爬虫,爬取了5种水果,包含橘子、火龙果、苹果、草莓、香蕉等。每种几十张图片,各自放在相应的文件夹下,文件的命名后缀_train是训练集,_test是测试集

pytorch-图像分类算法与小程序开发

 然后通过代码来生成数据集文本

这个代码运行完成后会在水果数据集文件夹下生成两个txt文本分别用于存放训练集和测试集的图片路径和对应标签

pytorch-图像分类算法与小程序开发

 有了这个txt数据集就可以写训练的代码了

这里用了迁移学习的resnet50,大家可以在model = models.resnet50(pretrained=True)这里修改

迁移学习其他网络也是可以的

最后训练得到的模型会测出测试集的准确率,生成的模型

pytorch-图像分类算法与小程序开发

然后我们来调用这个模型来对输入进来的图片进行测试吧

import torch
from PIL import Image
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision import transforms,modelsa='./1.jpg'      #输入要测试的图片名称
classes = ['橘子','火龙果','苹果','草莓','香蕉']  #标签序号对应类名
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')  #是否使用gpu加速def test_mydata(a):   #定义预测函数# 调整图片大小im = plt.imread(a)     #读入图片images = Image.open(a)     #将图片存储到images里面images = images.resize((224, 224))   #调整图片的大小为224*224的大小images = images.convert('RGB')    #RGB化transform = transforms.ToTensor()images = transform(images)   #图像转化成tensor类型images = images.resize(1, 3, 224, 224)    #调整输入网络图片大小images = images.to(device)   #gpu加速path_model = "./model.ckpt"   #调用训练好的模型model = torch.load(path_model)model = model.to(device)model.eval()outputs = model(images)    #将图片传入模型预测values, indices = outputs.data.max(1)  #返回最大概率值和下标  output不是tensor类型所以要加.dataprint(classes[int(indices[0])])   #输出类名plt.title(int(indices[0]))   #输出标签plt.imshow(im)    #展示结果plt.show()test_mydata(a)   #调用函数

这个代码能对本地1.jpg名称的图片进行输入模型,测试后显示他的水果种类。

如果想测其他图片在a=‘./1.jpg’ 这里改成其他图片的路径即可

本期就到这里

别忘记点赞、关注加收藏

有什么想让博主先写的代码可以在评论区留言,点赞超10个,更新如何部署小程序

还有想知道如何爬虫、目标检测算法的实现,pyqt界面如何写,网页如何部署,都欢迎留言

往期传送门:

在linux远程服务器配置pytorch-gpu

在windows笔记本中安装tensorflow1.13.2版本的gpu环境

pytorch版Cyclegan

python-图片批量处理大小并删除原图片

tensoflow1.x版本训练CycleGAN

pytorch版空洞卷积,以及RFBNet应用

python-词云生成

pytorch单机多gpu训练cycleGAN模型

python-使用PIL工具包将图片分割成四等分再还原

如何Xshell退出服务器,服务器里的程序仍可以运行,,和如何通过终端关闭程序