基于Pytorch框架的深度学习densenet121神经网络鸟类行为识别分类系统源码

 第一步:准备数据

5种鸟类行为数据:self.class_indict =

 ["bowing_status", "grooming", "headdown", "vigilance_status", "walking"]

,总共有23790张图片,每个文件夹单独放一种数据

第二步:搭建模型

简介:

DenseNet(Dense Convolutional Network)稠密卷积网络
CVPR2017的优秀文章
从feature入手,通过对feature的极致利用达到更好的效果和更少的参数。


优点:

减轻了vanishing-gradient(梯度消失)
加强了feature的传递
更有效地利用了feature
一定程度上较少了参数数量


在深度学习网络中,随着网络深度的加深,梯度消失问题会愈加明显,解决方法是创建浅层与深层之间的短路径。在DenseNet中,在保证网络中层与层之间最大程度的信息传输的前提下,直接将所有层连接起来。

在传统卷积神经网络中,如果你有L层,那么就会有L个连接,但是在DenseNet中,会有(L+1)/2个连接。简单来说,就是每一层的输入来自前面所有层的输出。如下图是dense block的结构图,x是数据,H是网络层。

第三步:训练代码

1)损失函数为:交叉熵损失函数

2)训练代码:

import os
import math
import argparseimport torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.optim.lr_scheduler as lr_schedulerfrom model import densenet121, load_state_dict
from my_dataset import MyDataSet
from utils import read_split_data, train_one_epoch, evaluatedef main(args):device = torch.device(args.device if torch.cuda.is_available() else "cpu")print(args)print('Start Tensorboard with "cd", view at http://localhost:6006/')tb_writer = SummaryWriter()if os.path.exists("./weights") is False:os.makedirs("./weights")train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}# 实例化训练数据集train_dataset = MyDataSet(images_path=train_images_path,images_class=train_images_label,transform=data_transform["train"])# 实例化验证数据集val_dataset = MyDataSet(images_path=val_images_path,images_class=val_images_label,transform=data_transform["val"])batch_size = args.batch_sizenw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,pin_memory=True,num_workers=nw,collate_fn=train_dataset.collate_fn)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=batch_size,shuffle=False,pin_memory=True,num_workers=nw,collate_fn=val_dataset.collate_fn)# 如果存在预训练权重则载入model = densenet121(num_classes=args.num_classes).to(device)if args.weights != "":if os.path.exists(args.weights):load_state_dict(model, args.weights)else:raise FileNotFoundError("not found weights file: {}".format(args.weights))# 是否冻结权重if args.freeze_layers:for name, para in model.named_parameters():# 除最后的全连接层外,其他权重全部冻结if "classifier" not in name:para.requires_grad_(False)pg = [p for p in model.parameters() if p.requires_grad]optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=1E-4, nesterov=True)# Scheduler https://arxiv.org/pdf/1812.01187.pdflf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf  # cosinescheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)for epoch in range(args.epochs):# trainmean_loss = train_one_epoch(model=model,optimizer=optimizer,data_loader=train_loader,device=device,epoch=epoch)scheduler.step()# validateacc = evaluate(model=model,data_loader=val_loader,device=device)print("[epoch {}] accuracy: {}".format(epoch, round(acc, 3)))tags = ["loss", "accuracy", "learning_rate"]tb_writer.add_scalar(tags[0], mean_loss, epoch)tb_writer.add_scalar(tags[1], acc, epoch)tb_writer.add_scalar(tags[2], optimizer.param_groups[0]["lr"], epoch)torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--num_classes', type=int, default=5)parser.add_argument('--epochs', type=int, default=100)parser.add_argument('--batch-size', type=int, default=16)parser.add_argument('--lr', type=float, default=0.001)parser.add_argument('--lrf', type=float, default=0.1)# 数据集所在根目录# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgzparser.add_argument('--data-path', type=str,default=r"E:\20240717\data")# densenet121 官方权重下载地址# https://download.pytorch.org/models/densenet121-a639ec97.pthparser.add_argument('--weights', type=str, default='densenet121.pth',help='initial weights path')parser.add_argument('--freeze-layers', type=bool, default=False)parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')opt = parser.parse_args()main(opt)

第四步:统计正确率

第五步:搭建GUI界面

视频演示地址:基于Pytorch框架的深度学习densenet121神经网络鸟类行为识别分类系统源码_哔哩哔哩_bilibili

第六步:整个工程的内容

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码

代码见:基于Pytorch框架的深度学习densenet121神经网络鸟类行为识别分类系统源码

有问题可以私信或者留言,有问必答

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3269865.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

Github个人网站搭建详细教程【Github+Jekyll模板】

文章目录 前言一、介绍1 Github Pages是什么2 静态网站生成工具3 Jekyll简介Jekyll 和 GitHub 的关系 4 Mac系统Jekyll的安装及使用安装Jekyll的简单使用 二、快速搭建第一个Github Pages网站三、静态网站模板——Chirpy1 个人定制 四、WordPress迁移到Github参考资料 前言 23…

Opencv学习项目4——手部跟踪

上一篇博客我们介绍了mediapipe库和对手部进行了检测,这次我们进行手部关键点的连线 代码实现 import cv2 import mediapipe as mpcap cv2.VideoCapture(1) mpHands mp.solutions.hands hands mpHands.Hands() mpDraw mp.solutions.drawing_utilswhile True:…

IDEA-安装插件 驼峰下划线转换

第一步:安装 file-settings-plugins-在marketplace搜索“CamelCase”-点击安装 第二步:设置 file-settings-editor-camel_case 第三步:使用 选中想转换的遍历 使用快捷键 Alt Shift U

Windows环境下安装docker、配置Ubuntu容器并使用vscode ssh连接到容器

目录 一、Windows环境下安装docker二、配置Ubuntu三、在容器中安装ssh服务参考文章 一、Windows环境下安装docker 在任务栏中搜索**“Windows功能”** -将适用于Linux的Windows子系统和虚拟机平台选上 然后按照提示重启电脑。然后开始安装WSL。通过cmd以管理员身份打开命令提…

知识图谱 networkx、pyvis页面简单可视化

参考: https://huggingface.co/learn/cookbook/en/information_extraction_haystack_nuextract networkx 构建保存节点和边,pyvis保存为html 注意点 1、pyvis vscode的jupyer不支持,会显示不正常,直接还是本地jupyter环境使用 2…

【算法】布隆过滤器

一、引言 在现实世界的计算机科学问题中,我们经常需要判断一个元素是否属于一个集合。传统的做法是使用哈希表或者直接遍历集合,但这些方法在数据量较大时效率低下。布隆过滤器(Bloom Filter)是一种空间效率极高的概率型数据结构&…

二叉树以及堆的实现

树 树的定义及概念 树是⼀种非线性的数据结构,它是由n(n>0) 个有限结点组成⼀个具有层次关系的集合。把它叫做树是因为它看起来像⼀棵倒挂的树,也就是说它是根朝上,而叶朝下的。 有⼀个特殊的结点,称…

不支持jdk8的jenkins部署jdk8项目

1、背景 目前最新的jenkins必须基于jdk8以上,才能安装。jenkins最新的插件部分也不支持jdk8了。 2、全局工具配置 配置一个jdk8 配置一个jdk8以上的版本,如jdk17 3、部署maven项目 jdk17项目 可以直接使用maven插件,部署。 jdk8项目 由…

01-调试开发k8s

使用 Docker 构建 Kubernete 官方 release 是使用 Docker 容器构建的。要使用 Docker 构建 Kubernetes,请遵循以下说明: Requirements docker Key scripts 以下脚本位于 build/ 目录中。请注意,所有脚本都必须从 Kubernetes 根目录运行 build/run.…

MySQL环境的配置文件json

突然了解到,使用json文件去进行环境的配置,这样修改参数的时候就只需要去改json文件中的内容,不需要去修改代码中的内容,其他人的MySQL和我的MySQL也不同,这时其他人只需要修改json文件中的内容,清晰明了&a…

代码随想录||day25 非递减子序列,全排列问题

491非递减子序列 力扣题目链接 题目描述: 给你一个整数数组 nums ,找出并返回所有该数组中不同的递增子序列,递增子序列中 至少有两个元素 。你可以按 任意顺序 返回答案。 数组中可能含有重复元素,如出现两个整数相等&#x…

双盲插便携屏方案:LDR6020系列便携显示

在当今这个追求高效与便携的时代,电子设备尤其是便携式显示器(Portable Monitor)的需求日益增长,成为商务人士、设计师、游戏玩家及移动办公族的理想伴侣。其中,6020系列便携屏以其卓越的显示效果、紧凑的机身设计赢得…

基于YOLO系列便捷式代码创新

目标检测 YOLOv5 与 YOLOv7 系列详细介绍 YOLOv5 详细介绍 版本与特点 网络结构 技术亮点 YOLOv7 详细介绍 主要贡献 网络结构 技术亮点 性能对比 基于YOLOv5和YOLOv7系列的多方面创新方法 融合BiFormer注意力机制 融合SImAM注意力机制 CBAM注意力机制 DBB多分枝模块 LSKA注意力…

NumpyPandas:pandas库的安装,不同对象的建立,文件的导入和了解数据

目录 前言 一、Pandas库的安装 二、不同对象的建立 1.Series对象的创建 1.用index方法指定索引 2.在创建的时候就指定索引 3.使用字典的方式创建 4.将一个常量与index一起传入创建 5.输出值和索引 2.DataFrame对象的创建 1.不指定列名则以键当列名 行索引为默认值 …

Hadoop3.3.5的安装与单机/伪分布式配置

文章目录 一、安装须知二、安装jdk三、安装shh四、安装配置hadoop五、运行hadoop 一、安装须知 本次安装的Hadoop版本为hadoop3.3.5。 在这之前完成了VMware虚拟软件的安装,并安装了Ubuntu22.04,在这基础上进行相关配置。 二、安装jdk 在Ubuntu中使用…

MySQL查询执行(一):count执行慢

查询处理器 MySQL查询处理器是MySQL数据库服务器的组件,它负责执行SQL查询。查询处理器的主要任务是解析查询(把用户提交的SQL查询转换为可以被数据库引擎理解和执行的数据操作指令序列),生成查询计划,然后执行该计划。…

C++程序的UI界面闪烁问题的解决办法总结

Windows C++程序复杂的UI界面要使用多种绘图技术(使用GDI、GDI+、ddraw、D3D等绘图),并要贴图去美化,在窗口移动或者改变大小的时候可能会出现闪烁。下面罗列一下UI界面产生闪烁的几种可能的原因,并给出相应的解决办法。 1、原因一 如果熟悉显卡原理的话,调用GDI函数向屏…

JVM系列(三) -类加载器及双亲委派模型介绍

在之前的文章中,介绍了类的加载过程中,我们有提到在加载阶段,通过一个类的全限定名来获取此类的二进制字节流操作,其实类加载器就是用来实现这个操作的。 在虚拟机中,任何一个类,都需要由加载它的类加载器…

《Milvus Cloud向量数据库指南》——Milvus Cloud不同场景下的部署形态选型

不同场景下的部署形态选型 一般说选型肯定离不开阶段。用到向量数据库的应用基本有这么几个阶段: AI 应用的快速原型构建。比如你在做一个 AI 个人助手、一个小的搜索引擎原型、一个端到端的 RAG 原型,这类项目的迭代速度是很关键的,而且原型构建期不需要关心性能或者稳定性…

暑假第二周任务——3Gshare的仿写

3GShare的仿写 登陆注册页面 这个界面的UI比较简单,比较困难的地方在于限制我们的输入长度以及我们输入的字符类型。 在这里我是通过在textField的代理中设计限定输入字符的内容,从而实现限制输入长度和限制输入字符的内容,下面给出相关的代…