2017 CS231n学习笔记(三)----损失函数和最优化(Loss Functions and Optimization )

video:https://study.163.com/course/courseMain.htm?courseId=1004697005
slides:http://cs231n.stanford.edu/slides/2017/cs231n_2017_lecture3.pdf
course notes:http://cs231n.github.io/
文章所有内容来自Stanford university 课程 CS231n 2017 spring

上一节课 讲了图像分类器的设计和线性分类器的解释,说明了如何通过参数 W 和 b 得到对应的类别分数,来判断图片究竟属于哪一个类别。但是没有说如何得到参数 W 和 b,这一节课通过讲解多元 SVM 损失函数、softmax loss 来衡量预测结果与ground truth label的差别和如何根据loss函数通过梯度下降(Gradient decent)来求解参数 W 和 b。

1. 损失函数(loss function)

损失函数(loss function)又叫代价函数(cost function)、目标函数(objective function)。是用来衡量预测labels 与 ground truth labels 一致性的度量指标。

1.1 多元支持向量机损失(Multiclass Support Vector Machine, SVM loss),

损失函数有很多种,本节课的第一个例子是多元支持向量机损失(Multiclass Support Vector Machine, SVM loss), SVM loss 目的是想要每一张图像的正确的类别分数比其他不正确的类分数高一个固定的阈值 Δ \Delta Δ

下面给出精确的数学公式,假设 x i x_i xi 是图片 i i i 的像素值, y i y_i yi 为图片的label(正确的类的索引)。 s s s f ( x i , W , b ) f(x_i, W, b) f(xi,W,b) 的结果, s j = f ( x i , W , b ) j s_j = f(x_i, W, b)_j sj=f(xi,W,b)j 代表分数向量 s s s 的第 j j j 个元素。由此,SVM loss为:
L i = ∑ j ≠ y i max ( 0 , s j − s y i + Δ ) L_i = \sum_{j\neq y_i} \text{max} (0, s_j - s_{y_i} + \Delta ) Li=j̸=yimax(0,sjsyi+Δ)
在这里插入图片描述
如果将 Δ = 1 \Delta = 1 Δ=1 ,则 loss function 如图中 “Hinge loss” 所示。一上图第一张猫的分数求对应的 loss function L i L_i Li 计算过程及结果如下图所示:
在这里插入图片描述
值得注意的是,最后整个数据集上的loss是所有单个图像loss的平均:
L = 1 N ∑ i = 1 N L i L = \frac 1N \sum_{i=1}^N L_i L=N1i=1NLi

针对SVM loss 这里有一些问题需要回答。

  • 当某张图片的loss为0时,如果图片改变一点点,loss 会发生怎样的变化?
    答:不会发生变化。SVM loss 只关注与正确的分数大于不正确的分数 Δ \Delta Δ, 在这种情况下,正确类的分数比其它分数都要大,如果该类的分数改变一点点,此时这个阈值 Δ \Delta Δ 依然有效,loss 并不会改变。
  • SVM loss 的最大值和最小值是多少?
    答: [ 0 , + ∞ ] [0, +\infty] [0,+]
  • 如果初始化 W 为一个很小的值以至于所有 s ≈ 0 s \approx0 s0, loss 等于多少?
    答: l o s s = ( C − 1 ) ∗ Δ loss = (C - 1) * \Delta loss=(C1)Δ ( C C C代表类别数量)
  • 如果loss在所有类上都求和(包括 j = y i j = y_i j=yi)会如何?
    答:loss会再加上 Δ \Delta Δ
  • 如果用平均代替求和会怎么样?
    答:没有影响。
  • 如果loss选择使用L2-SVM, 即 L i = ∑ j ≠ y i max ( 0 , s j − s y i + Δ ) L_i =\sum_{j\neq y_i}\text{max}(0, s_j - s_{y_i} + \Delta) Li=j̸=yimax(0,sjsyi+Δ) 会如何?
    答:L2-SVM 将会更加激烈的惩罚违反边界的行为,SVM loss显得更加稳定,具体选择需要根据不同的应用进行抉择,可以在cross-validation进行选择。

值得注意的是,一是 Δ \Delta Δ 的选择并没有什么特定的要求,你可以根据实际的数据集作出相应的调整。二是最优参数 W (将参数b扩展进W)的值并不唯一,比如2W也可以使得 L = 0 L = 0 L=0

在这里插入图片描述

我们知道,当我们在训练时,如果过度拟合训练集的数据,则很容易才测试的时候出现问题,这就是典型的过拟合(overfitting)问题。为解决这样的问题,可以添加正则项(Regularization term) 去惩罚权重W。于是得到整个loss函数,如下图所示:
在这里插入图片描述

正则项的意义

正则化的目的是为了解决模型过拟合的问题,提高模型的泛化性,添加正则项只是其中一种方法

假设权重 W 可以正确的分类所有的样本( l o s s = 0 loss = 0 loss=0),则会出现一个问题, W W W 不是唯一的, λ W ( λ > 1 ) \lambda W(\lambda > 1) λW(λ>1) 也可以是的 loss 为0。所以我们需要加入一些偏好使得消除这些模糊性,并使得参数 W W W 尽量简单,即 λ = 1 \lambda = 1 λ=1。通过在 loss 函数中加入一个正则项 R ( W ) R(W) R(W) 来惩罚参数 W W W 使得参数 W W W 尽量小。

除此之外,惩罚大权重也可以提高模型的泛化能力,如果存在输入向量 x = [ 1 , 1 , 1 , 1 ] x = [1, 1, 1, 1] x=[1,1,1,1], 假设有参数为 w 1 = [ 1 , 0 , 0 , 0 ] w _ { 1 } = [ 1,0,0,0 ] w1=[1,0,0,0] w 2 = [ 0.25 , 0.25 , 0.25 , 0.25 ] w_2 = [0.25, 0.25, 0.25, 0.25] w2=[0.25,0.25,0.25,0.25], 则 w 1 T x = W 2 T x = 1 w_1^Tx = W_2^Tx = 1 w1Tx=W2Tx=1, 很明显 w 2 w_2 w2 是一个更好的选择,因为它显得更加分散,不会过分依赖某一个特征。

最常见的正则项是 L2 正则:

R ( W ) = ∑ k ∑ l W k , l 2 R ( W ) = \sum _ { k } \sum _ { l } W _ { k , l } ^ { 2 } R(W)=klWk,l2

1.2 Softmax Classifier (Multinomial Logistic Regression)

另一个比较流行的 loss 函数是softmax, softmax 函数的本质就是将一个K维的任意实数向量压缩(映射)成另一个K维的实数向量,其中向量中的每个元素取值都介于(0,1)之间。将最终计算的分数通过 softmax 函数。
softmax函数形式如下:

σ ( z ) j = e z j ∑ k = 1 K e z k \sigma ( z ) _ { j } = \frac { e ^ { z _ { j } } } { \sum _ { k = 1 } ^ { K } e ^ { z _ { k } } } σ(z)j=k=1Kezkezj

其中j=1,2,…,K。

然后将结果通过交叉熵作为loss, 即:

H ( p , q ) = − ∑ x p ( x ) log ⁡ q ( x ) H ( p , q ) = - \sum _ { x } p ( x ) \log q ( x ) H(p,q)=xp(x)logq(x)

p 代表真实的概率值,q代表预测的概率值。

2. Optimization

我们已经通过一个 score function 将输入数据映射到了类别分数,而且使用loss function 测量了预测结果与实际结果的偏差。

现在我们需要去找到最优的参数 W W W 去最小化 loss function。这一过程称之为 优化(Optimization)

这里总共有三种优化策略。

2.1 随机搜索(Random Search)

这是最简单的方法,效果极差,就是简单的选择很多随机的权重,然后就算他们的loss,选择最优的权重。程序如下:

# assume X_train is the data where each column is an example (e.g. 3073 x 50,000)
# assume Y_train are the labels (e.g. 1D array of 50,000)
# assume the function L evaluates the loss functionbestloss = float("inf") # Python assigns the highest possible float value
for num in xrange(1000):W = np.random.randn(10, 3073) * 0.0001 # generate random parametersloss = L(X_train, Y_train, W) # get the loss over the entire training setif loss < bestloss: # keep track of the best solutionbestloss = lossbestW = Wprint 'in attempt %d the loss was %f, best %f' % (num, loss, bestloss)# prints:
# in attempt 0 the loss was 9.401632, best 9.401632
# in attempt 1 the loss was 8.959668, best 8.959668
# in attempt 2 the loss was 9.044034, best 8.959668
# in attempt 3 the loss was 9.278948, best 8.959668
# in attempt 4 the loss was 8.857370, best 8.857370
# in attempt 5 the loss was 8.943151, best 8.857370
# in attempt 6 the loss was 8.605604, best 8.605604
# ... (trunctated: continues for 1000 lines)

优化的核心思想其实是迭代修正,想要直接找到最优的权重基本上是不可能的,所以采用的方式是随机初始化权重,然后迭代修正发现更好的权重。

2.2 随机局部搜索(Random Local Search)

这种方法是,基于当前的权重,选择一个随机的方向走一小步,观察结果。比如存在权重参数 W W W,然后我们随机的向一个方向走一步 δ W \delta W δW, 如果在 W + δ W W + \delta W W+δW 处的 loss 小于在 W W W 处的 loss,则更新 W W W。程序如下:

W = np.random.randn(10, 3073) * 0.001 # generate random starting W
bestloss = float("inf")
for i in xrange(1000):step_size = 0.0001Wtry = W + np.random.randn(10, 3073) * step_sizeloss = L(Xtr_cols, Ytr, Wtry) # compute the lossif loss < bestloss:W = Wtrybestloss = lossprint 'iter %d loss is %f' % (i, bestloss)

2.3 跟随梯度(Following the Gradient)

其实我们可以计算出权重更新最佳的方向 ------- 梯度的负方向

有两种方式计算梯度,一种是数值梯度法(numerical gradient)分析梯度发(analytic gradient).

  1. 数值梯度法(numerical gradient): 一种缓慢的,近似地估计,但却容易的方法。
  2. 分析梯度法(analytic gradient):一种快速的,精确的,但却容易在求积分的时候出错。

3. 梯度下降(Gradient Descent)

我们计算出了loss function的梯度,现在我可以迭代地更新权重参数了。该过程叫做梯度下降,最简单的程序如下:

# Vanilla Gradient Descentwhile True:weights_grad = evaluate_gradient(loss_fun, data, weights)weights += - step_size * weights_grad # perform parameter update

这里的 step_size 是模型的一个超参数,可以叫做步长或学习率(learning rate),它代表你要朝更新方向一步走的步长大小,越大代表更新的越快,得到结果越快。但是太大的步长会使得得不到最优解,因为你可能一步踏过了最优解,然后一直在最优解边上徘徊。

mini-batch gradient descent

首先解释一下什么是batch,batch是一批,一堆的意思,整个数据集的数据可以叫做一个batch,由于有些数据集的数据量很大,计算一次整个数据集上的梯度需要大量的时间,所以提出mini-batch的概念。比如我们可以将一个mini-batch的batch size设置为32,则计算32张图片后就可以进行一次梯度下降。程序如下:

# Vanilla Minibatch Gradient Descentwhile True:data_batch = sample_training_data(data, 256) # sample 256 examplesweights_grad = evaluate_gradient(loss_fun, data_batch, weights)weights += - step_size * weights_grad # perform parameter update

当 batch size 为 1 时,称之为随机梯度下降(Stochastic Gradient Descent ,SGD).

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/1379811.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

时间和空间复杂度分析

前言 对于数据结构相关的博客文章&#xff0c;我是根据大学本科阶段《数据结构和算法》课程的内容和王争老师在即刻时间上的《数据结构和算法之美》系列课程&#xff08;墙裂推荐&#xff09;进行了一些排版参考和笔记性梳理。这些文章作为笔记总结&#xff0c;一方便是为了夯…

INT303 Big Data 个人笔记

又来到了经典半个月写一个学期内容的环节 目前更新至Lec{14}/Lec14 依旧是不涉及代码&#xff0c;代码请看学校的jupyter notebook~ Lec1 Introduction 介绍课程 TopicRangeTopic 1: Introduction to Big Data AnalyticsLec1~Lec3Topic2: Big data collection and visualiza…

日撸 Java 三百行(21 天: 二叉树及其基本操作)

注意&#xff1a;这里是JAVA自学与了解的同步笔记与记录&#xff0c;如有问题欢迎指正说明 目录 前言 一、一对多的结构&#xff1a;树形结构 二、二叉树 1.二叉树的体现运用 2.二叉树存储 三、二叉树遍历 1.树遍历的递归思想中的“三角抉择” 2.树的前、中、后序遍历…

C语言每日一练 —— 第21天:算法的应用

文章目录 前言一、算法简介1、推荐算法2、最短路算法3、最值算法4、排序算法5、压缩算法6、加密算法 二、为什么要学算法1、面试时2、工作中 三、算法能给我们带来什么能力的提升1、抽象问题的能力2、解决问题的能力3、编写代码的能力4、调试能力1&#xff09;画图2&#xff09…

C语言基础学习

**1.2 C语言程序设计入门三步骤 程序设计入门三步骤&#xff1a; &#xff08;1&#xff09;安装软件并开发HelloWorld程序。 &#xff08;2&#xff09;掌握基本的输入输出方法。 &#xff08;3&#xff09;理解该语言中程序的基本结构。 1.2.1 安装软件并开发第一个HelloWo…

BP算法Java实现

我们上次已经把公式给推导了出来。还举了例子&#xff0c;不懂的理论的点击这里&#xff0c;老师的代码   这回我们将要用Java进行初步实现&#xff0c;这个代码是我参考老师的&#xff0c;里面附带了详细的注解。要成功运行需要一些包&#xff0c;需要的可以联系我。 public…

关系代数和SQL语法

数据分析的语言接口 OLAP计算引擎是一架机器&#xff0c;而操作这架机器的是编程语言。使用者通过特定语言告诉计算引擎&#xff0c;需要读取哪些数据、以及需要进行什么样的计算。编程语言有很多种&#xff0c;任何人都可以设计出一门编程语言&#xff0c;然后设计对应的编译…

优雅的对象

最近一口气读完了二百多页的《Elegant Objects》。可能因为整理自博客所以排版一般,而且才二百多页定价却40多刀。但读过之后发现超值,甚至还想去买第二卷。作者观点大多比较激进,对自己的理念异常坚定,所以经常使用诸如“绝对不要使用XXX”、“记住XXX,就这样,句号”。但…

深入理解Java 8 Lambda

关于 深入理解 Java 8 Lambda&#xff08;语言篇——lambda&#xff0c;方法引用&#xff0c;目标类型和默认方法&#xff09;深入理解 Java 8 Lambda&#xff08;类库篇——Streams API&#xff0c;Collector 和并行&#xff09;深入理解 Java 8 Lambda&#xff08;原理篇——…

自然语言处理中注意力机制综述

https://www.toutiao.com/a6655120292144218637/ 目录 1.写在前面 2.Seq2Seq 模型 3.NLP中注意力机制起源 4.NLP中的注意力机制 5.Hierarchical Attention 6.Self-Attention 7.Memory-based Attention 8.Soft/Hard Attention 9.Global/Local Attention 10.评价指标 11.写在后面…

【深度学习基础】从零开始的炼丹生活00——机器学习数学基础以及数值计算数值优化方法

正值假期&#xff0c;决定恶补机器学习、深度学习及相关领域&#xff08;顺便开个博客&#xff09;。首先学习一下数学基础以及数值计算的方法&#xff08;主要参考《深度学习》&#xff09; 一、数学基础 这里简单复习一下机器学习相关的数学1.线性代数 范数 衡量一个向量的…

“泰迪杯”挑战赛 -利用非侵入式负荷检测进行高效率数据挖掘(完整数学模型)

目录 1 研究背景与意义 2 变量说明 3 问题分析 4 问题一 4.1 数据预处理 4.1.1 降噪处理 4.1.2 数据变换 4.2 负荷特征分析 4.2.1 暂态特征 4.2.2 稳态特征 5 问题二 5.1 相似度与权系数 5.2 模型建立 5.3 模型求解 6 问题三 6.1 事件检测算法 6.2 模型建立 6.3 模型求解…

37%原则如何优化我们做决定的时间

当需要百(千&#xff0c;万…)里挑一时&#xff0c;需要权衡最优解和效率&#xff0c;有一个37%原则比较有趣。 整个择优过程分为两个阶段&#xff1a; 观望&#xff1a;在前面 k k k个候选者中冒泡记录最优者 p p p&#xff0c;其分数为 V p V_p Vp​&#xff0c;但并不选择…

清风数学建模学习笔记——层次分析法

目录 一、模型简介 二、建模步骤 三、模型总结 一、层次分析法——模型简介 层次分析法&#xff0c;简称AHP&#xff0c;是指将与决策总是有关的元素分解成目标、准则、方案等层次&#xff0c;在此基础之上进行定性和定量分析的决策方法。该方法是美国运筹学家匹茨堡大学教授萨…

Attention is all you need ---Transformer

大语言模型已经在很多领域大显身手&#xff0c;其应用包括只能写作、音乐创作、知识问答、聊天、客服、广告文案、论文、新闻、小说创作、润色、会议/文章摘要等等领域。在商业上模型即产品、服务即产品、插件即产品&#xff0c;任何形态的用户可触及的都可以是产品&#xff0c…

you-get下载速度慢解决方法

Python版本&#xff1a;3.10 运行环境&#xff1a;Windows10 问题描述&#xff1a;在使用you-get下载X站视频时网速很慢&#xff0c;并一直限制在某个值,通过以下办法即可恢复正常网速 解决办法&#xff1a; 进入windows 安全中心-病毒和威胁防护-管理设置点击添加或删除排…

Microsoft store下载速度过慢

最开始是进入Microsoft store点击安装后一直无响应&#xff0c;后来知道这是因为Microsoft store下载速度过慢。下边几个步骤都尝试了&#xff0c;个人认为最重要的是Windows Update设置步骤&#xff0c;刚开始可能一直没有正确打开 修改DNS 右键任务栏网络图标->打开“网…

Linux网络编程 socket编程篇(一) socket编程基础

目录 一、预备知识 1.IP地址 2.端口号 3.网络通信 4.TCP协议简介 5.UDP协议简介 6.网络字节序 二、socket 1.什么是socket(套接字)&#xff1f; 2.为什么要有套接字&#xff1f; 3.套接字的主要类型 拓】网络套接字 三、socket API 1.socket API是什么&#xff1f; 2.为什么…

如何预防ssl中间人攻击?

当我们连上公共WiFi打开网页或邮箱时&#xff0c;殊不知此时可能有人正在监视着我们的各种网络活动。打开账户网页那一瞬间&#xff0c;不法分子可能已经盗取了我们的银行凭证、家庭住址、电子邮件和联系人信息&#xff0c;而这一切我们却毫不知情。这是一种网络上常见的“中间…

[保研/考研机试] KY3 约数的个数 清华大学复试上机题 C++实现

题目链接&#xff1a; KY3 约数的个数 https://www.nowcoder.com/share/jump/437195121691716950188 描述 输入n个整数,依次输出每个数的约数的个数 输入描述&#xff1a; 输入的第一行为N&#xff0c;即数组的个数(N<1000) 接下来的1行包括N个整数&#xff0c;其中每个…