8-pytorch-损失函数与反向传播

b站小土堆pytorch教程学习笔记

根据loss更新模型参数
1.计算实际输出与目标之间的差距
2.为我们更新输出提供一定的依据(反向传播)

在这里插入图片描述

1 MSEloss

import torch
from torch.nn import L1Loss
from torch import nninputs=torch.tensor([1,2,3],dtype=torch.float32)
targets=torch.tensor([1,2,5],dtype=torch.float32)inputs=torch.reshape(inputs,(-1,1,1,3))
targets=torch.reshape(targets,(-1,1,1,3))loss=L1Loss()
result=loss(inputs,targets)loss_mse=nn.MSELoss()
result_mse=loss_mse(inputs,targets)print(result)
print(result_mse)

tensor(0.6667)
tensor(1.3333)

2 Cross EntropyLoss

在这里插入图片描述

x=torch.tensor([0.1,0.2,0.3])#需要reshape为要求的(batch_size,class)
y=torch.tensor([1])#target已经为要求的batch_size无需reshape
x=torch.reshape(x,(-1,3))
loss_cross=nn.CrossEntropyLoss()
result_cross=loss_cross(x,y)
print(result_cross)

tensor(1.1019)

3 在具体的神经网络中使用loss

import torch
import torchvision.datasets
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset=torchvision.datasets.CIFAR10('dataset',train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader=DataLoader(dataset,batch_size=1)class Han(nn.Module):def __init__(self):super(Han, self).__init__()self.model1=Sequential(Conv2d(3,32,5,padding=2),MaxPool2d(2),Conv2d(32,32,5,padding=2),MaxPool2d(2),Conv2d(32,64,5,padding=2),MaxPool2d(2),Flatten(),Linear(1024,64),Linear(64,10))def forward(self,x):x=self.model1(x)return xloss=nn.CrossEntropyLoss()
han=Han()
for data in dataloader:imgs,target=dataoutput=han(imgs)# print(target)# print(output)result_loss=loss(output,target)print(result_loss)

*tensor([7])
tensor([[ 0.0057, -0.0201, -0.0796, 0.0556, -0.0625, 0.0125, -0.0413, -0.0056,
0.0624, -0.1072]], grad_fn=)…

tensor(2.2664, grad_fn=)…

4 反向传播 优化器

  1. 定义优化器
  2. 将待更新的每个参数梯度清零
  3. 调用损失函数的反向传播函数求出每个节点的梯度
  4. 使用step函数对模型的每个参数调优
import torch
import torchvision.datasets
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset=torchvision.datasets.CIFAR10('dataset',train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader=DataLoader(dataset,batch_size=64)class Han(nn.Module):def __init__(self):super(Han, self).__init__()self.model1=Sequential(Conv2d(3,32,5,padding=2),MaxPool2d(2),Conv2d(32,32,5,padding=2),MaxPool2d(2),Conv2d(32,64,5,padding=2),MaxPool2d(2),Flatten(),Linear(1024,64),Linear(64,10))def forward(self,x):x=self.model1(x)return xloss=nn.CrossEntropyLoss()
han=Han()
optim=torch.optim.SGD(han.parameters(),lr=0.01)for epoch in range(5):running_loss=0.0#一个epoch结束的loss和for data in dataloader:imgs,target=dataoutput=han(imgs)result_loss=loss(output,target)#每次迭代的lossoptim.zero_grad()#将网络中每个可调节参数对应的梯度调为0result_loss.backward()#优化器需要每个参数的梯度,使用反向传播获得optim.step()#对每个参数调优running_loss=running_loss+result_lossprint(running_loss)

Files already downloaded and verified
tensor(361.0316, grad_fn=)
tensor(357.6938, grad_fn=)
tensor(343.0560, grad_fn=)
tensor(321.8132, grad_fn=)
tensor(313.3173, grad_fn=)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2808189.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

nginx(二)

nginx的验证模块 输入用户名和密码 第一步先下载httpd 这个安装包 第二步编辑子配置文件 然后去网页访问192.168.68.3/admin/ 连接之后,会出现404,404出现是因为没给网页写页面 如果要写页面,则在/opt/html,建立一个admin&#x…

吴恩达deeplearning.ai:矩阵运算代码实战

神经网络向量化指的是将输入数据转化为向量形式,以便于神经网络的处理。向量化的作用包括以下几点: 提高计算效率:使用向量化的输入数据可以进行并行计算,加速神经网络的训练和推断过程。 减少存储空间:向量化可以将…

C#与VisionPro联合开发——TCP/IP通信

TCP/IP(传输控制协议/互联网协议)是一组用于在网络上进行通信的通信协议。它是互联网和许多局域网的基础,为计算机之间的数据传输提供了可靠性、有序性和错误检测。在软件开发中,TCP/IP 通信通常用于实现网络应用程序之间的数据交…

改进Yolov5目标检测与单目测距 yolo速度测量-pyqt界面-yolo添加注意力机制

当设计一个结合了 YOLOv5 目标检测、单目测距与速度测量以及 PyQt 界面的毕业设计时,需要考虑以下几个方面的具体细节: 计算机视觉、图像处理、毕业辅导、作业帮助、代码获取,私聊会回复! YOLOv5 目标检测: 首先,选择…

汇编反外挂

在软件保护领域,尤其是游戏保护中,反外挂是一个重要的议题。外挂通常指的是一种第三方软件,它可以修改游戏数据、操作游戏内存或提供其他作弊功能,从而给玩家带来不公平的优势。为了打击外挂,游戏开发者会采取一系列措…

H5元素形变

H5元素形变 一、缩放 语法: ​ transform:scale(缩放倍率) //整体缩放 ​ transform:scale(水平缩放倍率,垂直缩放倍率) //单独设置水平和垂直方向的缩放 ​ transform: scaleX(缩放倍率) //沿X轴缩放 ​ transform: scaleY(缩放倍率) //沿Y轴缩放…

Unity类银河恶魔城学习记录7-8 P74 Pierce sword源代码

Alex教程每一P的教程原代码加上我自己的理解初步理解写的注释,可供学习Alex教程的人参考 此代码仅为较上一P有所改变的代码 【Unity教程】从0编程制作类银河恶魔城游戏_哔哩哔哩_bilibili Sword_Skill.cs using System; using System.Collections; using System.C…

杰理701N可视化SDK之LED的配置和代码浅析

杰理701N可视化SDK LED的配置 LED硬件配置LED状态配置LED状态情景配置LED在SDK中相关代码 杰理可视化工具中可以配置LED的硬件配置和LED状态配置, 在可视化工具中的LED配置选项中设置 LED硬件配置 硬件配置可设置LED名, 推LED使用的IO口以及LED的点亮方式 SDK发布的标准原理…

Ubuntu中添加和修改Apt Repository

使用Ubuntu Software Center或 apt/apt-get等命令行工具安装软件包时,软件包是从一个或多个 apt 软件库(software repositories)下载的。APT repository是一个网络服务器或本地目录,其中包含可被 APT 工具读取的 deb 软件包和元数…

Linux之项目部署与发布

目录 一、Nginx配置安装(自启动) 1.一键安装4个依赖 2. 下载并解压安装包 3. 安装Nginx 4. 启动 nginx 服务 5. 对外开放端口 6. 配置开机自启动 7.修改/etc/rc.d/rc.local的权限 二、后端部署tomcat负载均衡 1. 准备2个tomcat 2. 修改端口 3…

Linux笔记之LD_LIBRARY_PATH详解

Linux笔记之LD_LIBRARY_PATH详解 code review! 文章目录 Linux笔记之LD_LIBRARY_PATH详解1.常见使用命令来设置动态链接库路径2.LD_LIBRARY_PATH详解设置 LD_LIBRARY_PATH举例注意事项 3.替代方案使用标准路径编译时指定链接路径优先使用 rpath 还是 runpath?注意…

嵌入式软件分层设计的思想分析

“嵌入式开发&#xff0c;点灯一路发” 那今天我们就以控制LED闪烁为例&#xff0c;来聊聊嵌入式软件分层: ——————————— | | | P1.1 |-----I<|--------------<| | | | P2.1 |-------------/ ---------…

【JavaEE】_synchronized关键字——监视器锁monitor lock

目录 1. synchronized的特性 2. synchronized的使用 3. Java标准库中的线程安全类 1. synchronized的特性 &#xff08;1&#xff09;互斥&#xff1a; 前文已经介绍&#xff0c;某个线程执行到某个对象的synchronized中时&#xff0c;其他线程如果也执行到同一个对象&…

Git笔记——4

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言 一、操作标签 二、推送标签 三、多人协作一 完成准备工作 协作开发 将内容合并进master 四、多人协作二 协作开发 将内容合并进master 五、解决 git branch -a…

第十二章 Linux——日志管理

第十二章 Linux——日志管理 基本介绍系统常用日志日志管理服务日志轮替基本介绍日志轮替文件命名logrotate配置文件自定义加入日志轮转应用实例 日志轮替机制原理查看内存日志 基本介绍 日志文件是重要的系统信息文件&#xff0c;其中记录了许多重要的系统事件&#xff0c;包…

【操作系统】磁盘文件管理系统

实验六 磁盘文件管理的模拟实现 实验目的 文件系统是操作系统中用来存储和管理信息的机构&#xff0c;具有按名存取的功能&#xff0c;不仅能方便用户对信息的使用&#xff0c;也有效提高了信息的安全性。本实验模拟文件系统的目录结构&#xff0c;并在此基础上实现文件的各种…

[c++] 工厂模式 + cyberrt 组件加载器分析

使用对象的时候&#xff0c;可以直接 new 一个&#xff0c;为什么还需要工厂模式 &#xff1f; 工厂模式属于创建型设计模式&#xff0c;将对象的创建和使用进行解耦&#xff0c;对用户隐藏了创建逻辑。 个人感觉上边的表述并没有说清楚为什么需要使用工厂模式。因为使用 new 创…

12个的无时间限制的录屏软件详细比较

您可能尝试过许多录制程序&#xff0c;但大多数都会在30 分钟后停止录制萤幕。如果您需要录制较长的内容&#xff0c;特别是为公司会议或简报进行录制&#xff0c;您将必须找到最好的没有时间限制的录屏软件。这款录音软体可以让您长时间录音而没有任何麻烦。下面列出了12 款无…

亚马逊产品数据抓取

抓取数据 启动抓取 &#xff0c;亚马逊平台前台网站中可以查看、抓取、分析的一系列数据源&#xff0c;其数据种类繁多&#xff0c;本系统主要抓取产品列表&#xff08;包含主图、标题、价格、review分值、prime服务信息等&#xff09;、Listing详情信息&#xff08;包含5点描…

MyBatis---初阶

一、MyBatis作用 是一种更简单的操作和读取数据库的工具。 二、MyBatis准备工作 1、引入依赖 2、配置Mybatis(数据库连接信息) 3、定义接口 Mapper注解是MyBatis中用来标识接口为Mapper接口的注解。在MyBatis中&#xff0c;Mapper接口是用来定义SQL映射的接口&#xff0c;通…