李沐56_门控循环单元——自学笔记

关注每一个序列

1.不是每个观察值都是同等重要

2.想只记住的观察需要:能关注的机制(更新门 update gate)、能遗忘的机制(重置门 reset gate)

!pip install --upgrade d2l==0.17.5  #d2l需要更新
import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
Downloading ../data/timemachine.txt from http://d2l-data.s3-accelerate.amazonaws.com/timemachine.txt...

下一步是初始化模型参数。 我们从标准差为0.01的高斯分布中提取权重, 并将偏置项设为0,超参数num_hiddens定义隐藏单元的数量, 实例化与更新门、重置门、候选隐状态和输出层相关的所有权重和偏置。

def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device)*0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xz, W_hz, b_z = three()  # 更新门参数W_xr, W_hr, b_r = three()  # 重置门参数W_xh, W_hh, b_h = three()  # 候选隐状态参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return params

将定义隐状态的初始化函数init_gru_state。此函数返回一个形状为(批量大小,隐藏单元个数)的张量,张量的值全部为零。

def init_gru_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), )

准备定义门控循环单元模型, 模型的架构与基本的循环神经网络单元是相同的, 只是权重更新公式更为复杂。

def gru(inputs, state, params):W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []for X in inputs:Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)H = Z * H + (1 - Z) * H_tildaY = H @ W_hq + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)

训练结束后,我们分别打印输出训练集的困惑度, 以及前缀“time traveler”和“traveler”的预测序列上的困惑度。

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.1, 31831.9 tokens/sec on cuda:0
time traveller for so it will be convenient to speak of himwas e
travelleryou can show black is white by argument said filby

在这里插入图片描述

简洁实现

num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs, num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.0, 255484.2 tokens/sec on cuda:0
time traveller for so it will be convenient to speak of himwas e
traveller with a slight accession ofcheerfulness really thi

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2980221.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

集群工具之HAProxy

集群工具之HAProxy HAProxy简介 它是一款实现负载均衡的调度器适用于负载特别大的web站点HAProxy的工作模式 mode http:只适用于web服务mode tcp:适用于各种服务mode health:仅做健康检查,很少使用 配置HAProxy client&#x…

Datawhale |【独家】万字长文带你梳理Llama开源家族:从Llama-1到Llama-3

本文来源公众号“Datawhale”,仅用于学术分享,侵权删,干货满满。 原文链接:【独家】万字长文带你梳理Llama开源家族:从Llama-1到Llama-3 0. 引言 在AI领域,大模型的发展正以前所未有的速度推进技术的边界…

4(第三章,数据治理)

目录 概述 业务驱动因素 目标和原则 1、可持续发展 2、嵌入式 3、可度量 基本概念 数据治理与数据管理的关系 数据治理组织 数据治理运营模型类型 数据管理岗位的类型 数据治理的成果体现 国内的数据治理 什么是数据治理 为什么进行数据治理 数据治理的必要性 …

Linux 操作系统的引导过程

Linux系统开机引导过程: 开机自检 检测硬件设备,找到能够引导系统的设备,比如硬盘MBR引导 运行MBR扇区里的主引导程序GRUB启动GRUB菜单 系统读取GRUB配置文件(/boot/grub2/grub.cfg)获取内核的设置和…

《内向者优势》:不要低估一个内向的人

#世界读书日 作者主页: 🔗进朱者赤的博客 精选专栏:🔗经典算法 作者简介:阿里非典型程序员一枚 ,记录在大厂的打怪升级之路。 一起学习Java、大数据、数据结构算法(公众号同名) ❤…

[RTOS 学习记录] 复杂工程项目的管理

[RTOS 学习记录] 复杂工程项目的管理 这篇文章是我阅读《嵌入式实时操作系统μCOS-II原理及应用》后的读书笔记,记录目的是为了个人后续回顾复习使用。 前置内容: 工程管理工具make及makefile 文章目录 1 批处理文件与makefile的综合使用1.1 批处理文件…

C语言学习/复习29--内存操作函数memcpy/memmove/memset/memcmp

一、内存操作函数 1.memcpy()函数 注意事项1:复制的数目以字节为单位 注意事项2:一定要保证有足够空间复制 模拟实现1 拷贝字符案例:由于拷贝时函数本事就以字节为单位拷贝所以该例子也可用于其他类型数据的拷贝。 模拟实现2 将自身的…

YOLOv8 关键点检测模型训练部署

文章目录 1、YOLOv8安装及使用1.2、命令行使用1.3、使用python-API模型预测1.4、pt转换ONNX 2、训练三角板关键点检测模型2.1、训练命令 3、ONNX Runtime部署 1、YOLOv8安装及使用 参考链接: 同济子豪兄视频 github原文链接 # 安装yolov8 pip install ultralytics --upgrade …

Linux-LVM与磁盘配额

一、LVM概述 Logical Volume Manager,逻辑卷管理 能够在保持现有数据不变的情况下动态调整磁盘容量,从而提高磁盘管理的灵活性 /boot分区用于存放引导文件,不能基于LVM创建 LVM机制的基本概念 PV(物理卷)&#xff…

情感识别——情感计算的模型和数据集调查

概述 情感计算指的是识别人类情感、情绪和感觉的工作,已经成为语言学、社会学、心理学、计算机科学和生理学等领域大量研究的主题。 本文将概述情感计算的重要性,涵盖思想、概念和方法。 情感计算是皮卡德于 1997 年提出的一个想法,此后出…

生产数据采集系统

在数字化浪潮的推动下,生产数据采集系统已经成为企业提升生产效率、优化运营管理的关键工具。那么,什么是生产数据采集系统呢?简单来说,生产数据采集系统是指通过一系列技术手段,实时收集、处理和分析生产线上的各类数…

STM32 I²C通信

一、IC总线通信 1.1 IC总线特点 IC(Inter Integrated Circuit,集成电路总线),通过串行数据线SDA(Serial Data)和串行时钟线SCL(Serial Clock)来完成数据的传输。 特点:…

java泛型介绍

Java 泛型是 JDK 5 引入的一个特性,它允许我们在定义类、接口和方法时使用类型参数,从而使代码更加灵活和类型安全。泛型的主要目的是在编译期提供类型参数,让程序员能够在编译期间就捕获类型错误,而不是在运行时才发现。这样做提…

(ICML-2021)从自然语言监督中学习可迁移的视觉模型

从自然语言监督中学习可迁移的视觉模型 Title:Learning Transferable Visual Models From Natural Language Supervision paper是OpenAI发表在ICML 21的工作 paper链接 Abstract SOTA计算机视觉系统经过训练可以预测一组固定的预定目标类别。这种受限的监督形式限制…

[笔试训练](四)

010 Fibonacci数列_牛客题霸_牛客网 (nowcoder.com) 题目: 题解: 1.创建一个数组fib[],保存范围内的所有斐波那契数,再求离N最近的斐波那契数。 2.创建3个数a,b,c,依次先后滚动,可得出所有的斐波那契数&#xff0c…

椋鸟数据结构笔记#11:排序·下

文章目录 外排序(外部排序)文件拆分并排序归并文件两个文件归并多文件归并优化 萌新的学习笔记,写错了恳请斧正。 外排序(外部排序) 当数据量非常庞大以至于无法全部写入内存时,我们应该怎么排序这些数据呢…

贪吃蛇(C语言版)

在我们学习完C语言 和单链表知识点后 我们开始写个贪吃蛇的代码 目标:使用C语言在Windows环境的控制台模拟实现经典小游戏贪吃蛇 贪吃蛇代码实现的基本功能: 地图的绘制 蛇、食物的创建 蛇的状态(正常 撞墙 撞到自己 正常退出&#xf…

SpringCloud系列(11)--将微服务注册进Eureka集群

前言:在上一章节中我们介绍并成功搭建了Eureka集群,本章节则介绍如何把微服务注册进Eureka集群,使服务达到高可用的目的 Eureka架构原理图 1、分别修改consumer-order80模块和provider-payment8001模块的application.yml文件,使这…

刷题之Leetcode242题(超级详细)

242.有效的字母异位词 力扣题目链接(opens new window)https://leetcode.cn/problems/valid-anagram/ 给定两个字符串 s 和 t ,编写一个函数来判断 t 是否是 s 的字母异位词。 示例 1: 输入: s "anagram", t "nagaram" 输出: true 示例 2…

使用kali进行DDos攻击

使用kali进行DDos攻击 1、打开命令提示符,下载DDos-Attack python脚本 git clone https://github.com/Elsa-zlt/DDos-Attack 2、下载好之后,cd到DDos-Attack文件夹下 cd DDos-Attack 3、修改(设置)对ddos-attack.py文件执行的权…