Bob's Blog

Web开发、测试框架、自动化平台、APP开发、机器学习等

返回上页首页

YOLOV5根据屏幕图像实时给出预测结果



目前在用yolov5,在调试时需要检验当前模型是否匹配出了期望的实体类型。因为目前需要匹配的就是当前计算机显示出的各类图像,于是在想做一个实时的检测屏幕图像并给出预测结果,就像之前文章中的根据摄像头实时给出预测一样。(不过有个问题,记录在后面)

除了yolov5需要的pytorch环境外,还需要安装一个新的库:

pip install mss

先试试能不能实时展示屏幕图像,能随着我切换应用也能加载对应的实时图像:

import cv2
import numpy as np
from mss import mss
bounding_box = {'top': 100, 'left': 300, 'width': 640, 'height': 480}
sct = mss()

while True:
    sct_img = sct.grab(bounding_box)
    scr_img = np.array(sct_img)
    cv2.imshow("Screen Realtime", scr_img)
    if (cv2.waitKey(1) & 0xFF) == ord('q'):
        cv2.destroyAllWindows()
        break

此时能看到一个框显示的是实时的屏幕图像,随着我的拖动窗口而变化。

好了,接着试下torch.hub.load来加载模型处理后的图片。还是以之前训练的那个血细胞检测的模型为例,图集可以在此自行下载:https://public.roboflow.com/object-detection/bccd

import torch
from PIL import Image

model = torch.hub.load('./', 'custom', path='path/ultralytics_yolov5/models/blood.pt', source='local')
model.autoshape()


img1 = Image.open("path/BCCD.v4-416x416_aug.yolov5pytorch/train/images/BloodImage_00001_jpg.rf.1a3206b15602db1d97193162a50bd001.jpg")
imgs = [img1]

results = model(imgs)
results.print()
results.show()

也是能弹出一个框,展示的是预测后的血细胞图像。

其中,torch.hub.load这里,第一个'./'是代表当前就是yolov5目录,如果给,第二个代表自定义,第三个是自定义模型的路径,否则会自动寻找和下载容易抛错。而这一句torch.hub.load('ultralytics/yolov5', 'yolov5s')这代表clone代码并下载预设的模型文件。

好了,到此把两者再集合在一起:

import torch
import cv2
import numpy as np
from mss import mss

model = torch.hub.load('./', 'custom', path='path/ultralytics_yolov5/models/blood.pt', source='local')
model.autoshape()

bounding_box = {'top': 340, 'left': 800, 'width': 640, 'height': 400}
sct = mss()

while True:
    sct_img = sct.grab(bounding_box)
    scr_img = np.array(sct_img)
    scr_img = model(scr_img)
    cv2.imshow("Screen Realtime", np.array(scr_img.render())[0])
    if (cv2.waitKey(1) & 0xFF) == ord('q'):
        cv2.destroyAllWindows()
        break

此时当我打开不同的血细胞图片时,始终有个窗口能实时的截取并预测出结果,我只需要打开图片就能看到不同的匹配结果,也知道当前模型是否满足需求,比较方便了。

不过有一个问题,如果不是按照默认的yolov5的目录结构来训练的话,则会抛出NotImplementedError之类的错误。错误提示大致如下:

Traceback (most recent call last):
  File "realtime.py", line 154, in <module>
    scr_img = model(scr_img)
  File "/workspace/yolo/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/workspace/yolo/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 201, in _forward_unimplemented
    raise NotImplementedError
NotImplementedError

比如我现在是为了配合另一个项目而打的一个python包,训练时则是按照包import的目录结构来的,用上面的代码片段就会抛错。

此时不用torch.hub.load,可以换做下面的代码,则解决了这个问题。

import torch
import cv2
import time
import numpy as np
from mss import mss
from models.experimental import attempt_load
from utils.datasets import LoadImages
from utils.general import check_img_size, increment_path, non_max_suppression, scale_coords
from utils.torch_utils import select_device, load_classifier, time_synchronized
from utils.plots import colors, plot_one_box


weights_file = "path/self_customized_trained.pt"
conf_thres = 0.33
iou_thres = 0.45
classes = None
agnostic_nms = None
max_det = 1000
line_thickness = 3

bounding_box = {'top': 340, 'left': 800, 'width': 640, 'height': 400}
sct = mss()

device = torch.device("cpu")
model = attempt_load(weights_file, map_location=device)
stride = int(model.stride.max())
imgsz = check_img_size(640, s=stride)
names = model.module.names if hasattr(model, 'module') else model.names

while True:
    sct_img = sct.grab(bounding_box)
    scr_img = np.array(sct_img)
    img_file = "./screen.png"
    cv2.imwrite(img_file, scr_img)

    dataset = LoadImages(img_file, img_size=imgsz, stride=stride)
    bs = 1
    vid_path, vid_writer = [None] * bs, [None] * bs
    t0 = time.time()
    for path, img, im0s, vid_cap in dataset:
        img = torch.from_numpy(img).to(device)
        img = img.float()
        img /= 255.0
        if img.ndimension() == 3:
            img = img.unsqueeze(0)
        t1 = time_synchronized()
        pred = model(img, augment=False, visualize=False)[0]
        pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
        t2 = time_synchronized()
        for i, det in enumerate(pred):
            p, s, im0, frame = path, '', im0s.copy(), getattr(dataset, 'frame', 0)
            if len(det):
                det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
                for *xyxy, conf, cls in reversed(det):
                    c = int(cls)
                    label = f'{names[c]} {conf:.2f}'
                    plot_one_box(xyxy, im0, label=label, color=colors(c, True), line_thickness=line_thickness)
            cv2.imshow(str(p), im0)
            if (cv2.waitKey(1) & 0xFF) == ord('q'):
                cv2.destroyAllWindows()
                break

这里挺像个草稿的,不过对于自定义目录结构训练的模型,用这个是可以加载并且实时的显示出预测效果的。部分的代码是来源于yolov5本身。

下一篇:  Pytesseract配置以检测非英语
上一篇:  使用gunicorn+gevent+nginx

共有3条评论

添加评论

一本正经
2023年3月12日 21:49
很好,终于解决问题
贝 姓
2023年1月16日 11:40
很有用 感谢
Lyz
2022年11月9日 18:49
Good