Gemma模型论文详解(附源码)

原文链接:Gemma模型论文详解(附源码)

1. 背景介绍

Gemma模型是在2023.2.21号Google新发布的大语言模型, Gemma复用了Gemini相同的技术(Gemini也是Google发布的多模态模型),Gemma这次发布了了2B和7B两个版本的参数,不仅提供了预训练的checkpoints,还提供了用于对话、指令跟随等fine-tune的checkpoints。在QA问答、常识。在11

在这里插入图片描述

2. 模型介绍

2.1 模型结构

Gemma模型使用了transformer decoder结构进行训练,训练的上下文大小为8192个token,模型参数如下:
在这里插入图片描述

相比原始transformer结构的区别:

  • Multi-Query Attention:7B模型使用了multi-head attention,2B模型使用了multi-query attention (with 𝑛𝑢𝑚_𝑘𝑣_ℎ𝑒𝑎𝑑𝑠 = 1)。对比llama2中用了group-query attention
    在这里插入图片描述

  • RoPE Embeddings: 不使用绝对位置编码,在每一层前加下RoPE Embedding,同时共享输入与输出层的embedding权重。

  • GeGLU Activations: ReLU的激活替换为GeGLU的激活。对比llama中用了swiglu。

  • Normalizer Location: 在transformer的每一层layer的前后都进行规一化,这里使用RMSNorm做为规一化层。

2.2 训练搭建

Gemma使用TPUv5e进行训练;一个pod中有256块TPUv5e芯片,256块芯片被设计为16X16的2D拓扑;Gemma-7B使用16个pods(4096块卡)进行训练,Gemma-2B使用2个pods(512块卡)。7B模型在一个pod内使用16路模型并行和16路数据并行,2B模型在一个pod内使用256路数据并行。优化器状态使用ZeRO-3进行切分,减少显存占用。在pod外使用类似Pathways的方式减少数据复制的成本。

和Gemini模型训练一样,综合了Jax和Pathways的单控制器single controller编程范式,使用单个python进程编排整个训练; 使用GSPMD partitioner用于训练step的计算,使用XLA compiler减少中间结果的大小。

2.3 训练数据

Gemma 2B和7B分别基于2T和6T个token进行训练,token来源于纯英文的文本,内容包括网页、数学、代码等。使用SentencePiece的tokenizer,字典大小有256K个token。数据过滤使用基于模型的分类器去除有害的、低质量的内容。最后采用类似Gemini的方式进行训练数据的混合,提升高质量数据的占比。

2.4 指令微调(Instruction Tuning)

2B和7B进行有监督微调(SFT)训练中使用混合生成数据和人工标注的prompt文本对,同时进行RLHF训练。在SFT阶段,基于给定的一个prompt,通过测试模型生成多个响应的回答结果,通过一个更大更好的模型进行结果的好坏判断。基于不同的侧重方向(指令跟随/事实/创造性/安全等)构建不同的prompt。使用多种基于LM的自动判断方法,比如chain-of-thought prompting

训练和推理过程中使用相同的数据格式,格式的设计重点在于两点,一个是确定多轮对话中的角色,一个是确定一轮对话的开始结束。对应格式标记和示例的训练数据如下:

在这里插入图片描述
在这里插入图片描述

3. 源码

  • Tensorflow实现的源码在github google-deepmind/gemma中,PyTorch实现的源码在github google/gemma_pytorch。

  • 模型的配置在gemma/config.py文件中, 7B与2B区别主要在于num_hidden_layers/num_attention_heads/num_key_value_heads/hidden_size/intermediate_size

@dataclasses.dataclass
class GemmaConfig:# The number of tokens in the vocabulary.vocab_size: int = 256000# The maximum sequence length that this model might ever be used with.max_position_embeddings: int = 8192# The number of blocks in the model.num_hidden_layers: int = 28# The number of attention heads used in the attention layers of the model.num_attention_heads: int = 16# The number of key-value heads for implementing attention.num_key_value_heads: int = 16# The hidden size of the model.hidden_size: int = 3072# The dimension of the MLP representations.intermediate_size: int = 24576# The number of head dimensions.head_dim: int = 256# The epsilon used by the rms normalization layers.rms_norm_eps: float = 1e-6# The dtype of the weights.dtype: str = 'bfloat16'# Whether a quantized version of the model is used.quant: bool = False# The path to the model tokenizer.tokenizer: Optional[str] = 'tokenizer/tokenizer.model'def get_dtype(self) -> Optional[torch.dtype]:"""Gets the torch dtype from the config dtype string."""return _STR_DTYPE_TO_TORCH_DTYPE.get(self.dtype, None)def get_config_for_7b() -> GemmaConfig:return GemmaConfig()def get_config_for_2b() -> GemmaConfig:return GemmaConfig(num_hidden_layers=18,num_attention_heads=8,num_key_value_heads=1,hidden_size=2048,intermediate_size=16384)
  • 模型定义在gemma/model.py文件中,GemmaDecoderLayer的定义如下:
class GemmaDecoderLayer(nn.Module):def __init__(self,config: gemma_config.GemmaConfig,):super().__init__()self.self_attn = GemmaAttention(hidden_size=config.hidden_size,num_heads=config.num_attention_heads,num_kv_heads=config.num_key_value_heads,head_dim=config.head_dim,quant=config.quant,)self.mlp = GemmaMLP(hidden_size=config.hidden_size,intermediate_size=config.intermediate_size,quant=config.quant,)self.input_layernorm = RMSNorm(config.hidden_size,eps=config.rms_norm_eps)self.post_attention_layernorm = RMSNorm(config.hidden_size,eps=config.rms_norm_eps)
  • GeGLU的实现跟llama的swiglu不同,geglu相比glu区是采用了gelu的激活,以下是glu的计算示例图:
    在这里插入图片描述

代码参考如下,代码中self.gate_proj对应上图中的B矩阵,gate相当于 σ ( B ) \sigma(B) σ(B)self.up_proj对应上图中的A矩阵.

class GemmaMLP(nn.Module):def __init__(self,hidden_size: int,intermediate_size: int,quant: bool,):super().__init__()self.gate_proj = Linear(hidden_size, intermediate_size, quant)self.up_proj = Linear(hidden_size, intermediate_size, quant)self.down_proj = Linear(intermediate_size, hidden_size, quant)def forward(self, x):gate = self.gate_proj(x)gate = F.gelu(gate)up = self.up_proj(x)fuse = gate * upoutputs = self.down_proj(fuse)return outputs

4. 参考

  • google-deepmind/gemma
  • Gemma 开放模型
  • Gemma: Open Models Based on Gemini Research and Technology
  • gemma-open-models
  • github google/gemma_pytorch
  • github google-deepmind/gemma
  • Grouped Query Attention论文阅读
  • SwiGLU论文阅读

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2803944.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

嵌入式Linux中apt、apt-get命令用法汇总

在Linux环境开发过程中接触ubuntu虚拟机时,在安装软件或者更新软件时apt和apt-get命令使用相对较频繁,下面对这两个命令的用法进行汇总。 apt(Advanced Package Tool)和 apt-get 是用于在基于 Debian 的 Linux 发行版中进行软件包…

什么是favicon.ico图标?如何在线生成ICO图标?如何安装favicon.ico图标?

在本站首页的活跃博客中经常看到有部分博客网站没有 favicon.ico 图标,所以今天打算普及一下相关知识,希望还没有 favicon.ico 图标的博主们,能够制作出自己独特的图标。 那么到底什么是favicon.ico? 好搜百科给出的解释&#xf…

electron学习和新建窗口

首先我们要先下载electron npm install --save-dev electron 建立入口文件main.js 新建一个入口文件 main.js,然后导入eletron新建一个窗口。 const { app, BrowserWindow, ipcMain } require("electron"); const path require("path");func…

Nginx 反向代理配置

Nginx就不废话了,web服务器。 最近在备案一个域名,想要备案,部署一个服务器,平常很少自己配置Nginx,今天记录下。 1、反向代理 正向代理 指 客户端通过代理访问后端服务 反向代理 指 服务器推出一个客户&#xff0…

6.网络游戏逆向分析与漏洞攻防-游戏网络架构逆向分析-通过逆向分析确定游戏明文发送数据过程

内容参考于:易道云信息技术研究院VIP课 上一个内容:测试需求与需求拆解 在开始之前要了解一个小知识,在逆向开始之前要很清楚知道要找的东西是什么,大概长什么样子,只有这样才能看到它第一眼发现它,现在我…

Unable to make field private JavacProcessingEnvironment$DiscoveredPro报错解决办法

maven项目打包报错 报错信息 Unable to make field private com.sun.tools.javac.processing.JavacProcessingEnvironment$DiscoveredProcessors com.sun.tools.javac.processing.JavacProcessingEnvironment.discoveredProcs accessible: module jdk.compiler does not &q…

[论文精读]Do Transformers Really Perform Bad for Graph Representation?

论文网址:[2106.05234] Do Transformers Really Perform Bad for Graph Representation? (arxiv.org) 论文代码:https://github.com/Microsoft/Graphormer 英文是纯手打的!论文原文的summarizing and paraphrasing。可能会出现难以避免的拼…

springmvc+ssm+springboot房屋中介服务平台的设计与实现 i174z

本论文拟采用计算机技术设计并开发的房屋中介服务平台,主要是为用户提供服务。使得用户可以在系统上查看房屋出租、房屋出售、房屋求购、房屋求租,管理员对信息进行统一管理,与此同时可以筛选出符合的信息,给笔者提供更符合实际的…

PYQT5-自定义事件

from PyQt5.QtCore import QEvent, QObject from PyQt5.QtWidgets import QApplication import sys# 自定义事件类 class CustomEvent(QEvent):# PYQT5 预留给用户自定义事件类型的起点为 QEvent.User1000custom_event_type QEvent.registerEventType()# 也可以这样写# custom…

Python 实现 ADTM 指标计算:股票技术分析的利器系列(9)

Python 实现 ADTM 指标计算:股票技术分析的利器系列(9) 介绍算法解释 核心代码rolling函数介绍计算 DTMnp.where 使用介绍np.maximum 计算 DBM计算 STM计算 SBM计算 ADTM 完整代码 介绍 ADTM(动态买卖气指标)是一种用…

高级语言期末2012级B卷

1.编写函数&#xff0c;输出任意正整数n的位数&#xff08;n默认为存储十进制的整形变量&#xff09; 例如&#xff1a;正整数13&#xff0c;则输出2,&#xff1b;正整数3088&#xff0c;则输出4 #include <stdio.h>int func(int n) {int count0;while(n>0) {n/10;co…

【Redis服务搭建】

目录 Redis的修改配置启动以及参数调优Redis的常用基本操作Redis运维监控命令Redis的配置的动态更新和写入Redis的多用户管理Redis的慢日志Redis禁用危险命令和压测工具Redis持久化存储1.Redis的RDB持久化存储2.Redis的AOF持久化存储 Redis的主从复制redis的哨兵实现主从自动切…

捕捉消费新趋势,脉纷纷让生活更便捷

随着科技的飞速发展和消费者需求的不断升级,消费市场呈现出前所未有的新趋势。在这个变革的时代背景下,脉纷纷凭借其敏锐的市场洞察力和创新精神,致力于捕捉消费新趋势,为消费者带来更加便捷的生活体验。 脉纷纷深知消费者对于便捷生活的渴望。因此,它紧密关注市场动态,通过大数…

NLP_构建GPT模型并完成文本生成任务

文章目录 搭建GPT模型&#xff08;解码器&#xff09;构建文本生成任务的数据集训练过程中的自回归文本生成中的自回归&#xff08;贪婪搜索&#xff09;完整代码小结 搭建GPT模型&#xff08;解码器&#xff09; GPT 只使用了 Transformer的解码器部分&#xff0c;其关键组件…

【Java程序设计】【C00291】基于Springboot的网上图书商城(有论文)

基于Springboot的网上图书商城&#xff08;有论文&#xff09; 项目简介项目获取开发环境项目技术运行截图 项目简介 这是一个基于Springboot的网上图书商城 本系统分为系统功能模块、管理员功能模块以及卖家功能模块。 系统功能模块&#xff1a;在系统首页可以查看首页、图书…

基于事件触发机制的孤岛微电网二次电压与频率协同控制MATLAB仿真模型

微❤关注“电气仔推送”获得资料&#xff08;专享优惠&#xff09; 本模型质量非常高&#xff0c;运行效果完美。本模型为4机并联孤岛系统&#xff0c;在下垂控制的基础上加入二次控制&#xff0c;二次电压与频率协同控制策略利用事件触发的方法来减少控制器的更新次数。该方法…

Python操作Kafka基础教程

01 Python操作Kafka基础教程 创建ZooKeeper容器 docker run -d --name zookeeper -p 2181:2181 -v /etc/localtime:/etc/localtime wurstmeister/zookeeper创建Kafka容器 语法是&#xff1a; docker run -d --name kafka -p 9092:9092 -e KAFKA_BROKER_ID0 -e KAFKA_ZOOKE…

观察者模式, 发布-订阅模式, 监听器模式

观察者模式, 发布-订阅模式, 监听器模式 观察者模式 观察者模式是一种行为型设计模式, 定义对象间的一种一对多的依赖关系&#xff0c;当一个对象的状态发生改变时&#xff0c;所有依赖于它的对象都得到通知并被自动更新 角色模型和结构图 在观察者模式中&#xff0c;只有两种…

码农永远高薪吃香的3项特质

最近看到Google在裁员滚滚&#xff0c;再次对CS就业环境有了清醒认知。之前听程序员担忧裁员&#xff0c;还以为他杞人忧天。然而&#xff0c;现实就是如此寒冷彻骨啊&#xff01; 当然&#xff0c;有些具备不可替代性的码农&#xff0c;永远吃香。总结发现有以下几点特质&…

vue Threejs实现任意画线(鼠标点击画线)

Threejs实现任意画线(鼠标点击画线) 鼠标左键单击添加点鼠标右键回退到上一个点,并继续画按住shift可以画平行于x轴或平行于z轴的线按Esc完成画线