pytorch-RNN实战-正弦曲线预测

目录

  • 1. 正弦数据生成
  • 2. 构建网络
  • 3. 训练
  • 4. 预测
  • 5. 完整代码
  • 6. 结果展示

1. 正弦数据生成

曲线如下图:
在这里插入图片描述
代码如下图:

  • 50个点构成一个正弦曲线
  • 随机生成一个0~3之间的一个值(随机的原因是防止每次都从相同的点开始,50个点的正弦曲线一样,被模型记住),值的范围区间是[start, start+10]
  • 输入x范围[0,48],预测值y范围是[1,49]

在这里插入图片描述

2. 构建网络

下图是构建的网络,注意out维度扩展出一个维度,是为了和y维度一致
在这里插入图片描述

3. 训练

loss计算采用均方差MSE,优化器采用Adam
注意:hidden_prev的自更新
在这里插入图片描述

4. 预测

预测是循环一个点一个点的预测,每次预测的点的结果作为下次点的输入,直到预测出全部点,放到predictions中。
input = x[:,0,:] 去掉了x[1,seq,1]中的seq维度,变成[1,1]
在这里插入图片描述

5. 完整代码

import  numpy as np
import  torch
import  torch.nn as nn
import  torch.optim as optim
from    matplotlib import pyplot as pltnum_time_steps = 50
input_size = 1
hidden_size = 16
output_size = 1
lr=0.01class Net(nn.Module):def __init__(self, ):super(Net, self).__init__()self.rnn = nn.RNN(input_size=input_size,hidden_size=hidden_size,num_layers=1,batch_first=True,)for p in self.rnn.parameters():nn.init.normal_(p, mean=0.0, std=0.001)self.linear = nn.Linear(hidden_size, output_size)def forward(self, x, hidden_prev):out, hidden_prev = self.rnn(x, hidden_prev)# [b, seq, h]out = out.view(-1, hidden_size)out = self.linear(out)out = out.unsqueeze(dim=0)return out, hidden_prevmodel = Net()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr)hidden_prev = torch.zeros(1, 1, hidden_size)for iter in range(6000):start = np.random.randint(3, size=1)[0]time_steps = np.linspace(start, start + 10, num_time_steps)data = np.sin(time_steps)data = data.reshape(num_time_steps, 1)x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)output, hidden_prev = model(x, hidden_prev)hidden_prev = hidden_prev.detach()loss = criterion(output, y)model.zero_grad()loss.backward()# for p in model.parameters():#     print(p.grad.norm())# torch.nn.utils.clip_grad_norm_(p, 10)optimizer.step()if iter % 100 == 0:print("Iteration: {} loss {}".format(iter, loss.item()))start = np.random.randint(3, size=1)[0]
time_steps = np.linspace(start, start + 10, num_time_steps)
data = np.sin(time_steps)
data = data.reshape(num_time_steps, 1)
x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)
y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)predictions = []
input = x[:, 0, :]
for _ in range(x.shape[1]):input = input.view(1, 1, 1)(pred, hidden_prev) = model(input, hidden_prev)input = predpredictions.append(pred.detach().numpy().ravel()[0])x = x.data.numpy().ravel()
y = y.data.numpy()
plt.scatter(time_steps[:-1], x.ravel(), s=90)
plt.plot(time_steps[:-1], x.ravel())plt.scatter(time_steps[1:], predictions)
plt.show()

6. 结果展示

图中黄色点是预测点,蓝色为实际点,前面的曲线是start不随机预测的效果,说明曲线已经被模型记住了;后面的曲线是start随机预测的效果,基本趋势和真实点是一致的。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3226365.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

JavaSE 面向对象程序设计进阶 IO流 字节流详解 抛出异常

input output 像水流一样读取数据 存储和读取数据的解决方案 内存中数据不能永久化存储 程序停止运行 数据消失 File只能对文件本身进行操作 不能读写文件里存储的数据 读写数据必须要有IO流 可以把程序中的数据保存到文件当中 还可以把本地文件中的数据读取到数据当中 分…

白酒营销策划全攻略:从市场调研到执行落地的实战指南!

为白酒品牌做营销策划,那可得像给自家的孩子挑衣服一样,得量身定制,得考虑孩子的身材、喜好,还得看看衣服的款式和布料。 这里可以分享一点自己多年的实战干货给你,希望对你有所帮助。 首先,得做好“侦查…

【常见开源库的二次开发】一文学懂CJSON

简介: JSON(JavaScript Object Notation)是一种轻量级的数据交换格式。它基于JavaScript的一个子集,但是JSON是独立于语言的,这意味着尽管JSON是由JavaScript语法衍生出来的,它可以被任何编程语言读取和生成…

CentOS7系统上安装MySQL8.0(rpm-bundle.tar)详细过程

一、MySQL官网下载安装包 1.进入官网MySQL :: Download MySQL Community Server 2.查看自己的版本和架构 uname -mcat /etc/redhat-release 3.选择对应版本并下载 4.查看linux自带的mariadb数据库,有就卸载掉。 rpm -qa | grep mariadbrpm -e mariadb-libs…

【卡尔曼滤波】高斯白噪声

生成高斯白噪声并将其应用于信号处理 生成高斯白噪声并将其应用于信号处理 #以下是一个生成高斯白噪声并将其应用于信号处理的示例代码:import numpy as np import matplotlib.pyplot as plt import matplotlib.font_manager ## not work#notice matplotlibrc is a file, not…

学生选课管理系统(Java+MySQL)

技术栈 Java: 用于实现系统的核心业务逻辑。MySQL: 作为关系型数据库,用于存储系统中的数据。JDBC: 用于Java程序与MySQL数据库之间的连接和交互。Swing GUI: 用于创建图形用户界面,提升用户体验。 系统功能 我们的学生选课管理系统主要针对学生和管理…

突破传统:实现智慧校园实习单位变更

在智慧校园的实习管理系统设计中,充分考虑到了实习阶段学生可能遇到的实际需求,特别是实习单位变更这一灵活性要求,系统特设了一套完善的在线处理机制,旨在促进学生、学校与企业间的顺畅沟通与协调,确保实习过程的平稳…

Gmail邮件提醒通知如何设置?有哪些方法?

Gmail邮件提醒通知功能怎么样?通知邮件怎么有效发送? Gmail作为全球广泛使用的电子邮件服务,提供了多种邮件提醒通知功能,帮助用户不错过重要信息。AokSend将详细介绍如何设置Gmail邮件提醒通知,确保您不会错过任何重…

IT审计必看!对比旧版,CISA考试改版升级亮点和重点内容是什么?

官方通知,今年8月1日,CISA新版考纲正式上线,旧版在7月23日后就无法约考了。 艾威培训邀请了国内知名的IT审计CISA授课老师吴老师来为大家详细讲解CISA新版考纲的变化 目前第28th版教材只有英文版,中文版尚未发布。我们艾威经验丰…

【NOI-题解】1108 - 正整数N转换成一个二进制数1290 - 二进制转换十进制1386 - 小丽找半个回文数1405 - 小丽找潜在的素数?

文章目录 一、前言二、问题问题:1108 - 正整数N转换成一个二进制数问题:1290 - 二进制转换十进制问题:1386 - 小丽找半个回文数问题:1405 - 小丽找潜在的素数? 三、感谢 一、前言 本章节主要对进制转换的题目进行讲解…

【UNI-APP】阿里NLS一句话听写typescript模块

阿里提供的demo代码都是javascript,自己捏个轮子。参考着自己写了一个阿里巴巴一句话听写Nls的typescript模块。VUE3的组合式API形式 startClient:开始听写,注意下一步要尽快开启识别和传数据,否则6秒后会关闭 startRecognition…

javascript高级部分笔记

javascript高级部分 Function方法 与 函数式编程 call 语法:call([thisObj[,arg1[, arg2[, [,.argN]]]]]) 定义:调用一个对象的一个方法,以另一个对象替换当前对象。 说明:call 方法可以用来代替另一个对象调用一个方法。cal…

侯捷C++面向对象高级编程(下)-2-non-explicit one argument constructor

1.构造函数 构造函数: Fraction(int num, int den 1) 初始化分子和分母,允许指定分子 num 和可选的分母 den。默认情况下,分母为 1。 加法运算符重载: Fraction operator(const Fraction& f) 重载了加法运算符 。这使得两个 Fraction 对象可以通过 …

NodeJS校园快递智能互助平台-计算机毕业设计源码58554

摘 要 随着校园人口的增加和生活节奏的加快,校园快递成为一个重要的服务需求。然而,传统的校园快递方式存在一些问题,例如无法满足快速和高效的需求,易发生丢失或损坏的情况,同时也给快递人员和用户带来不便。因此&am…

成功登上主要中心化交易所 (CEX) 的终极指南:从准备到上市的全面策略

对于区块链项目的创始人而言,成功的代币发行是项目发展的关键一步。尤其是在主要中心化交易所 (CEX) 上上市代币,可以极大地提高项目的曝光度和流动性。然而,CEX 上市过程复杂且充满挑战,需要创始人提前做好充分准备。本文将详细介…

JavaSE语法 | 初识Java!!!

初识Java 一、Java开发环境二、初步认识Java的main方法2.1 main方法的实现2.2 运行Java程序 三、注释四、标识符五、关键字 一、Java开发环境 IDEA版本:IntelliJ IDEA Community Edition 2022.3.3 JDK17 Windows 11 二、初步认识Java的main方法 2.1 main方法的实…

comsol multiphysics在岩土工程中的应用

comsol教程推荐: comsol multiphysics在岩土工程中的应用 [comsol multiphysics在岩土工程中的应用](https://download.csdn.net/download/qq_36980284/89529402) 出版发行项: 北京:中国建筑工业出版社,2014 ISBN及定价: 978-7-112-16188-1 CNY42.00 载体形态项:…

【Python】一文向您详细介绍 argparse中 action=‘store_false’ 的作用

【Python】一文向您详细介绍 argparse中 action‘store_false’ 的作用 下滑即可查看博客内容 🌈 欢迎莅临我的个人主页 👈这里是我静心耕耘深度学习领域、真诚分享知识与智慧的小天地!🎇 🎓 博主简介&#xff1a…

邮件推送服务的自动化流程设置与优化技巧?

邮件推送服务如何定制化?邮件推送的安全性如何保障? 无论是大型企业还是小型企业,通过精准的邮件推送服务,可以实现客户关系管理的有效增强,提升品牌认知度和销售转化率。AokSend将探讨如何通过自动化流程设置与优化技…

天翼云高级运维工程师202407回忆题库 最新出炉

备考天翼云高级运维工程师 必须备考天翼云 之前觉得外企牛批 然后民企,拔地而起,民企也不错,工资高,有钱途 现在看来看去,还是国企好,体制内的,有保障,树大根深 有必要备考下天…