BP算法Java实现

  我们上次已经把公式给推导了出来。还举了例子,不懂的理论的点击这里,老师的代码
  这回我们将要用Java进行初步实现,这个代码是我参考老师的,里面附带了详细的注解。要成功运行需要一些包,需要的可以联系我。

	public static void main(String[] args) {int[] tempLayerNodes = { 4, 8, 8, 3 };SimpleAnn tempNetwork = new SimpleAnn("D:/data/iris.arff", tempLayerNodes, 0.01,0.6);for (int round = 0; round < 5000; round++) {tempNetwork.train();} // Of for ndouble tempAccuracy = tempNetwork.test();System.out.println("The accuracy is: " + tempAccuracy);}// Of main

  我们来看主函数,开始就是层次数组。共三层,4,8,8,3 的意思是输入层4个神经元,两个隐层各8个神经元,输出层3个神经元。用权重的线连接起来,由于节点多连起来不好看,就自己脑补了哈。
在这里插入图片描述  接下来就是读入数据,然后一个for循环对数据进行训练,这里我们用的BP算法。后面两句就是把预测精度打印出来。
这里我们用了两个类,一个抽象的父类GeneralAnn和具体实现的子类SimpleAnn.两部分的代码在这里,子类涉及BP算法的核心代码,注释我已经给好了。但是一定要要对公式熟悉,不行的话就看着公式慢慢对下去。硬骨头很难啃下去。

package machinelearning.ann;import java.io.FileReader;
import java.security.PublicKey;
import java.util.Arrays;
import java.util.Random;import weka.core.Instances;
import weka.datagenerators.Test;public abstract class GeneralAnn {/*** 整个数据集*/Instances dataset;/*** 层数,根据结点计算*/int numLayers;/*** 每层的节点数, e.g., [3, 4, 6, 2] 意思是*三个输出层结点 (也是条件判断的属性), 两个隐层分别有4和6个结点* , 两个类别属性 (就是二元分类,只有是与非).*/int[] layerNumNodes;/*** 动量系数。*/public double mobp;/*** 学习率*/public double learningRate;/*** 用于生成随机数*/Random random = new Random();/*********************** 第一个构造函数* * @param paraFilename*            arff 类型的文件* @param paraLayerNumNodes*            每层结点的数目(可能不相同)* @param paraLearningRate*            学习率* @param paraMobp*            动量系数*********************/public GeneralAnn(String paraFilename, int[] paraLayerNumNodes, double paraLearningRate,double paraMobp) {// Step 1. 读数据try {FileReader tempReader = new FileReader(paraFilename);dataset = new Instances(tempReader);// 最后一个类别是一个做抉择的类别dataset.setClassIndex(dataset.numAttributes() - 1);tempReader.close();} catch (Exception ee) {System.out.println("Error occurred while trying to read \'" + paraFilename+ "\' in GeneralAnn constructor.\r\n" + ee);System.exit(0);} // Of try// Step 2. 接受参数layerNumNodes = paraLayerNumNodes;// 把参数层的结点数传过来numLayers = layerNumNodes.length;// 必要时进行调整layerNumNodes[0] = dataset.numAttributes() - 1;layerNumNodes[numLayers - 1] = dataset.numClasses();//意思就是把空数组填入[3, 4, 6, 2]数据learningRate = paraLearningRate;// 学习率传参mobp = paraMobp;// 学习动量系数传参	}//Of the first constructor	/*********************** 前瞻性预测* * @param paraInput*            一个实例的输入数据。* @return 输出端的数据。*********************/public abstract double[] forward(double[] paraInput);/*********************** 反向传播。* * @param paraTarget*            For 3-class data, it is [0, 0, 1], [0, 1, 0] or [1, 0, 0].*            *********************/public abstract void backPropagation(double[] paraTarget);/*********************** 使用数据集进行训练。*********************/public void train() {//是一个一个数据进行训练的double[] tempInput = new double[dataset.numAttributes() - 1];// 输入层+隐层double[] tempTarget = new double[dataset.numClasses()];//类别的数目,比如我们这里是二元的,那么就是2for (int i = 0; i < dataset.numInstances(); i++) {// 填装数据for (int j = 0; j < tempInput.length; j++) {tempInput[j] = dataset.instance(i).value(j);} // Of for j// 填装数据标签Arrays.fill(tempTarget, 0);//使决策值全部为零tempTarget[(int) dataset.instance(i).classValue()] = 1;//我们把这个目标训练的对象类型取整数作为数组索引,值为一。用于判断// Train with this instance.forward(tempInput);//第一次就相当于初始化。backPropagation(tempTarget);} // Of for i}// Of train/*********************** 获取与数组的最大值对应的索引。* * @return the index.*********************/public static int argmax(double[] paraArray) {int resultIndex = -1;double tempMax = -1e10;for (int i = 0; i < paraArray.length; i++) {if (tempMax < paraArray[i]) {tempMax = paraArray[i];resultIndex = i;} // Of if} // Of for ireturn resultIndex;}// Of argmax/*********************** 使用数据集进行测试。* * @return The precision.*********************/public double test() {double[] tempInput = new double[dataset.numAttributes() - 1];//输入层double tempNumCorrect = 0;double[] tempPrediction; //预测数组int tempPredictedClass = -1; //被预测的类型先置为0for (int i = 0; i < dataset.numInstances(); i++) { //一个一个数据进行训练// 填装数据for (int j = 0; j < tempInput.length; j++) {tempInput[j] = dataset.instance(i).value(j);} // Of for j// 训练这个数据tempPrediction = forward(tempInput);System.out.println("prediction: " + Arrays.toString(tempPrediction));tempPredictedClass = argmax(tempPrediction);//选择预测最大的哪个if (tempPredictedClass == (int) dataset.instance(i).classValue()) {tempNumCorrect++;} // Of if} // Of for iSystem.out.println("Correct: " + tempNumCorrect + " out of " + dataset.numInstances());return tempNumCorrect / dataset.numInstances();}// Of test}//Of class GeneralAnn

下面是子类

package machinelearning.ann;// 第二天 固定激活函数的BP神经网络
public class SimpleAnn extends GeneralAnn{/*** 在转发过程中更改的每个节点的值。 * 第一个维度代表层,第二个维度代表节点。*/public double[][] layerNodeValues;/*** 在反向传播过程中每个节点上更改的错误。* 第一个维度代表层,第二个维度代表节点。* */public double[][] layerNodeErrors;/*** 边的权重。第一个维度代表层, * 第二个代表层的节点索引, * 第三个维度代表下一层的节点索引值。*/public double[][][] edgeWeights;/*** 边权重的变化. 它的大小和edgeWeights一致.*/public double[][][] edgeWeightsDelta;/*********************** 第一个构造函数* * @param paraFilename*            文件名* @param paraLayerNumNodes*            每层的结点数目* @param paraLearningRate*            学习率* @param paraMobp*            动量系数*********************/public SimpleAnn(String paraFilename, int[] paraLayerNumNodes, double paraLearningRate,double paraMobp) {super(paraFilename, paraLayerNumNodes, paraLearningRate, paraMobp);// Step 1. 跨层初始化。layerNodeValues = new double[numLayers][];layerNodeErrors = new double[numLayers][];// 第一个维度长度是层数edgeWeights = new double[numLayers - 1][][];// 不要输入层?edgeWeightsDelta = new double[numLayers - 1][][];// Step 2. 内层初始化。for (int l = 0; l < numLayers; l++) {layerNodeValues[l] = new double[layerNumNodes[l]];//给每一层分配空间layerNodeErrors[l] = new double[layerNumNodes[l]];// 误差层和layerNodeValues一致// 减少一层,因为每条边跨越两层。if (l + 1 == numLayers) {// 到三停止break;} // of if // In layerNumNodes[l] + 1, 最后一个是为偏移量保留的,多增加了一个,偏移项不会指向自己edgeWeights[l] = new double[layerNumNodes[l] + 1][layerNumNodes[l + 1]];// 每一个元素对应8个结点edgeWeightsDelta[l] = new double[layerNumNodes[l] + 1][layerNumNodes[l + 1]];for (int j = 0; j < layerNumNodes[l] + 1; j++) {//加一是因为多了一个偏移量for (int i = 0; i < layerNumNodes[l + 1]; i++) {// 初始化权重edgeWeights[l][j][i] = random.nextDouble();//这个也太随意了。} // Of for i} // Of for j} // Of for l}// Of the constructor/*********************** 前瞻性预测。* * @param paraInput*            一个一个实例输入* @return The data at the output end.*********************/public double[] forward(double[] paraInput) {// 初始化输入层for (int i = 0; i < layerNodeValues[0].length; i++) {// 输出层在第一层,这里初始化它的数据layerNodeValues[0][i] = paraInput[i];} // Of for i// 计算每个层的节点值。double z;for (int l = 1; l < numLayers; l++) {for (int j = 0; j < layerNodeValues[l].length; j++) {// 根据偏移量初始化,偏移量总是为+1,因为edgeWeights没有包含偏移量所以要先算。z = edgeWeights[l - 1][layerNodeValues[l - 1].length][j];//此节点是所有边上的加权和。for (int i = 0; i < layerNodeValues[l - 1].length; i++) {// 对i进行循环z += edgeWeights[l - 1][i][j] * layerNodeValues[l - 1][i];// 这个边乘以权的和,edgeWeights[第几层][第几个结点][权值]} // Of for i// Sigmoid 函数激活处理// 对于其他激活功能,应更改此行。layerNodeValues[l][j] = 1 / (1 + Math.exp(-z));} // Of for j} // Of for lreturn layerNodeValues[numLayers - 1];//就会把每个结点的值算出来。}// Of forward/*********************** 反向传播和更改边权重。(BP)* * @param paraTarget*            For 3-class data, it is [0, 0, 1], [0, 1, 0] or [1, 0, 0].*********************/public void backPropagation(double[] paraTarget) {// Step 1. 初始化输出层的误差int l = numLayers - 1;//反向的所以索引在输出层for (int j = 0; j < layerNodeErrors[l].length; j++) {// 求导之后layerNodeErrors[l][j] = layerNodeValues[l][j] * (1 - layerNodeValues[l][j])* (paraTarget[j] - layerNodeValues[l][j]);// 这个输出层gi 没有涉及偏移量,他没有偏移项} // Of for j 先初始化了输出层// Step 2.即使对于l(L)==0,反向传播也是如此while (l > 0) {l--;//开始减一指向倒数第一个隐层。// 层l(L),用于每个节点。for (int j = 0; j < layerNumNodes[l]; j++) {// 第l层第j个结点和后一层第i个结点的权重,double z = 0.0;// 对于下一层的每个节点。for (int i = 0; i < layerNumNodes[l + 1]; i++) {// 这个循环建立在输出层if (l > 0) {z += layerNodeErrors[l + 1][i] * edgeWeights[l][j][i]; // sigma(求和) gi*bh +1选择后一层第一次是输出层的gi,这里是算隐层所以有求和} // Of if// 权重调整edgeWeightsDelta[l][j][i] = mobp * edgeWeightsDelta[l][j][i]+ learningRate * layerNodeErrors[l + 1][i] * layerNodeValues[l][j];// 后一层的giedgeWeights[l][j][i] += edgeWeightsDelta[l][j][i];// 这个公式还不一样,v=mv+n△vif (j == layerNumNodes[l] - 1) {// 当j等于层的节点数,除去偏移项// 调整偏置零件的权值。顺带更新bedgeWeightsDelta[l][j + 1][i] = mobp * edgeWeightsDelta[l][j + 1][i]+ learningRate * layerNodeErrors[l + 1][i];// j+1表示偏执项索引edgeWeights[l][j + 1][i] += edgeWeightsDelta[l][j + 1][i];} // Of if} // Of for i 把一层的权重搞完了// 根据Sigmoid的微分记录误差。// 对于其他激活功能,应更改此行.layerNodeErrors[l][j] = layerNodeValues[l][j] * (1 - layerNodeValues[l][j]) * z;// 什么意思?->隐层公式:} // Of for j} // Of while}// Of backPropagation/*********************** Test the algorithm.*********************/public static void main(String[] args) {int[] tempLayerNodes = { 4, 8, 8, 3 };SimpleAnn tempNetwork = new SimpleAnn("D:/data/iris.arff", tempLayerNodes, 0.01,0.6);for (int round = 0; round < 5000; round++) {tempNetwork.train();} // Of for ndouble tempAccuracy = tempNetwork.test();System.out.println("The accuracy is: " + tempAccuracy);}// Of main
}// Of class SimpleAnn

  这里有一个问题我们得提出来,我的理论讲的需要一个阈值。这里把阈值换成了线性函数,激活函数的自变量本应该是 输入值-阈值 。理论中我们举例也用的 y= x+b 这种形式 ,多了一个结点。实际上隐层的结点有九个了。

  还有一点就是在更新b的时候,我没有在理论将,代码在这里。当然你还是要根据自己的有一个大局观才看得懂,哈哈哈。

if (j == layerNumNodes[l] - 1) {// 当j等于层的节点数,更新到了b结点的前一个。// 调整偏置零件的权值。顺带更新bedgeWeightsDelta[l][j + 1][i] = mobp * edgeWeightsDelta[l][j + 1][i]+ learningRate * layerNodeErrors[l + 1][i];// j+1表示偏执项索引edgeWeights[l][j + 1][i] += edgeWeightsDelta[l][j + 1][i];} // Of if

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/1379803.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

关系代数和SQL语法

数据分析的语言接口 OLAP计算引擎是一架机器&#xff0c;而操作这架机器的是编程语言。使用者通过特定语言告诉计算引擎&#xff0c;需要读取哪些数据、以及需要进行什么样的计算。编程语言有很多种&#xff0c;任何人都可以设计出一门编程语言&#xff0c;然后设计对应的编译…

优雅的对象

最近一口气读完了二百多页的《Elegant Objects》。可能因为整理自博客所以排版一般,而且才二百多页定价却40多刀。但读过之后发现超值,甚至还想去买第二卷。作者观点大多比较激进,对自己的理念异常坚定,所以经常使用诸如“绝对不要使用XXX”、“记住XXX,就这样,句号”。但…

深入理解Java 8 Lambda

关于 深入理解 Java 8 Lambda&#xff08;语言篇——lambda&#xff0c;方法引用&#xff0c;目标类型和默认方法&#xff09;深入理解 Java 8 Lambda&#xff08;类库篇——Streams API&#xff0c;Collector 和并行&#xff09;深入理解 Java 8 Lambda&#xff08;原理篇——…

自然语言处理中注意力机制综述

https://www.toutiao.com/a6655120292144218637/ 目录 1.写在前面 2.Seq2Seq 模型 3.NLP中注意力机制起源 4.NLP中的注意力机制 5.Hierarchical Attention 6.Self-Attention 7.Memory-based Attention 8.Soft/Hard Attention 9.Global/Local Attention 10.评价指标 11.写在后面…

【深度学习基础】从零开始的炼丹生活00——机器学习数学基础以及数值计算数值优化方法

正值假期&#xff0c;决定恶补机器学习、深度学习及相关领域&#xff08;顺便开个博客&#xff09;。首先学习一下数学基础以及数值计算的方法&#xff08;主要参考《深度学习》&#xff09; 一、数学基础 这里简单复习一下机器学习相关的数学1.线性代数 范数 衡量一个向量的…

“泰迪杯”挑战赛 -利用非侵入式负荷检测进行高效率数据挖掘(完整数学模型)

目录 1 研究背景与意义 2 变量说明 3 问题分析 4 问题一 4.1 数据预处理 4.1.1 降噪处理 4.1.2 数据变换 4.2 负荷特征分析 4.2.1 暂态特征 4.2.2 稳态特征 5 问题二 5.1 相似度与权系数 5.2 模型建立 5.3 模型求解 6 问题三 6.1 事件检测算法 6.2 模型建立 6.3 模型求解…

37%原则如何优化我们做决定的时间

当需要百(千&#xff0c;万…)里挑一时&#xff0c;需要权衡最优解和效率&#xff0c;有一个37%原则比较有趣。 整个择优过程分为两个阶段&#xff1a; 观望&#xff1a;在前面 k k k个候选者中冒泡记录最优者 p p p&#xff0c;其分数为 V p V_p Vp​&#xff0c;但并不选择…

清风数学建模学习笔记——层次分析法

目录 一、模型简介 二、建模步骤 三、模型总结 一、层次分析法——模型简介 层次分析法&#xff0c;简称AHP&#xff0c;是指将与决策总是有关的元素分解成目标、准则、方案等层次&#xff0c;在此基础之上进行定性和定量分析的决策方法。该方法是美国运筹学家匹茨堡大学教授萨…

Attention is all you need ---Transformer

大语言模型已经在很多领域大显身手&#xff0c;其应用包括只能写作、音乐创作、知识问答、聊天、客服、广告文案、论文、新闻、小说创作、润色、会议/文章摘要等等领域。在商业上模型即产品、服务即产品、插件即产品&#xff0c;任何形态的用户可触及的都可以是产品&#xff0c…

you-get下载速度慢解决方法

Python版本&#xff1a;3.10 运行环境&#xff1a;Windows10 问题描述&#xff1a;在使用you-get下载X站视频时网速很慢&#xff0c;并一直限制在某个值,通过以下办法即可恢复正常网速 解决办法&#xff1a; 进入windows 安全中心-病毒和威胁防护-管理设置点击添加或删除排…

Microsoft store下载速度过慢

最开始是进入Microsoft store点击安装后一直无响应&#xff0c;后来知道这是因为Microsoft store下载速度过慢。下边几个步骤都尝试了&#xff0c;个人认为最重要的是Windows Update设置步骤&#xff0c;刚开始可能一直没有正确打开 修改DNS 右键任务栏网络图标->打开“网…

Linux网络编程 socket编程篇(一) socket编程基础

目录 一、预备知识 1.IP地址 2.端口号 3.网络通信 4.TCP协议简介 5.UDP协议简介 6.网络字节序 二、socket 1.什么是socket(套接字)&#xff1f; 2.为什么要有套接字&#xff1f; 3.套接字的主要类型 拓】网络套接字 三、socket API 1.socket API是什么&#xff1f; 2.为什么…

如何预防ssl中间人攻击?

当我们连上公共WiFi打开网页或邮箱时&#xff0c;殊不知此时可能有人正在监视着我们的各种网络活动。打开账户网页那一瞬间&#xff0c;不法分子可能已经盗取了我们的银行凭证、家庭住址、电子邮件和联系人信息&#xff0c;而这一切我们却毫不知情。这是一种网络上常见的“中间…

[保研/考研机试] KY3 约数的个数 清华大学复试上机题 C++实现

题目链接&#xff1a; KY3 约数的个数 https://www.nowcoder.com/share/jump/437195121691716950188 描述 输入n个整数,依次输出每个数的约数的个数 输入描述&#xff1a; 输入的第一行为N&#xff0c;即数组的个数(N<1000) 接下来的1行包括N个整数&#xff0c;其中每个…

wsl2安装mysql环境

安装完mysql后通过如下命令启动mysql service mysql start 会显示如下错误&#xff1a; mysql: unrecognized service 实际上上面显示的错误是由于mysql没有启动成功造成的 我们要想办法成功启动mysql才可以 1.通过如下操作就可以跳过密码直接进入mysql环境 2.如果想找到my…

nodejs+vue+elementui美食网站的设计与实现演示录像2023_0fh04

本次的毕业设计主要就是设计并开发一个美食网站软件。运用当前Google提供的nodejs 框架来实现对美食信息查询功能。当然使用的数据库是mysql。系统主要包括个人信息修改&#xff0c;对餐厅管理、用户管理、餐厅信息管理、菜系分类管理、美食信息管理、美食文化管理、系统管理、…

【百度翻译api】中文自动翻译为英文

欸&#xff0c;最近想做一些nlp的项目&#xff0c;做完了中文的想做做英文的&#xff0c;但是呢&#xff0c;国内爬虫爬取的肯定都是中文 &#xff0c;爬取外网的技术我没有尝试过&#xff0c;没有把握。所以我决定启用翻译&#xff0c;在这期间chatGPT给了我非常多的方法&…

关于电脑连接好WiFi却无法使用浏览器上网的一种解决方法

如果你的电脑的网络设置里选项是自动获取ip地址的话&#xff0c;那么大概率适用此方法。&#xff08;我这个已经是填好的&#xff0c;之前是自动获取&#xff09; 方法步骤&#xff1a;这里分两步 &#xff08;1&#xff09;首先确定无法使用浏览器上网的原因。&#xff08;比…

windows11连接上WiFi但是无法上网

电脑经常会出现网络等问题&#xff0c;win11在连接到WiFi&#xff0c;但是无法正常上网。进行网络诊断显示“该设备或资源&#xff08;Web代理&#xff09;未设置为接收端口7890”。借鉴过网络上许多方法都没有解决。可以尝试使用以下这种方式解决&#xff0c;本人亲测已解决。…

电脑显示wifi连接但是不能上网(dns无法连接)

网络问题 电脑显示wifi连接但是不能上网 1.使用手机等其它设备&#xff0c;连接同一个wifi&#xff0c;检查是否出现问题 如果其它设备也不能使用&#xff0c;则为网络本身的问题 如果不是&#xff0c;在继续检查电脑的问题 2.诊断问题 打开 “网络和Internet设置”找到下…