YOLOV5根据屏幕图像实时给出预测结果
目前在用yolov5,在调试时需要检验当前模型是否匹配出了期望的实体类型。因为目前需要匹配的就是当前计算机显示出的各类图像,于是在想做一个实时的检测屏幕图像并给出预测结果,就像之前文章中的根据摄像头实时给出预测一样。(不过有个问题,记录在后面)
除了yolov5需要的pytorch环境外,还需要安装一个新的库:
pip install mss
先试试能不能实时展示屏幕图像,能随着我切换应用也能加载对应的实时图像:
import cv2
import numpy as np
from mss import mss
bounding_box = {'top': 100, 'left': 300, 'width': 640, 'height': 480}
sct = mss()
while True:
sct_img = sct.grab(bounding_box)
scr_img = np.array(sct_img)
cv2.imshow("Screen Realtime", scr_img)
if (cv2.waitKey(1) & 0xFF) == ord('q'):
cv2.destroyAllWindows()
break
此时能看到一个框显示的是实时的屏幕图像,随着我的拖动窗口而变化。
好了,接着试下torch.hub.load来加载模型处理后的图片。还是以之前训练的那个血细胞检测的模型为例,图集可以在此自行下载:https://public.roboflow.com/object-detection/bccd:
import torch
from PIL import Image
model = torch.hub.load('./', 'custom', path='path/ultralytics_yolov5/models/blood.pt', source='local')
model.autoshape()
img1 = Image.open("path/BCCD.v4-416x416_aug.yolov5pytorch/train/images/BloodImage_00001_jpg.rf.1a3206b15602db1d97193162a50bd001.jpg")
imgs = [img1]
results = model(imgs)
results.print()
results.show()
也是能弹出一个框,展示的是预测后的血细胞图像。
其中,torch.hub.load这里,第一个'./'是代表当前就是yolov5目录,如果给,第二个代表自定义,第三个是自定义模型的路径,否则会自动寻找和下载容易抛错。而这一句torch.hub.load('ultralytics/yolov5', 'yolov5s')这代表clone代码并下载预设的模型文件。
好了,到此把两者再集合在一起:
import torch
import cv2
import numpy as np
from mss import mss
model = torch.hub.load('./', 'custom', path='path/ultralytics_yolov5/models/blood.pt', source='local')
model.autoshape()
bounding_box = {'top': 340, 'left': 800, 'width': 640, 'height': 400}
sct = mss()
while True:
sct_img = sct.grab(bounding_box)
scr_img = np.array(sct_img)
scr_img = model(scr_img)
cv2.imshow("Screen Realtime", np.array(scr_img.render())[0])
if (cv2.waitKey(1) & 0xFF) == ord('q'):
cv2.destroyAllWindows()
break
此时当我打开不同的血细胞图片时,始终有个窗口能实时的截取并预测出结果,我只需要打开图片就能看到不同的匹配结果,也知道当前模型是否满足需求,比较方便了。
不过有一个问题,如果不是按照默认的yolov5的目录结构来训练的话,则会抛出NotImplementedError之类的错误。错误提示大致如下:
Traceback (most recent call last):
File "realtime.py", line 154, in <module>
scr_img = model(scr_img)
File "/workspace/yolo/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/workspace/yolo/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 201, in _forward_unimplemented
raise NotImplementedError
NotImplementedError
比如我现在是为了配合另一个项目而打的一个python包,训练时则是按照包import的目录结构来的,用上面的代码片段就会抛错。
此时不用torch.hub.load,可以换做下面的代码,则解决了这个问题。
import torch
import cv2
import time
import numpy as np
from mss import mss
from models.experimental import attempt_load
from utils.datasets import LoadImages
from utils.general import check_img_size, increment_path, non_max_suppression, scale_coords
from utils.torch_utils import select_device, load_classifier, time_synchronized
from utils.plots import colors, plot_one_box
weights_file = "path/self_customized_trained.pt"
conf_thres = 0.33
iou_thres = 0.45
classes = None
agnostic_nms = None
max_det = 1000
line_thickness = 3
bounding_box = {'top': 340, 'left': 800, 'width': 640, 'height': 400}
sct = mss()
device = torch.device("cpu")
model = attempt_load(weights_file, map_location=device)
stride = int(model.stride.max())
imgsz = check_img_size(640, s=stride)
names = model.module.names if hasattr(model, 'module') else model.names
while True:
sct_img = sct.grab(bounding_box)
scr_img = np.array(sct_img)
img_file = "./screen.png"
cv2.imwrite(img_file, scr_img)
dataset = LoadImages(img_file, img_size=imgsz, stride=stride)
bs = 1
vid_path, vid_writer = [None] * bs, [None] * bs
t0 = time.time()
for path, img, im0s, vid_cap in dataset:
img = torch.from_numpy(img).to(device)
img = img.float()
img /= 255.0
if img.ndimension() == 3:
img = img.unsqueeze(0)
t1 = time_synchronized()
pred = model(img, augment=False, visualize=False)[0]
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
t2 = time_synchronized()
for i, det in enumerate(pred):
p, s, im0, frame = path, '', im0s.copy(), getattr(dataset, 'frame', 0)
if len(det):
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
for *xyxy, conf, cls in reversed(det):
c = int(cls)
label = f'{names[c]} {conf:.2f}'
plot_one_box(xyxy, im0, label=label, color=colors(c, True), line_thickness=line_thickness)
cv2.imshow(str(p), im0)
if (cv2.waitKey(1) & 0xFF) == ord('q'):
cv2.destroyAllWindows()
break
这里挺像个草稿的,不过对于自定义目录结构训练的模型,用这个是可以加载并且实时的显示出预测效果的。部分的代码是来源于yolov5本身。
共有3条评论
添加评论
一本正经
2023年3月12日 21:49很好,终于解决问题
贝 姓
2023年1月16日 11:40很有用 感谢
Lyz
2022年11月9日 18:49Good