PyTorch+AlexNet代码实训

参考文章:https://blog.csdn.net/red_stone1/article/details/122974771

数据集:
在这里插入图片描述

打标签:

import os# os.path.join: 每个参数都是一个路径段,将它们连接起来形成有效的路径名。
train_txt_path = os.path.join("data", "catVSdog", "train.txt")
train_dir = os.path.join("data", "catVSdog", "train_data")
valid_txt_path = os.path.join("data", "catVSdog", "test.txt")
valid_dir = os.path.join("data", "catVSdog", "test_data")def gen_txt(txt_path, img_dir): # 标签,图像f = open(txt_path, 'w')  # 打开一个文件,创建一个file对象# os.walk: 遍历一个目录树,返回目录中的每个目录和文件# os.walk每次迭代都会返回一个元组:#(当前目录的路径字符串,当前目录中所有子目录名称,当前目录所有文件名称)for root, s_dirs, _ in os.walk(img_dir, topdown=True):  # 获取 train文件下各文件夹名称# topdown用于决定遍历目录树的顺序# 以猫狗大战数据集为例,这里的s_dirs是cat和dog文件夹for sub_dir in s_dirs: # 对于猫或狗文件夹里的每个文件(每张图片)遍历i_dir = os.path.join(root, sub_dir)             # 获取各类的文件夹 绝对路径 ?img_list = os.listdir(i_dir)                    # 获取类别文件夹下所有png图片的路径 ? 应该是jpg# os.listdir: 用于返回指定目录中的所有文件和目录的名称列表for i in range(len(img_list)): # 遍历一个类别中的所有图片if not img_list[i].endswith('jpg'):         # 若不是png文件,跳过 ? 应该是jpgcontinue#label = (img_list[i].split('.')[0] == 'cat')? 0 : 1 label = img_list[i].split('.')[0] # 按.分割,并取点后的**第一个部分**# 将字符类别转为整型类型表示if label == 'cat':label = '0'else: # label == 'dog'label = '1'img_path = os.path.join(i_dir, img_list[i])line = img_path + ' ' + label + '\n'f.write(line) # 把打好的标签写在.txt里f.close()if __name__ == '__main__':# 共生成两个图片索引文件:train.txt和test.txtgen_txt(train_txt_path, train_dir)gen_txt(valid_txt_path, valid_dir)

构建数据集:

from PIL import Image
from torch.utils.data import Datasetclass MyDataset(Dataset):def __init__(self, txt_path, transform = None, target_transform = None):fh = open(txt_path, 'r') # 打开图片索引文件imgs = [] # 存储元组:(图片路径,类别(0或1))for line in fh: # 迭代读取文件的行line = line.rstrip() # 使用rstrip方法去除行末的空白符(包括\n)words = line.split() # 将字符串按空白符(空格、制表符等)进行分割imgs.append((words[0], int(words[1]))) # 类别转为整型intself.imgs = imgs # self.transform和self.target_transform:根据读入的参数赋值self.transform = transformself.target_transform = target_transform# __getitem__方法和__len__方法均继承自父类Datasetdef __getitem__(self, index):fn, label = self.imgs[index]img = Image.open(fn).convert('RGB') #img = Image.open(fn)if self.transform is not None:img = self.transform(img) # self.transform对图片进行处理,推测传入的是一个函数名return img, labeldef __len__(self):return len(self.imgs)

加载数据集&数据预处理:

from torchvision import transforms
# transforms.Compose接受一个列表或元组作为参数,列表中的每个元素都是一个数据转换操作
# transforms.Compose返回一个串行操作序列
pipline_train = transforms.Compose([#随机旋转图片transforms.RandomHorizontalFlip(),#将图片尺寸resize到227x227(这是AlexNet的要求)transforms.Resize((227,227)),#将图片转化为Tensor格式transforms.ToTensor(),#正则化(当模型出现过拟合的情况时,用来降低模型的复杂度,加快模型收敛速度)transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 均值为0.5,标准差为0.5#transforms.Normalize(mean = [0.485, 0.456, 0.406],std = [0.229, 0.224, 0.225])
])
pipline_test = transforms.Compose([#将图片尺寸resize到227x227transforms.Resize((227,227)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))#transforms.Normalize(mean = [0.485, 0.456, 0.406],std = [0.229, 0.224, 0.225])
])
train_data = MyDataset('./data/catVSdog/train.txt', transform=pipline_train)
test_data = MyDataset('./data/catVSdog/test.txt', transform=pipline_test)# train_data 和test_data包含多有的训练与测试数据,调用DataLoader批量加载
# batch_size: 每个小批量(batch)包含的样本数量。
# 在训练过程中,模型不会一次性处理整个数据集,而是分成多个小批量逐一输入模型进行训练
# shuffle: 数据洗牌-随机打乱数据集的顺序-使模型在训练时不会对数据顺序敏感
trainloader = torch.utils.data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(dataset=test_data, batch_size=32, shuffle=False)
# 类别信息也是需要我们给定的
classes = ('cat', 'dog') # 对应label=0,label=1

查看最终制作的数据集(图片&标签):

import numpy as np
examples = enumerate(trainloader) # 方便迭代trainloader中的每个批量数据并同时获取它们的索引
batch_idx, (example_data, example_label) = next(examples) # next: 获取枚举对象的下一个元素
# 批量展示图片
for i in range(4):plt.subplot(1, 4, i + 1) #plt.tight_layout()  #自动调整子图参数,使之填充整个图像区域img = example_data[i]img = img.numpy() # FloatTensor转为ndarrayimg = np.transpose(img, (1,2,0)) # 把channel那一维放到最后img = img * [0.5, 0.5, 0.5] + [0.5, 0.5, 0.5]plt.imshow(img)plt.title("label:{}".format(example_label[i]))plt.xticks([])plt.yticks([])
plt.show()

搭建AlexNet神经网络结构:

class AlexNet(nn.Module):"""Neural network model consisting of layers propsed by AlexNet paper."""def __init__(self, num_classes=2):"""Define and allocate layers for this neural net.Args:num_classes (int): number of classes to predict with this model"""super().__init__() # 继承父类的__init__方法# input size should be : (b x 3 x 227 x 227)# The image in the original paper states that width and height are 224 pixels, but# the dimensions after first convolution layer do not lead to 55 x 55.# nn.Sequential是一个用于构建神经网络的容器,它按顺序将各个模块(层)组合在一起,形成一个神经网络模型self.net = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4),  # (b x 96 x 55 x 55) # nn.Conv2d还可以继续添加参数:padding 表示边缘填充空白像素的宽度nn.ReLU(),nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=2),  # 局部响应归一化# 目前更多采用的是Batch Normalizationnn.MaxPool2d(kernel_size=3, stride=2),  # (b x 96 x 27 x 27)nn.Conv2d(96, 256, 5, padding=2),  # (b x 256 x 27 x 27)nn.ReLU(),nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=2),nn.MaxPool2d(kernel_size=3, stride=2),  # (b x 256 x 13 x 13)nn.Conv2d(256, 384, 3, padding=1),  # (b x 384 x 13 x 13)nn.ReLU(),nn.Conv2d(384, 384, 3, padding=1),  # (b x 384 x 13 x 13)nn.ReLU(),nn.Conv2d(384, 256, 3, padding=1),  # (b x 256 x 13 x 13)nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2),  # (b x 256 x 6 x 6))# classifier is just a name for linear layersself.classifier = nn.Sequential(nn.Dropout(p=0.5, inplace=True), # nn.Linear(in_features=(256 * 6 * 6), out_features=500),nn.ReLU(),nn.Dropout(p=0.5, inplace=True),nn.Linear(in_features=500, out_features=20),nn.ReLU(),nn.Linear(in_features=20, out_features=num_classes),)def forward(self, x):"""Pass the input through the net.Args:x (Tensor): input tensorReturns:output (Tensor): output tensor"""x = self.net(x)x = x.view(-1, 256 * 6 * 6)  # reduce the dimensions for linear layer input# x.view: 改变张量形状。# -1表示自动计算,后面的256*6*6表示将第二个维度变成这个尺寸,以便作为全连接层的输入。return self.classifier(x)

将模型部署到GPU/CPU:

#创建模型,部署gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AlexNet().to(device) # 这里的AlexNet是类名,通过.to(device)方法将模型移动到指定的设备
#定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.001) # PyTorch提供的Adam优化器
# model.parameters()返回模型中所有需要训练的参数迭代器
# lr是Adam优化器的学习率,控制每次参数更新的步长大小

定义训练过程:

def train_runner(model, device, trainloader, optimizer, epoch):#训练模型, 启用 BatchNormalization 和 Dropout, 将BatchNormalization和Dropout置为Truemodel.train()total = 0correct =0.0#enumerate迭代已加载的数据集,同时获取数据和数据下标for i, data in enumerate(trainloader, 0):inputs, labels = data#把模型部署到device上inputs, labels = inputs.to(device), labels.to(device)#初始化梯度optimizer.zero_grad()#保存训练结果outputs = model(inputs)#计算损失和#多分类情况通常使用cross_entropy(交叉熵损失函数), 而对于二分类问题, 通常使用sigmodloss = F.cross_entropy(outputs, labels)#获取最大概率的预测结果#dim=1表示返回每一行的最大值对应的列下标predict = outputs.argmax(dim=1)total += labels.size(0)correct += (predict == labels).sum().item()#反向传播loss.backward()#更新参数optimizer.step()if i % 100 == 0:#loss.item()表示当前loss的数值print("Train Epoch{} \t Loss: {:.6f}, accuracy: {:.6f}%".format(epoch, loss.item(), 100*(correct/total)))Loss.append(loss.item())Accuracy.append(correct/total)return loss.item(), correct/total

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3269451.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

浅谈HOST,DNS与CDN

首先这个是网络安全的基础,需得牢牢掌握。 1.什么是HOST HOSTS文件: 定义: HOSTS文件是一个操作系统级别的文本文件,通常位于操作系统的系统目录中(如Windows系统下的C:\Windows\System32\drivers\etc\hosts&#xf…

java数据结构(1):集合框架,时间,空间复杂度,初识泛型

目录 一 java数据结构的集合框架 1.什么是数据结构 2.集合框架 2.1什么是集合框架: 1. 接口 (Interfaces) 2. 实现类 (Implementations) 3. 算法 (Algorithms) 4. 并发集合 (Concurrent Collections) 2.2集合框架的优点: 二 时间和空间复杂度 …

请你谈谈:spring AOP的浅显认识?

在Java面向对象编程中,解决代码重复是一个重要的目标,旨在提高代码的可维护性、可读性和复用性。你提到的两个步骤——抽取成方法和抽取类,是常见的重构手段。然而,正如你所指出的,即使抽取成类,有时仍然会…

【Redis宕机啦!】Redis数据恢复策略:RDB vs AOF vs RDB+AOF

文章目录 Redis宕机了,如何恢复数据为什么要做持久化持久化策略RDBredis.conf中配置RDBCopy-On-Write, COW快照的频率如何把握优缺点 AOFAOF日志内容redis.conf中配置AOF写回策略AOF日志重写AOF重写会阻塞吗优缺点 RDB和AOF混合方式总结 Redis宕机了,如何…

Spring Bean - xml 配置文件创建对象

类型&#xff1a; 1、值类型 2、null &#xff08;标签&#xff09; 3、特殊符号 &#xff08;< -> < &#xff09; 4、CDATA <?xml version"1.0" encoding"UTF-8"?> <beans xmlns"http://www.springframework.org/schema/bea…

分布式锁的三种实现方式:Redis、基于数据库和Zookeeper

分布式锁的实现 操作共享资源&#xff1a;例如操作数据库中的唯一用户数据、订单系统、优惠券系统、积分系统等&#xff0c;这些系统需要修改用户数据&#xff0c;而多个系统可能同时修改同一份数据&#xff0c;这时就需要使用分布式锁来控制访问&#xff0c;防止数据不一致。…

最新爆火的开源AI项目 | LivePortrait 本地安装教程

LivePortrait 本地部署教程&#xff0c;强大且开源的可控人像AI视频生成 1&#xff0c;准备工作&#xff0c;本地下载代码并准备环境&#xff0c;运行命令前需安装git 以下操作不要安装在C盘和容量较小的硬盘&#xff0c;可以找个大点的硬盘装哟 2&#xff0c;需要安装FFmp…

项目开发实战案例 —— Spring Boot + MyBatis + Hibernate + Spring Cloud

作者简介 我是本书的作者&#xff0c;拥有多年Java Web开发经验&#xff0c;致力于帮助更多开发者快速掌握并运用Java Web技术栈中的关键框架和技术。本书旨在通过实战案例的方式&#xff0c;带领读者深入理解并实践Spring Boot、MyBatis、Hibernate以及Spring Cloud等热门技术…

2-46 基于matlab的声音信号的短时能量、短时过零率、端点检测

基于matlab的声音信号的短时能量、短时过零率、端点检测。通过计算计算短时能量、调整能量门限&#xff0c;然后开始端点检测。输出可视化结果。程序已调通&#xff0c;可直接运行。 2-46 短时能量 短时过零率 端点检测 - 小红书 (xiaohongshu.com)

Vue element ui分页组件示例

https://andi.cn/page/621615.html

Camera Raw:预设

Camera Raw 的预设 Presetss模块能够简化和加速照片编辑过程。预设不仅能大大提升工作效率&#xff0c;还能确保处理结果的一致性和专业性。 快捷键&#xff1a;Shift P 预设 Preset与配置文件、快照有其异同之处&#xff0c;它们都可以快速改变照片的影调和颜色。 不同是&…

SQL labs-SQL注入(三,sqlmap使用)

本文仅作为学习参考使用&#xff0c;本文作者对任何使用本文进行渗透攻击破坏不负任何责任。 引言&#xff1a; 盲注简述&#xff1a;是在没有回显得情况下采用的注入方式&#xff0c;分为布尔盲注和时间盲注。 布尔盲注&#xff1a;布尔仅有两种形式&#xff0c;ture&#…

python-学生排序(赛氪OJ)

[题目描述] 已有 a、b 两个链表&#xff0c;每个链表中的结点包括学号、成绩。要求把两个链表合并&#xff0c;按学号升序排列。输入格式&#xff1a; 输入共 NM1 行。 第一行&#xff0c;输入 a、b 两个链表元素的数量 N、M&#xff0c;中间用空格隔开。下来 N 行&#xff0c;…

全网爆火的AI老照片变视频项目来了,简单易上手,1单69,日入1000+

每天为大家带来一个可实操落地的副业项目&#xff0c;创业思维&#xff0c;只要你认真看完&#xff0c;多少都能够为你带来帮助或启发。 最近在短视频上看到很多怀旧视频流量真的大&#xff0c;同时也看到朋友圈很多人在培训这个项目。 既然有这么多人在做&#xff0c;就证明…

一天搞定React(5)——ReactRouter(下)【已完结】

Hello&#xff01;大家好&#xff0c;今天带来的是React前端JS库的学习&#xff0c;课程来自黑马的往期课程&#xff0c;具体连接地址我也没有找到&#xff0c;大家可以广搜巡查一下&#xff0c;但是总体来说&#xff0c;这套课程教学质量非常高&#xff0c;每个知识点都有一个…

C语言文件操作,文件读写

目录 为什么要使用文件&#xff1f; 文件概念 1. 什么是文件&#xff1f; 2. 程序文件 3. 数据文件 4. 文件名 文件的使用 1. 文件指针 2. 文件的打开与关闭 文件的顺序读写 1. 顺序读写函数 2. scanf系列与printf系列 文件的随机读写 1. fseek 2. ftell 3. …

B端:用弹框还是用抽屉,请说出你的依据。

选择浮层&#xff08;弹出框&#xff09;还是抽屉&#xff08;侧边栏&#xff09;作为B端系统的浮层&#xff0c;需要根据具体情况来决定。以下是一些依据供您参考&#xff1a; 1.功能需求&#xff1a; 浮层的选择应该符合系统的功能需求。如果需要在当前页面上提供一些额外的操…

C++ 基础(类和对象下)

目录 一. 再探构造函数 1.1. 初始化列表&#xff08;尽量使用列表初始化&#xff09; 二. static成员 2.1static成员初始化 三.友元 3.1友元&#xff1a;提供了⼀种 突破类访问限定符封装的方式. 四.内部类 4.1如果⼀个类定义在另⼀个类的内部&#xff0c;这个内部类就叫…

Google Android 2024年7月最新消息汇总

本文首发于公众号“AntDream”&#xff0c;欢迎微信搜索“AntDream”或扫描文章底部二维码关注&#xff0c;和我一起每天进步一点点 Google Android 2024年7月最新消息汇总 2024年7月&#xff0c;Google在Android生态系统中发布了多项更新和政策调整&#xff0c;涵盖了Google …

6万字,让你轻松上手的大模型 LangChain 框架

本文为我学习 LangChain 时对官方文档以及一系列资料进行一些总结&#xff5e;覆盖对Langchain的核心六大模块的理解与核心使用方法&#xff0c;全文篇幅较长&#xff0c;共计50000字&#xff0c;可先码住辅助用于学习Langchain。** 一、Langchain是什么&#xff1f; 如今各类…