使用Pytorch实现图像花朵分类

基于pytorch-classifier这个源码进行实现的图像分类

代码的介绍在这个链接里面,这篇博客主要是为了带着大家通过实践的方式熟悉一下代码的使用,并且了解相关功能。

1. 下载相关资料

这里我提供了一个花朵数据集,里面总共有十个类别的花朵作为本次实验的数据集。
我们下载代码和数据集到本地,然后我们在下图创建一个名字为dataset的文件夹,然后把花朵数据集放到里面并重命名为train,具体如下:
在这里插入图片描述
至此,完成第一步。

2. 配置环境

  1. 首先推荐使用anaconda作为你的python环境,代码工具可以使用vscode或者pycharm,这个根据使用者爱好,这边我使用的是pycharm,那么这里默认各位已经准备好anaconda和(vscode或者pycharm),不会安装的话可以百度一下,这方面的教程都非常丰富。
  2. 安装torch和torchvision
    你可以在这个pytorch官网中找到对应的安装命令,这里版本要求torch==1.12.0+,下面贴出torch==1.12.0的各项安装命令,各位看官可以根据自己的电脑情况进行选择
    CUDA 11.6
    pip install torch==1.12.0+cu116 torchvision==0.13.0+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
    CUDA 11.3
    pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
    CUDA 10.2
    pip install torch==1.12.0+cu102 torchvision==0.13.0+cu102 --extra-index-url https://download.pytorch.org/whl/cu102
    CPU only
    pip install torch==1.12.0+cpu torchvision==0.13.0+cpu --extra-index-url https://download.pytorch.org/whl/cpu
  3. 安装成功后,进入到代码路径进行pip install -r requirements.txt

至此,完成第二步。

3.分割数据集

在第一步中我们已经放好数据集,但是还没有划分验证集和测试集,这个代码中的processing.py提供了分割数据集的功能,其主要参数如下
在这里插入图片描述
其中–val_size --test_size就是验证集和测试集的比例,这里的分割思想是,先在全部数据集上分出–test_size比例的数据,然后再从剩下的分出-val_size比例的数据。具体参数解释可以看Readme.md文件。
在这里插入图片描述
那么我们默认参数即可,如果各位看官需要修改,可以自行修改即可,那么我们的运行命令就是:

python processing.py

如果需要更改参数,可以直接在命令后面指定,比如我想验证集比例是0.1:

python processing.py --val_size 0.1

运行成功后你可以看到下图所示:
在这里插入图片描述
至此,完成第三步。

4. 训练模型

在本次实验中我们训练ghostnet,其中训练的代码在main.py中,具体参数解释请看Readme.md,都已经解释得比较详细了。那么这里我们直接上命令并附带一些解释:

python main.py --model_name ghostnet --pretrained --config config/config.py --save_path runs/ghostnet_flower --lr 1e-4 --warmup --amp --imagenet_meanstd  --Augment AutoAugment 

运行后,你可以看到以下信息:
在这里插入图片描述
其中会输出所有参数的设置信息,可以通过这个表格观看你的参数设置是否有误。我们也可以看到其在下载预训练权重到本地路径,有些看官的网络不好的话,其可能会卡在这里不动,那我们可以手动复制这个链接:

https://github.com/z1069614715/pretrained-weights/releases/download/ghostnet_1x_v1.0/ghostnet_1x-f97d70db.pth

到迅雷或者浏览器上进行下载到:

C:\Users\Admin/.cache\torch\hub\checkpoints\ghostnet_1x-f97d70db.pth

上述的下载链接和路径在各位看官的输出信息中都可以找到。
然后就会开始训练:
在这里插入图片描述
我们可以看到显示了你选择的模型是ghostnet,也显示了这个模型的flops和参数量,然后就开始训练,并显示训练的进度条和每个epoch的log信息。然后我们就可以等待训练结束。
在这里插入图片描述
我们从图中可以看到其训练到83个epoch就停止,是因为代码中有一个–patience参数,默认值为30,也就是经过30个epoch没有提升,模型就会认为已经收敛,就停止训练。训练完成后我们可以打开刚设定好的–save_path路径:
在这里插入图片描述
其中我们可以看到以下的可视化:

  1. 曲线迭代图
    在这里插入图片描述
  2. 学习率曲线变化图
    在这里插入图片描述
  3. 训练过程中的图像可视化
    在这里插入图片描述
    在这里插入图片描述

这里就只展示两张图,默认是生成五张。
当然你对上述曲线可视化不满意,可以自行读取–save_path中的train.log文件进行自定义可视化。这个文件相当于记录了曲线上的值,方便后期各位美工图像等等…
在这里插入图片描述

我们还看到有一个best.pt和last.pt,后续的测试和预测的步骤,都会读取best.pt作为我们的模型进行测试和预测,其余文件有兴趣可以看看Readme.md中的Some explanation的第八点。
至此,完成第四步

5. 测试

我们在第四步中使用了训练集进行训练,验证集测试,那么剩下测试集就是用训练好的模型进行检测,那么这里我演示如何使用我们的metrice.py进行检测。

python metrice.py --task test --save_path runs/ghostnet_flower

–task 就是任务的意思 其支持train、val、test、fps。这里我们演示的是test,就是使用测试集进行测试计算指标。运行后的截图如下:
在这里插入图片描述
我们可以看到其会显示你当前的模型类型,显示你这个模型训练过程最好精度的指标,然后下面就是显示精度,类别平均精度,以及每个类别的preciesion、recall、f1-score、Kappa、accuracy.
你还可以在–save_path中的test文件夹中找到混淆矩阵,其中csv文件就是数据的存储,如果有作图、美工的需求,可以自行读取。
在这里插入图片描述
在这里插入图片描述
当然我们的测试过程是支持tta(详情可看Readme.md文件中的第10点)的,只需要额外加一个参数即可

python metrice.py --task test --save_path runs/ghostnet_flower --test_tta

在这里插入图片描述
我们可以看到使用tta可以增加精度,但是会增加一点时间。
我们的metrice.py中还支持可视化数据集的识别情况和tsne可视化,也是只需要添加两个参数即可。

python metrice.py --task test --save_path runs/ghostnet_flower --visual --tsne

当然这个过程也支持test_tta,但是为了节省时间就不加了。运行结束后,我们可以在–save_path路径中找到对应的数据文件:
在这里插入图片描述
在这里插入图片描述
其中tsne的坐标信息也保存到tsne.csv中,方便后期美工。
还有生成了correct.csv,incorrect.csv这两个csv文件,其中里面记录了文件路径,预测的类别,正确的类别,预测的类别对应的概率,以方便后期进行错误预测的分析。
至此,完成第五步。

6. 预测

预测的代码文件是predict.py,其支持输入单张图片或者一个文件夹,那么这里我们就展示一个文件夹的预测。
假如我们把测试集中其中一类的文件进行预测,我们可以运行以下命令:

python predict.py --source dataset/test/00 --save_path runs/ghostnet_flower

当然此过程也是支持test_tta,也是只需要在后面添加参数即可。
这是运行成功的截图:
在这里插入图片描述
我们可以打开–save_path的predict文件夹:
在这里插入图片描述
我们随便打开一张图像:
在这里插入图片描述
我们可以看到图像的预测类别和对应预测类别的概率。
你以为结束了吗?并没有,我们的predict.py文件结合pytorch_grad_cam库实现了热力图可视化,并支持多种热力图计算方法详情请看–cam_type参数和Some explanation第十二点,也是只需要加一个参数:

python predict.py --source dataset/test/00 --save_path runs/ghostnet_flower --cam_visual --cam_type GradCAMPlusPlus

运行成功后依然在–save_path的predict文件夹中可以找到对应保存的图像,这里我们也是随便打开一张图像:
在这里插入图片描述
我们可以看到在图像中添加了这个热力图可视化。
你们在文件夹中还可以找到一个result.csv文件,其记录了文件的路径,预测的类别,预测的类别对应的概率信息。
至此,完成第六步。

总结

整个程序的功能演示就到此结束,当然程序的功能不仅仅于此,具体可以看Readme.md中的具体解释,其还支持知识蒸馏,有兴趣的可以看一下代码中的Knowledge_Distillation.md,使用起来也是比较简单的,但是知识蒸馏的参数设置比较吃经验,需要使用者自行尝试。
如果遇到bug等等问题可以留言或者私信作者。
本次实验的代码数据模型全部文件:百度云链接

最后致敬bubbliiing的开源精神!

如果内容对你有帮助,麻烦点个赞,谢谢!

有计算机视觉合作项目可以私信作者!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/1620489.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

“花朵分类“ 手把手搭建【卷积神经网络】

前言 本文介绍卷积神经网络的入门案例,通过搭建和训练一个模型,来对几种常见的花朵进行识别分类; 使用到TF的花朵数据集,它包含5类,即:“雏菊”,“蒲公英”,“玫瑰”,“向日葵”,“郁金香”;共 3670 张彩色图片;通过搭建和训练卷积神经网络模型,对图像进行分类,…

(笔记一)利用open_cv在图像上进行点标记,文字注记,画圆、多边形、椭圆

(1)CV2中的绘图函数: cv2.line() 绘制线条cv2.circle() 绘制圆cv2.rectangle() 绘制矩形cv2.ellipse() 绘制椭圆cv2.putText() 添加注记 (2)注释 img表示需要绘制的图像color表示线条的颜色,采用颜色矩阵…

桌面图标不显示

问题 桌面图标不显示 解决办法 鼠标 右击->选择-查看->显示桌面图标

今天去看看俺姐(老婆)新开的超市

首发博客地址 https://blog.zysicyj.top/ 1 昨晚写博客到12点多,今天困死了,比较意外的是,早上老爸没有叫我,今天早上是老爸和小舅送的葡萄。 所以呢,今早睡得很晚,然后6点多才醒,睡得真舒服&am…

java恶魔之怒太平洋_熊猫人之怒恶魔降临手游辅助下载_熊猫人之怒恶魔降临修改器安卓版V3.1下载(暂未上线)_预约_飞翔下载...

熊猫人之怒恶魔降临修改器是一款简单好用的安卓游戏修改神器。通过修改正在运行的游戏的内存数据,达到修改游戏中的金钱、血量、得分、道具数量、攻击、防御、魔法等参数值。既简单又实用,让你想怎么改,就怎么改,你的游戏你做主。…

LeetCode-738-单调递增的数字

题目描述&#xff1a; 当且仅当每个相邻位数上的数字 x 和 y 满足 x < y 时&#xff0c;我们称这个整数是单调递增的。 给定一个整数 n &#xff0c;返回 小于或等于 n 的最大数字&#xff0c;且数字呈 单调递增 。 解题思路&#xff1a; 先将int变成char[]&#xff0c;获取…

UG+PRESSCAD五金连续模 成型模 复合模具设计视频教程

UGPRESSCAD五金连续模 成型模 复合模具设计视频教程 链接&#xff1a;https://pan.baidu.com/s/1MEQdf3DkmHAEHYOrP1USBQ 提取码&#xff1a;r9f0

教程 参数设置_UG教程之非切削参数设置

转移/快速 转移/快速指定如何从一个切削刀路移动到另一个切削刀路。通常情况下,刀具需要进行以下3个动作: (1)从其当前位置移动到指定的平面。 (2)移动到指定平面内高于进刀运动起点的位置。 (3)最后,刀具将从指定平面移动到进刀的起始处。 1.安全设置 功能:安全设置用于指…

在 WSL2 中使用 NVIDIA Docker 进行全栈开发和深度学习 TensorFlow pytorch GPU 加速

WSL2使用NVIDIA Docker进行全栈开发和深度学习 1. 前置条件 1.1. 安装系统 Windows 10 版本 2004 及更高版本&#xff08;内部版本 19041 及更高版本&#xff09;或 Windows 11 跳过 1.2. 处理好网络环境 安装过程中需要访问国际网络&#xff0c;自行处理好。建议开启 tu…

UML四大关系

文章目录 引言UML的定义和作用UML四大关系的重要性和应用场景关联关系继承关系聚合关系组合关系 UML四大关系的进一步讨论UML四大关系的实际应用软件开发中的应用其他领域的应用 总结 引言 在软件开发中&#xff0c;统一建模语言&#xff08;Unified Modeling Language&#x…

python+协同过滤算法实现简单的图书推荐系统

背景介绍 当我们做一些推荐系统网站时&#xff0c;通常需要合适的推荐算法&#xff0c;下面给大家介绍推荐系统中经典的推荐算法——协同过滤算法。在本文中通过Python语言&#xff0c;以一个图书推荐系统为案例&#xff0c;最终实现一个基于用户对图书的评分而对指定的用户个…

如何使用腾讯云服务器搭建网站?新手建站教程

使用腾讯云服务器搭建网站全流程&#xff0c;包括轻量应用服务器和云服务器CVM建站教程&#xff0c;轻量可以使用应用镜像一键建站&#xff0c;云服务器CVM可以通过安装宝塔面板的方式来搭建网站&#xff0c;腾讯云服务器网分享使用腾讯云服务器建站教程&#xff0c;新手站长搭…

代码随想录算法训练营第四十八天|LeetCode 583,72,编辑距离总结篇

目录 LeetCode 583.两个字符串的删除操作 动态规划五步曲&#xff1a; 1.确定dp[i][j]的含义 2.找出递推公式 3.初始化dp数组 4.确定遍历方向 5.打印dp数组 LeetCode 72.编辑距离 动态规划五步曲&#xff1a; 1.确定dp[i][j]的含义 2.找出递推公式 3.初始化dp数组 4.确定遍历方…

JAVA rs232

JAVA rs232 全套资源提供 全套项目资源环境都在我发布的资源里环境 MAVEN 依赖代码贴出 全套项目资源环境都在我发布的资源里 环境 Configure Virtual Serial Port Driver 模拟串口 友善串口工具调试 MAVEN 依赖 <dependency><groupId>org.bidib.jbidib.org.qba…

java输出hello world_java输出Hello World

一、输出“Hello World!” 1、新建一个java项目,点击File->New->Java Project,创建java项目的界面之后,输入项目名称wly,点击finish。 2、创建好java项目之后,鼠标右键项目,选择New->Class,创建一个类,mypackage为包名,Name类名Hello,首字母大写,点击fini…

二,java输出hello

1&#xff0c;创建文件Hello.java 2, 文件里输入 public class Hello{public static void main(String[] args){System.out.print("hello world!");} } 3&#xff0c; javac Hello.java 会生成一个class文件 4&#xff0c; 然后java Hello 注意&#xff1a; 1…

Go语言入门记录:从基础到变量、函数、控制语句、包引用、interface、panic、go协程、Channel、sync下的waitGroup和Once等

程序入口文件的包名必须是main&#xff0c;但主程序文件所在文件夹名称不必须是main&#xff0c;即我们下图hello_world.go在main中&#xff0c;所以感觉package main写顺理成章&#xff0c;但是如果我们把main目录名称改成随便的名字如filename也是可以运行的&#xff0c;所以…

C语言练习5(巩固提升)

C语言练习5 选择题 选择题 1&#xff0c;下面代码的结果是&#xff1a;( ) #include <stdio.h> #include <string.h> int main() {char arr[] { b, i, t };printf("%d\n", strlen(arr));return 0; }A.3 B.4 C.随机值 D.5 &#x1f4af;答案解析&#…