用PyTorch从零开始编写DeepSeek-V2

DeepSeek-V2是一个强大的开源混合专家(MoE)语言模型,通过创新的Transformer架构实现了经济高效的训练和推理。该模型总共拥有2360亿参数,其中每个令牌激活21亿参数,支持最大128K令牌的上下文长度。

在开源模型中,DeepSeek-V2实现了顶级性能,成为最强大的开源MoE语言模型。在MMLU(多模态机器学习)上,DeepSeek-V2以较少的激活参数实现了顶尖的性能。与DeepSeek 67B相比,DeepSeek-V2显著提升了性能,降低了42.5%的训练成本,减少了93.3%的KV缓存,并将最大生成吞吐量提高了5.76倍。

我们这里主要实现DeepSeek的主要改进:多头隐性注意力、细粒度专家分割和共享的专家隔离

架构细节

DeepSeek-V2整合了两种创新架构,我们将详细讨论:

  1. 用于前馈网络(FFNs)的DeepSeekMoE架构。
  2. 用于注意力机制的多头隐性注意力(MLA)。

DeepSeekMoE

在标准的MoE架构中,每个令牌被分配给一个(或两个)专家,每个MoE层都有多个在结构上与标准前馈网络(FFN)相同的专家。这种设置带来了两个问题:指定给令牌的专家将试图在其参数中聚集不同类型的知识,但这些知识很难同时利用;其次,被分配给不同专家的令牌可能需要共同的知识,导致多个专家在各自的参数中趋向于收敛,获取共享知识。

为了应对这两个问题,DeepSeekMoE引入了两种策略来增强专家的专业化:

  1. 细粒度专家分割:为了在每个专家中更有针对性地获取知识,通过切分FFN中的中间隐藏维度,将所有专家分割成更细的粒度。
  2. 共享专家隔离:隔离某些专家作为始终被激活的共享专家,旨在捕获不同上下文中的共同知识,并通过将共同知识压缩到这些共享专家中,减少其他路由专家之间的冗余。

让我们来定义DeepSeekMoE中第t个令牌的专家分配。如果u_t是该令牌的FFN输入,其输出h`_t将会是:

其中𝑁𝑠和𝑁𝑟分别是共享专家和路由专家的数量;FFN(𝑠)*𝑖和FFN(𝑟)*𝑖分别表示𝑖-th共享专家和𝑖-th路由专家。

对于路由专家而言,g_i,t 是第i个路由专家的门控值,s_i,t 是令牌到专家的亲和分数,Topk(., Kr) 包含了Kr个最高的亲和分数,其中Kr是活跃的路由专家的数量。

有了以上的公式,我们就来使用代码实现

门控模型实现:

 classMoEGate(torch.nn.Module):def__init__(self, num_experts_per_tok: int, n_routed_experts: int, routed_scaling_factor: int, topk_method: str, n_group: int, topk_group: int, hidden_size: int):super().__init__()self.top_k=num_experts_per_tokself.n_routed_experts=n_routed_expertsself.routed_scaling_factor=routed_scaling_factorself.topk_method=topk_methodself.n_group=n_groupself.topk_group=topk_groupself.weight=torch.nn.Parameter(torch.empty((self.n_routed_experts, hidden_size)))torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))defforward(self, x: torch.Tensor):batch, seq_len, h=x.shapehidden_states=x.view(-1, h)logits=torch.nn.functional.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32), None)scores=logits.softmax(dim=-1, dtype=torch.float32)ifself.topk_method=="greedy":topk_weight, topk_idx=torch.topk(scores, k=self.top_k, dim=-1, sorted=False)elifself.topk_method=="group_limited_greedy":group_scores= (scores.view(batch*seq_len, self.n_group, -1).max(dim=-1).values)group_idx=torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]  # [n, top_k_group]group_mask=torch.zeros_like(group_scores)  # [n, n_group]group_mask.scatter_(1, group_idx, 1)  # [n, n_group]score_mask= (group_mask.unsqueeze(-1).expand(batch*seq_len, self.n_group, self.n_routed_experts//self.n_group).reshape(batch*seq_len, -1))  # [n, e]tmp_scores=scores.masked_fill(~score_mask.bool(), 0.0)  # [n, e]topk_weight, topk_idx=torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)returntopk_idx, topk_weight

MoE

 classMoE(torch.nn.Module):def__init__(self, dim: int, routed_scaling_factor: int, topk_method: str, n_group: int, topk_group: int, hidden_dim: int|None=None, n_routed_experts: int=12, num_experts_per_tok: int=4, n_shared_experts: int=2, mlp: str="swiglu"):super().__init__()self.experts_per_rank=n_routed_expertsself.num_experts_per_tok=num_experts_per_tokself.n_shared_experts=n_shared_expertsmlp_block=SwiGLUself.experts=torch.nn.ModuleList([mlp_block(dim, hidden_dim) foriinrange(n_routed_experts)])self.gate=MoEGate(num_experts_per_tok, n_routed_experts, routed_scaling_factor, topk_method, n_group, topk_group, dim)self.shared_experts=mlp_block(dim, hidden_dim*n_shared_experts)defforward(self, x: torch.Tensor):identity=xorig_shape=x.shapetopk_idx, topk_weight=self.gate(x)x=x.view(-1, x.shape[-1])flat_topk_idx=topk_idx.view(-1)x=x.repeat_interleave(self.num_experts_per_tok, dim=0)y=torch.empty_like(x)y=y.type(x.dtype)fori, expertinenumerate(self.experts):y[flat_topk_idx==i] =expert(x[flat_topk_idx==i]).to(dtype=x.dtype)y= (y.view(*topk_weight.shape, -1) *topk_weight.unsqueeze(-1)).sum(dim=1)y=y.view(*orig_shape)output=y+self.shared_experts(identity)returnoutput

多头隐性注意力(MLA)

多头隐性注意力(MLA)相较于标准的多头注意力(MHA)实现了更优的性能,并且显著减少了KV缓存,提高了推理效率。与多查询注意力(MQA)和分组查询注意力(GQA)中减少KV头的方法不同,MLA将键(Key)和值(Value)共同压缩成一个潜在向量。

MLA不是缓存键(Key)和值(Value)矩阵,而是将它们联合压缩成一个低秩向量,这使得缓存的项目数量更少,因为压缩维度远小于多头注意力(MHA)中输出投影矩阵的维度。

标准的RoPE(旋转位置嵌入)与上述的低秩KV压缩不兼容。解耦RoPE策略使用额外的多头查询q_t和共享键k_t来实现RoPE。

下面总结了完整的MLA计算过程:

MLA实现

 classMLA(torch.nn.Module):def__init__(self, model_args: DeepseekConfig):super().__init__()d_model=model_args.d_modelself.num_heads=model_args.num_headsself.head_dim=model_args.d_model//model_args.num_headsself.attn_dropout=torch.nn.Dropout(model_args.dropout)self.res_dropout=torch.nn.Dropout(model_args.dropout)self.flash_attn=hasattr(torch.nn.functional, "scaled_dot_product_attention")self.q_lora_rank=model_args.q_lora_rankself.qk_rope_head_dim=model_args.qk_rope_head_dimself.kv_lora_rank=model_args.kv_lora_rankself.v_head_dim=model_args.v_head_dimself.qk_nope_head_dim=model_args.qk_nope_head_dimself.q_head_dim=model_args.qk_nope_head_dim+model_args.qk_rope_head_dimself.q_a_proj=torch.nn.Linear(d_model, model_args.q_lora_rank, bias=False)self.q_a_layernorm=RMSNorm(model_args.q_lora_rank)self.q_b_proj=torch.nn.Linear(model_args.q_lora_rank, self.num_heads*self.q_head_dim, bias=False)self.kv_a_proj_with_mqa=torch.nn.Linear(d_model,model_args.kv_lora_rank+model_args.qk_rope_head_dim,bias=False,)self.kv_a_layernorm=RMSNorm(model_args.kv_lora_rank)self.kv_b_proj=torch.nn.Linear(model_args.kv_lora_rank,self.num_heads* (self.q_head_dim-self.qk_rope_head_dim+self.v_head_dim),bias=False,)self.o_proj=torch.nn.Linear(self.num_heads*self.v_head_dim,d_model, bias=False,)defforward(self, x: torch.Tensor, mask: torch.Tensor, freqs_cis) ->torch.Tensor:batch, seq_len, d_model=x.shapeq=self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))q=q.view(batch, seq_len, self.num_heads, self.q_head_dim).transpose(1, 2)q_nope, q_pe=torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)compressed_kv=self.kv_a_proj_with_mqa(x)compressed_kv, k_pe=torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)k_pe=k_pe.view(batch, seq_len, 1, self.qk_rope_head_dim).transpose(1, 2)kv= (self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view(batch, seq_len, self.num_heads, self.qk_nope_head_dim+self.v_head_dim).transpose(1, 2))k_nope, value_states=torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)q_pe, k_pe=apply_rope(q_pe, k_pe, freqs_cis)k_pe=k_pe.transpose(2, 1)q_pe=q_pe.transpose(2, 1)query_states=k_pe.new_empty(batch, self.num_heads, seq_len, self.q_head_dim)query_states[:, :, :, : self.qk_nope_head_dim] =q_nopequery_states[:, :, :, self.qk_nope_head_dim :] =q_pekey_states=k_pe.new_empty(batch, self.num_heads, seq_len, self.q_head_dim)key_states[:, :, :, : self.qk_nope_head_dim] =k_nopekey_states[:, :, :, self.qk_nope_head_dim :] =k_peattn_mtx=torch.matmul(query_states, key_states.transpose(2, 3)) /math.sqrt(self.head_dim)attn_mtx=attn_mtx+mask[:, :, :seq_len, :seq_len]attn_mtx=torch.nn.functional.softmax(attn_mtx.float(), dim=-1).type_as(key_states)attn_mtx=self.attn_dropout(attn_mtx)output=torch.matmul(attn_mtx, value_states)  # (batch, n_head, seq_len, head_dim)output=output.transpose(1, 2).contiguous().view(batch, seq_len, self.num_heads*self.v_head_dim)output=self.o_proj(output)output=self.res_dropout(output)returnoutput

总结

本文详细介绍了DeepSeek-V2语言模型,这是一个强大的开源混合专家(MoE)语言模型,采用创新的架构来提高训练和推理的经济性和效率。DeepSeek-V2采用了两种核心技术:细粒度专家分割和共享专家隔离,这两种策略显著提高了专家的专业化水平。此外,文章还介绍了多头隐性注意力(MLA),这是一种改进的注意力机制,通过低秩键值联合压缩和解耦旋转位置嵌入,优化了模型的存储和计算效率。

除了理论探讨,我们通过编写代码实现DeepSeek-V2,可以更深入地理解其架构和工作原理。可以帮助你账务如何实现先进的混合专家(MoE)模型,还能深化对多头隐性注意力(MLA)和低秩键值压缩等关键技术的理解。通过实践,读者将能够验证理论的有效性,并对模型的性能和效率有直观的认识。

https://avoid.overfit.cn/post/317a967c8dac42ee98f96d8390851476

作者:Zain ul Abideen

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3270272.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

Godot入门 02玩家1.0版

添加Node2D节点,重命名Game 创建玩家场景,添加CharacterBody2D节点 添加AnimatedSprite2D节点 从精灵表中添加帧 选择文件 设置成8*8 图片边缘模糊改为清晰 设置加载后自动播放,动画循环 。动画速度10FPS,修改动画名称idle。 拖动…

数据结构之探索“堆”的奥秘

找往期文章包括但不限于本期文章中不懂的知识点: 个人主页:我要学编程(ಥ_ಥ)-CSDN博客 所属专栏:数据结构(Java版) 目录 堆的概念 堆的创建 时间复杂度分析: 堆的插入与删除 优先级队列 PriorityQ…

学习大数据DAY23 Linux基本指令4与ngnix安装以及Shell,python编写环境配置

目录 其他扩展类 echo 输出字符串 date 显示当前日期 (用于日期转字符串) date -d 日期解析(用于字符串转日期) date 设置日期 linux 网络对时 cal 查看日历 wget 命令 seq 命令 Linux 定时执行计划 特殊符号说明 linux 添加硬盘分区挂载 上…

【QT】QT 系统相关(事件、文件、多线程、网络、音视频)

一、Qt 事件 1、事件介绍 事件是应用程序内部或者外部产生的事情或者动作的统称。在 Qt 中使用一个对象来表示一个事件。所有的 Qt 事件均继承于抽象类 QEvent。事件是由系统或者 Qt 平台本身在不同的时刻发出的。当用户按下鼠标、敲下键盘,或者是窗口需要重新绘制…

初阶数据结构完结 图解所有初阶数据结构 顺序表

1数据结构 1.线性表 线性表(linear list)是n个具有相同特性的数据元素的有限序列。 线性表是⼀种在实际中⼴泛使 ⽤的 数据结构,常⻅的线性表:顺序表、链表、栈、队列、字符串… 线性表在逻辑上是线性结构,也就说是连…

Centos7_Minimal安装Cannot find a valid baseurl for repo: base/7/x86_6

问题 运行yum报此问题 就是没网 解决方法 修改网络信息配置文件,打开配置文件,输入命令: vi /etc/sysconfig/network-scripts/ifcfg-网卡名字把ONBOOTno,改为ONBOOTyes 重启网卡 /etc/init.d/network restart 网路通了

SSRF中伪协议学习

SSRF常用的伪协议 file:// 从文件系统中获取文件内容,如file:///etc/passwd dict:// 字典服务协议,访问字典资源,如 dict:///ip:6739/info: ftp:// 可用于网络端口扫描 sftp:// SSH文件传输协议或安全文件传输协议 ldap://轻量级目录访问协议 tftp:// 简单文件传输协议 gopher…

Python | TypeError: ‘float’ object is not subscriptable

Python | TypeError: ‘float’ object is not subscriptable 在Python编程中,遇到“TypeError: ‘float’ object is not subscriptable”这一错误通常意味着你尝试对浮点数(float)使用了下标访问(如数组或列表那样的访问方式&a…

Typecho仿百度响应式主题Xaink源码

新闻类型博客主题,简洁好看,适合资讯类、快讯类、新闻类博客建站,响应式设计,支持明亮和黑暗模式 直接下载 zip 源码->解压后移动到 Typecho 主题目录->改名为xaink->启用。 源码下载:https://download.csdn…

【秋招笔试题】小Q的树

解析&#xff1a;分析易得走过的路中至多存在一个分叉&#xff0c;则维护每个结点接下来的路的最大值与次大值然后相加即可。 #include <iostream> #include <vector> #include <algorithm> using namespace std; #define int long long const int MAXN 1…

09 算术运算符

① 运算符除了用于算数加法以外&#xff0c;还可以用于列表、元组、字符串的连接&#xff0c;但不支持不同类型的对象之间的相加或连接。 print([1, 2, 3] [4, 5, 6]) # 连接两个列表 print((1, 2, 3) (4,)) # 连接两个元组 print(hello 123) # 连接字符串 print(Fa…

c语言第四天笔记

关于 混合操作&#xff0c;不同计算结果推理 第一种编译结果&#xff1a; int i 5; int sum (i) (i) 6 7 13 第二种编译结果&#xff1a; int i 5; int sum (i) (i) 6 7 7 7 前面的7是因为后面i的变化被影响后&#xff0c;重新赋值 14 第一种编译结果&#xff…

【Linux网络】应用层协议:HTTP 与 HTTPS

本篇博客整理了 TCP/IP 分层模型中应用层的 HTTP 协议和 HTTPS协议&#xff0c;旨在让读者更加深入理解网络协议栈的设计和网络编程。 目录 一、协议是什么 1&#xff09;结构化数据的传输 2&#xff09;序列化和反序列化 补&#xff09;网络版计算器 .1- 协议定制 .2- …

OpenAI推出SearchGPT:革新搜索体验的新工具

引言 原文链接 在信息爆炸的时代&#xff0c;搜索引擎已经成为人们日常生活中不可或缺的工具。然而&#xff0c;传统的搜索引擎在理解复杂查询和提供准确答案方面仍有许多不足。为了解决这一问题&#xff0c;OpenAI与20240725推出了SearchGPT原型&#xff0c;将生成式AI与传统…

【Golang 面试基础题】每日 5 题(九)

✍个人博客&#xff1a;Pandaconda-CSDN博客 &#x1f4e3;专栏地址&#xff1a;http://t.csdnimg.cn/UWz06 &#x1f4da;专栏简介&#xff1a;在这个专栏中&#xff0c;我将会分享 Golang 面试中常见的面试题给大家~ ❤️如果有收获的话&#xff0c;欢迎点赞&#x1f44d;收藏…

【Android】Fragment与Activity间通信知识总结

文章目录 一、Activity向Fragment通信1.1 通过方法1.1.1 构造方法1.1.1 普通public方法 1.2 通过setArguments方法1.3 通过接口 二、Fragment向Activity通信2.1 通过getActivity2.2 通过接口 三、Fragment之间传递数据通过Activity中转 一、Activity向Fragment通信 1.1 通过方…

聊聊基于Alink库的主成分分析(PCA)

概述 主成分分析&#xff08;Principal Component Analysis&#xff0c;PCA&#xff09;是一种常用的数据降维和特征提取技术&#xff0c;用于将高维数据转换为低维的特征空间。其目标是通过线性变换将原始特征转化为一组新的互相无关的变量&#xff0c;这些新变量称为主成分&…

关于链表、顺序表、栈和队列的一些总结

关于链表、顺序表、栈和堆的一些总结 1.顺序表2.链表2.1 单向链表2.1 带哨兵位双向循环链表 3.栈4.队列 1.顺序表 2.链表 2.1 单向链表 2.1 带哨兵位双向循环链表 3.栈 4.队列

【Matlab】绘图时使用字母控制线型和颜色(内含多图对比示例)

概要 测试了英文字母a-z不同输入下线条的颜色和线型&#xff0c;供参考选择。 语法 plot(x, y, 颜色); 如 plot(x, y, b); 测试 以下测试设置线宽为1.5&#xff0c;代码 x 0: 0.01: 2*pi; y sin(x); plot(x, y, b, LineWidth, 1.5);修改时把 b 改成不同字母即可 ‘a’…

基于关联规则的分类算法(CBA) | 项集、频繁项集、关联规则 | arulesCBA库

基于关联规则的分类算法 目前使用较多且较为简洁的关联规则分类算法是基于关联规则的分类算法&#xff08;Classification Based on Association, CBA&#xff09;&#xff0c;下面将从该算法的相关概念开始介绍。 这部分笔记参考论文&#xff1a;孙菡悦.基于多因素交互效应的…