R语言深度学习-5-深度前馈神经网络

本教程参考《RDeepLearningEssential》

本篇我们将学习如何建立并训练深度预测模型。我们将关注深度前馈神经网络


5.1 深度前馈神经网络

我们还是使用之前提到的H2O包,详细可以见之前的博客:R语言深度学习-1-深度学习入门(H2O包安装报错解决及接入/H2O包连接数据集)-CSDN博客

library('h2o')
cl <- h2o.init(max_mem_size = "20G",nthreads = 10,ip = "127.0.0.1", port = 54321)

深度前馈神经网络,也被称为前馈神经网络或多层感知机(MLP),是一种典型的多层神经网络,其中数据在神经元之间单向流动,从输入层经过一个或多个隐藏层传递到输出层。这种类型的网络不包含任何循环或反馈连接,意味着信息的流向是从上到下,不会从输出层返回至输入层。

深度前馈神经网络的核心优势在于其能够通过学习输入数据和目标输出之间的复杂映射关系来执行各种任务。这得益于它的层次结构,每一层都从前一层接收信息并产生输出,这些输出作为下一层的输入。随着网络层次的加深,它能够捕捉更抽象的特征,从而提升模型的性能和泛化能力。

如图所示,来自输入X到输出Y的全部映射是一个多层函数。第一个隐藏层是:

H_{1} = f^{(1)}(X,\omega _{1},\alpha _{1})

而在每一层中有多少隐藏神经元及使用什么激活函数,我们在5.2讨论,另一个关键是成本或损失函数,常用的是交叉熵(cross-entropy)和二次的函数均方差(MSE)。

5.2 激活函数

激活函数在神经网络中扮演着至关重要的角色。它们通常被嵌入到神经网络的隐藏层中,用以引入非线性因素,使得神经网络能够学习和模拟复杂的数据模式。没有激活函数,无论神经网络有多少层,最终都只是相当于一个线性变换,无法解决非线性问题。

激活函数的种类多样,每种都有其特定的用途和特性。以下是一些常见的激活函数及其特点:

1. Sigmoid函数:Sigmoid函数可以将任意实数映射到(0,1)区间内,这使得它可以用来做二分类问题的输出层。然而,当输入值较大或较小时,Sigmoid函数的梯度接近于0,容易导致梯度消失

2. Tanh函数:Tanh函数是Sigmoid函数的变种,它将实数映射到(-1,1)区间内,相比于Sigmoid函数,Tanh函数的输出以0为中心,但它同样存在梯度消失的问题。

3. ReLU函数:ReLU(Rectified Linear Unit)函数是目前最常用的激活函数之一。它在输入大于0时直接输出该值,小于等于0时输出0。ReLU函数解决了梯度消失问题,计算简单且加速了神经网络的训练。但ReLU函数也有缺点,比如当输入为负数时,梯度始终为0,可能导致神经元“死亡”。

4. Leaky ReLU函数:Leaky ReLU是对ReLU的改进,它在输入小于等于0时,梯度不为0,而是一个很小的正数。这样可以缓解ReLU的死神经元问题。

5. Softmax函数:Softmax函数常用于多分类问题的输出层,它可以将一组数值转化为概率分布。

6. Swish函数:Swish函数是一个平滑且非单调的激活函数,由谷歌提出,在某些情况下比ReLU表现更好。

7. Mish函数:Mish函数结合了ReLU和Swish的优点,具有更好的性能表现。

5.3 选取超参数

我们在之前的模型选择参数,一般是选取如权重或者截距。不过还有一些参数能够被学到或者能被优化,我们在进行模型选择的时候,也是一种超参数。我们还是使用之前的手写字数据进行实践:

R语言深度学习-2-训练预测模型-CSDN博客

我们使用H2O的深度学习算法来训练一个分类器,并比较不同学习率对模型性能的影响。以下是两个深度学习模型的配置和它们的运行时间分析。

options(width = 70, digits = 2)
#初始化
dig_train <- read.csv("C:\\Users\\Huzhuocheng\\Desktop\\digit-recognizer\\train.csv")
dim(dig_train) #数据维度查看
dig_train$label <- factor(dig_train$label, levels = 0:9)h2odigits <- as.h2o(dig_train,destination_frame = "h2odigits")
i <- 1:32000
h2odigits.train <- h2odigits[i, ]
itest <- 32001:42000
h2odigits.test <- h2odigits[itest, ]
xnames <- colnames(h2odigits.train)[-1]#训练模型
system.time(ex1 <- h2o.deeplearning(x = xnames,y = "label",training_frame= h2odigits.train,validation_frame = h2odigits.test,activation = "RectifierWithDropout",hidden = c(100),epochs = 10,adaptive_rate = FALSE,rate = .001,input_dropout_ratio = 0,hidden_dropout_ratios = c(.2)
))
system.time(ex2 <- h2o.deeplearning(x = xnames,y = "label",training_frame= h2odigits.train,validation_frame = h2odigits.test,activation = "RectifierWithDropout",hidden = c(100),epochs = 10,adaptive_rate = FALSE,rate = .01,input_dropout_ratio = 0,hidden_dropout_ratios = c(.2)
))

 我们选择了不同的学习率,ex1中学习率是0.001,在ex2中,学习率是0.01,我们发现ex1的运行时间长很多,但是就模型效果来说,ex1更好:

 深刻理解超参数,对我们在模型进行训练中有事半功倍的效果,有的时候不是模型不行,而是选择了错误的超参数,这很重要。

5.4 深度神经网络训练及预测

我们使用之前提到的UCI数据进行演示:UCI Machine Learning Repository

#数据导入
train_x <- read.table("C:/Users/Huzhuocheng/Desktop/UCI数据/UCI HAR Dataset/UCI HAR Dataset/train/X_train.txt")
train_Y <- read.table("C:/Users/Huzhuocheng/Desktop/UCI数据/UCI HAR Dataset/UCI HAR Dataset/train/y_train.txt")
test_x <- read.table("C:/Users/Huzhuocheng/Desktop/UCI数据/UCI HAR Dataset/UCI HAR Dataset/test/X_test.txt")
test_Y <- read.table("C:/Users/Huzhuocheng/Desktop/UCI数据/UCI HAR Dataset/UCI HAR Dataset/test/y_test.txt")
barplot(table(train_Y))
train_x <- as.data.frame(train_x)
train_Y <- as.data.frame(train_Y)train_Y <- factor(train_Y)  
test_Y <- factor(test_Y)   use.train <- cbind(train_x, Outcome = train_Y)
use.test <- cbind(test_x, Outcome = test_Y)use.labels <- read.table("C:\\Users\\Huzhuocheng\\Desktop\\UCI数据\\UCI HAR Dataset\\UCI HAR Dataset\\activity_labels.txt")
h2oactivity.train <- as.h2o(use.train,destination_frame = "h2oactivitytrain")
h2oactivity.test <- as.h2o(use.test,destination_frame = "h2oactivitytest")

 接下来,我们使用H2O的deeplearning包进行深度学习,使用的激活函数是线性整流器,并且使用了我们上次讲的丢弃正则化,带有输入变量20%丢弃和隐藏神经元50%丢弃,并且我们建立的是一个50神经元和10次迭代的浅层网络,损失函数是交叉熵。

mt1 <- h2o.deeplearning(x = colnames(train_x),y = "Outcome",training_frame= h2oactivity.train,activation = "RectifierWithDropout",hidden = c(50),epochs = 10,loss = "CrossEntropy",input_dropout_ratio = .2,hidden_dropout_ratios = c(.5),export_weights_and_biases = TRUE
)

显示了层数及每个层中单元的个数,单元的类型,丢弃百分比和其他正则信息。

这个则显示了模型的性能,包括均方误,对数损失等。

混淆矩阵显示了预测与真实值的差距。

5.5 小结

我们本次使用H2O包对深度神经网络进行了学习应用,不过我们在例子中构建的都是浅层的神经网络,大家可以自己调参数实现更好的理解与应用。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2869489.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

[scikit-learn] 第一章 初识scikit-learn及内置数据集介绍

文章目录 菜鸡镇贴&#xff01;&#xff01;&#xff01;scikit-learn 简要介绍scikit-learn 安装scikit-learn 数据集介绍数据集API介绍LoadersSamples generator 导入数据集demo 菜鸡镇贴&#xff01;&#xff01;&#xff01; scikit-learn 简要介绍 ​ Scikit learn是一个…

RK3568平台开发系列讲解(基础篇)内核是如何发送事件到用户空间

🚀返回专栏总目录 文章目录 一、相关接口函数二、udevadm 命令三、实验沉淀、分享、成长,让自己和他人都能有所收获!😄 一、相关接口函数 kobject_uevent 是 Linux 内核中的一个函数, 用于生成和发送 uevent 事件。 它是 udev 和其他设备管理工具与内核通信的一种方式。…

SDN网络简单认识(1)——概述

一、概述 软件定义网络&#xff08;Software Defined Networking&#xff0c;SDN&#xff09;是一种网络架构理念&#xff0c;旨在使网络灵活和可编程&#xff0c;从而更好地支持动态和高度可扩展的计算环境。SDN通过抽象网络的控制层&#xff08;决策层&#xff09;和数据层&a…

面试经典-MySQL篇

一、MySQL组成 MySQL数据库的连接池&#xff1a;由一个线程来监听一个连接上请求以及读取请求数据&#xff0c;解析出来一条我们发送过去的SQL语句SQL接口&#xff1a;负责处理接收到的SQL语句查询解析器&#xff1a;让MySQL能看懂SQL语句查询优化器&#xff1a;选择最优的查询…

OpenCV 图像重映射函数remap()实例详解

OpenCV 图像重映射函数remap()对图像应用通用几何变换。其原型如下&#xff1a; void remap(InputArray src, OutputArray dst, InputArray map1, InputArray map2, int interpolation&#xff0c; int borderMode BORDER_CONSTANT&#xff0c; const Scalar & borde…

LeetCode 189.轮转数组

题目&#xff1a;给定一个整数数组 nums&#xff0c;将数组中的元素向右轮转 k 个位置&#xff0c;其中 k 是非负数。 思路&#xff1a; 代码&#xff1a; class Solution {public void rotate(int[] nums, int k) {int n nums.length;k k % n;reverse(nums, 0, n);revers…

吴恩达deeplearning.ai:使用多个决策树随机森林

以下内容有任何不理解可以翻看我之前的博客哦&#xff1a;吴恩达deeplearning.ai专栏 文章目录 为什么要使用树集合使用多个决策树(Tree Ensemble)有放回抽样随机森林XGBoost(eXtream Gradient Boosting)XGBoost的库实现何时使用决策树决策树和树集合神经网络 使用单个决策树的…

Spark-Scala语言实战(2)(在IDEA中安装Scala,超详细配图)

之前的文章中&#xff0c;我们学习了如何在windows下下载及使用Scala&#xff0c;但那对一个真正想深入学习Scala的人来说&#xff0c;是不够的&#xff0c;今天我会给大家带来如何在IDEA中安装Scala。同时&#xff0c;希望我的文章能帮助到你&#xff0c;如果觉得我的文章写的…

Javaweb--CSS

一&#xff1a;概述 CSS &#xff08;Cascading Style Sheet&#xff08;层叠样式表&#xff09;&#xff09;是一门语言&#xff0c;用于控制网页表现。 W3C标准规定了网页是由以下组成&#xff1a; 结构&#xff1a;HTML 表现&#xff1a;CSS 行为&#xff1a;JavaScrip…

分布式文件存储与数据缓存(一)| FastDFS

目录 分布式文件系统FastDFS概述_简介FastDFS特性&#xff1a;分布式文件服务提供商 FastDFS概述_核心概念trackerstorageclientgroup FastDFS概述_上传机制内部机制如下 FastDFS概述_下载机制内部机制如下 FastDFS环境搭建_Linux下载安装gcc下载安装FastDFS下载安装FastDFS依赖…

sqllab第二十五A关通关笔记

知识点&#xff1a; 数值型注入双写绕过 oorranand这里不能用错误注入&#xff08;固定错误回显信息&#xff09;联合注入 测试发现跟25关好像一样&#xff0c;就是过滤了and or # 等东西 构造payload:id1/0 发现成功运算了&#xff0c;这是一个数值型的注入 构造payload:id…

音频的录制及播放

在终端安装好pip install pyaudio&#xff0c;在pycharm中敲入录音的代码&#xff0c;然后点击运行可以在10s内进行录音&#xff0c;录音后的音频会保存在与录音代码同一路径项目中&#xff0c;然后再新建项目敲入播放的代码&#xff0c;点击运行&#xff0c;会把录入的录音进行…

Java学习笔记------常用API(五)

爬虫 从网站中获取 import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.net.MalformedURLException; import java.net.URL; import java.net.URLConnection; import java.util.regex.Matcher; import java.util.reg…

浏览器如何进行静态资源缓存?—— 强缓存 协商缓存

在平时使用浏览器排查问题的过程中&#xff0c;我们有时会看到浏览器网络请求中出现304状态码&#xff0c;那么是什么情况下出现304呢&#xff1f;下面是关于这一现象的解释&#xff1a; 浏览器如何进行静态资源缓存&#xff1f;—— 强缓存 & 协商缓存 状态码 304浏览器如…

python的opencv最最基础初学

localhost中详解OpenCV的函数imread()和函数imshow(),并利用它们实现对图像的读取和显示_opencv imshow-CSDN博客 其实以下均为numpy 显示一张图片 import cv2 ####opencv读取的格式是BGR import matplotlib.pyplot as plt import numpy as np %matplotlib inline imgcv2.…

k8s之图形界面DashBoard【九】

文章目录 9. DashBoard9.1 部署Dashboard9.2 使用DashBoard 镇场 9. DashBoard 之前在kubernetes中完成的所有操作都是通过命令行工具kubectl完成的。其实&#xff0c;为了提供更丰富的用户体验&#xff0c;kubernetes还开发了一个基于web的用户界面&#xff08;Dashboard&…

java小型人事管理系统

开发工具&#xff1a; MyEclipseJdkTomcatSQLServer数据库 运行效果视频&#xff1a; https://pan.baidu.com/s/1hshFjiG 定制论文&#xff0c;联系下面的客服人员

Mac版Jmeter安装与使用模拟分布式环境

Mac版Jmeter安装与使用&模拟分布式环境 1 安装Jmeter 1.1 安装Java环境 国内镜像地址&#xff1a;https://repo.huaweicloud.com/java/jdk/11.0.29/jdk-11.0.2_osx-x64_bin.dmg 下载dmg后&#xff0c;双击进行安装。 配置环境变量&#xff1a; # 1 打开环境变量配置文件…

微信小程序关闭首页广告

由于之前微信小程序默认开启了首页广告位。导致很多老人误入广告页的内容&#xff0c;所以想着怎么屏蔽广告。好家伙&#xff0c;搜索一圈&#xff0c;要么是用户版本的屏蔽广告&#xff0c;或者是以下一个模棱两可的答案&#xff0c;要开发者设置一下什么参数的&#xff0c;如…

牛客网-SQL大厂面试题-1.各个视频的平均完播率

题目&#xff1a;各个视频的平均完播率 DROP TABLE IF EXISTS tb_user_video_log, tb_video_info; CREATE TABLE tb_user_video_log (id INT PRIMARY KEY AUTO_INCREMENT COMMENT 自增ID,uid INT NOT NULL COMMENT 用户ID,video_id INT NOT NULL COMMENT 视频ID,start_time dat…