pytorch前馈神经网络--手写数字识别

前言

具体内容就是:

输入一个图像,经过神经网络后,识别为一个数字。从而实现图像的分类。

资源:

https://download.csdn.net/download/fengzhongye51460/89578965

思路:

确定输入的图像:会单通道灰度的28*28的图像,

把图像平铺后,输送到784个神经元的输入层

输入层输送到隐藏层,提取特征

隐藏层输送到输出层,显示概率

初始化模型

import torch  # Import PyTorch
from torch import nn  # Import the neural network module from PyTorch# Define the neural network class, inheriting from nn.Module
class Network(nn.Module):def __init__(self):super().__init__()  # Call the initializer of the parent class nn.Moduleself.layer1 = nn.Linear(784, 256)  # Define the first linear layer (input size 784, output size 256)self.layer2 = nn.Linear(256, 10)  # Define the second linear layer (input size 256, output size 10)def forward(self, x):x = x.view(-1, 28*28)  # Flatten the input tensor to a 1D tensor of size 28*28x = self.layer1(x)  # Pass the input through the first linear layerx = torch.relu(x)  # Apply the ReLU activation functionreturn self.layer2(x)  # Pass the result through the second linear layer and return it

__init__中

在输入层和隐藏层之间,创建一个线性层1 ,784个神经元转为256个

在隐藏层和输出层之间,创建一个线性层2,把256个神经元转为10个

forward中

先把输入图像x展平,然后输送到layer1中,用relu激活,再输送至layer2

训练模型

import torch
from torch import nn
from torch import optim
from model import Network
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoaderif __name__ == '__main__':# Define the image transformations: convert to grayscale and then to tensortransform = transforms.Compose([transforms.Grayscale(num_output_channels=1),transforms.ToTensor()])# Load the training dataset from the specified directory and apply transformationstrain_dataset = datasets.ImageFolder(root='./mnist_train', transform=transform)# Load the test dataset from the specified directory and apply transformationstest_dataset = datasets.ImageFolder(root='./mnist_test', transform=transform)# Print the length of the training datasetprint("train_dataset length: ", len(train_dataset))# Print the length of the test datasetprint("test_dataset length: ", len(test_dataset))# Create a DataLoader for the training dataset with batch size of 64 and shuffling enabledtrain_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)# Print the number of batches in the training DataLoaderprint("train_loader length: ", len(train_loader))# Iterate over the first few batches of the training DataLoaderfor batch_idx, (data, label) in enumerate(train_loader):# Uncomment the following lines to break after 3 batches# if batch_idx == 3:#     break# Print the batch indexprint("batch_idx: ", batch_idx)# Print the shape of the data tensorprint("data.shape: ", data.shape)# Print the shape of the label tensorprint("label.shape: ", label.shape)# Print the labelsprint(label)# Initialize the neural network modelmodel = Network()# Initialize the Adam optimizer with the model's parametersoptimizer = optim.Adam(model.parameters())# Define the loss function as cross-entropy losscriterion = nn.CrossEntropyLoss()# Train the model for 10 epochsfor epoch in range(10):# Iterate over the batches in the training DataLoaderfor batch_idx, (data, label) in enumerate(train_loader):# Forward pass: compute the model outputoutput = model(data)# Compute the lossloss = criterion(output, label)# Backward pass: compute the gradientsloss.backward()# Update the model parametersoptimizer.step()# Zero the gradients for the next iterationoptimizer.zero_grad()# Print the loss every 100 batchesif batch_idx % 100 == 0:print(f"Epoch {epoch + 1}/10 "f"| Batch {batch_idx}/{len(train_loader)} "f"| Loss: {loss.item():.4f}")# Save the trained model's state dictionary to a filetorch.save(model.state_dict(), 'mnist.pth')

1.数据的读取

        先把图像灰度化,然后转换为张量

    transform = transforms.Compose([transforms.Grayscale(num_output_channels=1),transforms.ToTensor()])

导入训练数据和测试数据,

    # Load the training dataset from the specified directory and apply transformationstrain_dataset = datasets.ImageFolder(root='./mnist_train', transform=transform)# Load the test dataset from the specified directory and apply transformationstest_dataset = datasets.ImageFolder(root='./mnist_test', transform=transform)# Print the length of the training datasetprint("train_dataset length: ", len(train_dataset))# Print the length of the test datasetprint("test_dataset length: ", len(test_dataset))# Create a DataLoader for the training dataset with batch size of 64 and shuffling enabledtrain_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)# Print the number of batches in the training DataLoaderprint("train_loader length: ", len(train_loader))

会把文件夹名称作为数据的标签

,例如 名称为0的文件夹,下面所有的文件都是数字0的图片

打印信息

可以看到导入了6w张训练图片,1w张测试图片,和60000/64=938 组数据

2.数据的训练

创建模型,设置优化器和损失函数

    # Initialize the neural network modelmodel = Network()# Initialize the Adam optimizer with the model's parametersoptimizer = optim.Adam(model.parameters())# Define the loss function as cross-entropy losscriterion = nn.CrossEntropyLoss()

训练数据

训练10轮 ,

每次的步骤

1.计算神经网络的前向传播结果

2.计算output和标签label之间的损失loss

3.使用backward计算梯度

4.使用optimizer更新参数

5.将梯度清零

    # Train the model for 10 epochsfor epoch in range(10):# Iterate over the batches in the training DataLoaderfor batch_idx, (data, label) in enumerate(train_loader):# Forward pass: compute the model outputoutput = model(data)# Compute the lossloss = criterion(output, label)# Backward pass: compute the gradientsloss.backward()# Update the model parametersoptimizer.step()# Zero the gradients for the next iterationoptimizer.zero_grad()# Print the loss every 100 batchesif batch_idx % 100 == 0:print(f"Epoch {epoch + 1}/10 "f"| Batch {batch_idx}/{len(train_loader)} "f"| Loss: {loss.item():.4f}")

3.保存模型

    # Save the trained model's state dictionary to a filetorch.save(model.state_dict(), 'mnist.pth')

测试模型

代码

from model import Network  # Import the custom neural network model class
from torchvision import transforms  # Import torchvision transformations
from torchvision import datasets  # Import torchvision datasets
import torch  # Import PyTorchif __name__ == '__main__':# Define the image transformations: convert to grayscale and then to tensortransform = transforms.Compose([transforms.Grayscale(num_output_channels=1),transforms.ToTensor()])# Load the test dataset from the specified directory and apply transformationstest_dataset = datasets.ImageFolder(root='./mnist_test', transform=transform)# Print the length of the test datasetprint("test_dataset length: ", len(test_dataset))# Initialize the neural network modelmodel = Network()# Load the model's state dictionary from the saved filemodel.load_state_dict(torch.load('mnist.pth'))right = 0  # Initialize a counter for correctly classified images# Iterate over the test datasetfor i, (x, y) in enumerate(test_dataset):output = model(x.unsqueeze(0))  # Forward pass: add batch dimension and compute the model outputpredict = output.argmax(1).item()  # Get the index of the highest score as the predicted labelif predict == y:right += 1  # Increment the counter if the prediction is correctelse:img_path = test_dataset.samples[i][0]  # Get the path of the misclassified image# Print details of the misclassified caseprint(f"wrong case: predict = {predict} actual = {y} img_path = {img_path}")sample_num = len(test_dataset)  # Get the total number of samples in the test datasetacc = right * 1.0 / sample_num  # Calculate the accuracy as the ratio of correct predictions# Print the test accuracyprint("test accuracy = %d / %d = %.31f" % (right, sample_num, acc))

1.读取测试数据集

    # Define the image transformations: convert to grayscale and then to tensortransform = transforms.Compose([transforms.Grayscale(num_output_channels=1),transforms.ToTensor()])# Load the test dataset from the specified directory and apply transformationstest_dataset = datasets.ImageFolder(root='./mnist_test', transform=transform)# Print the length of the test datasetprint("test_dataset length: ", len(test_dataset))

查看打印信息,导入了1w张测试图片

2.导入模型

    # Initialize the neural network modelmodel = Network()# Load the model's state dictionary from the saved filemodel.load_state_dict(torch.load('mnist.pth'))

3.测试

将测试图片导入模型

output = model(x.unsqueeze(0))  # Forward pass: add batch dimension and compute the model output

选择概率最大的测试标签

predict = output.argmax(1).item()  # Get the index of the highest score as the predicted label

查看结果

可以看到,1w图片中9807张图片识别正确。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3266114.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

基于dcm4chee搭建的PACS系统讲解(三)服务端使用Rest API获取study等数据

文章目录 DICOMWeb Support模块主要数据结构ER查询信息基本信息metadata信息统计信息 实践查询API及参数解析API返回的json数组定义VRObjectNodeObjectMapper解析显示指定tag并解析 后记 前期预研的PACS系统,近期要在项目中上线了。因为PACS系统采用无权限认证&…

【初阶数据结构】8.二叉树(3)

文章目录 4.实现链式结构二叉树4.1 前中后序遍历4.1.1 遍历规则4.1.2 代码实现 4.2 结点个数以及高度等4.3 层序遍历4.4 判断是否为完全二叉树4.5层序遍历和判断是否为完全二叉树完整代码 4.实现链式结构二叉树 用链表来表示一棵二叉树,即用链来指示元素的逻辑关系…

减轻幻觉新SOTA,7B模型自迭代训练效果超越GPT-4,上海AI lab发布

LLMs在回答各种复杂问题时,有时会“胡言乱语”,产生所谓的幻觉。解决这一问题的初始步骤就是创建高质量幻觉数据集训练模型以帮助检测、缓解幻觉。 但现有的幻觉标注数据集,因为领域窄、数量少,加上制作成本高、标注人员水平不一…

php反序列化--前置知识

🎼个人主页:金灰 😎作者简介:一名简单的大一学生;易编橙终身成长社群的嘉宾.✨ 专注网络空间安全服务,期待与您的交流分享~ 感谢您的点赞、关注、评论、收藏、是对我最大的认可和支持!❤️ 🍊易编橙终身成长社群&#…

文件共享功能无法使用提示错误代码0x80004005【笔记】

环境情况: 其他电脑可以正常访问共享端,但有一台电脑访问提示错误代码0x80004005。 处理检查: 搜索里输入“启用或关闭Windows功能”按回车键,在“启用或关闭Windows功能”里将“SMB 1.0/CIFS文件共享支持”勾选后(故…

不同行情下算法的具体使用!

上一篇我们说到了不同公司算法交易的区分,有朋友提出了不同的行情下的算法交易应该怎么使用,小编今天就带大家了解下!当然具体实际状况百出,这种可以实际为准(韭菜修养全拼实际探讨交流)! 我们在…

qt做的分页控件

介绍 qt做的分页控件 如何使用 创建 Pagination必须基于一个QWidget创建,否则会引发错误。 Pagination* pa new Pagination(QWidget*);设置总页数 Pagination需要设置一个总的页数,来初始化页码。 pa->SetTotalItem(count);设置可选的每页数量…

Java 每日一题: for 与 foreach 的区别 ?

for 循环:是最基本的循环结构,可以通过初始化语句、循环条件和迭代语句来控制循环的执行。 foreach 循环(也称为增强型 for 循环):用于遍历集合或数组中的元素,简化了遍历过程,没有显式地控制索…

[STM32]HAL库实现自己的BootLoader-BootLoader与OTA-STM32CUBEMX

目录 一、前言 二、BootLoader 三、BootLoader的实现 四、APP程序 五、效果展示 六、拓展 一、前言 听到BootLoader大家一定很熟悉,在很多常见的系统中都会存在BootLoader。本文将介绍BootLoader的含义和简易实现,建议大家学习前掌握些原理基础。 …

全链路追踪 性能监控,GO 应用可观测全面升级

作者:古琦 01 介绍 随着 Kubernetes 和容器化技术的普及,Go 语言不仅在云原生基础组件领域广泛应用,也在各类业务场景中占据了重要地位。如今,越来越多的新兴业务选择 Golang 作为首选编程语言。得益于丰富的 RPC 框架&#xff…

编程类精品GPTs

文章目录 编程类精品GPTs前言种类ChatGPT - GrimoireProfessional-coder-auto-programming 总结 编程类精品GPTs 前言 代码类的AI, 主要看以下要点: 面对含糊不清的需求是否能引导出完整的需求面对完整的需求是否能分步编写代码完成需求编写的代码是否具有可读性和可扩展性 …

力扣算法题:矩阵(玄幻不变量法),链表(虚拟头节点,递归法)

20240725 一、矩阵54.螺旋矩阵(循环不变量) 二、链表1 移除链表元素1.1 原链表删除元素:1.2 虚拟头节点(!!!) 2. 设计链表206. 反转链表(双向指针和递归)双指针递归 交换链表中的元素虚拟头节点法递归法 删…

如何解决 Nginx 与边缘计算节点的集成问题?

🍅关注博主🎗️ 带你畅游技术世界,不错过每一次成长机会! 文章目录 如何解决 Nginx 与边缘计算节点的集成问题?一、理解集成的需求和目标二、解决网络配置问题三、优化 Nginx 配置四、处理安全与认证问题五、监控与调试…

STM32是使用的内部时钟还是外部时钟

STM32是使用的内部时钟还是外部时钟,经常会有人问这个问题。 1、先了解时钟树,见下图: 2、在MDK中,使用的是HSEPLL作为SYSCLK,因此需要对时钟配置寄存器(RCC_CFGR)进行配置,寄存器内…

Jacoco 单元测试配置

前言 编写单元测试是开发健壮程序的有效途径,单元测试写的好不好可以从多个指标考量,其中一个就是单元测试的覆盖率。单元测试覆盖率可以看到我们的单元测试覆盖了多少代码行、类、分支等。查看单元测试覆盖率可以使用一些工具帮助我们计算,…

在IDEA中切换分支没有反应

说明:记录一次在IDEA中切换分支没有反应的情况,新建一个分支后,准备暂存代码,切换到其他分支去,发现怎么切都没有反应,也没有切过去; 解决:首先,我想到是不是当前新分支…

如何解决 Nginx 与无服务器架构的集成问题?

🍅关注博主🎗️ 带你畅游技术世界,不错过每一次成长机会! 文章目录 如何解决 Nginx 与无服务器架构的集成问题? 如何解决 Nginx 与无服务器架构的集成问题? 在当今的云计算时代,无服务器架构因…

机器学习驱动的智能化电池管理技术与应用

目录 主要内容 电池管理技术概述 电池的工作原理与关键性能指标 电池管理系统的核心功能 SOC估计 SOH估计 寿命预测 故障诊断 人工智能机器学习 基础 人工智能的发展 机器学习的关键概念 机器学习在电池管理中的应用案例介绍 人工智能在电池荷电状态估计中的…

AttributeError: ‘list‘ object has no attribute ‘text‘

AttributeError: ‘list‘ object has no attribute ‘text‘ 目录 AttributeError: ‘list‘ object has no attribute ‘text‘ 【常见模块错误】 【解决方案】 示例代码 欢迎来到英杰社区https://bbs.csdn.net/topics/617804998 欢迎来到我的主页,我是博主英…

谷粒商城实战笔记-63-商品服务-API-品牌管理-OSS获取服务端签名

文章目录 一,创建第三方服务模块thrid-party1,创建一个名为gulimall-third-party的模块2,nacos上创建third-party命名空间,用来管理这个服务的所有配置3,配置pom文件4,配置文件5,单元测试6&…