《动手学深度学习(PyTorch版)》笔记8.6

注:书中对代码的讲解并不详细,本文对很多细节做了详细注释。另外,书上的源代码是在Jupyter Notebook上运行的,较为分散,本文将代码集中起来,并加以完善,全部用vscode在python 3.9.18下测试通过,同时对于书上部分章节也做了整合。

Chapter8 Recurrent Neural Networks

8.6 Concise Implementation of RNN

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
import matplotlib.pyplot as pltbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)num_hiddens = 256
rnn_layer = nn.RNN(len(vocab), num_hiddens)state = torch.zeros((1, batch_size, num_hiddens))
print(state.shape)X = torch.rand(size=(num_steps, batch_size, len(vocab)))
Y, state_new = rnn_layer(X, state)#Y不涉及输出层的计算
print(Y.shape, state_new.shape)class RNNModel(nn.Module):#@save"""循环神经网络模型"""def __init__(self, rnn_layer, vocab_size, **kwargs):super(RNNModel, self).__init__(**kwargs)self.rnn = rnn_layerself.vocab_size = vocab_sizeself.num_hiddens = self.rnn.hidden_size# 如果RNN是双向的(之后将介绍),num_directions应该是2,否则应该是1if not self.rnn.bidirectional:self.num_directions = 1self.linear = nn.Linear(self.num_hiddens, self.vocab_size)else:self.num_directions = 2self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)def forward(self, inputs, state):X = F.one_hot(inputs.T.long(), self.vocab_size)X = X.to(torch.float32)Y, state = self.rnn(X, state)# 全连接层首先将Y的形状改为(时间步数*批量大小,隐藏单元数),它的输出形状是(时间步数*批量大小,词表大小)。output = self.linear(Y.reshape((-1, Y.shape[-1])))return output, statedef begin_state(self, device, batch_size=1):if not isinstance(self.rnn, nn.LSTM):#nn.GRU以张量作为隐状态#GRU为门控循环单元(Gated Recurrent Unit),是一种流行的循环神经网络变体。#GRU使用了一组门控机制来控制信息的流动,包括更新门(update gate)和重置门(reset gate),以更好地捕捉长期依赖关系return  torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens),device=device)else:#nn.LSTM以元组作为隐状态#LSTM代表长短期记忆网络(Long Short-Term Memory),是另一种常用的循环神经网络类型。#相比于简单的循环神经网络,LSTM引入了三个门控单元:输入门(input gate)、遗忘门(forget gate)和输出门(output gate),以及一个记忆单元(cell state),可以更有效地处理长期依赖性。return (torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device),torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device))device = d2l.try_gpu()
net = RNNModel(rnn_layer, vocab_size=len(vocab))
net = net.to(device)
d2l.predict_ch8('time traveller', 10, net, vocab, device)num_epochs, lr = 500, 1
d2l.train_ch8(net, train_iter, vocab, lr, num_epochs, device)
plt.show()

训练结果:
在这里插入图片描述

与上一节相比,由于pytorch的高级API对代码进行了更多的优化,该模型在较短的时间内达到了较低的困惑度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2780050.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

网络安全工程师技能手册(附学习路线图)

关键词:网络安全入门、渗透测试学习、零基础学安全、网络安全学习路线 安全是互联网公司的生命,也是每位网民的基本需求。现在越来越多的人对网络安全感兴趣,愿意投奔到网络安全事业之中,这是一个很好的现象。 很多对网络安全感…

蓝牙 - BTSnoop File Format

1, Overview [ 概览 ] BTSnoop 文件格式适用于存储 Bluetooth HCI 通讯数据。它与 RFC 1761 中记录的 snoop 格式非常相似。 The BTSnoop file format is suitable for storing Bluetooth HCI traffic. It closely resembles the snoop format, as documented in RFC 1761. 2, …

flask+python企业产品订单管理系统938re

在设计中采用“自下而上”的思想,在创新型产品提前购模块实现了个人中心、个体管理、发布企业管理、投资企业管理、项目分类管理、产品项目管理、个体投资管理、企业投资管理、个体订单管理、企业订单管理、系统管理等的功能性进行操作。最终,对基本系统…

CRNN介绍:用于识别图中文本的深度学习模型

CRNN:用于识别图中文本的深度学习模型 CRNN介绍:用于识别图中文本的深度学习模型CRNN的结构组成部分工作原理 CRNN结构分析卷积层(Convolutional Layers)递归层(Recurrent Layers)转录层(Transc…

Unity Meta Quest MR 开发(四):使用 Scene API 和 Depth API 实现深度识别和环境遮挡

文章目录 📕教程说明📕Scene API 实现遮挡📕Scene API 实现遮挡的缺点📕Depth API 实现遮挡⭐导入 Depth API⭐修改环境配置⭐添加 EnvironmentDepthOcclusion 预制体⭐给物体替换遮挡 Shader⭐取消现实手部的遮挡效果 此教程相关…

vue常用9个事件修饰符

第075个 查看专栏目录: VUE ------ element UI 专栏目标 在vue和element UI联合技术栈的操控下,本专栏提供行之有效的源代码示例和信息点介绍,做到灵活运用。 提供vue2的一些基本操作:安装、引用,模板使用,computed&a…

备战蓝桥杯---搜索(进阶4)

话不多说,直接看题: 下面是分析: (ab)%c(a%cb%c)%c; (a*b)%c(a%c*b%c)%c; 因此,如果两个长度不一样的值%m为相同值,那就舍弃长的(因为再加1位只不过是原来值*10那位值,因此他们得出的%m还是同…

【JavaScript 漫游】【014】正则表达式通关

文章简介 JS 语言中的 RegExp 对象提供正则表达式的功能。本篇文章旨在对该对象的相关知识点进行总结。内容包括: 正则表达式概述RegExp 对象的实例属性RegExp 对象的实例方法字符串与正则表达式相关的实例方法正则表达式匹配规则 概述 正则表达式的概念 正则表…

尚硅谷 Vue3+TypeScript 学习笔记(下)

目录 五、组件通信 5.1. 【props】 5.2. 【自定义事件】 5.3. 【mitt】 5.4.【v-model】 5.5.【$attrs】 5.6. 【$refs、$parent】 5.7. 【provide、inject】 5.8. 【pinia】 5.9. 【slot】 1. 默认插槽 2. 具名插槽 3. 作用域插槽 六、其它 API 6.1.【shallowR…

使用client-only 解决组件不兼容SSR问题

目录 前言 一、解决方案 1.基于Nuxt 框架的SSR应用 2.基于vue2框架的应用 3.基于vue3框架的应用 二、总结 往期回顾 前言 最近在我的单页面SSR应用上开发JSON编辑器功能,在引入组件后直接客户端跳转OK,但是在直接加载服务端渲染的时候一直报这…

操作系统——内存管理(附带Leetcode算法题LRU)

1.内存管理主要用来干什么? 操作系统的内存管理主要负责内存的分配与回收、内存扩充(虚拟技术)、地址转换(逻辑-物理)、内存保护(保证各进程在自己的内存空间运行,不会越界访问)..... 2.什么是内存碎片? 内存碎片是内存的申请和释放产生的…

第77讲用户管理功能实现

用户管理功能实现 前端&#xff1a; views/user/index.vue <template><el-card><el-row :gutter"20" class"header"><el-col :span"7"><el-input placeholder"请输入用户昵称..." clearable v-model"…

java学习(面向对象高级部分)

一、类变量 用一个例子引出类变量&#xff08;静态变量&#xff09; package object;public class temp {public static void main(String[] args) {child child1 new child("王");child1.count;child child2 new child("丽");child2.count;child child3…

python+flask+django医院预约挂号系统6nrhh

医院预约挂号系统主要有管理员、用户和医生三个功能模块。以下将对这三个功能的作用进行详细的剖析。 技术栈 后端&#xff1a;python 前端&#xff1a;vue.jselementui 框架&#xff1a;django/flask Python版本&#xff1a;python3.7 数据库&#xff1a;mysql5.7 数据库工具…

备战蓝桥杯---数学基础3

本专题主要围绕同余来讲&#xff1a; 下面介绍一下基本概念与定理&#xff1a; 下面给出解这方程的一个例子&#xff1a; 下面是用代码实现扩展欧几里得算法&#xff1a; #include<bits/stdc.h> using namespace std; int gcd(int a,int b,int &x,int &y){if(b…

读千脑智能笔记11_保存人类遗产

1. 智能生物通常能延续多久 1.1. SETI和METI计划的可行性在很大程度上取决于智能生物通常能延续多久 1.1.1. 搜寻地外文明&#xff08;以下简称SETI&#xff09;计划的目标 1.1.1.1. 这是一个力图寻找宇宙其他地方智能生物存在证据的研究项目 1.1.1.2. SETI计划旨在寻找含有…

AI跟踪报道第28期-新加坡内哥谈技术-本周AI新闻:Gemini Ultra 来了

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

【PWN · heap | Arbitrary Alloc】2015_9447ctf_search-engine

和【PWN heap | House Of Spirit】2014_hack.lu_oreo-CSDN博客略有区别&#xff0c;但都是通过malloc一块fake_chunk到指定区域&#xff0c;获得对该区域的写权限 目录 零、简单介绍 一、题目分析 1.主要功能 2.index_sentence(): 增添一条语句到“库”中 3.search_word(…

第四章 数学知识(一)(质数,质因子,)

一、数论 (一&#xff09;判断质数 1、质数合数都是针对>2的整数 2、质数的判断 一些因数的判断是不必要的。 #include<bits/stdc.h> //判断素数O&#xff08;n^0.5&#xff09; using namespace std; int n; bool is_Prime(int n) {if(n<2)return false;for(i…

pytorch常用激活函数笔记

1. relu函数&#xff1a; 公式&#xff1a; 深层网络内部激活函数常用这个 import matplotlib.pyplot as pltdef relu_fun(x):if x>0:return xelse:return 0x np.random.randn(10) y np.arange(10)plt.plot(y,x)for i ,t in enumerate(x):x[i] relu_fun(t) plt.p…