当前位置: 代码迷 >> 综合 >> HRNet-Semantic-Segmentation图像,视频推理
  详细解决方案

HRNet-Semantic-Segmentation图像,视频推理

热度:44   发布时间:2023-12-06 00:56:52.0

源码:https://github.com/HRNet/HRNet-Semantic-Segmentation/,我用的是pytorchv1.1分支。
这么好的项目居然没有inference代码,于是自己整理了一个简单的demo。

jit和onnx model导出

jit模型需要torch>=1.8

import torch
import torchvision
import argparse
import _init_paths
from config import config
from config import update_config
import models
from utils.utils import create_logger, FullModel, get_rank
from onnxruntime.datasets import get_example
import onnxruntime
from onnx import shape_inference
import os
from torch.nn import functional as F
import cv2
import numpy as npdef jit_export(model, pth_file):pretrained_dict = torch.load(pth_file, map_location="cpu")model_dict = model.state_dict()pretrained_dict = {
    k[6:]: v for k, v in pretrained_dict.items()if k[6:] in model_dict.keys()}model_dict.update(pretrained_dict)model.load_state_dict(model_dict)model.eval()dump_input = torch.rand((1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0]))print(model)print(dump_input.shape)traced_script_module = torch.jit.trace(model, dump_input)traced_script_module.save("export_models/export_model.pt")new_model = torch.jit.load("export_models/export_model.pt")dump_input = torch.rand((1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0]))out = new_model(dump_input)print(out.shape)def onnx_export(model, pth_file):pretrained_dict = torch.load(pth_file, map_location="cpu")model_dict = model.state_dict()pretrained_dict = {
    k[6:]: v for k, v in pretrained_dict.items()if k[6:] in model_dict.keys()}model_dict.update(pretrained_dict)model.load_state_dict(model_dict)dump_input = torch.rand((1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0]))export_onnx_file = os.path.join("export_models",os.path.basename(args.pth_file).replace("pth","onnx"))torch.onnx.export(model.cpu(), dump_input.cpu(), export_onnx_file, verbose=True)dump_input = torch.rand((1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0]))def to_numpy(tensor):return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()model.eval()x = torch.randn(1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0]).cpu()with torch.no_grad():torch_out = model(x)example_model = get_example(os.getcwd()+'/'+export_onnx_file)sess = onnxruntime.InferenceSession(example_model)onnx_out = sess.run(None, {
    sess.get_inputs()[0].name: to_numpy(x)})print(torch_out.shape,torch_out[0,0,0,0:10])print(onnx_out[0].shape,onnx_out[0][0,0,0,0:10])if __name__ == "__main__":parser = argparse.ArgumentParser(description='Train segmentation network')parser.add_argument('--cfg',help='experiment configure file name',required=True,type=str)parser.add_argument('--pth_file',type=str)parser.add_argument('--image_path',type=str)parser.add_argument('opts',help="Modify config options using the command-line",default=None,nargs=argparse.REMAINDER)args = parser.parse_args()update_config(config, args)pth_file = args.pth_fileimage_path = args.image_pathmodel = eval('models.'+config.MODEL.NAME +'.get_seg_model')(config)model.to("cpu")# onnx_export(model, pth_file)jit_export(model, pth_file)

jit、pth 模型图像、视频推理

jit模型需要torch>=1.8,pth模型随意,注意修改输入大小,mean和std。

import torch
import torchvision
import argparse
import _init_paths
from config import config
from config import update_config
import models
import os
from torch.nn import functional as F
import cv2
import numpy as np
import timedef preprocess(img, model, device):def input_transform(image): image = image.astype(np.float32)[:, :, ::-1]image = image / 255.0image -= meanimage /= stdreturn imagedef image_resize(image, long_size, label=None):h, w = image.shape[:2]if h > w:new_h = long_sizenew_w = np.int(w * long_size / h + 0.5)else:new_w = long_sizenew_h = np.int(h * long_size / w + 0.5)image = cv2.resize(image, (new_w, new_h), interpolation = cv2.INTER_LINEAR)if label is not None:label = cv2.resize(label, (new_w, new_h), interpolation = cv2.INTER_NEAREST)else:return imagereturn image, labeldef pad_image(image, h, w, size, padvalue):pad_image = image.copy()pad_h = max(size[0] - h, 0)pad_w = max(size[1] - w, 0)if pad_h > 0 or pad_w > 0:pad_image = cv2.copyMakeBorder(image, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=padvalue)return pad_imagedef multi_scale_aug(image, label=None, rand_scale=1, rand_crop=True):long_size = 473if label is not None:image, label = image_resize(image, long_size, label)if rand_crop:image, label = rand_crop(image, label)return image, labelelse:image = image_resize(image, long_size)return imagedef infer(model, image):size = image.size()# start = time.time()pred = model(image)# print("inference time:",time.time()-start)pred = F.upsample(input=pred, size=(size[-2], size[-1]), mode='bilinear')  return pred.exp()mean=[0.485, 0.456, 0.406]std=[0.229, 0.224, 0.225]padvalue = -1.0  * np.array(mean) / np.array(std)crop_size = (473, 473)new_img = multi_scale_aug(img)height, width = new_img.shape[:-1]new_img = input_transform(new_img)if max(height, width) <= np.min(crop_size):new_img = pad_image(new_img, height, width, crop_size, padvalue)new_img = new_img.transpose((2, 0, 1))new_img = np.expand_dims(new_img, axis=0)new_img = torch.from_numpy(new_img).to(device)preds = infer(model, new_img)preds = preds[:, :, 0:height, 0:width]else:if height < crop_size[0] or width < crop_size[1]:new_img = pad_image(new_img, height, width, crop_size, padvalue)new_h, new_w = new_img.shape[:-1]rows = np.int(np.ceil(1.0 * (new_h - crop_size[0]) / stride_h)) + 1cols = np.int(np.ceil(1.0 * (new_w - crop_size[1]) / stride_w)) + 1preds = torch.zeros([1, 2, new_h, new_w]).to(device)count = torch.zeros([1, 1, new_h, new_w]).to(device)for r in range(rows):for c in range(cols):h0 = r * stride_hw0 = c * stride_wh1 = min(h0 + crop_size[0], new_h)w1 = min(w0 + crop_size[1], new_w)crop_img = new_img[h0:h1, w0:w1, :]if h1 == new_h or w1 == new_w:crop_img = ad_image(crop_img, h1-h0, w1-w0, crop_size, padvalue)crop_img = crop_img.transpose((2, 0, 1))crop_img = np.expand_dims(crop_img, axis=0)crop_img = torch.from_numpy(crop_img).to(device)preds = infer(model, crop_img)preds[:,:,h0:h1,w0:w1] += pred[:,:, 0:h1-h0, 0:w1-w0]count[:,:,h0:h1,w0:w1] += 1preds = preds / countpreds = preds[:,:,:height,:width]return predsdef inference(model, pth_file, image_path, device):image = cv2.imread(image_path)ori_height, ori_width = image.shape[0], image.shape[1]preds = preprocess(image.copy(), model, device)#preds = F.upsample(preds, (ori_height, ori_width), mode='bilinear')preds = preds.detach().cpu().numpy().copy()preds = np.asarray(np.argmax(preds, axis=1), dtype=np.uint8)preds = preds.astype(np.uint8).transpose((1,2,0))preds[preds==1] = 255preds[preds!=255] = 0preds = cv2.merge([preds,preds,preds])image = cv2.resize(image,(preds.shape[1],preds.shape[0]),interpolation=cv2.INTER_LINEAR)image_merge = cv2.addWeighted(image, 0.5, preds, 0.5, 0)cv2.imshow(image_path, image_merge)cv2.waitKey(0)def inference_dir(model, pth_file, image_path, device):for name in os.listdir(image_path):image = cv2.imread(os.path.join(image_path, name))ori_height, ori_width = image.shape[0], image.shape[1]preds = preprocess(image.copy(), model, device)# preds = F.upsample(preds, (ori_height, ori_width), mode='bilinear')preds = preds.detach().cpu().numpy().copy()preds = np.asarray(np.argmax(preds, axis=1), dtype=np.uint8)preds = preds.astype(np.uint8).transpose((1,2,0))preds[preds==1] = 255preds[preds!=255] = 0preds = cv2.merge([preds,preds,preds])image = cv2.resize(image,(preds.shape[1],preds.shape[0]),interpolation=cv2.INTER_LINEAR)image_merge = cv2.addWeighted(image, 0.5, preds, 0.5, 0)cv2.imshow(name,image_merge)key = cv2.waitKey(0)cv2.destroyWindow(name)if key == 27:breakdef inference_video(model, pth_file, video_path, device, save=False):if video_path.endswith(".mp4"):vc = cv2.VideoCapture(video_path)else:vc = cv2.VideoCapture(0)if vc.isOpened():rval, frame = vc.read()else:rval = Falsestart_time = time.time()frame_count = 1if save:fps = vc.get(cv2.CAP_PROP_FPS) width, height = frame.shape[1], frame.shape[0]resize_ratio = 1.0 * 473 / max(width, height)target_size = (int(resize_ratio * width), int(resize_ratio * height))fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')#mp4voutVideo = cv2.VideoWriter(save, fourcc,fps,target_size)while rval:rval, frame = vc.read()if rval == False:if save:outVideo.release()breakori_height, ori_width = frame.shape[0], frame.shape[1]# cv2.imshow("frame",frame)# cv2.waitKey(0)preds = preprocess(frame.copy(), model, device)#preds = F.upsample(preds, (ori_height, ori_width), mode='bilinear')preds = preds.detach().cpu().numpy().copy()preds = np.asarray(np.argmax(preds, axis=1), dtype=np.uint8)preds = preds.astype(np.uint8).transpose((1,2,0))preds[preds==1] = 255preds[preds!=255] = 0preds = cv2.merge([preds,preds,preds])frame = cv2.resize(frame,(preds.shape[1],preds.shape[0]),interpolation=cv2.INTER_LINEAR)image_merge = cv2.addWeighted(frame, 0.5, preds, 0.5, 0)cv2.imshow("merge", image_merge)key = cv2.waitKey(1)if save:image_merge = image_merge.astype(np.uint8)r = outVideo.write(image_merge)if key == 27:  # exit on ESCif save:outVideo.release()breakif frame_count % 30 == 0:print("Frame Per second: {} fps.".format((time.time() - start_time) / frame_count))frame_count = frame_count + 1cv2.destroyAllWindows()if __name__ == "__main__":parser = argparse.ArgumentParser(description='Train segmentation network')parser.add_argument('--cfg',help='experiment configure file name',required=False,type=str)parser.add_argument('--pth_file',type=str)parser.add_argument('--image_path',type=str)parser.add_argument('--video_path',type=str)parser.add_argument('opts',help="Modify config options using the command-line",default=None,nargs=argparse.REMAINDER)device = "cuda:0"args = parser.parse_args()pth_file = args.pth_fileimage_path = args.image_pathvideo_path = args.video_path# update_config(config, args)# model = eval('models.'+config.MODEL.NAME +'.get_seg_model')(config)# model.to(device)# pretrained_dict = torch.load(pth_file, map_location=device)# model_dict = model.state_dict()# pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items()# if k[6:] in model_dict.keys()}# model_dict.update(pretrained_dict)# model.load_state_dict(model_dict)# model.eval()# # onnx_export(model, pth_file)# if image_path:# if os.path.isfile(image_path):# inference(model, pth_file ,image_path, device)# elif os.path.isdir(image_path):# inference_dir(model, pth_file ,image_path, device)# elif video_path:# inference_video(model, pth_file, video_path, device)model_jit = torch.jit.load("export_models/export_model.pt",map_location=device)# dump_input = torch.rand((1,3,473,473),device=device)# out = model_jit(dump_input)# print(out)# out2 = model(dump_input)# print(out2)if image_path:if os.path.isfile(image_path):inference(model_jit, pth_file ,image_path, device)elif os.path.isdir(image_path):inference_dir(model_jit, pth_file ,image_path, device)elif video_path:inference_video(model_jit, pth_file, video_path, device)

解决网络视频流阻塞问题

import torch
import torchvision
import argparse
from config import config
from config import update_config
from  seg_hrnet import get_seg_model
import os
from torch.nn import functional as F
import cv2
import numpy as np
import time
import threadingdef preprocess(img, model, device):def input_transform(image): image = image.astype(np.float32)[:, :, ::-1]image = image / 255.0image -= meanimage /= stdreturn imagedef image_resize(image, long_size, label=None):h, w = image.shape[:2]if h > w:new_h = long_sizenew_w = np.int(w * long_size / h + 0.5)else:new_w = long_sizenew_h = np.int(h * long_size / w + 0.5)image = cv2.resize(image, (new_w, new_h), interpolation = cv2.INTER_LINEAR)if label is not None:label = cv2.resize(label, (new_w, new_h), interpolation = cv2.INTER_NEAREST)else:return imagereturn image, labeldef pad_image(image, h, w, size, padvalue):pad_image = image.copy()pad_h = max(size[0] - h, 0)pad_w = max(size[1] - w, 0)if pad_h > 0 or pad_w > 0:pad_image = cv2.copyMakeBorder(image, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=padvalue)return pad_imagedef multi_scale_aug(image, label=None, rand_scale=1, rand_crop=True):long_size = 473if label is not None:image, label = image_resize(image, long_size, label)if rand_crop:image, label = rand_crop(image, label)return image, labelelse:image = image_resize(image, long_size)return imagedef infer(model, image):size = image.size()# start = time.time()pred = model(image)# print("inference time:",time.time()-start)pred = F.upsample(input=pred, size=(size[-2], size[-1]), mode='bilinear')  return pred.exp()mean=[0.485, 0.456, 0.406]std=[0.229, 0.224, 0.225]padvalue = -1.0  * np.array(mean) / np.array(std)crop_size = (473, 473)new_img = multi_scale_aug(img)height, width = new_img.shape[:-1]new_img = input_transform(new_img)if max(height, width) <= np.min(crop_size):#new_img = pad_image(new_img, height, width, crop_size, padvalue)new_img = new_img.transpose((2, 0, 1))new_img = np.expand_dims(new_img, axis=0)new_img = torch.from_numpy(new_img).to(device)preds = infer(model, new_img)#preds = preds[:, :, 0:height, 0:width]else:if height < crop_size[0] or width < crop_size[1]:new_img = pad_image(new_img, height, width, crop_size, padvalue)new_h, new_w = new_img.shape[:-1]rows = np.int(np.ceil(1.0 * (new_h - crop_size[0]) / stride_h)) + 1cols = np.int(np.ceil(1.0 * (new_w - crop_size[1]) / stride_w)) + 1preds = torch.zeros([1, 2, new_h, new_w]).to(device)count = torch.zeros([1, 1, new_h, new_w]).to(device)for r in range(rows):for c in range(cols):h0 = r * stride_hw0 = c * stride_wh1 = min(h0 + crop_size[0], new_h)w1 = min(w0 + crop_size[1], new_w)crop_img = new_img[h0:h1, w0:w1, :]if h1 == new_h or w1 == new_w:crop_img = ad_image(crop_img, h1-h0, w1-w0, crop_size, padvalue)crop_img = crop_img.transpose((2, 0, 1))crop_img = np.expand_dims(crop_img, axis=0)crop_img = torch.from_numpy(crop_img).to(device)preds = infer(model, crop_img)preds[:,:,h0:h1,w0:w1] += pred[:,:, 0:h1-h0, 0:w1-w0]count[:,:,h0:h1,w0:w1] += 1preds = preds / countpreds = preds[:,:,:height,:width]return predsdef inference(model, pth_file, image_path, device):image = cv2.imread(image_path)ori_height, ori_width = image.shape[0], image.shape[1]preds = preprocess(image.copy(), model, device)#preds = F.upsample(preds, (ori_height, ori_width), mode='bilinear')preds = preds.detach().cpu().numpy().copy()preds = np.asarray(np.argmax(preds, axis=1), dtype=np.uint8)preds = preds.astype(np.uint8).transpose((1,2,0))preds[preds==1] = 255preds[preds!=255] = 0preds = cv2.merge([preds,preds,preds])image = cv2.resize(image,(preds.shape[1],preds.shape[0]),interpolation=cv2.INTER_LINEAR)image_merge = cv2.addWeighted(image, 0.5, preds, 0.5, 0)cv2.imshow(image_path, image_merge)cv2.waitKey(0)def inference_dir(model, pth_file, image_path, device):for name in os.listdir(image_path):image = cv2.imread(os.path.join(image_path, name))ori_height, ori_width = image.shape[0], image.shape[1]preds = preprocess(image.copy(), model, device)# preds = F.upsample(preds, (ori_height, ori_width), mode='bilinear')preds = preds.detach().cpu().numpy().copy()preds = np.asarray(np.argmax(preds, axis=1), dtype=np.uint8)preds = preds.astype(np.uint8).transpose((1,2,0))preds[preds==1] = 255preds[preds!=255] = 0preds = cv2.merge([preds,preds,preds])image = cv2.resize(image,(preds.shape[1],preds.shape[0]),interpolation=cv2.INTER_LINEAR)image_merge = cv2.addWeighted(image, 0.5, preds, 0.5, 0)cv2.imshow(name,image_merge)key = cv2.waitKey(0)cv2.destroyWindow(name)if key == 27:breakdef inference_video(model, pth_file, video_path, device, save=False):if video_path.endswith(".mp4"):vc = cv2.VideoCapture(video_path)else:vc = cv2.VideoCapture(0)if vc.isOpened():rval, frame = vc.read()else:rval = Falsestart_time = time.time()frame_count = 1if save:fps = vc.get(cv2.CAP_PROP_FPS) width, height = frame.shape[1], frame.shape[0]resize_ratio = 1.0 * 473 / max(width, height)target_size = (int(resize_ratio * width), int(resize_ratio * height))fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')#mp4voutVideo = cv2.VideoWriter(save, fourcc,fps,target_size)while rval:rval, frame = vc.read()if rval == False:if save:outVideo.release()breakori_height, ori_width = frame.shape[0], frame.shape[1]# cv2.imshow("frame",frame)# cv2.waitKey(0)preds = preprocess(frame.copy(), model, device)#preds = F.upsample(preds, (ori_height, ori_width), mode='bilinear')preds = preds.detach().cpu().numpy().copy()preds = np.asarray(np.argmax(preds, axis=1), dtype=np.uint8)preds = preds.astype(np.uint8).transpose((1,2,0))preds[preds==1] = 255preds[preds!=255] = 0preds = cv2.merge([preds,preds,preds])frame = cv2.resize(frame,(preds.shape[1],preds.shape[0]),interpolation=cv2.INTER_LINEAR)image_merge = cv2.addWeighted(frame, 0.5, preds, 0.5, 0)cv2.imshow("merge", image_merge)key = cv2.waitKey(1)if save:image_merge = image_merge.astype(np.uint8)r = outVideo.write(image_merge)if key == 27:  # exit on ESCif save:outVideo.release()breakif frame_count % 30 == 0:print("Frame Per second: {} fps.".format((time.time() - start_time) / frame_count))frame_count = frame_count + 1cv2.destroyAllWindows()class Stack:def __init__(self, stack_size):self.items = []self.stack_size = stack_sizeself.flag = Truedef is_empty(self):return len(self.items) == 0def pop(self):return self.items.pop()def peek(self):if not self.isEmpty():return self.items[len(self.items) - 1]def size(self):return len(self.items)def push(self, item):if self.size() >= self.stack_size:for i in range(self.size() - self.stack_size + 1):self.items.remove(self.items[0])self.items.append(item)def end(self):self.flag = Falsedef capture_thread(video_path, frame_buffer, lock):print("capture_thread start")vid = cv2.VideoCapture(video_path)if not vid.isOpened():raise IOError("Couldn't open webcam or video")while True:return_value, frame = vid.read()if return_value is not True or frame_buffer.flag is not True:breaklock.acquire()frame_buffer.push(frame)lock.release()def play_thread(frame_buffer, lock, model):print("detect_thread start")print("detect_thread frame_buffer size is", frame_buffer.size())while True:if frame_buffer.size() > 0:lock.acquire()frame = frame_buffer.pop()lock.release()# 算法ori_height, ori_width = frame.shape[0], frame.shape[1]preds = preprocess(frame.copy(), model, device)#preds = F.upsample(preds, (ori_height, ori_width), mode='bilinear')preds = preds.detach().cpu().numpy().copy()preds = np.asarray(np.argmax(preds, axis=1), dtype=np.uint8)preds = preds.astype(np.uint8).transpose((1,2,0))preds[preds==1] = 255preds[preds!=255] = 0preds = cv2.merge([preds,preds,preds])frame = cv2.resize(frame,(preds.shape[1],preds.shape[0]),interpolation=cv2.INTER_LINEAR)image_merge = cv2.addWeighted(frame, 0.5, preds, 0.5, 0)cv2.imshow("merge", image_merge)key = cv2.waitKey(1)key = cv2.waitKey(1)if key == 27:  # exit on ESCframe_buffer.end()breakif __name__ == "__main__":parser = argparse.ArgumentParser(description='Train segmentation network')parser.add_argument('--cfg',help='experiment configure file name',default = "8dataset_custom_seg_hrnetv1_w18_473x473_sgd_lr7e-3_wd5e-4_bs_32_epoch100.yaml",required=False,type=str)parser.add_argument('--pth_file',type=str,default="8dataset_custom_seg_hrnetv1_w18_473x473_sgd_lr7e-3_wd5e-4_bs_32_epoch100.pth")parser.add_argument('--video_path',type=str)parser.add_argument('opts',help="Modify config options using the command-line",default=None,nargs=argparse.REMAINDER)device = "cuda:0"args = parser.parse_args()pth_file = args.pth_filevideo_path = args.video_pathupdate_config(config, args)model = get_seg_model(config)model.to(device)pretrained_dict = torch.load(pth_file, map_location=device)model_dict = model.state_dict()pretrained_dict = {
    k[6:]: v for k, v in pretrained_dict.items()if k[6:] in model_dict.keys()}model_dict.update(pretrained_dict)model.load_state_dict(model_dict)model.eval()frame_buffer = Stack(3)lock = threading.RLock()t1 = threading.Thread(target=capture_thread, args=(video_path, frame_buffer, lock))t1.start()t2 = threading.Thread(target=play_thread, args=(frame_buffer, lock, model))t2.start()

imgviz可视化

def vis(lbl, img):if len(img.shape) == 2:img = cv2.merge([img,img,img])viz = imgviz.label2rgb(label=lbl,img=imgviz.rgb2gray(img),font_size=15,loc="rb",)return viz
  相关解决方案