Bob's Blog

Web开发、测试框架、自动化平台、APP开发、机器学习等

返回上页首页

pytorch载入模型出现no module named models的解决办法



我目前在使用yolov5,它本身是用了pytorch。我在训练了数据集后,就把pt模型文件拷贝到其他项目中去了。可是在加载该pt文件时,却出现了类似'no module named models'的错误提示。在翻阅了资料后找到了原因和解决的办法之一。

假设训练的模型叫做my_model, yolov5是用了torch.save(my_model, path), 而yolov5本身和自己尝试的都是用的torch.load(path); 而这样的save会把当前目录结构以及py文件class都写入模型中保存下来,于是当把pt文件迁移到其他项目中使用,而其他项目的关于模型相关的目录结构有所变化,就会报no module named models的错误了。错误信息来自serialization.py类似如下:

Traceback (most recent call last):
  ......
    ckpt = torch.load(weights, map_location=map_location)
  File "/venv/lib/python3.8/site-packages/torch/serialization.py", line 607, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
  File "/venv/lib/python3.8/site-packages/torch/serialization.py", line 882, in _load
    result = unpickler.load()
  File "/venv/lib/python3.8/site-packages/torch/serialization.py", line 875, in find_class
    return super().find_class(mod_name, name)
ModuleNotFoundError: No module named 'models'

假如是把yolov5的models的目录里几个py文件都放在models目录下且models目录在根目录下,也不会报错。但这样有局限性,往往不能满足实际需求。

按照pytorch的官方文档,推荐的方式是将state_dict保存下来,state_dict相当于训练结果中的权重和各种参数,这样在加载时不会受到目录和class名等等的限制。

保存时用:

torch.save(my_model.state_dict(), PATH)

加载时用:

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

但这里有一个TheModelClass,是需要自己定义的,相当于把训练时用到的yaml文件里的anchor、backbone、head等都用python的语法创建一个类并声明相关属性,这个对于非专业人士来说就花更多时间导致更多错误了。

于是我用另一个方式解决了当前需要:

按照另一个项目中的目录结构来,在yolov5中新增相关的目录结构,并把yolov5中的models里的文件复制过去,可以不用在根目录下,也可以修改相关的文件名,只是需要记得在train.py和test.py中修改from models.xx import xxx里的models改成新的路径和文件,新目录里的py文件的互相调用的目录也需要调整一下,然后重新训练,新的pt文件便可用于新项目了。虽然不算好的方式,但可以快速的解决当前问题并继续后续的调试。如果有了好的方法,我再来更新。

update:

用官方的state dict完整解决了该问题:YOLOV5中使用pytorch的state dict来加载模型参数

参考链接:

https://github.com/pytorch/pytorch/issues/18325

https://pytorch.org/tutorials/beginner/saving_loading_models.html

https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py

https://github.com/ultralytics/yolov5/issues/1680

下一篇:  网页WCAG的检查以及工具尝试
上一篇:  个人服务器降低mysql内存占用

共有0条评论

添加评论

暂无评论