对话模型Demo解读(使用代码解读原理)

文章目录

  • 前言
  • 一、数据加工
  • 二、模型搭建
  • 三、模型训练
    • 1、构建模型
    • 2、优化器与损失函数定义
    • 3、模型训练
  • 四、模型推理
  • 五、所有Demo源码


前言

对话模型是一种人工智能技术,旨在使计算机能够像人类一样进行对话和交流。这种模型通常基于深度学习和自然语言处理技术,能够理解自然语言并做出相应的回应。然而现有博客很少介绍对话模型内容,也很少用一个简单代码带领大家理解其原理。因此,我创建一个简单的对话模型,在不适用Hugging Face或LSTM结构,旨在使用一个简单的全连接神经网络来实现这个模型,且代码基于PyTorch框架搭建,意在帮助读者构建对话模型知识。当然,模型仅是一个简单模型,旨在帮助理解原理,不具备很好效果能力。


一、数据加工

文本数据最终都是转为对应字典索引代表其文本内容,输入模型加工,实现nlp任务,对话模型也不列外。因此,我们需要构建一个字典映射(可参考:点击这里)与文本数据,并按照字典映射转换为对应索引id,其代码如下:

# 定义一个简单的对话数据集
data = [("hi", "hello"),("how are you?", "I'm fine, thank you."),("what's your name?", "I'm a chatbot.")
]# 构建词汇表
vocab = list(set(" ".join([x[0] + " " + x[1] for x in data])))
vocab.append("<SOS>")
vocab.append("<EOS>")word_to_idx = {word: i for i, word in enumerate(vocab)}
idx_to_word = {i: word for i, word in enumerate(vocab)}# 将对话数据集转换为索引序列
def to_idx_seq(sentence):return [word_to_idx[word] for word in sentence]data_x = [to_idx_seq(x[0]) for x in data]
data_y = [to_idx_seq(x[1]) for x in data]data_y =[[word_to_idx["<SOS>"]]+list(x)+[word_to_idx["<EOS>"]] for x in data_y]

其字典内容如下:
在这里插入图片描述

二、模型搭建

这里,也是最重要内容,如何搭建对话模型,我是使用transformer结构搭建(之前模型使用LSTM模型搭建),创建一个简单的对话生成模型,也使用一个基于全连接层的神经网络来实现这个模型,其代码如下:


# 定义一个简单的Transformer生成式对话模型
class ChatbotTransformer(nn.Module):def __init__(self, input_dim, output_dim, nhead, num_encoder_layers, num_decoder_layers):super(ChatbotTransformer, self).__init__()self.input_dim = input_dimself.output_dim = output_dimself.embedding = nn.Embedding(len(vocab), input_dim)self.transformer = nn.Transformer(d_model=input_dim, nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers)self.fc = nn.Linear(input_dim, output_dim)def forward(self, src, tgt):src = self.embedding(src).permute(1, 0, 2)  # 调整输入张量维度tgt = self.embedding(tgt).permute(1, 0, 2)  # 调整输入张量维度output = self.transformer(src, tgt)output = self.fc(output)return output

从代码上看,我们需要输入src为前面提问数据,而答案生成输入,是每个tgt字输入,按照顺序输出预测。

三、模型训练

1、构建模型

我大概试了一下,使用更多层对于简单数据反而效果不佳,我的数据又比较简单,我构建了较少的层来预测模型。其模型构建代码如下:

# 创建模型实例
# model = ChatbotTransformer(input_dim=256, output_dim=len(vocab), nhead=8, num_encoder_layers=6, num_decoder_layers=6)
model = ChatbotTransformer(input_dim=16, output_dim=len(vocab), nhead=8, num_encoder_layers=1, num_decoder_layers=1)

2、优化器与损失函数定义

优化器定义我将不考虑介绍,只说下文本预测是交叉熵方式,实际文本基本都采用该方法作为loss计算。其代码如下:

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

3、模型训练

接下来,我们解读如何训练对话模型,我们获得对应输入数据与生成预测数据,我们开始训练模型,其代码如下:

# 模型训练
epochs = 800
for epoch in range(epochs):total_loss = 0for i in range(len(data_x)):optimizer.zero_grad()input_seq = torch.tensor(data_x[i])target_seq = torch.tensor(data_y[i])for j in range(len(target_seq)-1):if random.random() < 0.5:j = random.randint(0, len(target_seq)-2)output = model(input_seq.unsqueeze(0), target_seq[:j+1].unsqueeze(0))  # 使用目标序列的前n-1个词预测后n-1个词loss = criterion(output.view(-1, len(vocab)), target_seq[1:j+2].view(-1))  # 计算损失,需要将输出形状转换成二维loss.backward()optimizer.step()total_loss += loss.item()if (epoch+1) % 10 == 0:print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, epochs, total_loss / len(data_x)))

假设input_seq =tensor([17, 6, 0, 15, 1, 8, 9, 15, 7, 6, 11, 4]),target_seq=tensor([20, 12, 18, 13, 15, 5, 16, 10, 9, 2, 15, 3, 17, 1, 10, 19, 15, 7, 6, 11, 14, 21]),vocab为字典映射,共22个映射。我将试图模型连续2轮解释一下,训练时候相关变化。
第二轮输出结果:
在这里插入图片描述
第三轮输出结果:
在这里插入图片描述
后面以此类推迭代。

很明显,提问每次都是全部输入,而输出则是第一个20开始与输入共同进模型,分别是模型src与tgt,不断重复与预测完成训练。而loss计算都是往后取一个target文本,也发现并不会计算20索引""文本。

四、模型推理

最后,让我们使用训练好的模型进行推理,实际和上面训练讲到方法类似,我们开始""文本开始,不断给出生成对应文本,也就是对话内容。

# 进行推理
def generate_response(input_sentence):model.eval()input_seq = torch.tensor(to_idx_seq(input_sentence))target_seq = torch.tensor([word_to_idx["<SOS>"]])  # 在开始时使用特殊的起始标记with torch.no_grad():for i in range(20):  # 限制生成的句子长度为20个词output = model(input_seq.unsqueeze(0), target_seq.unsqueeze(0))output_token = output.argmax(2)[-1].item()print(idx_to_word[output_token], end=" ")target_seq = torch.cat((target_seq, torch.tensor([output_token])), dim=0)res = [idx_to_word[int(k)] for k in target_seq]print(res[1:-1])return res[1:-1]# 进行对话生成
generate_response("hi")

其结果如下:
在这里插入图片描述


五、所有Demo源码

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
# 定义一个简单的对话数据集
data = [# ("hi", "hello"),("how are you?", "I'm fine, thank you."),# ("what's your name?", "I'm a chatbot.")
]# 构建词汇表
vocab = list(set(" ".join([x[0] + " " + x[1] for x in data])))
vocab.append("<SOS>")
vocab.append("<EOS>")word_to_idx = {word: i for i, word in enumerate(vocab)}
idx_to_word = {i: word for i, word in enumerate(vocab)}# 将对话数据集转换为索引序列
def to_idx_seq(sentence):return [word_to_idx[word] for word in sentence]data_x = [to_idx_seq(x[0]) for x in data]
data_y = [to_idx_seq(x[1]) for x in data]data_y =[[word_to_idx["<SOS>"]]+list(x)+[word_to_idx["<EOS>"]] for x in data_y]# 定义一个简单的Transformer生成式对话模型
class ChatbotTransformer(nn.Module):def __init__(self, input_dim, output_dim, nhead, num_encoder_layers, num_decoder_layers):super(ChatbotTransformer, self).__init__()self.input_dim = input_dimself.output_dim = output_dimself.embedding = nn.Embedding(len(vocab), input_dim)self.transformer = nn.Transformer(d_model=input_dim, nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers)self.fc = nn.Linear(input_dim, output_dim)def forward(self, src, tgt):src = self.embedding(src).permute(1, 0, 2)  # 调整输入张量维度tgt = self.embedding(tgt).permute(1, 0, 2)  # 调整输入张量维度output = self.transformer(src, tgt)output = self.fc(output)return output# 创建模型实例
# model = ChatbotTransformer(input_dim=256, output_dim=len(vocab), nhead=8, num_encoder_layers=6, num_decoder_layers=6)
model = ChatbotTransformer(input_dim=16, output_dim=len(vocab), nhead=8, num_encoder_layers=1, num_decoder_layers=1)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)# 模型训练
epochs = 800
for epoch in range(epochs):total_loss = 0for i in range(len(data_x)):optimizer.zero_grad()input_seq = torch.tensor(data_x[i])target_seq = torch.tensor(data_y[i])for j in range(len(target_seq)-1):if random.random() < 0.5:j = random.randint(0, len(target_seq)-2)output = model(input_seq.unsqueeze(0), target_seq[:j+1].unsqueeze(0))  # 使用目标序列的前n-1个词预测后n-1个词loss = criterion(output.view(-1, len(vocab)), target_seq[1:j+2].view(-1))  # 计算损失,需要将输出形状转换成二维loss.backward()optimizer.step()total_loss += loss.item()if (epoch+1) % 10 == 0:print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, epochs, total_loss / len(data_x)))# 进行推理
def generate_response(input_sentence):model.eval()input_seq = torch.tensor(to_idx_seq(input_sentence))target_seq = torch.tensor([word_to_idx["<SOS>"]])  # 在开始时使用特殊的起始标记with torch.no_grad():for i in range(20):  # 限制生成的句子长度为20个词output = model(input_seq.unsqueeze(0), target_seq.unsqueeze(0))output_token = output.argmax(2)[-1].item()print(idx_to_word[output_token], end=" ")target_seq = torch.cat((target_seq, torch.tensor([output_token])), dim=0)res = [idx_to_word[int(k)] for k in target_seq]print(res[1:-1])return res[1:-1]# 进行对话生成
generate_response("how are you?")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2778962.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

七、热身仪式(Warm-Up Rituals)

5.Warm Up Rituals 五、热身仪式 A warm up ritual is your per flight checklist you go through before you start focusing for a big session.It may be checking that you have water, that you don’t need to use the bathroom, that your phone is turned off or you’…

基于微信小程序的校园故障维修管理系统的研究与实现

博主介绍&#xff1a;✌程序员徐师兄、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;…

Sodinokibi(REvil)勒索病毒黑客组织攻击姿势全解

前言 2021年6月 11日&#xff0c;国外媒体 threatpost 发布文章宣称美国能源部 (DOE) 的分包商同时也是美国国家核安全局 (NNSA) 核武器开发合作商的 Sol Oriens 公司遭受到网络攻击&#xff0c;并且 Sol Oriens 公司人员已证实该公司于上月发现被勒索病毒攻击&#xff0c;而国…

Java图形化界面编程——组件绘图原理 笔记

2.8 绘图 ​ 很多程序如各种小游戏都需要在窗口中绘制各种图形&#xff0c;除此之外&#xff0c;即使在开发JavaEE项目时&#xff0c; 有 时候也必须"动态"地向客户 端生成各种图形、图表&#xff0c;比如 图形验证码、统计图等&#xff0c;这都需要利用AWT的绘图功…

深入理解Netty及核心组件使用—下

目录 ChannelHandler ChannelHandler 接口 ChannelInboundHandler 接口 ChannelHandler 的适配器 Handler 的共享和并发安全性 资源管理和 SimpleChannelInboundHandler Bootstrap ChannelInitializer ChannelOption ChannelHandler ChannelHandler 接口 从开发人员的…

重构利器:如何用 Immer 优雅地管理应用状态

1. immer immer 是一个 JavaScript 库,用于处理不可变数据的状态更新。不可变数据意味着一旦创建,数据结构就不能被修改。在编写复杂的应用程序时,不可变性可以带来一系列好处,比如更容易追踪数据的改变、更容易实现撤销/重做功能以及更简单的状态管理。 然而,处理不可变…

更新win11后无法上网

问题描述 系统提示可以更新win11了&#xff0c;然后我就想着更新一下试试。等了好久终于下载完了准备更新&#xff0c;结果提示更新失败&#xff0c;再次更新时下载到一半就停了&#xff0c;然后就发现连不上网了&#xff08;真服了&#xff0c;win11没更新成功&#xff0c;还…

MATLAB环境下一维时间序列信号的同步压缩小波包变换

时频分析相较于目前的时域、频域信号处理方法在分析时变信号方面&#xff0c;其主要优势在于可以同时提供时域和频域等多域信号信息&#xff0c;并清晰的刻画了频率随时间的变化规律&#xff0c;已被广泛用于医学工程、地震、雷达、生物及机械等领域。 线性时频分析方法是将信…

「C++ 类和对象篇 10」初始化列表

目录 一、什么是初始化列表&#xff1f; 二、为什么需要初始化列表&#xff1f; 三、初始化列表怎么使用&#xff1f; 3.1 在构造函数中使用初始化列表 3.2 注意 3.3 结论 3.4 应用场景 四、初始化列表的初始化顺序 五、另一种初始化成员变量的方法 【总结】 一、什么是初始化…

快速幂的应用

1.非递归的解法 #include <iostream> using namespace std; int main(){int a,b,c,t1;cin>>a>>b>>c;if(a>2&&a<1e3&&b>0&&a<1e7&&c>2&&c<1e5)for(int i0;i<b;i)tt*a%c;cout<<t;r…

Python算法题集_随机链表的复制

Python算法题集_随机链表的复制 题138&#xff1a;随机链表的复制1. 示例说明2. 题目解析- 题意分解- 优化思路- 测量工具 3. 代码展开1) 标准求解【双层循环】2) 改进版一【字典哈希】3) 改进版二【单层哈希】4) 改进版三【递归大法】 4. 最优算法 本文为Python算法题集之一的…

Verilog刷题笔记24

题目&#xff1a; Verilog has a ternary conditional operator ( ? : ) much like C: (condition ? if_true : if_false) This can be used to choose one of two values based on condition (a mux!) on one line, without using an if-then inside a combinational alwa…

K8s环境下rook-v1.13.3部署Ceph-v18.2.1集群

文章目录 1.K8s环境搭建2.Ceph集群部署2.1 部署Rook Operator2.2 镜像准备2.3 配置节点角色2.4 部署operator2.5 部署Ceph集群2.6 强制删除命名空间2.7 验证集群 3.Ceph界面 1.K8s环境搭建 参考&#xff1a;CentOS7搭建k8s-v1.28.6集群详情&#xff0c;把K8s集群完成搭建&…

C#,欧拉常数(Euler Constant)的算法与源代码

1 欧拉常数 欧拉常数最先由瑞士数学家莱昂哈德 欧拉 (Leonhard Euler) 在1735年发表的文章《De Progressionibus harmonicus observationes》中定义。欧拉曾经使用γ作为它的符号&#xff0c;并计算出了它的前6位&#xff0c;1761年他又将该值计算到了16位 。 欧拉常数最先由瑞…

java实现文件随机加密

1、引言 有时候我们需要对我们的某些文件数据进行加密&#xff0c;并且不希望被轻易破译&#xff0c;此时最好不要使用已知的加密方法&#xff0c;这里我就给大家提供一种数据加密的方式&#xff0c;用以实现文件数据的加密&#xff0c;我称之为随机加密&#xff0c;即使是对相…

setState的参数

目录 1、setState的第一个参数 2、setState的第二个参数 3、在 React 底层主要做了那些事呢&#xff1f; 4、类组件如何限制 state 更新视图 React 项目中的 UI 的改变来源于 State 改变&#xff0c;类组件中 setState 是更新组件&#xff0c;渲染视图的 1、setState的第一个参…

vs2013与对应的opencv3.2下载地址和环境配置

Visual Studio community 2013 • 链接&#xff1a; https://pan.baidu.com/s/1ye-EDDyUGsy8EaZKaq9Ksw 提取码&#xff1a; 06lw -- 来自百度网盘超级会员 V7 的分享 • 也可以从下面官网下载 https://learn.microsoft.com/zh-tw/visualstudio/releasenotes/vs2013-communit…

【JAVA WEB】 百度热榜实现 新闻页面 Chrome 调试工具

目录 百度热榜 新闻页面 Chrome 调试工具 --查看css属性 打开调试工具的方式 标签页含义 百度热榜 实现效果&#xff1a; 实现代码 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"vi…

骑砍MOD天芒传奇-天芒使用方法

骑砍1战团mod天芒传奇-使用红色天芒碎片开P51战斗机_单机游戏热门视频 (bilibili.com)https://www.bilibili.com/video/BV1nm41197iA/ 一.黄色天芒碎片 天芒盒子 野外战斗H键-召唤徐天地 二.绿色天芒碎片 天芒盒子 野外战斗H键-站在巨人肩膀上战斗 三.蓝色天芒碎片 天芒盒…

鸿蒙(HarmonyOS)项目方舟框架(ArkUI)之StepperItem组件

鸿蒙&#xff08;HarmonyOS&#xff09;项目方舟框架&#xff08;ArkUI&#xff09;之StepperItem组件 一、操作环境 操作系统: Windows 10 专业版、IDE:DevEco Studio 3.1、SDK:HarmonyOS 3.1 二、StepperItem组件 用作Stepper组件的页面子组件。 子组件 无。 接口 St…