动手学深度学习——多层感知机

1. 感知机

感知机本质上是一个二分类问题。给定输入x、权重w、偏置b,感知机输出:

以猫和狗的分类问题为例,它本质上就是找到下面这条黑色的分割线,使得所有的猫和狗都能被正确的分类。

与线性回归和softmax的不同点:

  • vs 线性回归:输出的都是一个数,但线性回归输出的是实数,而感知机输出的是离散的分类。
  • vs softmax: softmax是一个多分类(如果有n个分类,softmax就会输出n个元素),而感知机只输出一个元素。

感知机存在的问题: 它只能产生线性分割面,对于XOR(异或)函数,无法拟合(一条线不论怎么分割,都无法将绿色和红色分类正确)。

2. 多层感知机(MLP)

对于上面单层感知机的问题,一个改进思想是:一层函数如果做不了,就用多层函数来做,而多层就带来了网络,用不同层解决不同的问题,多层配合来解决更复杂的问题。

可以使用蓝线对所有数据进行x轴方向的正负分类,再使用黄线对所有数据进行y轴方向的正负分类,最后再将两次分类结果进行xor运算就能得到结果。

多层感知机使用隐藏层和激活函数来得到非线性模型。

在softmax基础上多了隐藏层。可选超参:

  • 隐藏层数
  • 每个隐藏层的宽度,通常选择2的若干次冥作为层的宽度

这两个参数的选择取决于输入和输出的复杂度

对复杂的输入,输入维度一般比较高,输出一般会比较少,有两种处理办法:

  1. 做单隐藏层,把模型做平,层的大小设大一点
  2. 做多隐藏层,把模型做深,层的大小可以设小一点,每层的维度逐步减少(如果每层维度都高,则会导致模型太大)

复杂输入到简单输出本质上是一个信息压缩的过程,多层逐步压缩能避免一次压缩太大导致信息损失太严重,例如:128->64->32->16->8
也可以先expand,从128->256->64->32->16->8

3. 激活函数

作用:在神经网络中引入非线性,可以理解为一个开关,当输入信号超过一定阀值时,神经元会被激活并产生输出,而未超过阀值时神经元将会被抑制。

在没有激活函数的情况下,神经网络只能表示线性映射,无法处理复杂的非线性关系。激活函数的作用就是线性结果映射到一个非线性的输出,以帮助神经网络更好的适应输入数据,提高非线性拟合能力。

举例:一个邮件过滤模型中的神经元,负责对输入邮件的特征(长度、关键词等)进行加权求和,但这个结果只是一个连续的数值我们交

激活函数不能是线性函数,否则会变成单层感知机,依然会存在线性分割面无法处理XOR的问题。

激活函数主要作用于隐藏层。

激活函数的几种选择:

  1. sigmoid: 对于任意输入x,都能投影到0~1区间内。

  2. tanh(x): 将输入投影到[-1,1]区间内

  1. ReLU: 就是一个Max函数(常用),特点是计算很快,相比前面基于指数运算的sigmoid和tanh函数都快很多(一次指数运算要100个时钟周期)

对ReLU函数求导,小于等于0时都是0,大于0时都是1,最终结果就是一个二分类。

4. 代码实现

4.1 初始化参数

我们将实现一个具有单隐藏层的多层感知机, 这个隐藏层包含128个隐藏单元。

对于每一层我们都要记录一个权重矩阵和一个偏置向量,并指定requires_grad=True来记录参数梯度。

import torch
from torch import nn
from d2l import torch as d2lnum_inputs, num_outputs, num_hiddens = 784, 10, 128W1 = nn.Parameter(torch.randn(num_inputs, num_hiddens, requires_grad=True) * 0.01)
b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad=True))
W2 = nn.Parameter(torch.randn(num_hiddens, num_outputs, requires_grad=True) * 0.01)
b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))params = [W1, b1, W2, b2]

通常,我们选择2的若干次幂作为层的宽度。 因为内存在硬件中的分配和寻址方式,这么做往往可以在计算上更高效。

4.2 加载数据集

这里继续使用Fashion-MNIST图像分类数据集。

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

4.3 激活函数

Relu函数的实现比较简单,就是一个max函数的调用, 它将输入的负值部分截断为0,保留正值部分不变。

def relu(X):a = torch.zeros_like(X)return torch.max(X, a)
  • torch.zeros_like(X): 创建了一个与X具有相同形状的全零张量a。
  • torch.max(X, a): 对于输入X中的每个元素,如果它是正值,则该元素保留不变;如果它是负值,则将其替换为0。

4.4 模型

def net(X):X = X.reshape((-1, num_inputs))    H = relu(X@W1 + b1)  # 隐藏层,这里“@”代表矩阵乘法return (H@W2 + b2)   # 输出层
  1. 使用reshape将输入的二维图像转换为一个长度为num_inputs=784的向量;
  2. 用ReLu函数对隐藏层的线性输出进行激活,得到输出张量H;
  3. 最后,由张量H和权重矩阵W2进行矩阵乘法操作,将偏置向量b2加到结果上,得到预测输出结果。

4.5 损失函数

这里直接使用pytorch中内置的交叉熵损失函数。

loss = nn.CrossEntropyLoss(reduction='none')

4.6 训练

多层感知机的训练过程与softmax的训练过程完全相同,可以直接调用之前定义过的train_ch3函数。

# 将迭代周期数设置为10,并将学习率设置为0.1.
num_epochs, lr = 10, 0.1
updater = torch.optim.SGD(params, lr=lr)
train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

训练过程中的模型损失和精度的收敛变化:

epoch: 1, loss: 1.1021366075515746, test_acc: 0.7544
epoch: 2, loss: 0.6142196039199829, test_acc: 0.8004
epoch: 3, loss: 0.5257990721384684, test_acc: 0.8061
epoch: 4, loss: 0.4842481053034465, test_acc: 0.7988
epoch: 5, loss: 0.4575055497487386, test_acc: 0.8266
epoch: 6, loss: 0.4389862974802653, test_acc: 0.8382
epoch: 7, loss: 0.42252545185089113, test_acc: 0.8443
epoch: 8, loss: 0.40933472124735515, test_acc: 0.8458
epoch: 9, loss: 0.3975078603744507, test_acc: 0.8467
epoch: 10, loss: 0.38488629398345947, test_acc: 0.8527

基于之前softmax模型上定义的预测函数,在测试数据集上使用这个模型做验证:

predict_ch3(net, test_iter)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3029763.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

一文彻底读懂信息安全等级保护:包含等保标准、等保概念、等保对象、等保流程及等保方案(附:等保相关标准文档)

1. 什么是等级保护? 1.1. 概念 信息安全等级保护是指根据我国《信息安全等级保护管理办法》的规定,对各类信息系统按照其重要程度和保密需求进行分级,并制定相应的技术和管理措施,确保信息系统的安全性、完整性、可用性。根据等…

通俗的理解网关的概念的用途(四):什么是网关设备?(网络层面)

任何一台Windows XP操作系统之后的个人电脑、Linux操作系统电脑都可以简单的设置,就可以成为一台具备“网关”性质的设备,因为它们都直接内置了其中的实现程序。MacOS有没有就不知道,因为没用过。 简单的理解,就是运行了具备第二…

串口初始化自己独立的见解--第九天

1.SM0,SM1 我们一般用 8位UART,波特率可变 (方式1的工作方式) SCON :SM2 一般不用,SM0 0 ,SM1 1 PCON : 有两位 我们不动它,不加速,初始值 TMOD:8位自动重装定时器&#xff0…

Linux 安装JDK和Idea

安装JDK 下载安装包 下载地址: Java Downloads | Oracle (1) 使用xshell 上传JDK到虚拟机 (2) 移动JDK 包到/opt/environment cd ~ cd /opt sudo mkdir environment # 在 /opt下创建一个environment文件夹 ls# 复制JDK包dao /opt/environment下 cd 下载 ls jd…

信息系统架构模型_1.单机应用模式和客户机/服务器模式

1.单机应用模式(Standalone) 单机应用系统是最简单的软件结构,是指运行在一台物理机器上的独立应用程序。这些软件系统,从今天的软件架构上来讲,是很简单,是标准的单机系统。当然至今,这种复杂的…

Blazor入门-基础知识+vs2022自带例程的理解

参考: Blazor 教程 - 生成首个应用 https://dotnet.microsoft.com/zh-cn/learn/aspnet/blazor-tutorial/intro Blazor基础知识:Visual Studio 2022 中的Blazor开发入门_vs2022 blazor webassembly-CSDN博客 https://blog.csdn.net/mzl87/article/detail…

如何应对Android面试官 -> WindowManagerService 启动流程分析

前言 本章主要从上面几个角度来讲解 WindowManagerService; 相关概念 介绍 WMS 之前,我们先来介绍几个相关的概念; WMS 存在于 system_server 系统服务进程,view 存在于 app 进程,所有的窗口最终都是通过 wms 来进行…

【算法与数据结构】数组

文章目录 前言数组数组的定义数组的基本操作增加元素删除元素修改元素查找元素 C STL 中的数组arrayvector Python3 中的列表访问更改元素值遍历列表检查列表中是否存在某元素增加元素删除元素拷贝列表总结 Python3 列表的常用操作 参考资料写在最后 前言 本系列专注更新基本数…

uniapp的app端推送功能,不使用unipush

1&#xff1a;推送功能使用htmlPlus实现&#xff1a;地址HTML5 API Reference (html5plus.org) 效果图&#xff1a; 代码实现&#xff1a; <template><view class"content"><view class"text-area"><button click"createMsg&q…

跨界内容营销:Kompas.ai如何帮助你的品牌打破行业边界

在当今多元化的市场环境中&#xff0c;跨界营销已成为品牌拓展影响力和用户基础的重要策略。通过跨界合作&#xff0c;品牌能够打破行业界限&#xff0c;创造独特的用户体验&#xff0c;从而提升品牌形象和市场竞争力。本文将深入分析跨界营销的作用&#xff0c;详细介绍Kompas…

AI-powered的搜索引擎:Perplexity 与知识工作者

Perplexity是一款AI-powered的搜索引擎&#xff0c;通过与OpenAI合作&#xff0c;利用GPT模型提供高速、准确的搜索结果&#xff0c;特别针对知识工作者的需求进行优化。 知识工作者通常需要进行复杂的研究和决策&#xff0c;他们希望能够快速获取准确的信息来支持他们的工作。…

buuctf-misc题目练习三

荷兰宽带数据泄露 BIN 文件&#xff0c;也称为二进制文件&#xff0c;是一种压缩文件格式&#xff0c;可以 包含图像和视频等信息 , 并被许多应用程序用于各种目的。 RouterPassView是一个找回路由器密码的工具。 大多数现代路由器允许备份到一个文件路由器的配置&#xff0c…

Unity TileMap入门

概述 相信很多同学学习制作游戏都是从2D游戏开始制作的吧&#xff0c;瓦片地图相信大家都有接触&#xff0c;那接下来让我们学习一下这部分的内容吧&#xff01; Tilemap AnimationFrameRate:设置每帧动画的播放速率。Color:瓦片地图的颜色TileAnchor:锚点&#xff0c;&#x…

100000订单直接拒掉,君子爱财,取之有道

近一个月询盘可谓寥寥无几&#xff0c;成交率为0&#xff0c;今天好不容易接了一个客户询盘&#xff0c;订单总价高达100000&#xff0c;听完细节直接拒掉&#xff0c;至于原因懂的都懂&#xff0c;不懂得等我慢慢道来。 前两天有2个询盘&#xff0c;其中一个是二次开发&#x…

正点原子[第二期]Linux之ARM(MX6U)裸机篇学习笔记-15.4讲 GPIO中断实验-IRQ中断服务函数详解

前言&#xff1a; 本文是根据哔哩哔哩网站上“正点原子[第二期]Linux之ARM&#xff08;MX6U&#xff09;裸机篇”视频的学习笔记&#xff0c;在这里会记录下正点原子 I.MX6ULL 开发板的配套视频教程所作的实验和学习笔记内容。本文大量引用了正点原子教学视频和链接中的内容。…

如何利用代理IP进行SEO优化?

“SEO”这个词相信对于做在线业务的朋友来说一定不陌生。 在网络营销中&#xff0c;SEO是至关重要的一环&#xff0c;对于增加有机流量、提升品牌知名度、增加网站的信任度和权威性非常有效。而代理IP在SEO优化中有着不可或缺的作用&#xff0c;它可以帮助网站管理员和SEO专家…

数据中心法

数据中心法是实现词法分析器的结构化方法。通过设计主表和子表分开存储状态转移信息&#xff0c;实现词法分析器的控制逻辑和数据结构分离。 主要解决了状态爆炸、难以维护和复杂性的问题。 状态爆炸是指当状态和转移较多时&#xff0c;单一使用一个表来存储所有的信息的话会导…

这3种深拷贝实现,你都知道吗?

目录&#xff1a; 1、JSON.parse 2、structuredClone 3、cloneDeep

Leetcode—138. 随机链表的复制【中等】(cend函数)

2024每日刷题&#xff08;129&#xff09; Leetcode—138. 随机链表的复制 实现代码 /* // Definition for a Node. class Node { public:int val;Node* next;Node* random;Node(int _val) {val _val;next NULL;random NULL;} }; */class Solution { public:Node* copyRan…