pytorch 绘制模型的网络结构有很多中方法,个人比较喜欢 torchview 生成的 Graphviz 风格的图片。
Graphviz介绍
Graphviz是一款开源的图形可视化软件,其名称来源于“Graph Visualization Software”的缩写。它通过使用一种名为DOT的描述语言来定义图形,能够自动生成复杂的图形和网络,并具备多种功能和特点。它通过自动布局和灵活的渲染能力,帮助用户快速生成和定制各种复杂的图形和网络。无论是软件开发、网络设计还是数据分析等领域,Graphviz都能提供有效的可视化解决方案。Graphviz
安装也比较简单,参考:
mert-kurttutan/torchview: torchview: visualize pytorch models (github.com)https://github.com/mert-kurttutan/torchview
为了让graphviz的python接口工作,你需要在你的系统中运行dot layout命令。如果它还没有安装,我建议你在你的操作系统上运行以下命令。
安装 graphviz
以Windows系统为例,以管理员权限打开 PowerShell,然后运行该命令
choco install graphviz
安装pip包
接下来可以切到你的pytorch conda环境下,安装pip包
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple graphviz
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple torchview
绘制 resent18 网络结构
以resent18来实验绘制网络结构
from torchvision.models import resnet18
from torchview import draw_graph# 将图表保存到本地
model_graph = draw_graph(resnet18(), input_size=(1,3,32,32), expand_nested=True, save_graph = True, filename = "resnet18"))
# 可视化显示
model_graph.visual_graph
将展示出resent18的网络结构
绘制 depth_anything_v2 vits 网络结构
depth_anything_v2 的安装这里就不讲了,重点看 网络结构的绘制代码
import torch
from torchview import draw_graph
from depth_anything_v2.dpt import DepthAnythingV2DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'model_configs = {'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
}encoder = 'vits' # or 'vitl', 'vitb', 'vitg'model = DepthAnythingV2(**model_configs[encoder])
model.load_state_dict(torch.load(f'checkpoints/depth_anything_v2_{encoder}.pth', map_location='cpu'))# 绘制并保存模型架构图
model_graph = draw_graph(model, input_size=(1,3,518,518), expand_nested=True, save_graph = True, filename = "vits")
展示的这个便是 vits 的网络结构图