pytorch实现水果2分类(蓝莓,苹果)

1.数据集的路径,结构

dataset.py

目的:

        输入:没有输入,路径是写死了的。

        输出:返回的是一个对象,里面有self.data。self.data是一个列表,里面是(图片路径.jpg,标签)

        -data[item]返回的是(img_tensor , one-hot编码)。one-hot编码是[0,1]或者[1,0]

import glob
import os.pathimport cv2
import torch
from torch.utils.data import Dataset
from torchvision import transformsclass DtataAndLabel(Dataset):def __init__(self,path='fruits',is_train=True):self.tran=transforms.Compose([transforms.ToTensor(),transforms.Resize(size=(88,88))])is_train='train' if True else 'test'self.data=[]path=os.path.join(path,is_train)print('path=',path)print(os.path.join(path, '*', '*'))img_paths=glob.glob(os.path.join(path,'*','*'))for img_path in img_paths:label=0 if img_path.split('\\')[-2]=='blueberry' else 1self.data.append((img_path,label))def __getitem__(self, idx):#每一张图片返回一个img_tensor,one_hotimg_path,label =self.data[idx]img=cv2.imread(img_path)# img_gray=cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)img_tensor=self.tran(img)img_tensor=img_tensor/255img_tensor=torch.flatten(img_tensor)one_hot=torch.zeros(2)one_hot[label]=1return img_tensor,one_hotdef __len__(self):return len(self.data)if __name__ == '__main__':# 测试data=DtataAndLabel()print(data[1][0].shape)print(data[1][1])

net.py

目的:将输入维度(k(k是加载进去的图片数),88,88,3)三通道的宽高是88,88,通过网络变化为(k,2)。

import torch.nn
import torch.nn as nnclass Net(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(nn.Linear(88*88*3, 800),nn.ReLU(),nn.Linear(800, 500),nn.ReLU(),nn.Linear(500, 800),nn.ReLU(),nn.Linear(800, 200),nn.ReLU(),nn.Linear(200, 2),)self.softmax=nn.Softmax(dim=1)def forward(self,x):x=self.model(x)x=self.softmax(x)return x
if __name__ == '__main__':net=Net()#测试一下x=torch.randn(1,100*100)out=net(x)print(out.shape)

test_train.py

目的:将图像丢进模型,然后训练出最优模型

步骤:

       1.定义初始化

                -定义拿到data对象

                -定义加载器分批加载,这里可以变换维度

                -定义初始化网络

                -定义损失函数,这里采用了均方差函数

                -定义优化器

        2.实现训练

                -将每一批数据丢给网络,此时维度发生了变化,产生了升维

                -使用优化器        

                        ---自动梯度清0

                        ---自动求导更新参数

                -计算损失值和准确度

        ·~自己建一个文件夹

import torch.optim
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdmfrom net import Net
from dataset import DtataAndLabel
import torch.nn as nn
class TrainAndTest():def __init__(self):self.writer = SummaryWriter("logs")self.train_data=DtataAndLabel(is_train=True)self.test_data=DtataAndLabel(is_train=False)#使用加载器分批加载self.train_loader=DataLoader(self.train_data,batch_size=10,shuffle=True)self.test_loader=DataLoader(self.test_data,batch_size=10,shuffle=True)#初始化网络#损失函数#优化器net=Net()self.net=netself.loss=nn.MSELoss()self.opt=torch.optim.Adam(net.parameters(),lr=0.001)self.min_loss=100.0self.weight_path='weight/best.pt'def train(self,epoch):sum_loss = 0sum_acc = 0for img_tensors, targets in tqdm(self.train_loader, desc="train...", total=len(self.train_loader)):out = self.net(img_tensors)loss = self.loss(out, targets)self.opt.zero_grad()loss.backward()self.opt.step()sum_loss += loss.item()pred_cls = torch.argmax(out, dim=1)target_cls = torch.argmax(targets, dim=1)accuracy = torch.mean(torch.eq(pred_cls, target_cls).to(torch.float32))sum_acc += accuracy.item()avg_loss = sum_loss / len(self.train_loader)avg_acc = sum_acc / len(self.train_loader)print(f'train:loss{round(avg_loss, 3)} acc:{round(avg_acc, 3)}')self.writer.add_scalars("loss", {"train_avg_loss": avg_loss}, epoch)self.writer.add_scalars("acc", {"train_avg_acc": avg_acc}, epoch)def test(self,epoch):sum_loss = 0sum_acc = 0for img_tensors, targets in tqdm(self.test_loader, desc="test...", total=len(self.test_loader)):out = self.net(img_tensors)loss = self.loss(out, targets)sum_loss += loss.item()pred_cls = torch.argmax(out, dim=1)target_cls = torch.argmax(targets, dim=1)accuracy = torch.mean(torch.eq(pred_cls, target_cls).to(torch.float32))sum_acc += accuracy.item()avg_loss = sum_loss / len(self.test_loader)avg_acc = sum_acc / len(self.test_loader)print(f'test:loss{round(avg_loss, 3)} acc:{round(avg_acc, 3)}')self.writer.add_scalars("loss", {"test_avg_loss": avg_loss}, epoch)self.writer.add_scalars("acc", {"test_avg_acc": avg_acc}, epoch)if avg_loss<self.min_loss:self.min_loss=min(self.min_loss,avg_loss)torch.save(self.net.state_dict(), self.weight_path)def run(self):for epo in range(100):self.train(epo)self.test(epo)if __name__ == '__main__':trainer=TrainAndTest()trainer.run()

精度的计算:

                比如通过网络出现的维度是(1,2),其数值是[[0.9 , 0.1]](0.9与0.1表示预测的两个类别的概率)。我们通过maxarg取到其中最大的索引0,与之前真实的标签0或者1做比较。从而可以得出结果

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3226862.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

02-图像基础-参数

在做有关图像和视频类的实际项目时&#xff0c;常常会涉及到图像的一些配置&#xff0c;下面对这些参数进行解释。 我们在电脑打开一张照片&#xff0c;可以看到一张完整的图像&#xff0c;比如一张360P的图片&#xff0c;其对应的像素点就是640*360&#xff0c;可以以左上角为…

python-25-零基础自学python-处理异常三兄弟try-except-else

学习内容&#xff1a;《python编程&#xff1a;从入门到实践》第二版第十章 知识点&#xff1a; 程序异常如何处理&#xff1f;try-except-else try-尝试可能引起错误的步骤 except-错误步骤发生&#xff0c;打印一些需要用户知道的信息&#xff0c;没有就pass else-错误不…

Java-常用API

1-Java API &#xff1a; 指的就是 JDK 中提供的各种功能的 Java类。 2-Scanner基本使用 Scanner&#xff1a; 一个简单的文本扫描程序&#xff0c;可以获取基本类型数据和字符串数据 构造方法&#xff1a; Scanner(InputStream source)&#xff1a;创建 Scanner 对象 Sy…

【保姆级教程】CenterNet的目标检测、3D检测、关键点检测使用教程

一、代码下载 仓库地址:https://github.com/xingyizhou/CenterNet?tab=readme-ov-file 二、目标检测 2.1 下载预训练权重 下载预训练权重ctdet_coco_dla_2x.pth放到models文件夹下 下载链接:https://drive.google.com/file/d/18Q3fzzAsha_3Qid6mn4jcIFPeOGUaj1d/edit …

13--memcache与redis

前言&#xff1a;数据库读取速度较慢一直是无法解决的问题&#xff0c;大型网站应对的方式主要是使用缓存服务器来缓解这种情况&#xff0c;减少数据库访问次数&#xff0c;以提高动态Web等应用的速度、提高可扩展性。 1、简介 Memcached/redis是高性能的分布式内存缓存服务器…

按模版批量生成定制合同

提出问题 一个仪器设备采购公司&#xff0c;商品合同采购需要按模版生成的固定的文件&#xff0c;模板是固定的&#xff0c;只是每次需要替换信息&#xff0c;然后打印出来寄给客户。 传统方法 如果手工来做这个事情&#xff0c;准备好数据之后&#xff0c;需要从Excel表格中…

亚马逊卖家告别熬夜!批量定时上下架,自动调价

必用能功三个&#xff0c;不限制上传商品。 大家好&#xff0c;今天来讲下这款erp的定时上下架功能。 打开工具这栏选择智能调价&#xff0c;点击添加智能调价选择店铺&#xff0c;选择定时上架的商品添加也可以全部添加。每个商品的价格都是不同的&#xff0c;可以点击保底价…

昇思学习打卡-12-Vision Transformer图像分类

文章目录 ViT模型学习构建模型Multi-Head AttentionTransformerEncoderpos_embeddingVit部分实现 推理结果 ViT模型学习 Vision Transformer&#xff08;ViT&#xff09;简介 ViT则是自然语言处理和计算机视觉两个领域的融合结晶。在不依赖卷积操作的情况下&#xff0c;依然可…

最简单详细的jwt用户登录校验教程(新手必看)

首先简单建张用户表。 DROP TABLE IF EXISTS user; CREATE TABLE user (id bigint NOT NULL AUTO_INCREMENT,name varchar(32) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL,username varchar(32) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL…

C++ 编译体系入门指北

前言 之从入坑C之后&#xff0c;项目中的编译构建就经常跟CMake打交道&#xff0c;但对它缺乏系统的了解&#xff0c;遇到问题又陷入盲人摸象。对C的编译体系是如何发展的&#xff0c;为什么要用CMake&#xff0c;它的运作原理是如何的比较感兴趣&#xff0c;所以就想系统学习…

云手机批量操作使用场景,从Amazon、TK等软件分析

云手机目前所具备的群控&#xff0c;批量操作&#xff0c;自动化等功能&#xff0c;对于电商&#xff0c;软测&#xff0c;办公&#xff0c;直播&#xff0c;营销等行业有很好的减负作用。 针对于具体的海外APP&#xff0c;云手机具体可以做哪些事情来帮助我们减轻压力&#x…

伺服【禾川X6】

驱动器&#xff1a; A&#xff1a;脉冲 B&#xff1a;EtherCAT // SV-X6 FB 040 AA 一套360 N&#xff1a;CANopen R&#xff1a;PROFINET 电机&#xff1a; SV-X6 MA 040A-B2 KA

MongoDB - 集合和文档的增删改查操作

文章目录 1. MongoDB 运行命令2. MongoDB CRUD操作1. 新增文档1. 新增单个文档 insertOne2. 批量新增文档 insertMany 2. 查询文档1. 查询所有文档2. 指定相等条件3. 使用查询操作符指定条件4. 指定逻辑操作符 (AND / OR) 3. 更新文档1. 更新操作符语法2. 更新单个文档 updateO…

Linux--线程ID封装管理原生线程

目录 1.线程的tid&#xff08;本质是线程属性集合的起始虚拟地址&#xff09; 1.1pthread库中线程的tid是什么&#xff1f; 1.2理解库 1.3phtread库中做了什么&#xff1f; 1.4线程的tid&#xff0c;和内核中的lwp 1.5线程的局部存储 2.封装管理原生线程库 1.线程的tid…

odoo视图继承

odoo视图继承 在模型时候&#xff0c;不对视图、菜单等进行修改&#xff0c;原视图和菜单等视图数据仍然可以使用&#xff0c;不需要重新构建 form视图继承案例 model&#xff1a;为对应模型 inherit_id&#xff1a;为继承的视图&#xff0c;ref:为继承视图的id&#xff0…

帝特(DTECH)USB转RS485/422串口线在Ubuntu系统中的安装

因为测试需要&#xff0c;买了一根帝特&#xff08;DTECH&#xff09;USB转RS485/422串口线&#xff0c;今天测试了一下在Ubuntu 22.04系统上的使用。帝特的网站上提供了驱动程序&#xff0c;下载以后发现接口芯片是CP2102&#xff0c;厂商只提供了Linux内核2.6和3.x版本的驱动…

新版FMEA培训未能达到预期效果怎么办?

在制造业的质量管理中&#xff0c;FMEA&#xff08;Failure Mode and Effects Analysis&#xff0c;失效模式与影响分析&#xff09;是一项至关重要的工具&#xff0c;它帮助企业识别和评估产品或过程中潜在的失效模式&#xff0c;以及这些失效模式可能导致的后果。然而&#x…

AIGC技术引领创意设计行业革新,“谁”能成职业发展新引擎?

随着科技的日新月异&#xff0c;生成式人工智能&#xff08;AIGC&#xff09;技术正迅速崛起&#xff0c;成为创意设计领域的一股强大新势力。该技术不仅显著提升了设计师的工作效率&#xff0c;更为他们打开了前所未有的创意空间。在这一波技术浪潮中&#xff0c;Adobe国际认证…

AutoMQ 与蚂蚁数科达成战略合作

近期&#xff0c;AutoMQ 与蚂蚁数科正式签署战略合作协议&#xff0c;将和蚂蚁数科云原生 PaaS 平台 SOFAStack 在产品研发、生态集成、市场合作、技术社区影响力等多方面开展深度合作。 AutoMQ 是业内领先的消息和流存储服务提供商&#xff0c;基于云原生基础设施重新设计了 …

如何整合生成的人工智能?(GenAI)为你未来的工作增加动力

生成人工智能(GenAI)它发展迅速&#xff0c;以前所未有的速度取得了突破。人工智能将继续改变各行各业&#xff0c;预计2023年至2030年的年增长率将达到37.3%。由于一种新的知识工作者现在面临被取代的风险&#xff0c;生成式人工智能的惊人崛起进一步加剧了这种紧迫性。据《未…