大模型中的采样(Sampling)选择机制详解

Sampling

大模型中的采样选择机制详解

在自然语言处理(NLP)和生成模型(如GPT)中,采样选择机制是一种从模型的概率分布中选择词的方法,用于控制生成文本的多样性和质量。本文将详细介绍几种经典的采样选择机制,包括随机采样、Top-k采样、Top-p(Nucleus)采样、温度采样、束搜索(Beam Search)和逆向温度采样,并配以公式和代码示例。

一、采样选择机制概述

采样选择机制通过不同的方法从模型的输出概率分布中选择下一个生成的词,从而影响生成文本的特性和质量。

二、经典采样选择机制

1. 随机采样

随机采样是最简单的一种方法,直接从模型输出的概率分布中随机选择下一个词。它保留了概率分布的多样性,但可能生成不连贯的文本。

公式描述

给定词汇表 V V V和概率分布 P t P_t Pt,在时间步 t t t时,根据概率分布直接进行采样:

w t ∼ P t ( w ) w_t \sim P_t(w) wtPt(w)

代码示例

import torch
import torch.nn.functional as Fdef random_sampling(logits):"""随机采样:param logits: 模型输出的logits:return: 采样得到的下一个词的索引"""# 计算概率分布probs = F.softmax(logits, dim=-1)# 根据概率分布进行采样next_token = torch.multinomial(probs, 1)return next_token.item()# 示例logits
logits = torch.tensor([2.5, 1.2, 0.3, 3.7, 0.8])# 执行随机采样
next_token_index = random_sampling(logits)
print("随机采样得到的下一个词索引:", next_token_index)
2. Top-k采样

Top-k采样通过选择概率最高的k个词,截断概率分布以限制候选集,随后从中采样。这种方法可以减少生成不合理词的概率。

公式描述

  1. 对概率分布 P t P_t Pt进行排序,得到前k个最高概率的词 w i 1 , w i 2 , . . . , w i k w_{i_1}, w_{i_2}, ..., w_{i_k} wi1,wi2,...,wik及其对应的概率 P t ( w i 1 ) , P t ( w i 2 ) , . . . , P t ( w i k ) P_t(w_{i_1}), P_t(w_{i_2}), ..., P_t(w_{i_k}) Pt(wi1),Pt(wi2),...,Pt(wik)
  2. 截断并重新归一化:
    P t ^ ( w i j ) = P t ( w i j ) ∑ j = 1 k P t ( w i j ) \hat{P_t}(w_{i_j}) = \frac{P_t(w_{i_j})}{\sum_{j=1}^{k} P_t(w_{i_j})} Pt^(wij)=j=1kPt(wij)Pt(wij)
  3. 根据重新归一化后的概率分布进行采样。

代码示例

import torch
import torch.nn.functional as Fdef top_k_sampling(logits, k):"""根据给定的logits进行Top-k采样:param logits: 模型输出的logits:param k: Top-k值:return: 采样得到的下一个词的索引"""# 对logits进行排序并截断top_k_logits, top_k_indices = torch.topk(logits, k)# 重新归一化概率top_k_probs = F.softmax(top_k_logits, dim=-1)# 根据概率分布进行采样next_token = torch.multinomial(top_k_probs, 1)# 获取对应的词汇索引next_token_index = top_k_indices[next_token]return next_token_index.item()# 示例logits
logits = torch.tensor([2.5, 1.2, 0.3, 3.7, 0.8])# 执行Top-k采样
next_token_index = top_k_sampling(logits, k=3)
print("Top-k采样得到的下一个词索引:", next_token_index)
3. Top-p(Nucleus)采样

Top-p(Nucleus)采样通过选择累计概率达到某个阈值p的最小词集,动态调整候选集的大小,从而在控制多样性和质量之间取得平衡。

公式描述

  1. 对概率分布 P t P_t Pt进行排序,得到排序后的词集合 w 1 , w 2 , . . . , w V w_1, w_2, ..., w_V w1,w2,...,wV及其对应的概率 P t ( w 1 ) , P t ( w 2 ) , . . . , P t ( w V ) P_t(w_1), P_t(w_2), ..., P_t(w_V) Pt(w1),Pt(w2),...,Pt(wV)
  2. 选择最小的词集合使得累计概率达到阈值p:
    ∑ i = 1 m P t ( w i ) ≥ p \sum_{i=1}^{m} P_t(w_i) \geq p i=1mPt(wi)p
  3. 截断并重新归一化选择的词集合的概率。
  4. 根据重新归一化后的概率分布进行采样。

代码示例

import torch
import torch.nn.functional as Fdef top_p_sampling(logits, p):"""根据给定的logits进行Top-p采样:param logits: 模型输出的logits:param p: Top-p值:return: 采样得到的下一个词的索引"""# 计算概率分布并排序sorted_logits, sorted_indices = torch.sort(logits, descending=True)sorted_probs = F.softmax(sorted_logits, dim=-1)# 计算累计概率cumulative_probs = torch.cumsum(sorted_probs, dim=-1)# 找到累计概率大于p的最小索引cutoff_index = torch.where(cumulative_probs >= p)[0][0]# 截断并重新归一化top_p_probs = sorted_probs[:cutoff_index + 1]top_p_indices = sorted_indices[:cutoff_index + 1]top_p_probs /= top_p_probs.sum()# 根据概率分布进行采样next_token = torch.multinomial(top_p_probs, 1)# 获取对应的词汇索引next_token_index = top_p_indices[next_token]return next_token_index.item()# 示例logits
logits = torch.tensor([2.5, 1.2, 0.3, 3.7, 0.8])# 执行Top-p采样
next_token_index = top_p_sampling(logits, p=0.8)
print("Top-p采样得到的下一个词索引:", next_token_index)
4. 温度采样

温度采样通过调整概率分布的“温度”参数来控制生成文本的多样性。温度越高,生成的文本越多样化;温度越低,生成的文本越确定性。

公式描述

给定词汇表 V V V和概率分布 P t P_t Pt,在时间步 t t t时,通过调整温度参数 τ \tau τ得到新的概率分布:

P t ( w i ) = exp ⁡ ( l o g i t s ( w i ) τ ) ∑ j = 1 V exp ⁡ ( l o g i t s ( w j ) τ ) P_t(w_i) = \frac{\exp(\frac{logits(w_i)}{\tau})}{\sum_{j=1}^{V} \exp(\frac{logits(w_j)}{\tau})} Pt(wi)=j=1Vexp(τlogits(wj))exp(τlogits(wi))

其中, τ \tau τ为温度参数。

代码示例

import torch
import torch.nn.functional as Fdef temperature_sampling(logits, temperature=1.0):"""温度采样:param logits: 模型输出的logits:param temperature: 温度参数:return: 采样得到的下一个词的索引"""# 调整logits的温度adjusted_logits = logits / temperature# 计算概率分布probs = F.softmax(adjusted_logits, dim=-1)# 根据概率分布进行采样next_token = torch.multinomial(probs, 1)return next_token.item()# 示例logits
logits = torch.tensor([2.5, 1.2, 0.3, 3.7, 0.8])# 执行温度采样
next_token_index = temperature_sampling(logits, temperature=0.7)
print("温度采样得到的下一个词索引:", next_token_index)
5. 束搜索(Beam Search)

束搜索(Beam Search是一种启发式搜索算法,通过保留多个候选序列来寻找最优序列。束搜索在每个时间步保留固定数量的候选序列,并扩展这些候选序列直到达到最大长度。

公式描述

  1. 初始化beam_width个候选序列,每个序列的初始概率为1。
  2. 在每个时间步,扩展每个候选序列,生成新的候选序列。
  3. 对所有新的候选序列进行排序,保留前beam_width个最优序列。
  4. 重复步骤2和3,直到达到最大序列长度或满足终止条件。

代码示例

import torch
import torch.nn.functional as Fdef beam_search(logits_fn, initial_input, beam_width=3, max_length=20):"""束搜索:param logits_fn: 生成下一个词的logits函数:param initial_input: 初始输入:param beam_width: 束宽度:param max_length: 最大序列长度:return: 最优序列"""sequences = [[initial_input, 1.0]]for _ in range(max_length):all_candidates = []for seq, score in sequences:logits = logits_fn(seq)probs = F.softmax(logits, dim=-1)top_k_probs, top_k_indices = torch.topk(probs, beam_width)for i in range(beam_width):candidate = [seq + [top_k_indices[i].item()], score * top_k_probs[i].item()]all_candidates.append(candidate)# 对所有候选序列进行排序,保留前beam_width个最优序列ordered = sorted(all_candidates, key=lambda x: x[1], reverse=True)sequences = ordered[:beam_width]return sequences[0][0]# 示例logits函数
def example_logits_fn(seq):# 简单模拟logits输出return torch.tensor([2.5, 1.2, 0.3, 3.7, 0.8])# 执行束搜索
initial_input = [0]
best_sequence = beam_search(example_logits_fn, initial_input, beam_width=3, max_length=5)
print("束搜索得到的最优序列:", best_sequence)
6. 逆向温度采样(Reverse Temperature Sampling)

逆向温度采样通过逐步提高温度参数,从确定性较高的分布逐步过渡到多样性更高的分布。这种方法可以生成更加自然的文本。

公式描述

给定初始温度 τ 0 \tau_0 τ0和增长速率 α \alpha α,在每个时间步 t t t调整温度参数:

τ t = τ 0 ⋅ α t \tau_t = \tau_0 \cdot \alpha^t τt=τ0αt

代码示例

import torch
import torch.nn.functional as Fdef reverse_temperature_sampling(logits, initial_temperature=1.0, alpha=1.1, step=0):"""逆向温度采样:param logits: 模型输出的logits:param initial_temperature: 初始温度:param alpha: 温度增长速率:param step: 当前时间步:return: 采样得到的下一个词的索引"""# 计算当前时间步的温度temperature = initial_temperature * (alpha ** step)# 调整logits的温度adjusted_logits = logits / temperature# 计算概率分布probs = F.softmax(adjusted_logits, dim=-1)# 根据概率分布进行采样next_token = torch.multinomial(probs, 1)return next_token.item()# 示例logits
logits = torch.tensor([2.5, 1.2, 0.3, 3.7, 0.8])# 执行逆向温度采样
next_token_index = reverse_temperature_sampling(logits, initial_temperature=1.0, alpha=1.1, step=2)
print("逆向温度采样得到的下一个词索引:", next_token_index)

三、总结

本文详细介绍了大模型中的几种经典采样选择机制,包括随机采样、Top-k采样、Top-p(Nucleus)采样、温度采样、束搜索(Beam Search)和逆向温度采样。每种机制有不同的特点和适用场景,选择适当的机制可以有效地控制生成文本的质量和多样性。希望通过本文的介绍,读者能够理解并应用这些采样选择机制,提高生成模型的表现。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3248569.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

计算机毕业设计-基于Springboot的养老院管理系统-源码程序文档

项目源码,请关注❥点赞收藏并私信博主,谢谢~ 本系统开发采用技术为JSP、Bootstrap、Ajax、SSM、Java、Tomcat、Maven 此文章为本人亲自指导加编写,禁止任何人抄袭以及各类盈利性传播, 相关的代码部署论文ppt代码讲解答辩指导文件…

怎么将图片插入excel单元格中

首先选中单元格选择插入 在图片位置选择插入图片的位置 然后就插入成功了,一开始会觉得图片是附在表格上面的,并不在单元格里面,但是右边有一个小图片的图标,点击它可以缩小到单元格里面,再点击就是放大;

Redis中数据分片与分片策略

概述 数据分片是一种将数据分割并存储在多个节点上的技术,可以有效提高系统的扩展性和性能。在Redis中,数据分片主要用于解决单个实例存储容量和性能瓶颈的问题。通过将数据分散存储到多个Redis节点中,可以将负载均衡到不同的服务器上&#…

Qt5.12.2安装教程

文章目录 文章介绍下载连接安装教程 文章介绍 安装Qt5.12.2 下载连接 点击官网下载 安装包下载完毕 安装教程 点开设置,添加临时储存库,复制连接“https://download.qt.io/online/qtsdkrepository/windows_x86/root/qt/” 点击测试&#xff0…

NetSuite Item Receipt的头行To Location字段设置

最近用户有碰到一个问题是说,在没有转移或者调整,发出货品的情况下,为什么在Item Receipt上明明写的是Location A,而对应的库存却到了Location B中呢?有点奇怪,查明原因后是与To Location的头行设置与改动相…

Stable Diffusion【艺术风格】:当游戏角色遇上古代纸莎草纸艺术

提示词[character] as Oni demon | full body | ancient papyrus art | Goryeo blueprint mapping[角色] 饰演 Oni demon |全身 |古代纸莎草纸艺术 |高丽蓝图映射** 纸莎草纸艺术**通常指的是古埃及时期使用纸莎草纸进行书写和绘画的艺术形式。纸莎草纸(Papyrus&…

【数据结构】Splay详解

Splay 引入 Splay旋转操作splay操作插入操作查询x排名查询排名为x删除操作查询前驱/后继模板Splay时间复杂度分析 进阶操作截取区间区间加,区间赋值,区间查询,区间最值区间翻转原序列整体插入指定位置插入整体插入末尾区间最大子段和 一些好题…

学会这个技巧,电子画册制作从此不再难

​在数字化时代,电子画册作为一种新型的宣传和展示工具,已经越来越受到企业和个人的青睐。它不仅能够以生动活泼的形式展示内容,还能够实现高度的互动性和分享性,从而大大提高信息的传播效率。然而,制作一款精美且功能…

【机器学习】机器学习与图像分类的融合应用与性能优化新探索

文章目录 引言第一章:机器学习在图像分类中的应用1.1 数据预处理1.1.1 数据清洗1.1.2 数据归一化1.1.3 数据增强 1.2 模型选择1.2.1 卷积神经网络1.2.2 迁移学习1.2.3 混合模型 1.3 模型训练1.3.1 梯度下降1.3.2 随机梯度下降1.3.3 Adam优化器 1.4 模型评估与性能优…

GESP CCF C++ 七级认证真题 2024年6月

第 1 题 下列C代码的输出结果是&#xff08; &#xff09;。 #include <iostream> #include <cmath> using namespace std; int main() { cout << sin(3.1415926 / 2); return 0; } A. 0 B. 1 C.0.5 D.0.7071 第 2 题 对于如下图的二叉树&#x…

【免费】中国电子学会所有历年真题卷全部免费

今天登录到csdn 遇到一件非常气愤的事情 原本就是电子学会网站的试卷 某些博主为了赚那么点钱 真的是不要Face了 之前没有放开资源 是因为懒得整理 为了这个不要face 花了我一下午时间把所有的资源整合在一起 现在全部拿走 全部免费&#xff01;全部免费&#xff01;全…

【网络】掌握网络基础概念

文章目录 OSI七层模型TCP/IP五层&#xff08;或四层&#xff09;模型为什么要有TCP/IP协议网络传输的基本流程网络传输流程图数据包封装和分用 网络中的地址管理IP地址Mac地址比较IP地址和Mac地址 OSI七层模型 OSI即Open System Interconnection,开发系统互连。OSI七层模型是一…

ABAP 物料主数据屏幕增强记录

参考文章&#xff1a;https://zhuanlan.zhihu.com/p/692818545 先从SPRO进入——》SAP 参考IMG——》后勤_常规——》物料主数据——》配置物料主记录——》创建定制子屏幕的程序 然后会让你创建一个函数组,此处命名为ZTEST2 &#xff08;后面才发现这张图截图不对&#xf…

昇思25天学习打卡营第13天|LLM-基于MindSpore实现的GPT对话情绪识别

打卡 目录 打卡 预装环境 流程简述 部分执行结果演示 词向量加载过程 模型结构 模型训练过程 模型预测过程 代码 预装环境 pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore2.2.14 pip install mindnlp pip install jieba pip install spacy pip …

Typescript 实现倒计时功能 useCountdown

效果图 代码块 useCountdown.ts import {onUnmounted, reactive, ref, watch} from "vue";type union days | hours | minutes | seconds | millisecondsexport type Remains Record<union, number>;/*** 创建一个倒计时** 用法*/ export const useCountDo…

Python酷库之旅-第三方库Pandas(029)

目录 一、用法精讲 74、pandas.api.interchange.from_dataframe函数 74-1、语法 74-2、参数 74-3、功能 74-4、返回值 74-5、说明 74-6、用法 74-6-1、数据准备 74-6-2、代码示例 74-6-3、结果输出 75、pandas.Series类 75-1、语法 75-2、参数 75-3、功能 75-4…

C语言函数:编程世界的魔法钥匙(2)

引言 注&#xff1a;由于这部分内容比较抽象&#xff0c;而小编我又是一个刚刚进入编程世界的计算机小白&#xff0c;所以我的介绍可能会有点让人啼笑皆非。希望大家多多包涵&#xff01;万分感谢&#xff01;待到小编我学有所成&#xff0c;一定会把这块知识点重新介绍一遍&a…

【JAVA基础】反射

编译期和运行期 首先大家应该先了解两个概念&#xff0c;编译期和运行期&#xff0c;编译期就是编译器帮你把源代码翻译成机器能识别的代码&#xff0c;比如编译器把java代码编译成jvm识别的字节码文件&#xff0c;而运行期指的是将可执行文件交给操作系统去执行&#xff0c; …

Linux介绍和文件管理

一Linux的起源 1.Unix Dennis Ritchie和Ken Thompson发明了C语言&#xff0c;而后写出了Unix的内核 2.Minix MINIX是一种基于微 内核架构的类UNIX计算机操作系统&#xff0c;由 Andrew S. Tanenbaum发明 3.Linux内核 芬兰赫尔辛基大学的研究生Linus Torvalds基于Gcc、 ba…

stack与queue的介绍与使用

stack 栈&#xff08;stack&#xff09;是一种遵循先入后出&#xff08;FILO&#xff09;逻辑的线性数据结构。其只能从容器的一端进行元素的插入与提取操作。 我们可以把他比作串串&#xff0c;我们在串肉的时候都是从底依次往上串肉&#xff0c;然后在吃的时候是从串顶依次…