简单的神经网络

一、softmax的基本概念

我们之前学过sigmoid、relu、tanh等等激活函数,今天我们来看一下softmax。

先简单回顾一些其他激活函数:

  1. Sigmoid激活函数:Sigmoid函数(也称为Logistic函数)是一种常见的激活函数,它将输入映射到0到1之间。它常用于二分类问题中,特别是在输出层以概率形式表示结果时。Sigmoid函数的优点是输出值限定在0到1之间,相当于对每个神经元的输出进行了归一化处理。
  2. Tanh激活函数:Tanh函数(双曲正切函数)将输入映射到-1到1之间。与Sigmoid函数相比,Tanh函数的中心点在零值附近,这意味着它的输出是以0为中心的。这种特性可以在某些情况下提供更好的性能。
  3. ReLU激活函数:ReLU(Rectified Linear Unit)函数是当前非常流行的一个激活函数,其表达式为f(x)=max(0, x)。ReLU函数的优点是计算简单,能够在正向传播过程中加速计算。此外,ReLU函数在正值区间内梯度为常数,有助于缓解梯度消失问题。但它的缺点是在负值区间内梯度为零,这可能导致某些神经元永远不会被激活,即“死亡ReLU”问题。

Softmax函数是一种在机器学习中广泛使用的函数,尤其是在处理多分类问题时。它的主要作用是将一组未归一化的分数转换成一个概率分布。Softmax函数的一个重要性质是其输出的总和等于1,这符合概率分布的定义。这意味着它可以将一组原始分数转换为概率空间,使得每个类别都有一个明确的概率值。

  • 二分类问题选择sigmoid激活函数

  • 多分类问题选择softmax激活函数

二、交叉熵损失函数

交叉熵损失函数的公式可以分为二分类和多分类两种情况。对于二分类问题,假设我们只考虑正类(标签为1)和负类(标签为0)在多分类问题中,交叉熵损失函数可以扩展为−∑𝑖=1𝐾𝑦𝑖⋅log⁡(𝑝𝑖)−∑i=1K​yi​⋅log(pi​),其中𝐾K是类别的总数,( y_i )是样本属于第𝑖i个类别的真实概率(通常用one-hot编码表示),而𝑝𝑖pi​是模型预测该样本属于第( i )个类别的概率。

import torch
from torch import nn# 确定随机数种子
torch.manual_seed(7)
# 自定义数据集
X = torch.rand((7, 2, 2))
target = torch.randint(0, 2, (7,))

定义网络结构

  • 一层全连接层 + Softmax层
  • x1𝑥1,x2𝑥2,x3𝑥3,x4𝑥4为 X
  • o1𝑜1,o2𝑜2,o3𝑜3为 target
class LinearNet(nn.Module):def __init__(self):super(LinearNet, self).__init__()# 定义一层全连接层self.dense = nn.Linear(4, 3)# 定义Softmaxself.softmax = nn.Softmax(dim=1)def forward(self, x):y = self.dense(x.view((-1, 4)))y = self.softmax(y)return ynet = LinearNet()
  •  nn.Softmax(dim=1)用于计算输入张量在指定维度上的softmax激活。dim=1表示沿着第二个维度(即列)进行softmax操作。

定义损失函数和优化函数

  • torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')
  • 衡量模型输出与真实标签的差异,在分类时相当有用。
  • 结合了nn.LogSoftmax()和nn.NLLLoss()两个函数,进行交叉熵计算。
loss = nn.CrossEntropyLoss()  # 交叉熵损失函数
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)  # 随机梯度下降法

训练模型

for epoch in range(70):train_l = 0.0y_hat = net(X)l = loss(y_hat, target).sum()# 梯度清零optimizer.zero_grad()# 自动求导梯度l.backward()# 利用优化函数调整所有权重参数optimizer.step()train_l += lprint('epoch %d, loss %.4f' % (epoch + 1, train_l))

三、自动微分模块

torch.autograd.backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False)  :自动求取梯度

  • grad_tensors:多梯度权重
  • create_graph:创建导数计算图,用于高阶求导
  • retain_graph:保存计算图
  • tensors:用于求导的张量,如 loss
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)y.backward(retain_graph=True)

 注意点:

  1. 梯度不自动清零
  2. 依赖于叶子节点的节点,requires_grad默认为True
  3. 叶子节点不可执行in-place

神经网络全连接层: 每个神经元都与前一层的所有神经元相连接。全连接层通常用于网络的最后几层,它将之前层(如卷积层和池化层)提取的特征进行整合,以映射到样本标记空间,即最终的分类或回归结果。

关于loss.backward()方法:

主要作用就是计算损失函数对模型参数的梯度,loss.backward()实现了反向传播算法,它通过链式法则计算每个模型参数相对于最终损失的梯度。这个过程从输出层开始,向后传递到输入层,逐层计算梯度。

过程:得到每个参数相对于损失函数的梯度,这些梯度信息会存储在对应张量的.grad属性中。loss.backward本身不负责更细权重,但它为权重更新提供了梯度值,方便配合optimizer.step()来更新参数。

前向传播过程中,数据从输入层流向输出层,并生成预测结果;而在反向传播过程中,误差(即预测值与真实值之间的差距,也就是损失函数的值)会从输出层向输入层传播,逐层计算出每个参数相对于损失函数的梯度。这些梯度指示了如何调整每一层中的权重和偏置,以最小化损失函数。

  • 损失函数衡量了当前模型预测与真实情况之间的不一致程度,而梯度则提供了损失函数减少最快的方向。

建立一个简单的全连接层:

import torch
import torch.nn as nn# 定义一个简单的全连接层模型
class SimpleFC(nn.Module):def __init__(self, input_size, output_size):super(SimpleFC, self).__init__()self.fc = nn.Linear(input_size, output_size)def forward(self, x):  return self.fc(x)# 创建输入数据和目标输出
input_data = torch.tensor([[1.0, 2.0, 3.0]])
target_output = torch.tensor([[4.0, 5.0]])# 实例化模型、损失函数和优化器
model = SimpleFC(input_size=3, output_size=2)
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# 前向传播
output = model(input_data)# 计算损失
loss = criterion(output, target_output)# 反向传播
loss.backward()# 更新参数
optimizer.step()

当调用loss.backward()时,PyTorch会自动计算损失值关于模型参数的梯度,并将这些梯度存储在模型参数的.grad属性中。然后优化器(torch.optim.SGD)可以使用这些梯度来更新模型参数,以最小化损失函数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3031092.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

【回溯算法】【Python实现】符号三角形问题

文章目录 [toc]问题描述回溯法时间复杂性Python实现 问题描述 下图是由 14 14 14个“ ”和 14 14 14个“ − - −”组成的符号三角形, 2 2 2个同号下面都是” “, 2 2 2个异号下面都是“ − - −” 在一般情况下,符号三角形的第一行有 n…

机器学习-L1正则/L2正则

机器学习-L1正则/L2正则 目录 1.L1正则 2.L2正则 3.结合 1.L1正则 L1正则是一种用来约束模型参数的技术,常用于机器学习和统计建模中,特别是在处理特征选择问题时非常有用。 想象一下,你在装备行囊准备去旅行,但你的行囊有一…

第五十八节 Java设计模式 - 适配器模式

Java设计模式 - 适配器模式 我们在现实生活中使用适配器很多。例如,我们使用存储卡适配器连接存储卡和计算机,因为计算机仅支持一种类型的存储卡,并且我们的卡与计算机不兼容。 适配器是两个不兼容实体之间的转换器。适配器模式是一种结构模…

Ubuntu搭建VsCode C++ 开发环境

Ubuntu搭建VsCode C 开发环境 安装VS Code 使用命令来安装VS Code:他会下载vscode的最新版本。 sudo snap install --classic code如果不使用命令 的方式 在官网下载vscode安装包( 后缀为 .deb的包 )之后(可以选择版本 &#x…

YOLOv9独家原创改进: 特征融合创新 | 一种基于内容引导注意力(CGA)的混合融合 | IEEE TIP 2024 浙大

💡💡💡创新点:提出了一种基于内容引导注意力(CGA)的混合融合方案,将编码器部分的低级特征与相应的高级特征有效融合。 💡💡💡在多个数据集实现暴力涨点,适用于小目标,低对比度场景 💡💡💡如何跟YOLOv9结合:将backbone和neck的特征融合,改进结构图如下…

揭秘设计模式的魔法:打造高效、可维护的软件架构

设计模式是软件架构设计师的必修课,设计模式中蕴含的思想是架构设计师必须掌握的。毋庸置疑,良好的设计可以让系统更容易地被复用、被移植和维护,而如何快速进行良好的设计则离不开设计模式,尤其是面向对象设计和编程。 说到设计模…

用ps显示出淘宝裸眼3d立体画中的内容

淘宝前段时间在弄猜数字的游戏,其中有一题是3d立体画,如果我们把图片用ps处理一下,结果马上就出来了。打开原图,再复制进一个新图层,新图层混合模式选“差值”,左右移动新图层,就看到答案啦。 原…

Xilinx 千兆以太网TEMAC IP核用户接口信号

用户接口包括AX14-Stream发送接口和AX14-Stream接收接口,下文简称为用户发送接口和用户接收接口,数据案度可以是易位或16位,其中,8位接口主要针对标准的以太网应用,它利用一个125MHz的时钟产生1Gbps的数据率;当使用16位…

Redis20种使用场景

Redis20种使用场景 1缓存2抽奖3Set实现点赞/收藏功能4排行榜5PV统计(incr自增计数)6UV统计(HeyperLogLog)7去重(BloomFiler)8用户签到(BitMap)9GEO搜附近10简单限流11全局ID12简单分…

基于MWORKS 2024a的MIMO-OFDM 无线通信系统设计

一、引言 在终端设备和数据流量爆发式增长的今天,如何提升通信系统容量、能量效率和频谱利用率成为5G通信的关键问题之一。大规模天线阵列作为5G及B5G无线通信关键技术通过把原有发送端天线数量提升一个或多个数量级,实现波束聚集、控制波束转向&#x…

深入学习指针3

目录 前言 1.二级指针 2.指针数组 3.指针数组模拟二维数组 前言 Hello,小伙伴们我又来了,上期我们讲到了数组名的理解,指针与数组的关系等知识,那今天我们就继续深入到学习指针域数组的练联系,如果喜欢作者菌生产的内容还望不…

OmniPlan Pro 4 for Mac中文激活版:项目管理的新选择

OmniPlan Pro 4 for Mac作为一款专为Mac用户设计的项目管理软件,为用户提供了全新的项目管理体验。其直观易用的界面和强大的功能特性,使用户能够轻松上手并快速掌握项目管理要点。 首先,OmniPlan Pro 4 for Mac支持自定义视图,用…

Java框架精品项目【用于个人学习】

源码获取:私聊回复【项目关键字】获取 更多选题参考: Java练手项目 & 个人学习等选题参考 推荐菜鸟教程Java学习、Javatpoint学习 前言 大家好,我是二哈喇子,此博文整理了各种项目需求 此文下的项目用于博主自己学习&#x…

Kafka应用Demo:生产者自定义消息分区方法

背景 没有设置消息键时Kafka默认的分区算法是轮循,设置了消息键将按消息键的hashcode计算分区值。这种方法可以保证未设置消息键时各分区负载均衡。也可以保证设置消息键后的消息放到同一个分区发送,以保证消息按顺序消费。 但在某些业务场景下&#xff…

Java练手项目 个人学习等选题参考

难度系数说明: 难度系数用来说明项目本身进行分析设计的难度 难度系数大于1的项目是非常值得反复学习的,从项目中成长 前言 大家好,我是二哈喇子,此博文整理了各种项目需求 要从本篇文章下的项目中学习的思路: 用的…

大型动作模型 (LAM):AI 驱动的交互的下一个前沿

1.概述 现在人工智能中几个关键的领域,包括生成式人工智能(Generative AI)、大型动作模型(Large Action Models, LAM)、以及交互式人工智能(Interactive AI)。以下是对这些概念的简要解释和它们…

​​​【收录 Hello 算法】5.1 栈

目录 5.1 栈 5.1.1 栈的常用操作 5.1.2 栈的实现 1. 基于链表的实现 2. 基于数组的实现 5.1.3 两种实现对比 5.1.4 栈的典型应用 5.1 栈 栈(stack)是一种遵循先入后出逻辑的线性数据结构。 我们可以将栈类比为桌面上的一摞盘子…

hypack如何采集多波束数据?(上)

多波束设备有3种:多波束阵列,比如Seabat T50P;相干声纳,比如EdgeTeck 6205;多个单波束并列,比如Ross Sweep System,见下图。 辅助传感器主要有:罗经(提供航向&#xff09…

ubuntu server 22.04 安装docker、docker-compose

ubuntu server 22.04安装docker有两种方式,第一种是使用ubuntu镜像源的软件包进行安装,第二种使用官方GPG密钥手动添加Docker存储库方式进行安装,两种方式都可以,但第二种方式略复杂,这里介绍第一种比较简单的安装方式…

JavaScript基础(六)

break & continue continue跳出本次循环&#xff0c;继续下面的循环。 break跳出终止循环。 写个简单的例子: <script> for (var i1; i<5; i){ if (i3){ continue; } console.log(i); } </script> 结果就是跳过i等于3的那次循环&#xff0c;而break: f…