llm的inference(一)

文章目录

  • 前提
  • LLM
    • LLM结构
    • 1.Encoder-only
    • 2.Encoder-Decoder
    • 3.Decoder-only
  • 宏观层面的LLM推理过程
  • 宏观推理过程的进一步详细说明
    • 从字符串输入到网络的输出
  • 总结
  • 参考链接

前提

对LLM(大语言模型)的推理不太清楚,自己把遇到的和推理相关的知识做个总结,如有错误,还请大家多多指导。

LLM

LLM可以分为:闭源和半开源。闭源就是什么都没有,只有接口供用户使用,例如GPT系列;半开源就是提供了权重文件,但是没有训练的代码,我们可以使用这个权重进行推理,例如llama系列。

LLM结构

现代的LLM模型基本上都是在Transformer结构上进行展开的。原生Transfromer结构如下:
在这里插入图片描述

在《Attention is All Your Need》这篇文章中,Encoder和Decoder的层数 N = 6 N=6 N=6
不同的LLM的架构架构有所区别,主要分为以下三种:

1.Encoder-only

只使用Encoder部分,没有Decoder部分,这样就可以看成是我们之前了解的自编码器。它可以对数据进行维度压缩、特征提取等等。Bert就使用的是Encoder-only的架构。

2.Encoder-Decoder

同时使用编码-解码器,编码器将输入编码为中间表达,然后解码器对中间表达进行解码。输入是序列,输出也是序列,所以它可以应对翻译,生成这样的任务。使用这种架构的llm有GLM、T5、Bart等。

3.Decoder-only

只使用Decoder部分,没有Encoder。通常用于条件生成任务,给定一些条件信息作为输入,模型通过解码器生成相应的输出,例如图像描述生成,文本生成任务。使用这种架构的llm有GPT系列,llama系列等。常见的chat模型都是基于Decoder-only。

宏观层面的LLM推理过程

先从宏观层面了解一下llm的推理过程,llm的推理是一个token一个token往出崩的,也就是串行输出的。总的来说:LLM推理的过程是一个自回归的过程,也就是说前i次的token会作为第i+1次的预测数据送入模型,拿到第i+1次的推理token。如下图所示:
在这里插入图片描述

起始输入的token是china,输出的token是is;在接下里的这次回归过程中,会将上一次的输出ischina进行拼接后再一起送入模型,以此类推,直到遇到条件中止。
宏观上LLM就是这么进行推理的。

宏观推理过程的进一步详细说明

上面的过程是宏观的过程,那么这个过程中大概的数据流是什么样呢?接下来我们对此做个简单的介绍。

从字符串输入到网络的输出

假设我们正在和GPT3进行对话,我们输入"Lionel Messi is a",则input="Lionel Messi is a"。我们知道神经网络只能输入数字(也就是tensor),那么从input开始是如何变为tensor的,它的维度又是多少,网络输出的维度又是多少呢?在下文我们一一解答。

  1. 首先会对input进行tokenizer,也就是将input划分为不同的token,并将其转换为对应的ID。在openAI 网站中可以给我们一个直观的解释。
    在这里插入图片描述
    tokenizerinput划分为6个Tokens,并为每个Token赋予了一个ID:
    在这里插入图片描述
    所以经过tokenizer后,input变为了[43, 295, 417, 36128, 318, 257]
  2. 进行one-hot编码。可以确定的是在GPT3中总共使用了50257个词汇,因此将上面的input进行one-hot编码,编码后的维度为[6,50257]。
  3. embedding(词嵌入)。可以看到50257维度有些太大了,并且大部分都为0,浪费了很多空间。所以可以试图将输入向量进行压缩(也就是映射到低维空间),减小维度。GPT压缩后的维度为12288,具体的过程就是:[6,50257]*[50257,12888] = [6,12288]。其中[50257,12288]这个矩阵的权重已经提前训练好了。
  4. 最终网络的输入为[6,12888],经过网络的推理后维度仍为[6,12288],然后我们对其进行反映射(乘以另一个权重矩阵,大小为[12888,50257]),最终得到的输出大小为[6,50257],然后我们取最后一个行向量即为网络预测的结果,然后从其中取概率最大的token作为网络的输出。

注意:这里我们的输入只有6个token,但是在实际的代码运行过程中会将这6个token填补到2048大小(GPT3),使用"空"值填补其它位置。参考一些解答:说是主要为了positional embeding,因此位置信息编码是固定的。

相关代码如下:

# copy from https://zhuanlan.zhihu.com/p/630832593
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
model = GPT2LMHeadModel.from_pretrained("gpt2", torchscript=True).eval()# tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
in_text = "Lionel Messi is a"
in_tokens = torch.tensor(tokenizer.encode(in_text))
print(in_tokens)# inference
token_eos = torch.tensor([198]) # line break symbol
out_token = None
i = 0
with torch.no_grad():while out_token != token_eos:print("input tokens shape:", in_tokens.shape)logits, _ = model(in_tokens)out_token = torch.argmax(logits[-1, :], dim=0, keepdim=True)in_tokens = torch.cat((in_tokens, out_token), 0)text = tokenizer.decode(in_tokens)print(f'step {i} input: {text}', flush=True)i += 1out_text = tokenizer.decode(in_tokens)
print(f'Input: {in_text}')
print(f'Output: {out_text}')

可以看到,大模型预测结果是一个token一个token串行进行预测的,它只会拿预测概率最高的token。

总结

上面我们首先介绍了llm的几种架构,从宏观层面分析了LLM的推理过程,并对其中的一些数据流做了简单的分析,接下来我们要从工程方面分析大模型如何进行推理以及推理过程中的一些指标。

参考链接

  1. https://zhuanlan.zhihu.com/p/630832593
  2. https://zhuanlan.zhihu.com/p/174782647

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2805416.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

Autoencoder深度学习中的无监督学习神经网络

在当今的深度学习领域中,自动编码器(Autoencoder)是一种常见的无监督学习神经网络模型,用于学习有效的数据表示。自动编码器在许多领域都有广泛的应用,包括特征提取、降维、图像去噪、生成模型等。 自动编码器的基本原…

Servlet使用Cookie和Session

一、会话技术 当用户访问web应用时,在许多情况下,web服务器必须能够跟踪用户的状态。比如许多用户在购物网站上购物,Web服务器为每个用户配置了虚拟的购物车。当某个用户请求将一件商品放入购物车时,web服务器必须根据发出请求的…

Danswer-开源统一搜索,用AI与您的文档聊天

简介 Danswer允许您以自然语言提问并根据您团队的特定文档获取答案。如果 ChatGPT 能够访问您团队的独特知识。连接到所有常见的工作场所工具,例如 Slack、Google Drive、Confluence 等。 优势 加快客户支持和升级时间。通过使文档和代码变更日志易于查找来提高工…

openGauss学习笔记-228 openGauss性能调优-系统调优-LLVM使用建议

文章目录 openGauss学习笔记-228 openGauss性能调优-系统调优-LLVM使用建议 openGauss学习笔记-228 openGauss性能调优-系统调优-LLVM使用建议 目前LLVM在数据库内核侧已默认打开,用户可结合上述的分析进行配置,总体建议如下: 设置合理的wor…

谷歌搜索引擎关键词优化,竞价排名怎么做?大舍传媒

公司 大舍传媒成立于2005年,并从那时开始专注于谷歌搜索引擎优化(SEO)。如今,我们已经拥有了十八年的海外数字营销经验。我们为全球数千个国际知名品牌客户提供服务,是一家专注于技术的公司。 谷歌排名成果 在谷歌&…

Python读取.nc数据并提取指定时间、经纬度维度对应的变量数值

本文介绍基于Python语言的netCDF4库,读取.nc格式的数据文件,并提取指定维(时间、经度与纬度)下的变量数据的方法。 我们之前介绍过.nc格式的数据,其是NetCDF(Network Common Data Form)文件的扩…

完全增量式PID应用介绍(详细框图算法分析)

PID系列算法和代码可以订阅PID专栏查看更多应用介绍,常用链接如下: 1、增量式PID的抗扰 https://rxxw-control.blog.csdn.net/article/details/136253663https://rxxw-control.blog.csdn.net/article/details/1362536632、线性化功能块S_RTR https://rxxw-control.blog.cs…

普中51单片机学习(红外通信)

红外通信 红外线系统的组成 外线遥控器已被广泛使用在各种类型的家电产品上,它的出现给使用电器提供了很多的便利。红外线系统一般由红外发射装置和红外接收设备两大部分组成。红外发射装置又可由键盘电路、红外编码芯片、电源和红外发射电路组成。红外接收设备可由…

数学家的趣闻轶事65则

目录 前言趣闻轶事65则参考文献 前言 有人的地方就有江湖,有江湖的地方就有故事。数学本身就是一个江湖,这个江湖也充满着血雨腥风和侠骨柔情,至今流传着各种各样的传说,其中不乏”马踏江湖潇潇事“,也有"何当共…

【openGL教程08】关于着色器(02)

LearnOpenGL - Shaders 一、说明 着色器是openGL渲染的重要内容,客户如果想自我实现渲染灵活性,可以用着色器进行编程,这种程序小脚本被传送到GPU的显卡内部,起到动态灵活的着色作用。 二、着色器简述 正如“Hello Triangle”一章…

【JavaEE】_tomcat的安装与使用

目录 1. Tomcat简介 2. Tomcat安装 2.1 下载Tomcat并解压缩 2.2 启动Tomcat 2.2.1 Tomcat乱码问题 2.2.2 Tomcat闪退问题 2.3 访问Tomcat欢迎页面 3. 使用Tomcat部署前端代码 3.1 路径匹配 3.2 文件路径访问与网络访问 4. 静态页面与动态页面 5. 基于tomcat的网站后…

如何成交国外大客户拿下大单?

点线面 作为外贸人,很多人都会感慨,拿下客户订单不容易,拿下大客户的大订单更不容易,因为大客户的采购必须顾及更多因素和风险。 这就要求我们在面对大客户时,必须综合点、线、面相结合为切入点,充分挖掘…

互联网加竞赛 机器视觉 opencv 深度学习 驾驶人脸疲劳检测系统 -python

文章目录 0 前言1 课题背景2 Dlib人脸识别2.1 简介2.2 Dlib优点2.3 相关代码2.4 人脸数据库2.5 人脸录入加识别效果 3 疲劳检测算法3.1 眼睛检测算法3.2 打哈欠检测算法3.3 点头检测算法 4 PyQt54.1 简介4.2相关界面代码 5 最后 0 前言 🔥 优质竞赛项目系列&#x…

iOS调用系统已安装地图及内置地图实现

info.plist要添加scheme: 1.地图列表: NSArray *mapKeys=[[NSArray alloc] initWithObjects:@"com.autonavi.minimap",@"com.baidu.BaiduMap",@"com.google.android.apps.maps",@"com.tencent.map", nil]; NSArray *mapSchemes=[[NS…

深度学习基础(二)卷积神经网络(CNN)

之前的章节我们初步介绍了深度学习相关基础知识和训练神经网络: 深度学习基础(一)神经网络基本原理-CSDN博客文章浏览阅读924次,点赞13次,收藏19次。在如今的科技浪潮中,神经网络作为人工智能的核心技术之…

HMI界面:是工业自动化的“窗口”,大有用武之地。

Hello,我是大千UI工场,本期分享HMI人机交互界面在工业自动化领域的应用,关注大千,学习N多UI干货,有设计需求,我们也可以接单。 HMI(Human Machine Interface,人机界面)在…

java: warning: source release 11 requires target release 11 解决办法

遇到问题 运行项目时报如下错 java: warning: source release 11 requires target release 11 原因:创建项目的时候选择的java11版本,现在用java8版本运行就会报这个错 查看项目的iml文件中LANGUAGE_LEVEL“JDK_xx”是多少 .iml 文件是 IntelliJ ID…

JAVA工程师面试专题-Mysql篇

一、基础 1、mysql可以使用多少列创建索引? 16 2、mysql常用的存储引擎有哪些 存储引擎Storage engine:MySQL中的数据、索引以及其他对象是如何存储的,是一套文件系统的实现。常用的存储引擎有以下: Innodb引擎:In…

How to implement multiple file uploads based on Swagger 3.x in Spring boot 3.x

How to implement multiple file uploads based on Swagger 3.x in Spring boot 3.x Projectpom.xmlOpenAPIConfigFileUploadControllerapplication.yaml Project pom.xml <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://…

爬取m3u8视频

网址&#xff1a;https://www.bhlsm.com/cupfoxplay/609-3-1/ 相关代码&#xff1a; #采集网址&#xff1a;https://www.bhlsm.com/cupfoxplay/609-3-1/ #正常视频网站&#xff1a;完整视频内容 # pip install pycryptodomex #流媒体文件&#xff1a;M3U8&#xff08;把完整的…