卷积神经网络——LeNet——FashionMNIST

目录

  • 一、整体结构
  • 二、model.py
  • 三、model_train.py
  • 四、model_test.py

GitHub地址

一、整体结构

在这里插入图片描述

二、model.py

import torch
from torch import nn
from torchsummary import summaryclass LeNet(nn.Module):def __init__(self):super(LeNet,self).__init__()self.c1 = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5,padding=2)self.sig = nn.Sigmoid()self.s2 = nn.AvgPool2d(kernel_size=2,stride=2)self.c3 = nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5)self.s4 = nn.AvgPool2d(kernel_size=2,stride=2)self.flatten = nn.Flatten()self.f5 = nn.Linear(in_features=5*5*16,out_features=120)self.f6 = nn.Linear(in_features=120,out_features=84)self.f7 = nn.Linear(in_features=84,out_features=10)def forward(self,x):x = self.sig(self.c1(x))x = self.s2(x)x = self.sig(self.c3(x))x = self.s4(x)x = self.flatten(x)x = self.f5(x)x = self.f6(x)x = self.f7(x)return x# if __name__ =="__main__":
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#
#     model = LeNet().to(device)
#
#     print(summary(model,input_size=(1,28,28)))

三、model_train.py

# 导入所需的Python库
from torchvision.datasets import FashionMNIST
from torchvision import transforms
import torch.utils.data as Data
import torch
from torch import nn
import time
import copy
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from model import LeNet  # model.py中定义了LeNet模型
from tqdm import tqdm  # 导入tqdm库,用于显示进度条# 定义数据加载和处理函数
def train_val_data_process():# 加载FashionMNIST数据集,Resize到28x28尺寸,并转换为Tensortrain_data = FashionMNIST(root="./data",train=True,transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),download=True)# 将加载的数据集分为80%的训练数据和20%的验证数据train_data, val_data = Data.random_split(train_data, lengths=[round(0.8 * len(train_data)), round(0.2 * len(train_data))])# 为训练数据和验证数据创建DataLoader,设置批量大小为32,洗牌,2个进程加载数据train_dataloader = Data.DataLoader(dataset=train_data,batch_size=32,shuffle=True,num_workers=2)val_dataloader = Data.DataLoader(dataset=val_data,batch_size=32,shuffle=True,num_workers=2)# 返回训练和验证的DataLoaderreturn train_dataloader, val_dataloader# 定义模型训练和验证过程的函数
def train_model_process(model, train_dataloader, val_dataloader, num_epochs):# 设置使用CUDA如果可用device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 打印使用的设备dev = "cuda" if torch.cuda.is_available() else "cpu"print(f'当前模型训练设备为: {dev}')# 初始化Adam优化器和交叉熵损失函数optimizer = torch.optim.Adam(model.parameters(), lr=0.001)criterion = nn.CrossEntropyLoss()# 将模型移动到选定的设备上model = model.to(device)# 复制模型权重用于后续更新最佳模型best_model_wts = copy.deepcopy(model.state_dict())best_acc = 0.0  # 初始化最佳准确度# 初始化用于记录训练和验证过程中损失和准确度的列表train_loss_all = []val_loss_all = []train_acc_all = []val_acc_all = []# 记录训练开始时间start_time = time.time()# 迭代指定的训练轮数for epoch in range(1, num_epochs + 1):# 记录每个epoch开始的时间since = time.time()# 打印分隔符和当前epoch信息print("-" * 10)print(f"Epoch: {epoch}/{num_epochs}")# 初始化训练和验证过程中的损失和正确预测数量train_loss = 0.0train_corrects = 0val_loss = 0.0val_corrects = 0# 初始化批次计数器train_num = 0val_num = 0# 创建训练进度条progress_train_bar = tqdm(total=len(train_dataloader), desc=f'Training {epoch}', unit='batch')# 训练数据集的遍历for step, (b_x, b_y) in enumerate(train_dataloader):# 将数据移动到相应的设备上b_x = b_x.to(device)b_y = b_y.to(device)# 训练模型model.train()# 前向传播output = model(b_x)# 计算预测标签pre_label = torch.argmax(output, dim=1)# 计算损失loss = criterion(output, b_y)# 清空梯度optimizer.zero_grad()# 反向传播loss.backward()# 更新权重optimizer.step()# 累加损失和正确预测数量train_loss += loss.item() * b_x.size(0)train_corrects += torch.sum(pre_label == b_y.data)# 更新批次计数器train_num += b_x.size(0)# 更新训练进度条progress_train_bar.update(1)# 关闭训练进度条progress_train_bar.close()# 创建验证进度条progress_val_bar = tqdm(total=len(val_dataloader), desc=f'Validation {epoch}', unit='batch')# 验证数据集的遍历for step, (b_x, b_y) in enumerate(val_dataloader):# 将数据移动到相应的设备上b_x = b_x.to(device)b_y = b_y.to(device)# 评估模型model.eval()# 前向传播output = model(b_x)# 计算预测标签pre_label = torch.argmax(output, dim=1)# 计算损失loss = criterion(output, b_y)# 累加损失和正确预测数量val_loss += loss.item() * b_x.size(0)val_corrects += torch.sum(pre_label == b_y.data)# 更新批次计数器val_num += b_x.size(0)# 更新验证进度条progress_val_bar.update(1)# 关闭验证进度条progress_val_bar.close()# 计算并记录epoch的平均损失和准确度train_loss_all.append(train_loss / train_num)train_acc_all.append(train_corrects.double().item() / train_num)val_loss_all.append(val_loss / val_num)val_acc_all.append(val_corrects.double().item() / val_num)# 打印训练和验证的损失与准确度print(f'{epoch} Train Loss: {train_loss_all[-1]:.4f} Train Acc: {train_acc_all[-1]:.4f}')print(f'{epoch} Val Loss: {val_loss_all[-1]:.4f} Val Acc: {val_acc_all[-1]:.4f}')# 计算并打印epoch训练耗费的时间time_use = time.time() - sinceprint(f'第 {epoch} 个 epoch 训练耗费时间: {time_use // 60:.0f}m {time_use % 60:.0f}s')# 若当前epoch的验证准确度为最佳,则更新最佳模型权重if val_acc_all[-1] > best_acc:best_acc = val_acc_all[-1]best_model_wts = copy.deepcopy(model.state_dict())# 训练结束,保存最佳模型权重torch.save(best_model_wts, 'D:/Pycharm/deepl/LeNet/weight/best_model.pth')# 如果当前epoch为总epoch数,则保存最终模型权重if epoch == num_epochs:torch.save(model.state_dict(), f'D:/Pycharm/deepl/LeNet/weight/{num_epochs}_model.pth')# 将训练过程中的统计数据整理成DataFrametrain_process = pd.DataFrame(data={"epoch": range(1, num_epochs + 1),"train_loss_all": train_loss_all,"val_loss_all": val_loss_all,"train_acc_all": train_acc_all,"val_acc_all": val_acc_all})# 打印总训练时间consume_time = time.time() - start_timeprint(f'总耗时:{consume_time // 60:.0f}m {consume_time % 60:.0f}s')# 返回包含训练过程统计数据的DataFramereturn train_process# 定义绘制训练和验证过程中损失与准确度的函数
def matplot_acc_loss(train_process):# 创建图形和子图plt.figure(figsize=(12, 4))# 绘制训练和验证损失plt.subplot(1, 2, 1)plt.plot(train_process["epoch"], train_process["train_loss_all"], 'ro-', label="train_loss")plt.plot(train_process["epoch"], train_process["val_loss_all"], 'bs-', label="val_loss")plt.legend()plt.xlabel("epoch")plt.ylabel("loss")# 保存损失图像plt.savefig('./result_picture/training_loss_accuracy.png', bbox_inches='tight')# 绘制训练和验证准确度plt.subplot(1, 2, 2)plt.plot(train_process["epoch"], train_process["train_acc_all"], 'ro-', label="train_acc")plt.plot(train_process["epoch"], train_process["val_acc_all"], 'bs-', label="val_acc")plt.legend()plt.xlabel("epoch")plt.ylabel("accuracy")# 保存准确率曲线图plt.savefig('./result_picture/training_accuracy.png', bbox_inches='tight')plt.show()if __name__ == "__main__":model = LeNet()train_dataloader, val_dataloader = train_val_data_process()train_process = train_model_process(model, train_dataloader, val_dataloader, num_epochs=20)matplot_acc_loss(train_process)

四、model_test.py

import torch
import torch.utils.data as Data
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from model import LeNet
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
# t代表testdef t_data_process():test_data = FashionMNIST(root="./data",train=False,transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),download=True)test_dataloader = Data.DataLoader(dataset=test_data,batch_size=1,shuffle=True,num_workers=0)return test_dataloaderdef t_model_process(model, test_dataloader):if model is not None:print('Successfully loaded the model.')device = "cuda" if torch.cuda.is_available() else "cpu"model = model.to(device)# 初始化参数test_corrects = 0.0test_num = 0all_preds = []  # 存储所有预测标签all_labels = []  # 存储所有实际标签# 只进行前向传播,不计算梯度with torch.no_grad():for test_x, test_y in test_dataloader:test_x = test_x.to(device)test_y = test_y.to(device)# 设置模型为验证模式model.eval()# 前向传播得到一个batch的结果output = model(test_x)# 查找最大值对应的行标pre_lab = torch.argmax(output, dim=1)# 收集预测和实际标签all_preds.extend(pre_lab.tolist())all_labels.extend(test_y.tolist())# 计算准确率test_corrects += torch.sum(pre_lab == test_y.data)# 将所有的测试样本进行累加test_num += test_x.size(0)# 计算准确率test_acc = test_corrects.double().item() / test_numprint(f'测试的准确率:{test_acc}')# 绘制混淆矩阵conf_matrix = confusion_matrix(all_labels, all_preds)sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')plt.xlabel('Predicted labels')plt.ylabel('True labels')plt.title('Confusion Matrix')plt.show()plt.savefig('./result_picture/Confusion_Matrix.png', bbox_inches='tight')if __name__=="__main__":# 加载模型model = LeNet()print('loading model')# 加载权重model.load_state_dict(torch.load('D:/Pycharm/deepl/LeNet/weight/best_model.pth'))# 加载测试数据test_dataloader = t_data_process()# 加载模型测试的函数t_model_process(model,test_dataloader)device = "cuda" if torch.cuda.is_available() else "cpu"model = model.to(device)classes = ['T-shirt/top','Trouser','Pullover','Dress','coat','Sandal','Shirt','Sneaker','Bag','Ankle boot']with torch.no_grad():for b_x,b_y in test_dataloader:b_x = b_x.to(device)b_y = b_y.to(device)model.eval()output = model(b_x)pre_lab = torch.argmax(output,dim=1)result = pre_lab.item()label = b_y.item()print(f'预测值:{classes[result]}',"-----------",f'真实值:{classes[label]}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3226267.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

第3章 计算机视觉基础

第3章 计算机视觉基础 3.1 计算机视觉 3.1.1 计算机视觉概述 计算机视觉(Computer Vision)又称机器视觉(Machine Vision),是一门让机器学会如何去“看”的学科,是深度学习技术的一个重要应用领域&#xff0…

会员运营体系设计及SOP梳理

一些做会员的经验和方法分享给大家,包括顶层思考、流程的梳理、组织的建立,后续会做成系列,最近几期主要围绕顶层策略方面,以下是核心内容的整理: 1、会员运营体系设计 顶层设计与关键业务定位:建立客户运营…

使用ResizeObserver观察DOM元素的尺寸变化

文章目录 关于ResizeObserver示例代码示例代码结果如下所示echarts自适应容器div大小示例代码结果如下所示echarts自适应容器大小的方式二 关于ResizeObserver 关于这个Web API,可以看mdn的官网,ResizeObserver - Web API | MDN (mozilla.org)&#xff…

CvT:微软提出结合CNN的ViT架构 | 2021 arxiv

CvT将Transformer与CNN在图像识别任务中的优势相结合,从CNN中借鉴了多阶段的层级结构设计,同时引入了Convolutional Token Embedding和Convolutional Projection操作增强局部建模能力,在保持计算效率的同时实现了卓越的性能。此外&#xff0c…

【R语言+Gephi】利用R语言和Gephi实现共发生网络的可视化

【R语言Gephi】利用R语言和Gephi实现共发生网络的可视化 注:本文仅作为自己的学习记录以备以后复习查阅 一 概述 Gephi是一款开源免费的多平台网络分析软件,在Windows、Linux和Mac os上均可以运行,像他们官网所说的,他们致力于…

前端/python脚本/转换-使用天地图下载的geojson(echarts4+如果直接使用会导致坐标和其他信息不全)

解决echarts4如果直接使用天地图下载的geojson会导致坐标和其他信息不全 解决方法是使用python脚本来补全其他信息:center,level,adcode等内容 前提是必须有一个之前使用的json文件(需要全一点的数据供echarts使用) …

python(3.7版本)安装mitmproxy

环境介绍:win11, python3.7 pip install mitmproxy5.0.0 命令行cmd下,输入 Mitmdump 查看结果是否报错 如果报错上面这样子,就是markupsafe版本问题 换个Markupsafe版本就可以了 成功了吧!!!,如有问题,欢迎留言

js 图片放大镜

写购物项目的时候&#xff0c;需要放大图片&#xff0c;这里用js写了一个方法&#xff0c;鼠标悬浮的时候放大当前图片 这个是class写法 <!--* Descripttion: * Author: 苍狼一啸八荒惊* LastEditTime: 2024-07-10 09:41:34* LastEditors: 夜空苍狼啸 --><!DOCTYPE …

市场流行的蜗牛星际NAS和Think Pad X250硬件CPU等比较

当前二手市场流行的Mini服务器硬件是蜗牛星际NAS套件&#xff0c;流行的笔记本是Think Pad X230以及Think Pad X250&#xff0c;这里就比较下蜗牛和X250,因为它们都是低功耗、低运算的架构&#xff0c;至于神机Think Pad X230&#xff0c;它的速度太快&#xff0c;功耗太高&…

PCI PTS 硬件安全模块(HSM)模块化安全要求 v5.0

符合条件的 PCI SSC 利益相关者在 30 天的意见征询 (RFC) 期间审查 PCI PTS 硬件安全模块 (HSM) 模块化安全要求 v5.0 草案并提供反馈。 PCI PTS 硬件安全模块(HSM)模块化安全要求 v5.0图 从 7 月 8 日到 8 月 8 日&#xff0c;邀请符合条件的 PCI SSC 利益相关者在 30 天的意见…

结构体案例1

代码 #include <iostream> using namespace std; #include <string> #include <ctime>//学生的结构体 struct Student {string sName;int score; }; //老师的结构体定义 struct Teacher {string tName;struct Student sArray[5]; };//给老师和学生赋值的函数…

用vite创建Vue3项目的步骤和文件解释

创建项目的原则是不能出现中文和特殊字符&#xff0c;最好为小写字母&#xff0c;数字&#xff0c;下划线组成 之后在visual studio code 中打开创建的这个项目 src是源代码文件 vite和webpack是有去别的&#xff0c;对于这个vite创建的工程来说index.js是入口文件 在终端里面输…

microblaze时钟更改出现时序问题

在使用microblaze时&#xff0c;我给的时钟是200MHz的时钟&#xff0c;但会在跑布线的时候出现时序上的问题&#xff0c;一开始是没有任何的头绪&#xff0c;知道我尝试更改时钟的频率才发现问题的所在。 当我把200MHz的时钟改为100MHz的时钟时&#xff0c;就不会出现时序上的…

nodejs实现文件的分片写入和读取

&#xff08;1&#xff09;创建 test.cjs 文件 &#xff08;2&#xff09;代码 const {readFileSync,writeFileSync} require(fs); const {dirname} require(path); const chunkSize 1024 * 8; // 切片大小 const path C:\\Users\\cat\\De…

如何查询并下载韩国签证

登录大韩民国签证门户网站&#xff08;https://www.visa.go.kr&#xff09;&#xff0c;点击“查询/签发”- “办理进度查询及打印”。 2) 输入护照号码、英文姓名及出生日期后点击查询。 3) 若签证通过&#xff0c;办理状态信息栏下面会显示签证信息。 4&#xff09;点击“签证…

man手册的安装和使用

man手册 - HQ 文章目录 man手册 - HQ[toc]man手册的使用Linux man中文手册安装man中文手册通过安装包安装通过apt安装 配置man中文手册README使用说明配置步骤 man手册的使用 首先man分为八个目录&#xff0c;每个目录用一个数字表示 1.可执行程序2.系统调用3.库函数4.特殊文…

【吊打面试官系列-MyBatis面试题】简述 Mybatis 的插件运行原理,以及如何编写一个插件?

大家好&#xff0c;我是锋哥。今天分享关于 【简述 Mybatis 的插件运行原理&#xff0c;以及如何编写一个插件?】面试题&#xff0c;希望对大家有帮助&#xff1b; 简述 Mybatis 的插件运行原理&#xff0c;以及如何编写一个插件? Mybatis 仅可以编写针对 ParameterHandler、…

【单片机毕业设计选题24053】-基于单片机的WiFi控制门禁系统设计

系统功能: 系统上电后OLED显示智能门禁系统 Door:xxxxxx 初始化ESP8266完成后显示 Door:Closed 短按按键SW4可打开电磁锁OLED显示Door:Open&#xff0c;约五秒后电磁锁自动关闭OLED 显示Door:Closed 根据“TCP调试助手使用说明”操作&#xff0c; 在调试助手界面发送Open后…

Backend - C# 操作PostgreSQL DB

目录 一、安装 Npgsql 插件 &#xff08;一&#xff09;作用 &#xff08;二&#xff09;操作 &#xff08;三&#xff09;注意 二、操作类 &#xff08;一&#xff09;操作类 1.NpgsqlConnection类 &#xff08;1&#xff09;作用 &#xff08;2&#xff09;引入 &a…

gltf模型旋转

const loader new GLTFloader()loader.load(cars.gltf,(gltf) > {scene.add(gltf.scene)let angle 0function animate() {requestAnimateFrame(animate)angle 0.01//0.01是旋转速度gltf.scene.rotation.y anglerenderer.render(scene, camera)}animate()} ) 效果如下&a…