李沐57_长短期记忆网络LSTM——自学笔记

LSTM

1.忘记门:将值朝着0减少

2.输入门:决定不是忽略掉输入数据

3.输出门:决定是不是使用隐状态

!pip install --upgrade d2l==0.17.5  #d2l需要更新

首先加载时光机器数据集。

import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
Downloading ../data/timemachine.txt from http://d2l-data.s3-accelerate.amazonaws.com/timemachine.txt...

初始化模型参数:超参数num_hiddens定义隐藏单元的数量。 我们按照标准差0.01的高斯分布初始化权重,并将偏置项设为0。

def get_lstm_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device)*0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xi, W_hi, b_i = three()  # 输入门参数W_xf, W_hf, b_f = three()  # 遗忘门参数W_xo, W_ho, b_o = three()  # 输出门参数W_xc, W_hc, b_c = three()  # 候选记忆元参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,b_c, W_hq, b_q]for param in params:param.requires_grad_(True)return params

定义模型:长短期记忆网络的隐状态需要返回一个额外的记忆元, 单元的值为0,形状为(批量大小,隐藏单元数)。

def init_lstm_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device),torch.zeros((batch_size, num_hiddens), device=device))

实际模型的定义与我们前面讨论的一样: 提供三个门和一个额外的记忆元。

def lstm(inputs, state, params):[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,W_hq, b_q] = params(H, C) = stateoutputs = []for X in inputs:I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)C = F * C + I * C_tildaH = O * torch.tanh(C)Y = (H @ W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H, C)

训练和预测:引入的RNNModelScratch类来训练一个长短期记忆网络

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.3, 28797.1 tokens/sec on cuda:0
time traveller abt whther werl wechat said al somick and the a s
traveller whthe graid that ore seellon twatere whree time s

在这里插入图片描述

简洁实现

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.0, 319759.1 tokens/sec on cuda:0
time traveller for so it will be convenient to speak of himwas e
travelleryou can show black is white by argument said filby

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2981445.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

「React Native」为什么要选择 React Native 作为的跨端方案

文章目录 前言一、常见因素二、举个栗子2.1 项目背景2.2 为什么选择 React Native2.3 项目实施2.4 成果总结 前言 没有完美的跨端技术,只有适合的场景。脱离适用场景去谈跨端技术没有什么意义。 一、常见因素 共享代码库: React Native 允许开发者编写…

LeetCode78:子集

题目描述 给你一个整数数组 nums &#xff0c;数组中的元素 互不相同 。返回该数组所有可能的 子集 &#xff08;幂集&#xff09;。 解集 不能 包含重复的子集。你可以按 任意顺序 返回解集。 代码 class Solution { public:vector<vector<int>> res;vector<…

团队如何异地共享文件?

在当今全球化的办公环境中&#xff0c;团队成员往往分散在不同的地理位置上。为了更好地协同工作&#xff0c;团队之间需要快速、安全地共享文件。本文将介绍一种名为“团队异地共享文件”的解决方案&#xff0c;它能够帮助团队成员在不同地点方便地共享文件&#xff0c;提高工…

代码随想录第四天打卡笔记

链表part2 1&#xff09;两两交换链表中的节点 这道题按照三步走方式&#xff1a; &#xff08;1&#xff09;第一步&#xff0c;设置cur指针指向这两个元素的第一个&#xff08;这里一定要注意保存原结点&#xff01;&#xff09;&#xff0c;断开cur与第一个节点的链接&…

echarts实现云台控制按钮效果,方向按钮

效果图 代码 option {color: [#bfbfbf],tooltip: {show: false},series: [{name: ,type: pie,radius: [40%, 70%],avoidLabelOverlap: true,itemStyle: {// borderRadius: 10,borderColor: #fff,borderWidth: 2},label: {show: true,position: inside,fontSize: 36,color: #f…

CISAW应急服务:网络安全应急响应之路——从经验到认证的体会

随着信息技术的飞速发展&#xff0c;网络安全问题日益凸显。作为一名网络安全从业人员&#xff0c;我深知每一次安全事件给组织甚至国家带来的巨大损失和潜在影响。在多年的实际工作中&#xff0c;我积累了一些网络安全应急服务经验&#xff0c;并参加了信息安全保障人员认证&a…

火绒安全的应用介绍

火绒安全软件是一款集成了杀毒、防御和管控功能的安全软件&#xff0c;旨在为用户提供全面的计算机安全保障。以下是火绒安全软件的一些详细介绍&#xff1a; 系统兼容性强&#xff1a;该软件支持多种操作系统&#xff0c;包括Windows 11、Windows 10、Windows 8、Windows 7、…

深度学习| Attention U-Net(包含Attention Gate代码)

前言&#xff1a;最近在阅读一篇文章&#xff0c;用到了Attention Unet所以特地写了两篇文章&#xff0c;上一篇文章介绍了Attention的基础知识&#xff0c;这篇文化在哪个介绍Attention Unet相关知识以及代码。 Attention U-Net 基础注意力机制软注意力和硬注意力U-Net为什么…

【蓝牙协议栈】【BLE】低功耗蓝牙工作流程(角色\广播\扫描\连接等专业名词介绍)

1. 精讲蓝牙协议栈&#xff08;Bluetooth Stack&#xff09;&#xff1a;SPP/A2DP/AVRCP/HFP/PBAP/IAP2/HID/MAP/OPP/PAN/GATTC/GATTS/HOGP等协议理论 2. 欢迎大家关注和订阅&#xff0c;【精讲蓝牙协议栈】和【Android Bluetooth Stack】专栏会持续更新中.....敬请期待&#x…

✯✯✯宁波ISO45001认证:保障员工健康与安全✯✯✯

&#x1f308;&#x1f308;&#x1f308;最近&#xff0c;我们公司&#x1f970;通过了ISO 45001认证&#xff0c;&#x1f920;这让我感到非常&#x1f60e;兴奋和自豪&#xff01;&#x1f939;作为一个在&#x1f3d9;️宁波工作的员工&#xff0c;&#x1f913;我深深地感…

Golang | Leetcode Golang题解之第48题旋转图像

题目&#xff1a; 题解&#xff1a; func rotate(matrix [][]int) {n : len(matrix)// 水平翻转for i : 0; i < n/2; i {matrix[i], matrix[n-1-i] matrix[n-1-i], matrix[i]}// 主对角线翻转for i : 0; i < n; i {for j : 0; j < i; j {matrix[i][j], matrix[j][i]…

动力学重构/微分方程参数拟合 - 基于模型

这一篇文章&#xff0c;主要是给非线性动力学&#xff0c;对微分方程模型参数拟合感兴趣的朋友写的。笼统的来说&#xff0c;这与混沌系统的预测有关&#xff1b;传统的机器学习的模式识别虽然也会谈论预测结果&#xff0c;但他们一般不会涉及连续的预测。这里我们考虑的是&…

解决双击PDF文件出现打印的问题【Adobe DC】

问题描述 电脑安装Adobe Acrobat DC之后&#xff0c;双击PDF文件就会出现打印&#xff0c;而无法直接打开。 右键PDF文件就会发现&#xff0c;第一栏出现的不是用Adobe打开&#xff0c;而是打印。 重装软件多次仍然无法解决。 原因 右键菜单被改写了。双击其实是执行右键菜…

【氧化镓】影响Ga2O3器件沟道载流子迁移率的关键因素

总结 这篇文章对β-Ga2O3 MOSFETs中的通道载流子迁移率进行了深入研究&#xff0c;通过实验测量和理论分析&#xff0c;揭示了影响迁移率的关键因素&#xff0c;特别是库仑散射的影响。研究结果对于理解和改进β-Ga2O3 MOSFETs的性能具有重要意义&#xff0c;为未来的研究和器…

企业微信知识库为什么这么多企业都在用?原因就在这里

在科技迅猛发展的当下&#xff0c;企业信息化建设逐渐成为推动效率提升和核心竞争力增强的关键因素。而企业微信作为一个集团队沟通、项目管理、办公应用于一体的平台&#xff0c;其知识库功能为什么会受到广泛的欢迎和使用&#xff1f;让我们一探究竟。 首先&#xff0c;企业…

面试中关于 SpringCloud 都需要了解哪些基础?

在面试中&#xff0c;对于Spring Cloud的基础知识了解是至关重要的&#xff0c;因为Spring Cloud是构建分布式系统和微服务架构的关键技术栈之一。以下是在面试中可能会涉及到的相关问题。 1. 微服务架构基础 概念理解&#xff1a;理解微服务架构的概念&#xff0c;包括服务拆…

【C++】手撕list(list的模拟实现)

目录 01.节点 02.迭代器 迭代器运算符重载 03.list类 &#xff08;1&#xff09;构造与析构 &#xff08;2&#xff09;迭代器相关 &#xff08;3&#xff09;容量相关 &#xff08;4&#xff09;访问操作 &#xff08;5&#xff09;插入删除 我们在学习数据结构的时候…

【OceanBase诊断调优 】—— 如何快速定位SQL问题

作者简介&#xff1a; 花名&#xff1a;洪波&#xff0c;OceanBase 数据库解决方案架构师&#xff0c;目前负责 OceanBase 数据库在各大型互联网公司及企事业单位的落地与技术指导&#xff0c;曾就职于互联网大厂和金融科技公司&#xff0c;主导过多项数据库升级、迁移、国产化…

台灯太亮会影响视力吗?分享五款防近视护眼台灯

台灯作为一盏实用的桌面照明灯具&#xff0c;是孩子学习过程中必不可少的“好伴侣”&#xff0c;不过大多数家长只关注到台灯的光线是否够亮&#xff0c;认为只要亮度足够就可以了&#xff0c;那么台灯太亮会影响视力吗&#xff1f; 答案是会的。首先太暗的环境肯定是不够支撑孩…

用友NC Cloud importhttpscer接口任意文件上传漏洞

声明 本文仅用于技术交流&#xff0c;请勿用于非法用途 由于传播、利用此文所提供的信息而造成的任何直接或者间接的后果及损失&#xff0c;均由使用者本人负责&#xff0c;文章作者不为此承担任何责任。 一、漏洞描述 用友NC Cloud的importhttpscer接口如果存在任意文件上传…