使用PyTorch导出JIT模型:C++ API与libtorch实战

PyTorch导出JIT模型并用C++ API libtorch调用

本文将介绍如何将一个 PyTorch 模型导出为 JIT 模型并用 PyTorch 的 C++API libtorch运行这个模型。

Step1:导出模型

首先我们进行第一步,用 Python API 来导出模型,由于本文的重点是在后面的部署阶段,因此,模型的训练就不进行了,直接对 torchvision 中自带的 ResNet50 进行导出。在实际应用中,大家可以对自己训练好的模型进行导出。

# export_jit_model.py
import torch
import torchvision.models as modelsmodel = models.resnet50(pretrained=True)
model.eval()example_input = torch.rand(1, 3, 224, 224)jit_model = torch.jit.trace(model, example_input)
torch.jit.save(jit_model, 'resnet50_jit.pth')

导出 JIT 模型的方式有两种:trace 和 script。

我们采用
torch.jit.trace
的方式来导出 JIT 模型,这种方式会根据一个输入将模型跑一遍,然后记录下执行过程。这种方式的问题在于对于有分支判断的模型不能很好的应对,因为一个输入不能覆盖到所有的分支。但是在我们 ResNet50 模型中不会遇到分支判断,因此这里是合适的。关于两种导出 JIT 模型的方式各自优劣不是本文的中断,以后会再写一篇来分析。

在我们的工程目录
demo
下运行上面的
export_jit_model.py
,会得到一个 JIT 模型件:
resnet50_jit.pth

Step 2:安装libtorch

接下来我们要安装 PyTorch 的 C++ API:libtorch。这一步很简单,直接下载官方预编译的文件并解压即可:

wget https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip
unzip libtorch-shared-with-deps-latest.zip

也解压在我们的工程目录
demo
下即可。

Step 3:安装OpenCV

用 Python 或 C++ 做图像任务,OpenCV 是经常用到的。如果还没有安装的读者可以参考如下在工程目录
demo
下进行安装,构建的过程可能会比较久。已经安装的读者可跳过此步骤,一会儿在
CMakeLists.txt
文件中正确地指定本机的 OpenCV 地址即可。

git clone --branch 3.4 --depth 1 https://github.com/opencv/opencv.git
mkdir demo/build && cd demo/build
cmake ..
make -j 6

Step 4:准备测试图像并用Python测试

我们先准备一张小猫的图像,并用 PyTorch ResNet50 模型正常跑一下,一会儿与我们 C++ 模型运行的结果对比来验证 C++ 模型是否被正确的部署。

kitten.jpg

写一个脚本用 PyTorch 运行一下模型:

# pytorch_test.pyimport torchvision.models as models
from torchvision.transforms import transforms
import torch
from PIL import Image# normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
all_transforms = transforms.Compose([transforms.Resize(224),transforms.ToTensor()])# normalize])model = models.resnet50(pretrained=True)
model.eval()img = Image.open('kitten.jpg').convert('RGB')
img_tensor = all_transforms(img).unsqueeze(dim=0)
pred = model(img_tensor).squeeze(dim=0)
print(torch.argmax(pred).item())

输出结果是:282。通过查看
ImageNet 1K 类别名与索引的对应关系
,可以看到,结果为 tiger cat,模型预测正确。一会儿我们看一下部署后的 C++ 模型是否能正确输出结果 282。

Step 5:准备cpp源文件

我们下面准备一会要执行的 cpp 源文件,第一次使用 libtorch 的读者可以先借鉴下面的文件。

这里有几个点要说一下,不注意可能会犯错:

  1. cv::imread()
    默认读取为三通道BGR,需要进行B/R通道交换,这里采用
    cv::cvtColor()
    实现。
  2. 图像尺寸需要调整到

224

×

224

224\times 224

2

2

4

×

2

2

4

,通过
cv::resize()
实现。
3. opencv读取的图像矩阵存储形式:H x W x C, 但是pytorch中 Tensor的存储为:N x C x H x W, 因此需要进行变换,就是
np.transpose()
操作,这里使用
tensor.permut()
实现,效果是一样的。
4. 数据归一化,采用
tensor.div(255)
实现。

// test_model.cpp
#include <vector>#include <torch/torch.h>
#include <torch/script.h>#include <opencv2/core.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/highgui/highgui.hpp>int main(int argc, char* argv[]) {// 加载JIT模型auto module = torch::jit::load(argv[1]);// 加载图像auto image = cv::imread(argv[2], cv::ImreadModes::IMREAD_COLOR);cv::Mat image_transfomed;cv::resize(image, image_transfomed, cv::Size(224, 224));cv::cvtColor(image_transfomed, image_transfomed, cv::COLOR_BGR2RGB);// 图像转换为Tensortorch::Tensor tensor_image = torch::from_blob(image_transfomed.data, {image_transfomed.rows, image_transfomed.cols, 3},torch::kByte);tensor_image = tensor_image.permute({2, 0, 1});// tensor_image = tensor_image.toType(torch::kFloat);tensor_image = tensor_image.div(255.);// tensor_image = tensor_image.sub(0.5);// tensor_image = tensor_image.div(0.5);tensor_image = tensor_image.unsqueeze(0);// 运行模型torch::Tensor output = module.forward({tensor_image}).toTensor();// 结果处理int result = output.argmax().item<int>();std::cout << "The classifiction index is: " << result << std::endl;return 0;
}

Step 6:构建运行验证

我们先来写一下
CMakeLists.txt

cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(resnet50)find_package(Torch REQUIRED PATHS ./libtorch)
find_package(OpenCV REQUIRED)add_executable(resnet50  test_model.cpp)
target_link_libraries(resnet50 "${TORCH_LIBRARIES}" "${OpenCV_LIBS}")set_property(TARGET resnet50  PROPERTY CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED TRUE)set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")

现在我们的工程目录
demo
下有以下文件:

CMakeLists.txt  export_jit_model.py  kitten.jpg  libtorch  pytorch_test.py  resnet50_jit.pth  test_model.cpp

然后开始用 CMake 构建工程:

mkdir build && cd build
OpenCV_DIR=[YOUR_PATH_TO_OPENCV]/opencv/build cmake ..
make

整个过程没有报错的话我们就已经构建完成了,会得到一个可执行文件
resnet50
在工程目录
demo
下。

接下来我们执行,并验证运行结果是否与 PyTorch 的结果一致:

./build/resnet50 resnet50_jit.pth kitten.jpg

输出:

The classifiction index is: 282

运行成功并且结果正确。

Ref:

https://www.jianshu.com/p/7cddc09ca7a4

https://blog.csdn.net/cxx654/article/details/115916275

https://zhuanlan.zhihu.com/p/370455320

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3268029.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

Java-----栈

目录 1.栈&#xff08;Stack&#xff09; 1.1概念 1.2栈的使用 1.3栈的模拟实现 1.4栈的应用场景 1.5栈、虚拟机栈、栈帧有什么区别呢 1.栈&#xff08;Stack&#xff09; 1.1概念 栈&#xff1a;一种特殊的线性表&#xff0c;其只允许在固定的一端进行插入和删除元素操…

Java基础入门14:常用API(Object(s)类、包装类、Math、Arrays、日期时间、Lambda表达式、方法引用)

Object类 Object类是Java中所有类的祖宗类&#xff0c;因此&#xff0c;Java中所有类的对象都可以直接使用Object类中提供的一些方法。 Object类的常见方法&#xff1a; package com.itchinajie.d12_api_object;public class Test {public static void main(String[] args) {…

mybatis-plus默认字段填充以及批量数据插入优化

日常开发中&#xff0c;我们需要设置一些数据库的默认字段填充&#xff0c;比如创建时间、创建人、更新时间、更新人等等&#xff0c;那么mybatis-plus给我们提供了一个这样的接口去做这件事情 MetaObjectHandler。 1、首先可以创建一个实现类来实现MetaObjectHandler接口 C…

DBMotion x Chat2DB:高效迁移,优雅同步,数据腾飞不再愁

DBMotion 基本介绍 数据传输服务DBMotion是一款轻量、绿色的数据库迁移、同步、校验工具。支持国产化数据迁移、支持容灾演练、支持两地三中心和异地多活&#xff1b;源库无感知、简单易集成、丝滑高性能。助您在多云之间随心迁移、自由容灾。 功能介绍 已支持的数据库 v1.…

7月22日JavaSE学习笔记

Collection接口&#xff0c;还有一个父级接口Iterable可迭代的 Collection继承树 Set 集合 Set的底层是用Map实现&#xff08;存储在key中&#xff0c;value中是空的Object对象&#xff09; 有序&#xff1a;取出的顺序和添加的顺序是一样的。 List是有序的&#xff0c;Set是…

“软件质量”,构筑企业值得信赖的护城河

引子 质量是产品的生命线&#xff0c;质量问题不仅会导致企业财产损失&#xff0c;还可能引发业务中断、客户满意度下降、企业品牌声誉受损等负面影响。如何在软件开发过程中全方位构建产品质量防护盾&#xff0c;是各行业保障产品高质量的重要课题。 如何保障软件质量&#…

五、SpringIoC/DI的使用

1. 类注解、方法注解 告诉spring管理bean—>bean的存储 1、类注解&#xff1a;五大注解 Controller&#xff08;控制器存储&#xff09;、 Service&#xff08;服务存储&#xff09;、 Repository&#xff08;仓库存储&#xff09;、 Component&#xff08;组件存储&#xf…

【Linux】管道通信和 system V 通信

文章目录 一、进程通信原理&#xff08;让不同进程看到同一份资源&#xff09;二、管道通信2.1 管道原理及其特点2.1 匿名管道和命名管道 三、共享内存通信3.1 共享内存原理3.2 创建和关联共享内存3.3 去关联、ipc 指令和删除共享内存 四、消息队列和信号量&#xff08;了解&am…

VirtualSurveyor9.0.3 无人机测绘软件功能介绍

Virtual Surveyor9.0.3中文版是功能强大的无人机测绘软件&#xff0c;使用旨在为用户提供完整的地理空间数据可视化和分析功能&#xff0c;带来提高的生产力&#xff0c;功能全面而强大&#xff0c;在无人机到CAD模型的过程中&#xff0c;使用Virtual Surveyor软件来拆卸输送机…

情绪稳定的人有什么特点?

第一部分&#xff1a;至纯之人&#xff0c;大器晚成 1.1 单纯&#xff0c;不是天真 你知道吗&#xff1f;那些能够成就大事的人&#xff0c;往往在人性上非常单纯。他们对外界的需求很低&#xff0c;更多的是向内寻求。这样的人&#xff0c;他们的内心世界像一片净土&#xff…

数据结构与算法--顺序表(Java)

&#x1f4dd;个人主页&#x1f339;&#xff1a;誓则盟约 ⏩收录专栏⏪&#xff1a;Java SE &#x1f921;往期回顾&#x1f921;&#xff1a;Java SE--基本数据类型&#xff08;详细讲解&#xff09; &#x1f339;&#x1f339;期待您的关注 &#x1f339;&#x1f339; 什么…

每日任务:TCP/IP模型和OSI模型的区别

介绍一下TCP/IP模型和OSI模型的区别&#xff1f; OSI模型由国标准化组织提出&#xff0c;而TCP/IP模型是由美国国防部开发的&#xff1b; OSI模型由七个层次组成&#xff0c;从下到上依次为物理层、数据链路层、网络层、传输层、会话层、表示层和应用层。而TCP/IP模型只有四层…

AI视频生成(即梦)

1.打开即梦网页版 https://jimeng.jianying.com/ai-tool/home 2.图片生成-导入参考图&#xff08;这里原本的红色或者灰度图都是可以的&#xff09;-精细度5&#xff08;最高图质量越高&#xff09; 注&#xff1a;根据需要&#xff0c;选择不同的生图模型&#xff0c;具有…

线上监控诊断 - Arthas

简介 Arthas 是一款线上监控诊断产品&#xff0c;通过全局视角实时查看应用 load、内存、gc、线程的状态信息&#xff0c;并且能在不修改应用代码的情况下&#xff0c;对业务问题进行诊断&#xff0c;包括查看方法调用的出入参、异常&#xff0c;监测方法执行耗时&#xff0c;…

SAPUI5基础知识20 - 对话框和碎片(Dialogs and Fragments)

1. 背景 在 SAPUI5 中&#xff0c;Fragments 是一种轻量级的 UI 组件&#xff0c;类似于视图&#xff08;Views&#xff09;&#xff0c;但它们没有自己的控制器&#xff08;Controller&#xff09;。Fragments 通常用于定义可以在多个视图中重用的 UI 片段&#xff0c;从而提…

项目实战1(30小时精通C++和外挂实战)

项目实战1&#xff08;30小时精通C和外挂实战&#xff09; 01-MFC1-图标02-MFC2-按钮、调试、打开网页05-MFC5-checkbox及按钮绑定对象06--文件格式、OD序列号08-暴力破解09-CE10-秒杀僵尸 01-MFC1-图标 这个外挂只针对植物大战僵尸游戏 开发这个外挂&#xff0c;首先要将界面…

FPGA:流水灯设计

本次基于FPGA实现流水灯&#xff0c;即让LED[0:7]从左到右依次电量&#xff0c;每个LED灯频闪周期为1s钟&#xff0c;在这里&#xff0c;给出下面三种实现思路&#xff1a; 1、实验思路 1、使用位运算符 在复位时令LED灯为LED8’b0000_0001&#xff0c;然后每过一秒钟&#x…

软考:软件设计师 — 7.软件工程

七. 软件工程 1. 软件工程概述 &#xff08;1&#xff09;软件生存周期 &#xff08;2&#xff09;软件过程 软件开发中所遵循的路线图称为 "软件过程"。 针对管理软件开发的整个过程&#xff0c;提出了两个模型&#xff1a;能力成熟度模型&#xff08;CMM&#…

嵌入式C++、STM32、MySQL、GPS、InfluxDB和MQTT协议数据可视化:智能物流管理系统设计思路流程(附代码示例)

目录 项目概述 系统设计 硬件设计 软件设计 系统架构图 代码实现 1. STM32微控制器与传感器代码 代码讲解 2. MQTT Broker设置 3. 数据接收与处理 代码讲解 4. 数据存储与分析 5. 数据分析与可视化 代码讲解 6. 数据可视化 项目总结 项目概述 随着电子商务的快…

简单小案例分析

一、容器和实例关系 <div class"app"><h1>Hello,{{name}}</h1> </div> <div class"app"><h1>Hello,{{name}}</h1> </div><script>//创建Vue实例new Vue({el:".app", //el用于指定当前V…