用最简单线性回归理解梯度下降

上篇文章,我们已经理解了线性回归

现根据线性回归去理解梯度下降

初始化数据

import numpy as npnp.random.seed(42)  # to make this code example reproducible
m = 100  # number of instances
X = 2 * np.random.rand(m, 1)  # column vector
y = 4 + 3 * X + np.random.randn(m, 1)  # column vector

可视化

import matplotlib.pyplot as pltplt.figure(figsize=(6, 4))
plt.plot(X, y, "b.")
plt.xlabel("x")
plt.ylabel("y")
plt.axis([0, 2, 0, 15])
plt.grid()
plt.show()

在这里插入图片描述
由文章线性回归可知,我们要得到线性模型,就是想画出一条预测线,如图
在这里插入图片描述
也就是说,我们要得到两个参数,对应y=ax+b中的a,b

梯度下降就是随便假设a,b,通过不断的计算,去逼近a,b真实值

梯度下降计算方法

几个概念

  • 学习率eta,每次调整多少,逼近值
  • theta = np.random.randn(2, 1),这里对应对应y=ax+b中的a,b,随便设置
  • n_epochs 迭代次数
  • 计算方法不用管,矩阵点积相关的计算,不重要(因为方法可以多种多样)
from sklearn.preprocessing import add_dummy_featureX_b = add_dummy_feature(X)  # add x0 = 1 to each instanceeta = 0.1  # learning rate
n_epochs = 1000
m = len(X_b)  # number of instancesnp.random.seed(42)
theta = np.random.randn(2, 1)  # randomly initialized model parametersfor epoch in range(n_epochs):gradients = 2 / m * X_b.T @ (X_b @ theta - y)theta = theta - eta * gradients

这里直接输出theta ,通过计算,我们已经得到了y=ax+b中的a,b,就是这么神奇

在这里插入图片描述
画出理解梯度预测线

X_new = np.array([[0], [2]])
X_new_b = add_dummy_feature(X_new)  # add x0 = 1 to each instance
y_predict = X_new_b @ theta
y_predict

输出
在这里插入图片描述
我们得到两个坐标,(0, 4.21509616),(2, 9.75532293),根据这两个点,我们就可以画出一条线

import matplotlib.pyplot as pltplt.figure(figsize=(6, 4))plt.plot(X_new, y_predict, "r-")
plt.plot(X, y, "b.")plt.xlabel("x")
plt.ylabel("y")
plt.axis([0, 2, 0, 15])
plt.grid()
plt.show()

在这里插入图片描述

学习率理解

# extra code – generates and saves Figure 4–8import matplotlib as mpldef plot_gradient_descent(theta, eta):m = len(X_b)plt.plot(X, y, "b.")n_epochs = 50theta_path = []for epoch in range(n_epochs):y_predict = X_new_b @ thetacolor = mpl.colors.rgb2hex(plt.cm.OrRd(epoch / 50.15))plt.plot(X_new, y_predict, linestyle="solid", color=color)gradients = 2 / m * X_b.T @ (X_b @ theta - y)theta = theta - eta * gradientstheta_path.append(theta)plt.xlabel("$x_1$")plt.axis([0, 2, 0, 15])plt.grid()plt.title(fr"$\eta = {eta}$")return theta_pathnp.random.seed(42)
theta = np.random.randn(2, 1)  # random initializationplt.figure(figsize=(10, 4))
plt.subplot(131)
plot_gradient_descent(theta, eta=0.02)
plt.ylabel("$y$", rotation=0)
plt.subplot(132)
theta_path_bgd = plot_gradient_descent(theta, eta=0.1)
plt.gca().axes.yaxis.set_ticklabels([])
plt.subplot(133)
plt.gca().axes.yaxis.set_ticklabels([])
plot_gradient_descent(theta, eta=0.5)
plt.show()

输出图
在这里插入图片描述

如图所示,学习率设置分别有:0.02,0.1, 0.5也就是逼近真实值a, b的速度,也就是收敛速度,如果设置太大了,如0.5,直接就错过了真实值a, b,如果太小,如0.02,要得到真实值a, b要计算很多次,如果epoch (上述是50次)要计算很多次,不一定会得到真实值a, b,0.1就比较合适。

结论:学习率设置很有必要

随机梯度下降 SGD理解

如何理解随机,也就是学习率动态在调整,对应learning_schedule方法,与前面相比,不是一个固定值

theta_path_sgd = []  # extra code – we need to store the path of theta in the#              parameter space to plot the next figuren_epochs = 50
t0, t1 = 5, 50  # learning schedule hyperparametersdef learning_schedule(t):return t0 / (t + t1)np.random.seed(42)
theta = np.random.randn(2, 1)  # random initializationn_shown = 20  # extra code – just needed to generate the figure below
plt.figure(figsize=(6, 4))  # extra code – not needed, just formattingfor epoch in range(n_epochs):for iteration in range(m):# extra code – these 4 lines are used to generate the figureif epoch == 0 and iteration < n_shown:y_predict = X_new_b @ thetacolor = mpl.colors.rgb2hex(plt.cm.OrRd(iteration / n_shown + 0.15))plt.plot(X_new, y_predict, color=color)random_index = np.random.randint(m)xi = X_b[random_index : random_index + 1]yi = y[random_index : random_index + 1]gradients = 2 * xi.T @ (xi @ theta - yi)  # for SGD, do not divide by meta = learning_schedule(epoch * m + iteration)theta = theta - eta * gradientstheta_path_sgd.append(theta)  # extra code – to generate the figure# extra code – this section beautifies and saves Figure 4–10
plt.plot(X, y, "b.")
plt.xlabel("$x_1$")
plt.ylabel("$y$", rotation=0)
plt.axis([0, 2, 0, 15])
plt.grid()
plt.show()

在这里插入图片描述
如图所示,收敛速度可以做到快慢的很好结合

输出y=ax+b中的a,b

在这里插入图片描述
当然我们也可以使用sklearn提供的方法SGDRegressor简化得出y=ax+b中的a,b

from sklearn.linear_model import SGDRegressorsgd_reg = SGDRegressor(max_iter=1000, tol=1e-5, penalty=None, eta0=0.01,n_iter_no_change=100, random_state=42)
sgd_reg.fit(X, y.ravel())  # y.ravel() because fit() expects 1D targetssgd_reg.intercept_, sgd_reg.coef_

输出
在这里插入图片描述

画一个a、b参数的学习过程

多加入一个min-SGD计算,加深理解

# extra code – this cell generates and saves Figure 4–11from math import ceiln_epochs = 50
minibatch_size = 20
n_batches_per_epoch = ceil(m / minibatch_size)np.random.seed(42)
theta = np.random.randn(2, 1)  # random initializationt0, t1 = 200, 1000  # learning schedule hyperparametersdef learning_schedule(t):return t0 / (t + t1)theta_path_mgd = []
for epoch in range(n_epochs):shuffled_indices = np.random.permutation(m)X_b_shuffled = X_b[shuffled_indices]y_shuffled = y[shuffled_indices]for iteration in range(0, n_batches_per_epoch):idx = iteration * minibatch_sizexi = X_b_shuffled[idx : idx + minibatch_size]yi = y_shuffled[idx : idx + minibatch_size]gradients = 2 / minibatch_size * xi.T @ (xi @ theta - yi)eta = learning_schedule(iteration)theta = theta - eta * gradientstheta_path_mgd.append(theta)theta_path_bgd = np.array(theta_path_bgd)
theta_path_sgd = np.array(theta_path_sgd)
theta_path_mgd = np.array(theta_path_mgd)plt.figure(figsize=(7, 4))
plt.plot(theta_path_sgd[:, 0], theta_path_sgd[:, 1], "r-s", linewidth=1,label="SGD")
plt.plot(theta_path_mgd[:, 0], theta_path_mgd[:, 1], "g-+", linewidth=2,label="Mini-SGD")
plt.plot(theta_path_bgd[:, 0], theta_path_bgd[:, 1], "b-o", linewidth=3,label="Batch")
plt.legend(loc="upper left")
plt.xlabel(r"a")
plt.ylabel(r"b", rotation=0)
plt.axis([2.6, 4.6, 2.3, 3.4])
plt.grid()
plt.show()

输出

在这里插入图片描述
我们可以看出,SGD学习过程(计算a、b)是比较优秀的 Min-SGD > SGD < batch

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2869242.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

Vue+SpringBoot打造教学过程管理系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 教师端2.2 学生端2.3 微信小程序端2.3.1 教师功能如下2.3.2 学生功能如下 三、系统展示 四、核心代码4.1 查询签到4.2 签到4.3 查询任务4.4 查询课程4.5 生成课程成绩 六、免责说明 一、摘要 1.1 项目介绍 基于JAVAVu…

牛客网-SQL大厂面试题-2.平均播放进度大于60%的视频类别

题目&#xff1a;平均播放进度大于60%的视频类别 DROP TABLE IF EXISTS tb_user_video_log, tb_video_info; CREATE TABLE tb_user_video_log (id INT PRIMARY KEY AUTO_INCREMENT COMMENT 自增ID,uid INT NOT NULL COMMENT 用户ID,video_id INT NOT NULL COMMENT 视频ID,start…

如何保存缓存和MySQL的双写一致呢?

如何保存缓存和MySQL的双写一致呢&#xff1f; 所谓的双写一致指的是&#xff0c;在同时使用缓存(如Redis)和数据库(如MySQL)的场景下,确保数据在缓存和数据库中的更新操作保持一致。当对数据进行修改的时候&#xff0c;无论是先修改缓存还是先修改数据库&#xff0c;最终都要保…

Java中上传数据的安全性探讨与实践

✨✨谢谢大家捧场&#xff0c;祝屏幕前的小伙伴们每天都有好运相伴左右&#xff0c;一定要天天开心哦&#xff01;✨✨ &#x1f388;&#x1f388;作者主页&#xff1a; 喔的嘛呀&#x1f388;&#x1f388; 目录 引言 一. 文件上传的风险 二. 使用合适的框架和库 1. Spr…

MySQL数据导入的方式介绍

MySQL数据库中的数据导入是一个常见操作&#xff0c;它涉及将数据从外部源转移到MySQL数据库表中。在本教程中&#xff0c;我们将探讨几种常见的数据导入方式&#xff0c;包括它们的特点、使用场景以及简单的示例。 1. 命令行导入 使用MySQL命令行工具mysql是导入数据的…

企业计算机服务器中了halo勒索病毒怎么办,halo勒索病毒解密工具流程

随着网络技术的不断应用与发展&#xff0c;越来越多的企业开始利用网络开展各项工作业务&#xff0c;网络为企业的发展与生产生活提供了极大便利。但网络中的勒索病毒攻击企业服务器的事件频发&#xff0c;给企业的数据安全带来了严重威胁&#xff0c;数据安全问题一直是企业关…

Android U pipeline-statusbar

Android U - statusbar pipeline 写在前面 Android原生从T开始对SystemUI进行MVVM改造&#xff0c;U上状态栏部分进行了修改&#xff1b;第一次出现修改不会删除原有逻辑&#xff0c;而是两版并行&#xff0c;留给其他开发者适配的时间&#xff1b;在下一个大版本可能会删除原…

C#对ListBox控件中的数据进行的操作

目录 1.添加数据&#xff1a; 2.删除数据&#xff1a; 3.清空数据&#xff1a; 4.选择项&#xff1a; 5.排序&#xff1a; 6.获取选中的项&#xff1a; 7.获取ListBox中的所有项&#xff1a; 8.综合示例 C#中对ListBox控件中的数据进行的操作主要包括添加、删除、清空、…

STM32CubeMX与HAL库开发教程八(串口应用/轮询/中断/DMA/不定长数据收发)

目录 前言 初识串口-轮询模式 串口中断模式收发 串口DMA模式 蓝牙模块与数据包解析 前言 前面我们简单介绍过串口的原理和初步的使用方式&#xff0c;例如怎么配置和简单的收发&#xff0c;同时我们对串口有了一个初步的了解&#xff0c;这里我们来深入的来使用一下串口 …

LAMP架构部署--yum安装方式

这里写目录标题 LAMP架构部署web服务器工作流程web工作流程 yum安装方式安装软件包配置apache启用代理模块 配置虚拟主机配置php验证 LAMP架构部署 web服务器工作流程 web服务器的资源分为两种&#xff0c;静态资源和动态资源 静态资源就是指静态内容&#xff0c;客户端从服…

Golang实现Redis分布式锁(Lua脚本+可重入+自动续期)

Golang实现Redis分布式锁&#xff08;Lua脚本可重入自动续期&#xff09; 1 概念 应用场景 Golang自带的Lock锁单机版OK&#xff08;存储在程序的内存中&#xff09;&#xff0c;分布式不行 分布式锁&#xff1a; 简单版&#xff1a;redis setnx》加锁设置过期时间需要保证原…

CentOS部署 JavaWeb 实现 MySql业务

一、项目打war包 在eclispe或idea中找到项目&#xff0c;右键打war包。 二、上传项目到linux 2.1云服务器虚拟机均可 以tomcat为例 /usr/local/tomcat/webapps 将war包通过ssh连接上传到webapps目录下。 如果是root目录则不需要项目名即 ip或域名端口直接访问&#xff08…

游戏引擎中的动画基础

一、动画技术简介 视觉残留理论 - 影像在我们的视网膜上残留1/24s。 游戏中动画面临的挑战&#xff1a; 交互&#xff1a;游戏中的玩家动画需要和场景中的物体进行交互。实时&#xff1a;最慢需要在1/30秒内算完所有的场景渲染和动画数据。&#xff08;可以用动画压缩解决&am…

pytorch 实现线性回归(Pytorch 03)

一 线性回归框架 线性模型的四个模块&#xff1a;训练的数据集&#xff0c;线性模型&#xff0c;损失函数&#xff0c;优化算法。 1.1 数据集 使用房价预测数据集&#xff0c;我们希望根据房屋的面积和房龄等来估算房屋价格。 1.2 线性模型 预测公式&#xff0c; 价格 权重…

Spark相关

1.Hadoop主要有哪些缺点&#xff1f;相比之下&#xff0c;Spark具有哪些优点&#xff1f; Hadoop主要有哪些缺点&#xff1a;Hadoop虽然已成为大数据技术的事实标准&#xff0c;但其本身还存在诸多缺陷&#xff0c;最主要的缺陷是 MapReduce计算模型延迟过高&#xff0c;无法胜…

idea中database的一些用法

1、查看表结构 方法1&#xff0c;右键&#xff0c;选这个 方法2 双击表后&#xff0c;看到数据&#xff0c;点DDL 方法3 写SQL时&#xff0c;把鼠标放在表名上&#xff0c;可以快速查看表结构 2、表生成对应的实体类 表中右键&#xff0c;选择这2个&#xff0c;选择生成的路…

鸿蒙Harmony应用开发—ArkTS声明式开发(容器组件:Swiper)

滑块视图容器&#xff0c;提供子组件滑动轮播显示的能力。 说明&#xff1a; 该组件从API Version 7开始支持。后续版本如有新增内容&#xff0c;则采用上角标单独标记该内容的起始版本。 子组件 可以包含子组件。 说明&#xff1a; 子组件类型&#xff1a;系统组件和自定义组…

数据结构的概念大合集03(栈)

概念大合集03 1、栈1.1 栈的定义和特点1.2 栈的基础操作1.3 栈的顺序存储1.3.1 顺序栈1.3.2 栈空&#xff0c;栈满&#xff0c;进栈&#xff0c;出栈的基本思想1.3.3 共享栈1.3.3.1 共享栈的4要素 1.4 栈的链式存储1.4.1 链栈的实现1.4.2 链栈的4个要素 1、栈 1.1 栈的定义和特…

高可用系统有哪些设计原则

1.降级 主动降级&#xff1a;开关推送 被动降级&#xff1a;超时降级 异常降级 失败率 熔断保护 多级降级2.限流 nginx的limit模块 gateway redisLua 业务层限流 本地限流 gua 分布式限流 sentinel 3.弹性计算 弹性伸缩—K8Sdocker 主链路压力过大的时候可以将非主链路的机器给…

T1.数据库MySQL

二.SQL分类 2.1 DDL 2.1.1数据库操作 1). 查询所有数据库 show databases ; 2). 查询当前数据库 select database(); 3)创建数据库 create database [if not exists] 数据库名 [default charset 字符集] [collate 排序规则] ; 4&#xff09;删除数据库 drop database …