用code去探索理解Llama架构的简单又实用的方法

除了白月光我们也需要朱砂痣

      我最近也在反思,可能有时候算法和论文也不是每个读者都爱看,我也会在今后的文章中加点code或者debug模型的内容,也许还有一些好玩的应用demo,会提升这部分在文章类型中的比例

      今天带着大家通过代码角度看一下Llama,或者说看一下Casual-LLM的Transfomer到底长啥样

     对Transfomer架构需要更了解的读者,可以先看这个系列

小周带你读论文-2之"草履虫都能看懂的Transformer老活儿新整"Attention is all you need(1) (qq.com)

小周带你读论文-2之"草履虫都能看懂的Transformer老活儿新整"Attention is all you need(2) (qq.com)

小周带你读论文-2之"草履虫都能看懂的Transformer老活儿新整"Attention is all you need(3) (qq.com)

小周带你读论文-2之"草履虫都能看懂的Transformer老活儿新整"Attention is all you need(4) (qq.com)

       友情提示,看代码和debug都不需要GPU,在你的PC上就可以做

       首先先安装transfomer库

       pip install transfomers

       然后进入到库下面,一般在这

图片

       进去就能找到Transfomers的库,往下拉到models,可以发现各种模型都在里面

图片

找到Llama,包含以下文件

图片

      点进modeling_llama.py发现1200多行,根本没法看(一般人没那么多耐心,但是其实仔细看两遍还是很有收获的),然后点击左边outline 大纲,(或者ctrl+shift+o),就可以有选择的看你想要研究的具体网络层,这样感觉压力瞬间小了百分之80以上

图片

      我想查某个函数的代码块,想对它加深了解,举个例子,看起来比较怪异命名的,这个线性扩展RoPE embedding的函数

图片

       然后ctl按住函数名字就能查到它作用于attetion的机制里面,在这种情况下即使我不知道它到底干啥,也能猜个89不离10,至少和什么主模块相关我清楚了

图片

     本章的话,我们先从模型主体部分看起,点LlamaModel,就能看到非常清晰的逻辑

    

图片

    主体部分包含3个子模块:

  • 先要embedding token

  • 再要包含一个通过for循环,不断的持续经过的decoder层

  • 还要包含一个Normal(RMSNorm)

     

     从大面上看也就这么3个操作

     初始化部分,初始了哪些模块我们看完了,再看看细节,从forward看起(大家看任何网络都要重点看forward)

图片

 forward输入部分,要求输入的参数:

  • input_ids:输入序列标记的索引,形状为(batch_size, sequence_length)的torch.LongTensor。

  • attention_mask:避免对填充标记进行注意力计算的掩码,形状为(batch_size, sequence_length)的torch.Tensor

  • position_ids:输入序列标记在位置嵌入中的索引,形状为(batch_size, sequence_length)的torch.LongTensor

  • past_key_values:包含预先计算的隐藏状态(自注意力块和交叉注意力块中的键和值),用于加速顺序解码的tuple,推理用的,当use_cache=True时,会返回这个参数,训练不用管

  • inputs_embeds:直接传入embedding表示而不是input_ids,形状为(batch_size, sequence_length, hidden_size)的torch.FloatTensor

  • use_cache:是否使用缓存加速解码的布尔值,当设置为True时,past_key_values的键值状态将被返回,用于加速解码

  • output_attentions:是否返回所有注意力层的注意力张量,布尔值,当设置为True时,会在返回的结果中包含注意力张量

  • output_hidden_states:是否返回所有层的隐藏状态,布尔值,当设置为True时,会在返回的结果中包含隐藏状态

  • return_dict:是否返回utils.ModelOutput对象而不是普通的元组,布尔值,当设置为True时,会返回一个ModelOutput对象

         我们继续看,下面就是一些操作步骤,输入的input_id,会被向量化,生成hidden_states

图片

    hidden_states然后就被扔进了若干个hidden_layer被for循环来回的操作,比如Llama7B的32层

   

图片

     我们简单写一段逻辑描述上述的代码

     比如在把"我爱你"已经分词的情况下 我=100,爱=200,你=300

  input_ids = [100,200,300]

  input_ids  -> nn.Emebdding(dims=3) -> hidden_states

  hidden_states = [[0.1,0.2,0.3],[0.4,0.5,0.6],[0.7.0.8.1.1]]

  hidden_states ->layer1 ->layer2 ------>layer32

  最终还是hidden_states

  也就是最终的shape和初始的hidden_states的shape的相同的

   hidden_states -> Norm -> hidden_states

   其实没必要把hidden_states理解的那么悬,就当它是个中间变量就可以了

     Layer里面都有什么呢?我们点进去Layer里面就能看到

图片

     包含了我们在Transfomer里学到的attention层,MLP层,LayerNorm这些层

     继续往下

图片

     我们可以看到首先定义了残差,然后hidden_states在没过layer之前就先被Normal了一下(这个知识点以前也讲过,Llama RMS LN是前LN),然后过attetion,过MLP,最后hidden_states=residual+hidden_states, 这样就把位置编码啥的也都带过来了

     然后我们点击进入attetion这块,看看它是咋做的

图片

      这块其实感觉都不用讲了,我觉得全部代码就这块写的逻辑是最清楚的

图片

,QKV矩阵就是下面前三个线性层,然后分别主管生成qkv,上面的参数也解释了多头,和多头dim是咋算出来的,GQA咋算,这块代码写的实在清晰,属实没啥可说的

       最后看一眼MLP层,就是FFN

图片

      这一层是由3个线性层组成的gate+up+down,和标准transfomer的FFN是真不一样,主要有两点不同:

1- 原始transfomer是由两个线性层组成,这里有3个,原因就在于 SwiGLU 激活函数比类似 ReLU 的激活函数,需要多一个 Linear 层(gate)进行门控,所以你说SwiGLU是比ReLU效率高了,还是低了呢?

图片

(当然这个门控可以并行操作)

2- 原始Transfomer第一个线性层先将维度映射为4h维,第二个线性层再映射回h维,接着进行激活函数操作。而llama则是将原有4h变成一个常量作为输入,且计算方式也略有不同,可能是因为4h这个说法如果模型太大会罩不住,就用一个常量来代替了(我瞎猜的)

        这基本网络就讲完了,其实大家看看也没啥玩意,比较简单

      再比如说我们像看一下Llama是咋做下游任务的,modeling里面有2个,我们挑CLM任务(主管生成的任务,CasualLM)来看

图片

     如上图所示,就是在Llama的基础模型之上,加了一个线性层

     然后我们展开LlamaForCausalLM

图片

      可以看到它包含很多内容,这里面重点肯定还是forward(前向),我们就看一下它的前向是怎么进行的,输入是啥,输出是啥?

图片

      我们先看forward输入部分,要求输入的参数和前面都一样唯一的区别就是Label

  • Labels:没法解释,就是字面上的Labels

    图片

    图片

     输出这块我们就关注hidden_states就行了,这也是最重要的

图片

      然后hiden_states和你的labels进行匹配求交叉熵损失,最后损失最小的就是最高概率生成的token 

      下面那个文本分类的其实也是一样的,只不过它计算Loss的时候会有说明,单标签就是回归任务,多标签就变分类任务了

图片

     通过这些简单的方法,就能把读者们把以前理解不太透彻的网络模块都理一遍,加深印象      

     本章完

图片

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2780299.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

HTTP 超文本传送协议

1 超文本传送协议 HTTP HTTP 是面向事务的 (transaction-oriented) 应用层协议。 使用 TCP 连接进行可靠的传送。 定义了浏览器与万维网服务器通信的格式和规则。 是万维网上能够可靠地交换文件(包括文本、声音、图像等各种多媒体文件)的重要基础。 H…

鸿蒙(HarmonyOS)项目方舟框架(ArkUI)之Divider组件

鸿蒙(HarmonyOS)项目方舟框架(ArkUI)之Divider组件 一、操作环境 操作系统: Windows 10 专业版、IDE:DevEco Studio 3.1、SDK:HarmonyOS 3.1 二、Divider组件 提供分隔器组件,分隔不同内容块/内容元素。 子组件 …

设计模式学习笔记05(小滴课堂)

讲解Adapeter设计模式和应用场景 接口的适配器案例实战 代码: 定义一个接口: 编写适配器: 写我们的商品类: 会员类: 这样我们不同的需求可以根据需要去实现不同的接口方法,而不用实现全部接口方法。 适配…

python+django咖啡网上商城网站

全网站共设计首页、咖啡文化、咖啡商城、个人信息、联系我们5个栏目以及登录、注册界面,让用户能够全面的了解中国咖啡咖啡文化宣传网站以及一些咖啡知识、文化。 栏目一首页,主要放置咖啡的起源及发展进程的图文介绍;栏目二咖啡文化&#xf…

《Linux 简易速速上手小册》第2章: 命令行的艺术(2024 最新版)

文章目录 2.1 基本 Linux 命令2.1.1 重点基础知识2.1.2 重点案例:整理下载文件夹2.1.3 拓展案例 1:批量重命名文件2.1.4 拓展案例 2:查找并删除特定文件 2.2 文件和目录管理2.2.1 重点基础知识2.2.2 重点案例:部署一个简单的网站2…

中国电子学会2020年9月份青少年软件编程Scratch图形化等级考试试卷三级真题(编程题)

编程题(共3题,共30分) 36.题目:魔术表演“开花” 1.准备工作 (1)将舞台设置为"Party"; (2)删除默认角色,自行绘制椭圆花瓣角色; (3&#xf…

fast.ai 机器学习笔记(一)

机器学习 1:第 1 课 原文:medium.com/hiromi_suenaga/machine-learning-1-lesson-1-84a1dc2b5236 译者:飞龙 协议:CC BY-NC-SA 4.0 来自机器学习课程的个人笔记。随着我继续复习课程以“真正”理解它,这些笔记将继续更…

【Django】Django项目部署

项目部署 1 基本概念 项目部署是指在软件开发完毕后,将开发机器上运行的软件实际安装到服务器上进行长期运行。 在安装机器上安装和配置同版本的环境[python,数据库等] django项目迁移 scp /home/euansu/Code/Python/website euansuxx.xx.xx.xx:/home…

C#系列-Entity Framework 架构(18)

下图展示了EF的整体架构。现在让我们逐个地看看架构的各个组件: EF组件图 EDM(Entity Data Mode 实体数据模型):EDM 由三个主要部分组成:概念模型,映射和存储模型。 Conceptual Model(概念模型&#xff0…

【Langchain Agent研究】SalesGPT项目介绍(二)

【Langchain Agent研究】SalesGPT项目介绍(一)-CSDN博客 上节课,我们介绍了SalesGPT他的业务流程和技术架构,这节课,我们来关注一下他的项目整体结构、poetry工具和一些工程项目相关的设计。 项目整体结构介绍 我们把…

【安装记录】安装 netperf 和 perf

这是一篇发疯随笔X.X 我的环境是虚拟机debian12,出于种种原因,之前直接使用apt-get install netperf apt-get install perf指令直接安装,报错找不到包 然后上网搜了一堆教程,有说下载netperf源码编译的,那些教程里面有…

Guitarpro 8.1.1.17中文解锁版2024最新安装激活图文教程

Guitarpro 8.1.1.17中文解锁版一直备受用户喜爱和关注,但也存在一个被诟病的问题,即不支持中国专属的简谱功能。作为国人为了方便学习音乐独创的一种谱写方式,简谱在国内广受欢迎,然而在国际上使用的却很少。为了解决这一问题&…

CrossOver虚拟机软件功能相似的软件

与 CrossOver 功能相似的软件有: Wine:Wine 是一款在 Unix 和 Unix-like 系统(如 Linux、macOS)上运行 Windows 应用程序的兼容层。与 CrossOver 类似,Wine 通过模拟 Windows 的 API 来实现应用程序的兼容性。它支持大…

【新书推荐】7.4节 寄存器间接和相对寻址方式

本节内容:当指令操作数为内存操作数,且内存操作数的地址使用指针寄存器表示时,称为寄存器间接寻址方式。 ■寄存器间接寻址方式:在地址表达式中,只能使用BX、SI、DI、BP四个指针寄存器用来寻址。 7.4.1 寄存器间接寻…

《深入浅出OCR》第八章:文档任务多模态预训练

✨专栏介绍: 经过几个月的精心筹备,本作者推出全新系列《深入浅出OCR》专栏,对标最全OCR教程,具体章节如导图所示,将分别从OCR技术发展、方向、概念、算法、论文、数据集等各种角度展开详细介绍。 👨‍💻面向对象: 本篇前言知识主要介绍深度学习知识,全面总结知知识…

【Rust】使用Rust实现一个简单的shell

一、Rust Rust是一门系统编程语言,由Mozilla开发并开源,专注于安全、速度和并发性。它的主要目标是解决传统系统编程语言(如C和C)中常见的内存安全和并发问题,同时保持高性能和底层控制能力。 Rust的特点包括&#x…

C++2024寒假J312实战班2.5

题目列表: #1多项式输出 #2龙虎斗 #3表达式求值 #4解密 #1多项式输出 这是第一个题目很简单,我也作对了。 我们下来看一下题目: 我们先来看一下样例: 5 100 -1 1 -3 0 10 首先100是第一项,所以不输出加号&…

【Linux技术宝典】Linux入门:揭开Linux的神秘面纱

文章目录 官网Linux 环境的搭建方式一、什么是Linux?二、Linux的起源与发展三、Linux的核心组件四、Linux企业应用现状五、Linux的发行版本六、为什么选择Linux?七、总结 Linux,一个在全球范围内广泛应用的开源操作系统,近年来越来…

枚举(C/C++)

没有什么成套的算法&#xff0c;直接上例题&#xff01;&#xff01; 例题1&#xff1a;赢球票 代码&#xff1a; #include <bits/stdc.h> using namespace std;const int maxn 105; int n,num1[maxn],num2[maxn],cnt,cnt1,sum,ans;int check1()//检查剩余个数 {cnt1…

单片机学习笔记---蜂鸣器播放提示音音乐(天空之城)

目录 蜂鸣器播放提示音 蜂鸣器播放音乐&#xff08;天空之城&#xff09; 准备工作 主程序 中断函数 上一节讲了蜂鸣器驱动原理和乐理基础知识&#xff0c;这一节开始代码演示&#xff01; 蜂鸣器播放提示音 先创建工程&#xff1a;蜂鸣器播放提示音 把我们之前模块化的…