Pytorch深度学习实践(5)逻辑回归

逻辑回归

逻辑回归主要是解决分类问题

  • 回归任务:结果是一个连续的实数
  • 分类任务:结果是一个离散的值

分类任务不能直接使用回归去预测,比如在手写识别中(识别手写 0 − − 9 0 -- 9 09),因为各个类别之间没有大小之差。

因此,对于分类问题,我们最终的输出是个概率,即属于某个类别的概率是多少,然后从概率集合里找最大值,作为当前预测的结果

下载MNIST数据集

import torchvision
train_set = torchvision.dataset.MNIST(root="../dataset/mnist", train=True, download=True)
test_set = torchvision.dataset.MNIST(root="../dataset/mnist", train=False, dowload=True)
  • 通过train参数来指定训练集和测试集

逻辑回归

将之前的学习时长—考试分数转化为二分类任务,即学习时长—是否通过考试

x(hours)y(pass/fail)
10(fail)
20(fail)
31(pass)
4?

其中, P ( y ^ = 1 ) + P ( y ^ = 0 ) = 1 P(\hat y = 1) + P(\hat y = 0) = 1 P(y^=1)+P(y^=0)=1

当输出的概率在 0.5 0.5 0.5附近时,即模型不确定,因此通常会输出一个不确定的值

对于二分类任务,逻辑回归会先使用回归,生成一个得分值,即落在实数集区间内,然后再使用 s i g m o i d sigmoid sigmoid函数,将得分值映射到 [ 0 , 1 ] [0, 1] [0,1]区间内,得到预测概率

s i g m o i d sigmoid sigmoid函数
σ ( x ) = 1 1 + e − x \sigma (x) = \frac{1}{1+e^{-x}} σ(x)=1+ex1
在这里插入图片描述

S i g m o i d Sigmoid Sigmoid常用来做二分类任务,常具备三个特征:

  • 饱和函数
  • 单调递增
  • 有极限

当我们使用线性回归来得到逻辑回归的得分值时,逻辑回归模型的函数定义就如下所示:
y ^ = σ ( x ∗ ω + b ) \hat y = \sigma (x*\omega + b) y^=σ(xω+b)

损失函数

线性回归使用的损失函数是计算预测值和真实值之差

而对于逻辑回归,由于我们得到的是概率,是一个 0 − 1 0-1 01分布,因此需要修改损失函数
l o s s = − ( y l o g y ^ + ( 1 − y ) l o g ( 1 − y ^ ) ) loss = -(ylog\hat y + (1-y)log(1-\hat y)) loss=(ylogy^+(1y)log(1y^))
即我们比较的是分布之间的差异

交叉熵 c r o s s − e n t r o p y cross-entropy crossentropy

存在两个分布 P D 1 ( x ) P_{D1}(x) PD1(x) P D 2 ( x ) P_{D2}(x) PD2(x)

两个分布的差异程度使用公式: ∑ i = 1 n P D 1 ( x i ) l n P D 2 ( x i ) \sum_{i=1}^{n}P_{D1}(x_i)lnP_{D2}(x_i) i=1nPD1(xi)lnPD2(xi) 来衡量

上述公式越大时,两个分布的差异越小

模型的改变

模型构造的改变
class LogisticRegressionModel(torch.nn.Module):def __init__(self):super(LogisticRegressionModel, self).__init__()self.linear = torch.nn.Linear(1, 1)# 由于逻辑回归中Sigmoid函数不需要传参 所以在forward中直接计算即可# 在这里不需要实例化def forward(self, x):y_pred = F.sigmoid(self.linear(x))return y_pred

需要先将输入写入到linear()线性模型中,再使用Sigmoid()函数

模型损失函数的改变

使用交叉熵函数BCELoss

criterion = torch.nn.BCELoss(size_average=False)

整体代码

import torch
import matplotlib.pyplot as plt
import numpy as np########## 数据集准备 ##########
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])########## 模型定义 ##########
class LogisticRegressionModel(torch.nn.Module):def __init__(self):super(LogisticRegressionModel, self).__init__()self.linear = torch.nn.Linear(1, 1)# 由于逻辑回归中Sigmoid函数不需要传参 所以在forward中直接计算即可# 在这里不需要实例化def forward(self, x):y_pred = torch.sigmoid(self.linear(x))return y_predmodel = LogisticRegressionModel()########## 损失函数和优化器的设置 ##########
criterion = torch.nn.BCELoss(size_average=False) # BCELoss -- 交叉熵函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)########## 模型训练 ##########
for epoch in range(1000):y_pred = model(x_data)loss = criterion(y_pred, y_data)print(epoch, loss.item())optimizer.zero_grad()loss.backward()optimizer.step()########## 模型测试 ##########
x = np.linspace(0, 10, 200)
x_test = torch.Tensor(x).view((200, 1)) # view()相当于reshape
y_test = model(x_test)
y = y_test.data.numpy()  # 转化为np类型
plt.plot(x, y)
plt.plot([0, 10], [0.5, 0.5], 'r--')
plt.xlabel("Hours")
plt.ylabel("Probability of Pass")
plt.grid()
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3269172.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

动态获取配置文件中的配置参数,当配置文件中的参数修改之后,不需要重启项目

这里写目录标题 一、本地开发环境二、nocas环境配置 一、本地开发环境 如果是在本地开发环境中,读取的配置文件是本地根目录下的application.properties文件: 路径为配置文件的绝对路径。 配置文件里面配置的参数需要和获取的参数名称相互对应 通过Au…

linux怎么创建python

第一步,创建一个test文件夹。 第二步,打开终端进入该文件。 第三步,vim test.py。 第四步,编写代码。 第五步,编辑好之后,按Esc键切换到命令模式,然后输入:wq,再按回车键即可自动保存…

Java突击复习小练习,附带讲解

练习: 1.下面的代码在主方法中可以正常执行吗,如果不能,为什么? 2.已知i的值为10,请问在情况一和情况二中j的值是否相同呢?如果不相同,它们的值分别是多少呢?为什么? 3.在主方法运…

3D打印:重塑模具制造业的创新引擎

在科技浪潮的推动下,3D打印技术正以前所未有的速度渗透到制造业的核心,尤其在模具制造领域,它正引领一场深刻的创新革命。该技术通过颠覆传统制造范式,显著优化了模具生产的复杂流程,实现了从设计到成品的一体化的高效…

系统架构设计师教程 第4章 信息安全技术基础知识-4.3 信息安全系统的组成框架4.4 信息加解密技术-解读

系统架构设计师教程 第4章 信息安全技术基础知识-4.3 信息安全系统的组成框架 4.3 信息安全系统的组成框架4.3.1 技术体系4.3.1.1 基础安全设备4.3.1.2 计算机网络安全4.3.1.3 操作系统安全4.3.1.4 数据库安全4.3.1.5 终端安全设备4.3.2 组织机构体系4.3.3 管理体系4.4 信息加…

【PostgreSQL 16】专栏日常

本专栏从 3 个月前开始着手准备&#xff0c;利用周末及节假日的时间来整理。 ldczzDESKTOP-HVJOUVN MINGW64 ~/mypostgres (dev) $ git lg |tee * 7a7f468 - (HEAD -> dev, origin/main, origin/dev, main) 完成服务端编程的初步整理 (6 minutes ago) <Laven Liu> * …

masscan 端口扫描——(Golang 简单使用总结)

1. 前言 最近要做一个扫描 ip 端口的功能 扫描的工具有很多&#xff0c;但是如何做到短时间扫描大量的 ip 是个相对困难的事情。 市场上比较出名的工具有 masscan和nmap masscan 支持异步扫描&#xff0c;对多线程的利用很好&#xff0c;同时仅仅支持 syn 半开扫描&#xff…

GraphHopper-map-navi_路径规划、导航(web前端页面版)

文章目录 一、项目地址二、踩坑环境三、问题记录3.1、graphhopper中地图问题3.1.1. getOpacity不存在的问题3.1.2. dispatchEvent不存在的问题3.1.3. vectorLayer.set(background-maplibre-layer, true)不存在set方法3.1.4. maplibre-gl.js.map不存在的问题3.1.5. Uncaught Ref…

聊聊基于Alink库的特征工程方法

独热编码 OneHotEncoder 是用于将类别型特征转换为独热编码的类。独热编码是一种常用的特征编码方式&#xff0c;特别适用于处理类别型特征&#xff0c;将其转换为数值型特征。 对于每个类别型特征&#xff0c;OneHotEncoder 将其编码成一个长度为类别数量的向量。 每个类别对…

在线教育数仓项目(数据采集部分1)

文章目录 数据仓库概念项目需求及架构设计项目需求分析系统数据流程设计框架版本选型集群规模估算集群资源规划设计 数据生成模块目标数据页面事件曝光启动播放错误 数据埋点主流埋点方式&#xff08;了解&#xff09;埋点数据上报时机埋点数据日志结构 服务器和JDK准备服务器准…

JMeter接口测试:测试中奖概率!

介绍 Apache JMeter 是 Apache 组织基于 Java 开发的压力测试工具&#xff0c;用于对软件做压力测试。JMeter 最初被设计用于 Web 应用测试&#xff0c;但后来扩展到了其他测试领域&#xff0c;可用于测试静态和动态资源&#xff0c;如静态文件、Java 小服务程序、CGI 脚本、J…

【机器学习】解开反向传播算法的奥秘

&#x1f308;个人主页: 鑫宝Code &#x1f525;热门专栏: 闲话杂谈&#xff5c; 炫酷HTML | JavaScript基础 ​&#x1f4ab;个人格言: "如无必要&#xff0c;勿增实体" 文章目录 解开反向传播算法的奥秘反向传播算法的概述反向传播算法的数学推导1. 前向传播2…

【计算机网络】WireShark和简单http抓包实验

一&#xff1a;实验目的 1&#xff1a;熟悉WireShark的安装流程和界面操作流程。 2&#xff1a;学会简单http的抓取和过滤&#xff0c;并分析导出结果。 二&#xff1a;实验仪器设备及软件 硬件&#xff1a; Windows 2019操作系统的计算机等。 软件&#xff1a;WireShark、…

草图也能秒变完整画稿?三星 Galaxy Z Fold6 、Flip6硬件升级

在科技的不断进步中&#xff0c;智能手机行业的竞争愈发激烈&#xff0c;各大厂商纷纷推出创新产品以吸引消费者。 最近&#xff0c;三星在 Galaxy Unpacked 发布会上就带来了 Galaxy Z Fold6 和 Flip6 两款手机新品&#xff0c;这两款设备不仅在硬件上有所突破&#xff0c;更…

docker dotnet-dump离线部署

1.下载指定dotnet版本的dotnet-dump 示例地址&#xff1a; https://www.nuget.org/packages/dotnet-dump/3.1.141901#dependencies-body-tab 我本地测试的是netcore 3.1 2. 在本地解压 将文件解压出来。看到any目录,能看到我们要用的dotnet-dump文件 3. 将tools/netcoreapp2.…

C++文件系统操作6 - 跨平台实现查找指定文件夹下的特定文件

1. 关键词 C 文件系统操作 查找指定文件夹下的特定文件 跨平台 2. fileutil.h #pragma once#include <string> #include <cstdio> #include <cstdint> #include "filetype.h" #include "filepath.h"namespace cutl {/*** brief The fi…

【吊打面试官系列-Dubbo面试题】服务调用是阻塞的吗?

大家好&#xff0c;我是锋哥。今天分享关于 【服务调用是阻塞的吗&#xff1f;】面试题&#xff0c;希望对大家有帮助&#xff1b; 服务调用是阻塞的吗&#xff1f; 默认是阻塞的&#xff0c;可以异步调用&#xff0c;没有返回值的可以这么做。 Dubbo 是基于 NIO 的非阻塞实现…

Axious的请求与响应

Axious的请求与响应 1.什么是Axious Axious是一个开源的可以用在浏览器和Node.js的异步通信框架&#xff0c;它的主要作用就是实现AJAX异步通信&#xff0c;其功能特点如下&#xff1a; 从浏览器中创建XMLHttpRequests ~从node.js创建Http请求 支持PromiseAPI 拦截请求和…

深入解析AI技术:从深度学习到GPT大模型的全面探索

深入解析AI技术&#xff1a;从深度学习到GPT大模型的全面探索 引言 在21世纪的科技浪潮中&#xff0c;人工智能&#xff08;AI&#xff09;无疑是最引人注目的领域之一。从简单的语音助手到复杂的自动驾驶系统&#xff0c;AI正以前所未有的速度改变着我们的世界。而深度学习&a…

【Python机器学习】朴素贝叶斯——使用朴素贝叶斯进行文档分类(理论基础)

机器学习的一个重要应用就是文档的自动分类。在文档分类中&#xff0c;整个文档&#xff08;比如电子邮件&#xff09;是实例&#xff0c;而电子邮件中的某些元素则构成特征。虽然电子邮件是一种会不断增加的文本&#xff0c;但我们同样也可以对新闻报道、用户流言、公文等其他…