论文:
https://arxiv.org/pdf/2310.11511.pdf
code:
https://github.com/langchain-ai/langgraph/blob/main/examples/rag/langgraph_self_rag.ipynb?ref=blog.langchain.dev
相关工作
RAG
尽管 RAG 方法可以通过检索生成得到更为准确的结果,但是不可避免的是会引入以下几个问题:
- 降低运行效率
- 对于不相关上下文的鲁棒性,即如果上下文无关,则容易检索出不相关信息
- 缺乏归因,个人理解为无法找到很好的支撑点,即无法确认产生的结果是否是因为检索回来的内容导致的
为解决上述问题,作者提出了一种方法,训练任意的 LM 模型,使其可以可以根据不同指令的需求查询检索相关文献。并且通过 reflections token 引导生成受控的内容,以进一步提高生成的质量。
Concurrent RAG work/并行 RAG 工作
已有的一些工作会提出一些新的训练方法或者 prompt 策略来提高 RAG 的效果。例如,分两步在instruction-tuning 数据集上对 retriever 和 LM 进行微调;利用摘要模型过滤或者压缩检索段落等。
在训练和生成阶段引入 critics
训练 LLM 的时候引入 RLHF(Reinforcement Learning from Human Feedback)已经被证明是一个非常有效的方法。与已有方法不同的是,Self-RAG 在包含 reflection tokens 的离线任务示例上训练 LM,相较于 RLHF 训练成本低了很多。
Reflection token 定义
作者定义了如下的几个 token 来辅助生成,在文中统一称为 reflection token,如下图所示:
- Retrieve token:决定是否需要调用 retriever 来检索相关信息
- IsRel token:决定检索回来的信息是否和 input prompt 相关
- IsSup token:决定检索回来的信息是否可以支撑生成结果
- IsUse token:判断生成结果是否是对input prompt 有用和有信息的回答,与它是否事实无关
训练流程
Training critic 模型
创建数据集
作者通过 GPT-4 生成 reflection token 来创建监督数据并将这部分知识加入到 in-house 的 critical model 中。
以 Retrieve token 为例,作者用特定的指令集(Given an instruction, make a judgment on whether finding some external documents from the web helps to generate a better response.) 提示 GPT-4 ,随后加入 few-shot 演示结果,包括原始的输入,输出以及正确的 reflection token。 手动评估显示,GPT-4 reflection 预测与人类评估高度一致。
Critic 模型学习
在收集训练数据后,作者用预先训练的 LM 作为初始化的 Critic model,之后用标准 LM 目标在训练数据集上训练它,最大化可能性,如下图所示:
训练生成模型
创建数据集
给定输入输出对(x,y),作者使用检索到的信息和 critic 模型来增强原始输出,以创建精确模拟Self-RAG 推理时间过程的监督数据。
- 对于每一个 segment,作者调用 critic model 来评估是否需要额外的信息来增强生成。
- 如果判断是需要检索,则会将 Retrieve token 设置为 True 并且检索 top-K 的段落。
- 对于每一个检索返回的段落,Critic 模型会进一步用 IsRel token 来评估该段落是否和预测相关。
- 如果段落是相关的,Critic 模型会进一步用 IsSup token 来评估该段落是否支持模型的生成结果。
- (论文中写道 IsRel token 和 IsSup token 会被追加到检索段落或者生成段落之后,这个地方和用 LangGraph 实现的 code 有一些不一致,code 中直接将其作为判断节点,进行后续操作)
- 最后 Critic 模型会用 IsUse token 评估是否需要被使用,最终将经过增强的输出和 reflection token 以及原始的输入作为 pair 对加入到数据集中。
生成器训练
作者用标准的下一个令牌目标在上述生成的具有 reflection token 的语料库上训练生成器模型,具体目标函数如下:
与 Critic 训练不同,生成器训练用于预测目标输出以及 reflection token。在训练过程中,作者屏蔽了检索到的文本块以进行损失计算,并新增了一组 reflection token{Critique,Retrieve}。
推理流程
- 对于每个输入和上次生成的结果,模型(Self-RAG)会对 retrieval token 进行解码,评估是否需要进行 retrieval
- 如果不需要检索,则模型会和正常 LM 一样预测下一个输出段落
- 如果需要检索,模型会生成一个 critique token 评估检索回来段落的相关性、下一段落的输出和一个 critique token 评估生成的下一段落输出是否得到了检索回来段落的支撑
- 最终一个新的 critique token 会用来评估整体输出的质量
在生成每个段落的时候,Self-RAG 并行处理多个段落(Tree-decoding),并用自己生成的 reflection token 对生成的结果进行软约束或者硬约束。如图所示。
Tree-decoding with critique tokens
在每个 segment ,当需要检索时,会检索 K 个段落,生成器M并行处理每个段落并输出K个不同的连续候选。每个 segment 和检索段落都会有一个 critic 分数,该分数是每个 critique token 类型的归一化概率的线性加权和。每个 segment 的分数计算公式如下:
实验结果
对比实验
消融实验,参数敏感性实验