Bob's Blog

Web开发、测试框架、自动化平台、APP开发、机器学习等

返回上页首页

YOLOV5中使用pytorch的state dict来加载模型参数



前些时候在使用yolov5训练出模型但遇到了加载时提示no module named models的问题,当时用取巧的方式绕过了这个问题,让训练出的模型能在自己的项目中被成功加载,但当时的解决方式只是个临时的,以后当目录结构有变化时容易导致继续修改,于是看了yolov5的代码和pytorch的官方链接。

https://www.byincd.com/bobjiang/article-01214/

https://pytorch.org/tutorials/beginner/saving_loading_models.html

pytorch官方推荐的是使用state dict。

除了在训练完成时直接保存为state dict,还可以加载用save保存的模型再保存为state dict。

import torch

m = torch.load("path/yolov5.pt")
torch.save(m["model"].state_dict(), "path/state.pth")

按照yolov5中train.py的参照可以这样加载,还好不用额外再写一份model class。但是必须用到自己项目的参数和配置定义,如下只是个例子。实际修改会比这个多,并不通用。

import torch
from models.yolo import Model

model = Model("path/ultralytics_yolov5/models/yolov5s.yaml", 3, 12, None).to("cpu")
model.load_state_dict(torch.load("path/state.pth"))
model.eval()

在attempt_load中的:

ckpt = torch.load(attempt_download(w), map_location=map_location)
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval())

如果此时加载的是state dict的文件内容,需要换成或加上条件分支:

ckpt = Model("path/yolov5s.yaml", 3, 12, None).to("cpu")
ckpt.load_state_dict(torch.load(weights, map_location=map_location))
ckpt.eval()
model.append(ckpt)

如果用的是自定义的训练配置,则需要指定到对应的yaml文件或者dict内容, 否则会得到非常多的size mismatch或missing key in state_dict的错误。

这里还有一个坑,是因为state dict未加上训练类别名,因此需要额外指定一下,否则能识别但无法给出具体的类别

ckpt.names = ['xx', 'xxx'...]

此时再度加载时就能跟使用pt文件一样的效果。解决了目录层次绑定的问题。

下一篇:  微信小程序开发(二)增加底部菜单
上一篇:  Selenium做自动化时截取浏览器全屏页面

共有3条评论

添加评论

ethan
2023年5月26日 19:02
感谢分享,间接解决了困扰我两天的大bug
wfy
2023年2月20日 15:56
感谢,解决了我的问题,右下角的猫咪我很喜欢,是用了插件还是什么吗,可以介绍一下吗
zj
2022年9月20日 21:10
可以展示更多的代码嘛