深度学习中预训练模型库非常重要,它能够帮助我们非常方便的获取到模型的结构、模型的权重文件等,这大大降低了入门深度学习的门槛,如高性能的硬件设备(服务器、GPU),同时使用迁移学习的思想能够大大缩短我们开发可实用模型的时间。常见的预训练模型库包含有 torchvision.models(CV 模型)、transformers(CV 和 NLP 大模型相关)、timm(包含 CV 领域小模型和大模型,开发公司同 transformers 的 hugging face)。其中 timm 非常方便我们查看模型结构,同时可加载预训练的模型权重,且支持的模型比较多。本篇介绍 timm。
安装
timm 的官方网址是:https://github.com/huggingface/pytorch-image-models
使用
查看预训练模型
1 2 3 4
| import timm
avail_pretrained_models = timm.list_models(pretrained=True) len(avail_pretrained_models)
|
查看某类模型
1 2
| all_densnet_models = timm.list_models("*sam*") all_densnet_models
|
1 2 3 4
| ['samvit_base_patch16', 'samvit_base_patch16_224', 'samvit_huge_patch16', 'samvit_large_patch16']
|
查看模型结构
1
| timm.create_model(model_name="vit_huge_patch14_224")
|
加载模型
1 2
| model = timm.create_model("samvit_huge_patch16", pretrained=True) model.default_cfg
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| {'url': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', 'hf_hub_id': 'timm/samvit_huge_patch16.sa1b', 'architecture': 'samvit_huge_patch16', 'tag': 'sa1b', 'custom_load': False, 'input_size': (3, 1024, 1024), 'fixed_input_size': True, 'interpolation': 'bicubic', 'crop_pct': 1.0, 'crop_mode': 'center', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'num_classes': 0, 'pool_size': None, 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc', 'license': 'apache-2.0'}
|
保存模型权重
1 2 3 4 5 6 7 8 9 10
| import torch
torch.save(model.state_dict(),'./timm_model-state_dict.pth')
torch.save(model.state_dict(),'./timm_model.pth')
model.load_state_dict(torch.load('./timm_model-state_dict.pth'))
|
可视化模型结构
当我们有模型权重文件(*.pth)后,我们可以使用 netron 来可视化模型结构,更加直观。
netron 网址为:https://netron.app/
参考文献
- 6.3 模型微调 - timm
- 视觉 Transformer 优秀开源工作:timm 库 vision transformer 代码解读
- timm——pytorch下的迁移学习模型库·详细使用教程