当前位置: 代码迷 >> 综合 >> faster rcnn inception_resnet_v2物品辨识比赛demo记录
  详细解决方案

faster rcnn inception_resnet_v2物品辨识比赛demo记录

热度:48   发布时间:2023-12-06 01:07:53.0

使用tensorflow model里面的object detection训练的,因为没有时间限制,使用faster rcnn inception_resnet_v2识别10个类别,在1080ti上训练了5个小时,在1050上测试的,2s一张图片。
定义 pascal_label_map.pbtxt

item {
    id: 1name: 'cola'
}item {
    id: 2name: 'milk tea'
}item {
    id: 3name: 'ice tea'
}
item {
    id: 4name: 'beer'
}
item {
    id: 5name: 'shampoo'
}
item {
    id: 6name: 'toothpaste'
}
item {
    id: 7name: 'soap'
}
item {
    id: 8name: 'pear'
}
item {
    id: 9name: 'apple'
}
item {
    id: 10name: 'orange'
}

测试代码

#-*-coding:utf-8-*-
import sys
import argparse
from PIL import Image
import os
import cv2
import numpy as np
import speech_recognition as srimport wave
import requests
import time
import base64
from pyaudio import PyAudio, paInt16
import webbrowser
import serial
import speech
import numpy as np
import os
import sys
import tensorflow as tf
from PIL import Image
sys.path.append("..")
from utils import label_map_util
from utils import visualization_utils as vis_util
import cv2
from timeit import default_timer as timerframerate = 16000  # 采样率
num_samples = 2000  # 采样点
channels = 1  # 声道
sampwidth = 2  # 采样宽度2bytes
FILEPATH = 'speech.wav'base_url = "https://openapi.baidu.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s"
APIKey = "***"
SecretKey = "***"HOST = base_url % (APIKey, SecretKey)PATH_TO_CKPT = 'F:/python_project/比赛' + '/frozen_inference_graph.pb'
PATH_TO_LABELS = 'F:/python_project/比赛/pascal_label_map.pbtxt'NUM_CLASSES = 80
detection_graph = tf.Graph()
with detection_graph.as_default():od_graph_def = tf.GraphDef()with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:serialized_graph = fid.read()od_graph_def.ParseFromString(serialized_graph)tf.import_graph_def(od_graph_def, name='')label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES,use_display_name=True)
category_index = label_map_util.create_category_index(categories)
print(category_index)def detect():with detection_graph.as_default():with tf.Session(graph=detection_graph) as sess:state = Truecap = cv2.VideoCapture(1)while state:start = timer()f, frame = cap.read()show = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)image = Image.fromarray(show)image_np = np.array(image)image_np_expanded = np.expand_dims(image_np, axis=0)image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')boxes = detection_graph.get_tensor_by_name('detection_boxes:0')scores = detection_graph.get_tensor_by_name('detection_scores:0')classes = detection_graph.get_tensor_by_name('detection_classes:0')num_detections = detection_graph.get_tensor_by_name('num_detections:0')(boxes, scores, classes, num_detections) = sess.run([boxes, scores, classes, num_detections],feed_dict={
    image_tensor: image_np_expanded})#print(num_detections)end = timer()image_np,num=vis_util.visualize_boxes_and_labels_on_image_array_(image_np, np.squeeze(boxes),np.squeeze(classes).astype(np.int32),np.squeeze(scores),category_index,use_normalized_coordinates=True,line_thickness=8)(r, g, b) = cv2.split(image_np)image_np = cv2.merge([b, g, r])print(num)if num >= 5: #识别5个物体state = False#cv2.imwrite("wxy-TRY-" + time.strftime("%H%M", time.localtime()) + ".jpg", image_np)tmp = "wxy-TRY-" + time.strftime("%H%M", time.localtime()) + ".jpg" ##带中文路径要用imencodecv2.imencode('.jpg', image_np)[1].tofile(tmp)speech.say("识别完成")print("写入成功,停止检测")cv2.imshow("test", image_np)cv2.waitKey(1)print(end - start)def getToken(host):res = requests.post(host)return res.json()['access_token']def save_wave_file(filepath, data):wf = wave.open(filepath, 'wb')wf.setnchannels(channels)wf.setsampwidth(sampwidth)wf.setframerate(framerate)wf.writeframes(b''.join(data))wf.close()def my_record():pa = PyAudio()stream = pa.open(format=paInt16, channels=channels,rate=framerate, input=True, frames_per_buffer=num_samples)my_buf = []# count = 0t = time.time()print('正在录音...')while time.time() < t + 4:  # 秒string_audio_data = stream.read(num_samples)my_buf.append(string_audio_data)print('录音结束.')save_wave_file(FILEPATH, my_buf)stream.close()def get_audio(file):with open(file, 'rb') as f:data = f.read()return datadef speech2text(speech_data, token, dev_pid=1537):FORMAT = 'wav'RATE = '16000'CHANNEL = 1CUID = '*******'SPEECH = base64.b64encode(speech_data).decode('utf-8')data = {
    'format': FORMAT,'rate': RATE,'channel': CHANNEL,'cuid': CUID,'len': len(speech_data),'speech': SPEECH,'token': token,'dev_pid': dev_pid}url = 'https://vop.baidu.com/server_api'headers = {
    'Content-Type': 'application/json'}# r=requests.post(url,data=json.dumps(data),headers=headers)print('正在识别...')r = requests.post(url, json=data, headers=headers)Result = r.json()if 'result' in Result:return Result['result'][0]else:return Resultdef detect_img(yolo):state = Truecap = cv2.VideoCapture(0)while state:f, frame = cap.read()show = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)image = Image.fromarray(show)image_,num = yolo.detect_image_(image)image = cv2.cvtColor(np.asarray(image_), cv2.COLOR_RGB2BGR)#print(datetime.datetime.now())time_str = time.strftime("%H:%M:%S", time.localtime())if num == 2:state = Falsecv2.imwrite("WXY-TRY-"+time.strftime("%H%M", time.localtime())+".jpg",image)speech.say("识别完成")print("写入成功,停止检测")cv2.imshow("test", image)cv2.waitKey(1)def test():state = Truewhile state:my_record()TOKEN = getToken(HOST)speech_ = get_audio(FILEPATH)result = speech2text(speech_, TOKEN, int(1536))print(result)if result == "开始":state = Falseif not state:speech.say("开始识别")detect()def test2():serialPort = "COM4"  # 串口baudRate = 115200  # 波特率ser = serial.Serial(serialPort, baudRate, timeout=0.5)print("参数设置:串口=%s ,波特率=%d" % (serialPort, baudRate))state = True# 收发数据while state:#ser.write((str + '\n').encode())#print(ser.readline(),"接收成功") # 可以接收中文tmp = ser.readline()if tmp:detect()state = Falseser.close()def test3():speech.say("开始识别")print(time.strftime("%H%M", time.localtime()))
if __name__ == '__main__':#test() #加百度语音识别#test2() #加科大讯飞的语音唤醒#test3() #测试windows下的speech模块detect()   #直接检测
  相关解决方案