检索增强大语言模型生成RAG

12/17/2023 检索增强生成RAGBM25及BGE检索Agentic-RAGRAG效果优化

# 1. 检索增强生成

# 1.1 RAG基本介绍

# 1.1.1 RAG是什么

开源的基座模型参数量不够大,本身拥有的能力有限。要完成复杂的知识密集型的任务,可以基于语言模型构建一个系统,通过访问外部知识源来做到。这样可以使生成的答案更可靠,有助于缓解“幻觉”问题。

RAG 会接受输入并检索出一组相关/支撑的文档,并给出文档的来源。这些文档作为上下文和输入的原始提示词组合,送给文本生成器得到最终的输出。这样 RAG 更加适应事实会随时间变化的情况,这非常有用,因为 LLM 的参数化知识是静态的,RAG 让语言模型不用重新训练就能够获取最新的信息,基于检索生成产生可靠的输出。

RAG基本介绍

# 1.1.2 RAG发展历程

“RAG”概念由Lewis在2020年引入,其发展迅速,标志着研究旅程中的不同阶段。最初,这项研究旨在通过在预训练阶段为它们注入额外知识来增强语言模型。ChatGPT的推出引发了对利用大型模型进行深度上下文理解的高度兴趣,加速了RAG在推断阶段的发展。随着研究人员更深入地探索大型语言模型(LLMs)的能力,焦点转向提升他们的可控性和推理技巧以跟上日益增长的需求。GPT-4 的出现标志着一个重要里程碑,它革新了 RAG ,采取一种将其与微调技术相结合的新方法,并继续优化预训练策略。

RAG发展时间轴

# 1.1.3 RAG生态及挑战

RAG的应用已不再局限于问答系统,其影响力正在扩展到更多领域。现在,诸如推荐系统、信息提取和报告生成等各种任务开始从RAG技术的应用中受益。与此同时,RAG技术栈正在经历一次繁荣。除了众所周知的工具如Langchain和LlamaIndex外,市场上也出现了更多针对性强的RAG工具,例如:为满足更专注场景需求而定制化的;为进一步降低入门门槛而简化使用的;以及功能专业化、逐渐面向生产环境目标发展的。

RAG当前面临的挑战:

  • 上下文长度:当检索到的内容过多并超出窗口限制时该怎么办?如果LLMs的上下文窗口不再受限,应如何改进RAG?
  • 鲁棒性:如何处理检索到的错误内容?如何筛选和验证检索到的内容?如何增强模型对毒化和噪声的抵抗力?
  • 与微调协同工作:如何同时利用RAG和FT的效果,它们应该如何协调、组织,是串行、交替还是端对端?
  • 规模定律:RAG模型是否满足规模定律?会有什么情况下可能让RAG经历逆向规模定律现象呢?
  • 生产环境应用:如何减少超大规模语料库的检索延迟? 如何确保被 LLMS 检索出来的内容不会泄露?

# 1.2 RAG技术实现

# 1.2.1 RAG技术范式

在RAG的技术发展中,我们从技术范式的角度总结了其演变过程,主要分为以下几个阶段:

  • 初级RAG:初级RAG主要包括三个基本步骤:1)索引——将文档语料库切分成更短的片段,并通过编码器建立向量索引。2)检索——根据问题和片段之间的相似性检索相关文档片段。3)生成——依赖于检索到的上下文来生成对问题的回答。
  • 高级RAG:初级RAG在检索、生成和增强方面面临多重挑战。随后提出了高级RAG范式,涉及到预检索和后检索阶段额外处理。在检索之前,可以使用查询重写、路由以及扩展等方法来调整问题与文档片段之间语义差异。在检索之后,重新排列已获取到的文档语料库可以避免"迷失在中间"现象,或者可以过滤并压缩上下文以缩短窗口长度。
  • 模块化RAG:随着RAG技术进一步发展和演变,模块化RAG的概念诞生了。结构上,它更自由灵活,引入更具体功能模块如查询搜索引擎以及多答案融合。技术层面上,它将信息查找与微调、强化学习等技术集成起来。在流程方面,RAG模块设计并协同工作形成各种不同类型RAG。

然而,模块化 RAG 并非突然出现,这三种范式存在继承与发展关系。高级RAG是模块化RAG的特殊情况,而初级RAG是高级RAG的特殊情况。

RAG技术范式

# 1.2.2 选择RAG还是微调

除了RAG之外,LLMs的主要优化策略还包括提示工程和微调(FT)。每种都有其独特的特点。根据它们对外部知识的依赖性以及对模型调整的需求,每种都有适合的应用场景。

RAG与FT的比较

RAG就像是给模型提供了一本定制信息检索的教科书,非常适合特定的查询。另一方面,FT就像一个学生随着时间内化知识,更适合模仿特定的结构、风格或格式。通过增强基础模型的知识、调整输出和教授复杂指令,FT可以提高模型的性能和效率。然而,它并不擅长整合新知识或快速迭代新用例。RAG和FT并不互斥,它们相辅相成,并且同时使用可能会产生最好的结果。

RAG与FT的关系

# 1.2.3 如何评价RAG的效果

对RAG的评估方法多种多样,主要包括三个质量分数:上下文相关性、答案准确性和答案相关性。此外,评估还涉及四项关键能力:抗噪声能力、拒绝能力、信息整合以及反事实鲁棒性。这些评价维度将传统的定量指标与针对RAG特点的专门评估标准相结合,尽管这些标准尚未得到标准化。

在评价框架方面,有RGB和RECALL等基准测试,以及像RAGAS、ARES和TruLens等自动化评价工具,它们帮助全面衡量RAG模型的表现。

如何评价RAG的效果

# 1.3 朴素RAG与Agentic RAG

# 1.3.1 朴素RAG的流程

朴素RAG的基本流程概述:用户输入问题——>问题重构(补全指代信息,保证多轮对话的能力)——>从检索库检索答案——用LLM总结答案

朴素RAG的流程

RAG 由两部分组成:

  • 第一部分负责在知识库中,根据 query 检索出匹配的文档。
  • 第二部分将 query 和文档拼接起来作为 QA 的 prompt,送入 seq2seq 模型,生成回复。

RAG原理

一部分简单场景下,朴素的 RAG 已经可以满足用户意图明确的场景的要求,因为答案已经包含在检索出来的结果中,只要交给 LLM 即可。然而在更多的情况下用户意图并不明确,无法直接通过检索找到答案,例如一些针对多文档的总结类提问需要进行多步推理等。这类场景就需要引入 Agentic RAG ,也就是在问答的过程中引入任务编排机制。

# 1.3.2 Agentic RAG的流程

Agentic RAG,顾名思义,是基于 Agent 的 RAG。Agent 与 RAG 关系紧密,两者互为基石。Agentic RAG 和简单 RAG 的最大区别在于 Agentic RAG 引入了 Agent 的动态编排机制,因此可以根据用户提问的不同意图,引入反馈和查询改写机制,并进行“多跳”式的知识推理,从而实现对复杂提问的回答。

  • 开放域问答:直接通过 LLM 产生答案而无需依赖 RAG 检索。
  • 多跳问答:首先将多跳查询分解为更简单的单跳查询,重复访问 LLM 和 RAG 检索器来解决这些子查询,并合并它们的答案以形成完整答案。
  • 自适应检索:适用于需要多步逻辑推理的复杂问题。复杂的问答往往需要从多个数据源综合信息并进行多步推理。自适应检索通过迭代地访问 RAG 检索器和 LLM,逐步构建起解决问题所需的信息链。

Agentic-RAG的流程

# 1.4 信息检索算法

稠密检索(Dense Retrieval, DR)一般指的是将documents编码为稠密向量(Dense Vector),这个如今一般都通过预训练模型的encoder进行完成,例如BERT或者T5等(GPT这种decoder架构的也可以做到)。随后基于向量数据库(如FAISS)等进行类似于K近邻的搜索方法,来查找与查询内容接近的高维文档向量。【需要的空间大,查询速度快】

稀疏检索(Sparse Retrieval, SR)将文档投射到一个稀疏向量上,顾名思义,这个稀疏向量通常与文档的语言词汇一致,例如你的一篇文章,对每个词进行向量化,随后在词这个维度上进行执行你的检索策略。当然,这个传统的BM25或者TF-IDF也可以做到,但随着Transformer接管了这一领域,你会看到像 SPLADE 这样的方法,使用神经模型来推断与文档相关的词汇,即使这些词汇并不存在。这种方法的好处是,你可以离线处理文章中的词等细粒度的向量表示,从而大大加速检索的效率。【需要的空间小,查询速度慢】

# 1.4.1 BM25检索

在信息检索领域,BM25算法被广泛认为是一种经典且有效的排名函数,用于估计文档与用户查询之间的相关性。BM25(Best Matching 25)是基于Okapi TF-IDF算法的改进版本,旨在解决一些Okapi算法存在的问题。BM25的核心思想是利用词频(TF)和逆文档频率(IDF)来衡量文档与查询之间的相关性,同时引入文档长度信息来进一步调整相关性的计算。

  • 词频(TF):词频是衡量一个词在文档中重要性的基本指标。在BM25算法中,词频是通过计算查询中的词在文档中出现的频率来确定的。词频越高,这个词在文档中的重要性越大。
  • 逆文档频率(IDF):逆文档频率用于衡量一个词在整个文档集合中的独特性或信息量。它是由整个文档集合中包含该词的文档数量决定的。一个词在很多文档中出现,其IDF值就会低,反之则高。这意味着罕见的词通常有更高的IDF值,从而在相关性评分中拥有更大的权重。
  • 文档长度:除了词频和逆文档频率,BM25还引入了文档长度信息来调整相关性的计算。较长的文档可能仅因为它们的长度就有更高的词频,因此需要用文档长度来调整词频的影响。

BM25算法公式解析

# 1.4.2 BGE检索

智源研究院发布了一款开源的中英文语义向量模型BGE(BAAI General Embedding),在中英文语义检索精度与整体语义表征能力方面全面超越了OpenAI、Meta等同类模型。BGE模型的发布,标志着语义向量模型(Embedding Model)在搜索、推荐、数据挖掘等领域的应用迈入了一个新的阶段。

BGE的技术亮点:

  • 高效预训练和大规模文本微调;
  • 在两个大规模语料集上采用了RetroMAE预训练算法,进一步增强了模型的语义表征能力;
  • 通过负采样和难负样例挖掘,增强了语义向量的判别力;
  • 借鉴Instruction Tuning的策略,增强了在多任务场景下的通用能力。
BGE向量检索与其他检索的对比

# 1.5 RAG效果优化

基础RAG架构的流程是十分简单的,其最大的特点是数据单向流通,因此搭建一个这样的系统是十分快捷的,但离真正能投入到生产环境中使用还是很远的。为了增强原有架构的文档召回率和系统鲁棒性,其优化路径大致有两条:增加召回管道和增加反馈机制。增加召回管道就是查询变换(子查询、rag-fusion)、混合检索这类通过多路召回来最大化召回率的优化方法;增加反馈机制就是rerank、后退提示、self-rag这类基于原始结果进行优化来最大化准确率的优化方法。

# 1.5.1 数据预处理

不管RAG系统结构怎样复杂,由于其数据驱动的特性,高信噪比的数据是十分重要的,在检索之前对原始数据的优化包括以下方法:

  • 实体解析:消除实体和术语的歧义以实现一致的引用。例如,将“LLM”、“大语言模型”和“大模型”标准化为通用术语。
  • 文档划分:合理地划分不同主题的文档。不同主题的文档是集中在一处还是分散在多处?如果人类都不能轻松地判断出需要查阅哪个文档才能回答提问,那么检索系统也无法做到。
  • 数据增强:使用同义词、释义甚至其它语言的翻译来增加知识库的多样性。
  • 处理特殊数据:例如时间敏感数据,对于经常更新的主题,实施一种机制来使过时的文档失效或更新。
  • 增加元数据:增加内容摘要、时间戳、用户可能提出的问题等附加信息来丰富知识库。

# 1.5.2 文本分块

通常被检索知识库中的数据量是远超于LLM所能接受的输入长度的,因此合理的分块应尽可能做到在不超出LLM输入长度限制的情况下,保证块之间的差异性和块内部的一致性。因此我们需要尽可能提供有用的信息给LLM,而不是提供无关的信息分散其注意力,可以采用以下高级的分块方法:

  • 句分割:使用 NLTK 或者 spaCy 库提供的句子分割功能,主流开发框架如 LangChain 都有集成。
  • 递归分割:通过重复地应用分块规则来递归地分解文本。例如,在 LangChain 中会先通过段落换行符(\n\n)进行分割。然后检查这些块的大小,如果大小不超过一定阈值,则该块被保留。对于超过阈值的块,使用单换行符(\n)再次分割。以此类推,不断根据块大小更新更小的分块规则(如空格,句号)。这种方法可以灵活地调整块的大小。例如,对于文本中的密集信息部分,可能需要更细的分割来捕捉细节;而对于信息较少的部分,则可以使用更大的块。
  • 语义分割:通过计算向量化后的文本的相似度来进行语义层面的分割。
  • 特殊结构分割:针对特定结构化内容(例如Markdown、LaTex、JSON等)的专门分割器。这些分割器特别设计来处理这些类型的文档,以确保正确地保留其结构。

分块还有一个因素比较重要,就是块的大小。除了嵌入模型,文档的类型和用户查询的长度及复杂性也是决定分块大小的重要因素。处理长篇文章或书籍时,较大的分块有助于保留更多的上下文和主题连贯性;而对于社交媒体帖子,较小的分块可能更适合捕捉每个帖子的精确语义。如果用户的查询通常是简短和具体的,较小的分块可能更为合适;相反,如果查询较为复杂,可能需要更大的分块。实际场景中,我们可能还是需要不断实验调整,在一些测试中,128大小的分块往往是最佳选择,在无从下手时,可以从这个大小作为起点进行测试。

# 1.5.3 嵌入

相当于数据的类型转换,即对文本数据使用嵌入(Embedding)模型进行向量化(Vectorization),以便于在检索阶段使用向量检索(Vector Retrieval)。嵌入阶段有以下几个可以优化的点:

  • 尽量使用动态嵌入:动态嵌入相较于静态嵌入更能够处理一词多义的情况,语义完全不一样的词使用静态嵌入其向量是固定的。相比之下,引入自注意力机制的模型,能够提供动态的词义理解,这意味着它可以根据上下文动态地调整词义,使得同一个词在不同语境下有不同的向量表示。
  • 微调嵌入:大多数嵌入模型都是在通用语料上进行训练的,有些项目为了让模型对垂直领域的词汇有更好的理解,会对嵌入模型进行微调。使模型能够对垂直领域词汇和通用词汇一视同仁,不被分散注意力。
  • 混合嵌入:对用户问题和知识库文本使用不同的嵌入模型。

# 1.5.4 查询优化

在实际环境中,可能由于用户的表述多样性亦或是模糊的,导致在检索阶段召回率和准确率较低,这时就需要对查询做一个优化,能够规范和丰富查询所包含的信息,便于在系统中检索到与用户相关的文档。对查询的优化方法有以下几个:

  • 查询重写:通过提示LLM或者使用专门的“问题重写器”(通常是经过微调的小型Transformer)来对用户的问题进行改写。
  • 后退提示:提示LLM提出一个关于高层次概念或原则的抽象通用问题(称之为“后退”问题)。后退问题的抽象程度需要根据特定任务进行调整。最终后退问题和原始问题一起进行检索。例如,对于问题“Estella Leopold在1954年8月至11月期间上了哪所学校?”这个问题很难直接解决,因为有时间范围的详细限制。在这两种情况下,提出一个后退问题“Estella Leopold的教育经历怎么样的?”则有助于系统的检索。
  • 后续问题:使用LLM针对历史对话和当前问题生成一个独立问题。这个方法主要针对以下情况:a. 后续问题建立在前一次对话的基础上,或引用了前一次谈话。例如,如果用户先问“我在意大利能做什么”,然后问“那里有什么类型的食物”——如果只嵌入“那里有哪种类型的食物“,LLM就不知道“那里”在哪里。b.嵌入整个对话(或最后k条消息)。如果后续问题与之前的对话完全无关,那么它可能会返回完全无关的结果,从而在生成过程中分散LLM的注意力。
  • HyDE:用LLM生成一个“假设”答案,将其和问题一起进行检索。HyDE的核心思想是接收用户提问后,先让LLM在没有外部知识的情况下生成一个假设性的回复。然后,将这个假设性回复和原始查询一起用于向量检索。假设回复可能包含虚假信息,但蕴含着LLM认为相关的信息和文档模式,有助于在知识库中寻找类似的文档。
  • 多问题查询:基于原始问题,提示LLM从不同角度产生多个新问题或者子问题,并使用每一个新问题进行检索,在后续阶段使用RRF或者Rerank合并来自不同问题的检索结果。例如,对于原始问题:谁最近赢得了总冠军,红袜队还是爱国者队?,可以生成两个子问题:a. 红袜者队上一次赢得总冠军是什么时候?b. 爱国者队上一次赢得总冠军是什么时候?

# 1.5.5 检索

检索(Retrieval)最终的目标就是获取最相关的文档或者保证最相关的文档在获取的文档列表中存在。为了达成这个目标,该环节有以下几个优化方法:

  • 上下文压缩:当文档块过大时,可能包含太多不相关的信息,传递这样的文档块可能导致更昂贵的LLM调用和更差的响应。上下文压缩的思想就是通过LLM的帮助根据上下文对单个文档内容进行压缩,或者对返回结果进行一定程度的过滤仅返回相关信息。
  • 句子窗口搜索:相反,文档文块太小会导致上下文的缺失。其中一种解决方案就是窗口搜索,该方法的核心思想是当提问匹配好文档块后,将该文档块周围的块作为上下文一并交给LLM进行输出,来增加LLM对文档上下文的理解。
  • 父文档搜索:父文档搜索也是一种很相似的解决方案,父文档搜索先将文档分为尺寸更大的主文档,再把主文档分割为更短的子文档两个层级,用户问题会与子文档匹配,然后将该子文档所属的主文档发送给LLM。
  • 自动合并:自动合并是在父文档搜索上更进一步的复杂解决方案。同样地,我们先对文档进行结构切割,比如将文档按三层树状结构进行切割,顶层节点的块大小为1024,中间层的块大小为512,底层的叶子节点的块大小为128。而在检索时只拿叶子节点和问题进行匹配,当某个父节点下的多数叶子节点都与问题匹配则将该父节点作为结果返回。
  • 混合检索:RAG系统从根本上来说是作为开放域、基于自然语言的问答系统。为了获得开放式用户查询的高事实召回率,概括和聚焦应用场景以选择合适的检索模式或组合至关重要。在大多数文本搜索场景中,主要目标是确保最相关的结果出现在候选列表中。混合检索通过混合多个检索方法来实现不同检索技术的协同作用从而能够最大化事实召回率。例如,可以采用向量检索+关键词检索的组合来构建RAG系统的检索模块。
  • 路由机制:当建立了多个针对不同数据类型和查询需求的索引后,例如,可能有一个索引专门处理摘要类问题,另一个专门应对直接寻求具体答案的问题,还有一个专门针对需要考虑时间因素的问题。这时就需要使用路由机制来选择最合适的索引进行数据检索,从而提升检索质量和响应速度。
  • 使用Agent:该方法就是使用Agent来决定应该采用什么样的检索方法,从不同的检索方法中选取一种或多种进行召回。同时组合方式也是灵活的,是垂直关系还是平行关系。例如:对于查询“最新上映的科幻电影推荐”,Agent可能首先将其路由至专门处理当前热点话题的索引,然后利用专注于娱乐和影视内容的索引来生成相关推荐。

# 1.5.6 检索后处理

检索后处理这个概念还是很宽泛的,是对检索结果进行进一步的处理以便于后续LLM更好的生成,比较典型的就是重排序(Rerank)。向量检索其实就是计算语义层面的相似性,但语义最相似并不总是代表最相关。重排模型通过对初始检索结果进行更深入的相关性评估和排序,确保最终展示给用户的结果更加符合其查询意图。实现重排序除了可以提示LLM进行重排,更多的是使用了专门的重排序模型。这些模型会考虑更多的特征,如查询意图、词汇的多重语义、用户的历史行为和上下文信息,从而保证最相关的文档排在结果列表的最前面。

# 1.5.7 生成

在生成(Generation)阶段的优化更多的是考虑用户体验,有以下几点可以供参考:

  • 多轮对话:也就是带聊天历史的RAG,以AI搜索为例,明星产品perplexity就是支持多轮对话的,这样用户可以通过连续对话来深入了解解决某个问题。
  • 增加追问机制:在prompt中加入“如果无法从背景知识回答用户的问题,则根据背景知识内容,对用户进行追问,问题限制在3个以内”。这个机制并没有什么技术含量,主要依靠大模型的能力。不过大大改善了用户体验,用户在多轮引导中逐步明确了自己的问题,从而能够得到合适的答案。
  • Prompt优化:RAG系统中的Prompt应明确指出回答仅基于搜索结果,不要添加任何其他信息。例如,可以设置Prompt:“你是一名智能客服。你的目标是提供准确的信息,并尽可能帮助提问者解决问题。你应保持友善,但不要过于啰嗦。请根据提供的上下文信息,在不考虑已有知识的情况下,回答相关查询。”当然也可以根据场景需要,适当让模型的回答融入一些主观性或其对知识的理解。此外,使用Few-shot的方法指导LLM如何利用检索到的知识,也是提升LLM生成内容质量的有效方法。
  • 用户反馈循环:基于现实世界用户的反馈不断更新数据库,标记它们的真实性。

# 2. 实例场景及服务器环境

# 2.1 服务器测试环境

实验环境:实体GPU服务器,NVIDIA RTX 4090 / 24GB,CentOS 7.9,Anaconda3-2019.03,CUDA 12.4

如果没有GPU服务器,可以租用AutoDL等平台的。服务器的租用及基础环节的安装这里就不赘述了,详见我的另一篇博客:常用深度学习平台的使用指南 (opens new window)

# 2.2 实例场景及源码

实例场景概述:有一批内部的政府政策文档,需要基于这些文档进行垂直领域的RAG问答。

本项目我已经在Github上进行了开源,项目地址为:https://github.com/Logistic98/rag-omni (opens new window)

  • 注:此项目仅作为示例演示整个RAG的过程,实际应用追求效果的话,这些是远远不够的,开源模型和检索算法的性能不够好,也有很多特殊情况需要处理。
.
├── README.md
├── data                  // 示例数据
│   ├── original_data          // 原始文档数据
│   └── preprocess_data        // 处理后的结构化数据
├── convert               // 转换数据
│   ├── data_convert_json      // 数据结构化转换脚本
│   └── marker_parse_pdf       // Marker解析PDF工具
├── llm                   // 大模型服务
│   ├── nginx_balance          // Nginx负载均衡
│   ├── llm_server.py          // 部署本地大模型服务
│   ├── llmtuner               // 部署本地大模型服务的核心代码
│   ├── models                 // 存放本地大模型的模型文件
│   └── test                   // 测试大模型服务的脚本
├── retrieval            // 检索服务
│   ├── bge                    // BGE检索算法的核心代码
│   ├── bm25                   // BM25检索算法的核心代码
│   ├── openai_embedding       // OpenAI Embedding检索算法的核心代码
│   ├── code.py
│   ├── log.py
│   ├── response.py
│   ├── retrieval_index.py     // 构建索引文件的脚本
│   ├── retrieval_server.py    // 部署检索服务
│   └── test                   // 测试检索服务的脚本
├── rag                  // RAG服务
│   ├── code.py
│   ├── log.py
│   ├── rag_server.py           // 部署RAG服务
│   ├── rag_solve.py            // RAG处理逻辑的核心代码
│   ├── response.py
│   └── test                    // 测试RAG服务的脚本
└── chat                 // RAG页面 
    ├── babel.config.js
    ├── jsconfig.json
    ├── package.json
    ├── public
    ├── src                     // RAG页面的主要源码
    └── vue.config.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

# 2.3 原始数据预处理

# 2.3.1 数据预处理要求

数据预处理:需要将数据预处理成结构化数据之后,才能方便的构建检索库。

  • 数据预处理要求:每个文档拆开,拆开后每个数据是文档中的某一段,目的是保证每条数据都有较完整的语义,并且长度不会太长。
  • 数据预处理方式:提供的文档主要是Word、PDF等格式,无法直接使用。数据量少的话,可以直接人工去处理。数据量大的话,建议先使用脚本批量处理一下,有些解析不成功的再人工处理。

用于RAG的结构化数据

# 2.3.2 数据预处理脚本

PDF格式是非常难处理的,如果是文本类型的可以使用以下脚本来初步处理,如果本身就是图片类型的,那该脚本解析不了,就需要OCR技术来辅助了。关于复杂PDF文件的解析可以使用 Marker 工具。

./rag-omni/convert/data_convert_json/pdf_to_docx.py

# -*- coding: utf-8 -*-

import os
from pdf2docx import Converter
import argparse

parser = argparse.ArgumentParser(description="服务调用方法:python3 pdf_to_docx.py --pdf_path 'xxx.pdf' --docx_path 'xxx.docx'")
parser.add_argument("--pdf_path", type=str, required=True, help="要解析的 PDF 文件地址")
parser.add_argument("--docx_path", type=str, required=True, help="解析后的 DOCX 文件输出地址")
args = parser.parse_args()

docx_dir = os.path.dirname(args.docx_path)
if not os.path.exists(docx_dir):
    os.makedirs(docx_dir)

try:
    # 初始化转换器并转换 PDF 到 DOCX
    cv = Converter(args.pdf_path)
    cv.convert(args.docx_path)  # 默认转换所有页面
    cv.close()
    print("PDF 文件已成功转换为 DOCX 格式。")
except Exception as e:
    print(f"转换过程中发生错误:{str(e)}")
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23

./rag-omni/convert/data_convert_json/docx_to_json.py

# -*- coding: utf-8 -*-

import os
from docx import Document
import json
import argparse

parser = argparse.ArgumentParser(description="服务调用方法:python3 docx_to_json.py --docx_path 'xxx.docx' --output_path 'xxx.json' --max_length 500")
parser.add_argument("--docx_path", type=str, required=True, help="docx 文件地址")
parser.add_argument("--output_path", type=str, required=True, help="结果输出地址")
parser.add_argument("--max_length", default=500, type=int, help="切片大小")
args = parser.parse_args()

docx = Document(args.docx_path)
max_length = args.max_length

result = []
current_text = ""

for paragraph in docx.paragraphs:
    section = paragraph.text.strip()
    if not current_text or len(current_text) + len(section) + 1 <= max_length:
        current_text += " " + section
    else:
        result.append({
            "file_name": os.path.basename(args.docx_path),
            "part_content": current_text.strip()
        })
        current_text = section

if current_text:
    result.append({
        "file_name": os.path.basename(args.docx_path),
        "part_content": current_text.strip()
    })

output_dir = os.path.dirname(args.output_path)
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

with open(args.output_path, "w", encoding="utf-8") as file:
    json.dump(result, file, ensure_ascii=False, indent=2)

print(f"{args.docx_path} 处理完成")
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

# 3. 部署大模型服务

后面的检索服务和RAG服务,对接了 OpenAI 风格的 API,可以使用任意符合该格式的服务。如果有数据保密、本地化部署的需求,可以使用本地化部署的大模型服务。如果直接使用外部的API,本节可跳过。

# 3.1 LLaMA-Factory的推理服务

这里用了 LLaMA-Factory (opens new window) 项目的 /src/llmtuner 部分,它支持了 vLLM,对推理进行了加速,本项目代码里用的版本需要要求 vllm==0.4.0 版本。

llm
├── nginx_balance
├── llm_server.py
├── llmtuner
│   ├── api
│   ├── chat
│   ├── data
│   ├── extras
│   ├── hparams
│   └── model
├── models
│   ├── download_baichuan_model.py
│   └── download_qwen_model.py
└── test
    ├── llm_server_stress_test.py
    └── llm_server_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16

注:开启vLLM可以充分利用显卡计算资源,带来更好的推理性能,详见我的另一篇博客:基于vLLM加速大模型推理服务 (opens new window)

# 3.2 部署大模型服务并进行测试

# 3.2.1 下载基座大模型

这里下载Qwen1.5-0.5B大模型

models文件夹提供了 Qwen、Baichuan 模型的下载脚本(不仅限于这些,模型的支持情况详见 LLaMA-Factory 项目),这里使用 Qwen1.5-0.5B 进行实验。

# -*- coding: utf-8 -*-

import os
from huggingface_hub import snapshot_download

# 设置代理
# os.environ['http_proxy'] = 'http://127.0.0.1:7890'
# os.environ['https_proxy'] = 'http://127.0.0.1:7890'

# 模型仓库的标识
repo_id = "Qwen/Qwen1.5-0.5B"

# 下载模型到指定目录
local_dir = "./Qwen1.5-0.5B"

# 检查目录是否存在,如果不存在则创建
if not os.path.exists(local_dir):
    os.makedirs(local_dir)

snapshot_download(repo_id=repo_id, local_dir=local_dir)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20

# 3.2.2 启动大模型服务

启动Qwen大模型服务

$ cd ./llm
$ CUDA_VISIBLE_DEVICES=0 python3 llm_server.py \
    --model_name_or_path ./models/Qwen1.5-0.5B \
    --template default \
    --infer_backend vllm \
    --vllm_gpu_util 0.9
1
2
3
4
5
6

注:vllm_gpu_util 参数用于控制显存占用比例,默认值为0.9,详见 ./rag-omni/llm/llmtuner/hparams/model_args.py

vllm_gpu_util参数支持

不同vllm_gpu_util参数设置的显存占用对比:

不同vllm_gpu_util参数设置的显存占用

启动之后,使用浏览器打开此地址:http://<your_server_ip>:5000/docs,可以访问到接口文档。

开启vllm的大模型推理服务

# 3.2.3 测试大模型服务

测试Qwen大模型服务,执行 ./rag-omni/llm/test/llm_server_test.py 脚本即可进行测试。

# -*- coding: utf-8 -*-

import json
import requests


def send_post_request(url, payload):
    """
    向指定的URL发送POST请求。
    """
    headers = {
        "accept": "application/json",
        "Content-Type": "application/json"
    }

    updated_payload = {
        "model": "qwen-1.5-0.5b",
        "messages": [
            {
                "role": "user",
                "content": payload["prompt"]
            }
        ],
        "tools": [
            {
                "type": "function",
                "function": {
                    "name": "string",
                    "description": "string",
                    "parameters": {}
                }
            }
        ],
        "temperature": 0,
        "top_p": 0,
        "n": 1,
        "max_tokens": 0,
        "stream": False
    }

    response = requests.post(url, headers=headers, data=json.dumps(updated_payload))
    try:
        response_json = response.json()
        print(response_json)
    except ValueError:
        print("Response could not be decoded as JSON:", response.text)


if __name__ == '__main__':
    api_url = 'http://127.0.0.1:5000/v1/chat/completions'
    payload = {
        "prompt": "解释一下量子计算"
    }
    send_post_request(api_url, payload)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54

它的请求与返回格式都是标准 OpenAI API 形式。

{
    "id": "chatcmpl-987741d7e7f049dd864b7fae170d36e1",
    "object": "chat.completion",
    "created": 1715835635,
    "model": "qwen-1.5-0.5b",
    "choices": [
        {
            "index": 0,
            "message": {
                "role": "assistant",
                "content": "量子计算是一种利用量子力学原理进行计算的计算机技术。与经典计算机使用二进制位(0和1)进行计算不同,量子计算机使用量子比特(qubit)进行计算。这种量子比特可以同时处于0和1的状态,而传统计算机只能在0和1之间进行比较和计算。量子计算的优点包括量子纠错和量子计算速度的提高,这些都使得它在许多领域,如模拟和预测、加密和加密、模拟和模拟、模拟和模拟、模拟和模拟等方面具有潜力。\n量子计算与经典计算机相比有许多显著的优势,例如:\n1.量子纠错:量子计算机可以对输入数据进行更加精确的计算,因为它可以同时计算出多个状态,从而避免了传统计算机中的错误。\n2.量子计算速度的提高:由于量子比特可以同时处于0和1的状态,因此量子计算机的计算速度比经典计算机更快。\n3.量子纠错和量子计算速度的提高:量子计算机可以对输入数据进行更加精确的计算,因为它可以同时计算出多个状态,从而避免了传统计算机中的错误。\n4.量子纠错和量子计算速度的提高:量子计算机可以对输入数据进行更加精确的计算,因为它可以同时计算出多个状态,从而避免了传统计算机中的错误。\n因此,量子计算是一种非常有前途的计算机技术,具有许多优势,例如量子纠错、量子计算速度的提高、量子纠错和量子计算速度的提高等。",
                "tool_calls": null
            },
            "finish_reason": "stop"
        }
    ],
    "usage": {
        "prompt_tokens": 12,
        "completion_tokens": 302,
        "total_tokens": 314
    }
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22

注:执行 ./rag-omni/llm/test/llm_server_stress_test.py 脚本即可进行压力测试。

大模型服务压力测试效果

# 3.3 使用Nginx配置负载均衡

需求情景:一台服务器上有多张显卡,用不同的显卡部署了多个大模型服务,现在想要进一步提高大模型服务的并发量,可以使用Nginx负载均衡来实现。

这里假设启动了3个大模型服务,端口号分别是4997、4998、4999,现在想要将其都配置到5000端口上。修改以下配置文件,换成实际的服务地址,weight=1是权重,这里默认各服务为相同权重。

./rag-omni/llm/nginx_balance/nginx_balance.conf

upstream nginx_balance {
        server 127.0.0.1:4999 weight=1;
        server 127.0.0.1:4998 weight=1;
        server 127.0.0.1:4997 weight=1;
}
server {
    listen       5000;
    server_name  127.0.0.1;
    location ~* ^(/) {
        gzip on;
        gzip_vary on;
	    gzip_min_length 1k;
	    gzip_buffers 16 16k;
        gzip_http_version 1.1;
        gzip_comp_level 9;
        gzip_types text/plain application/javascript application/x-javascript text/css text/xml text/javascript application/json;
        proxy_pass http://nginx_balance;
        client_max_body_size    48m;
        include proxy.conf;
    }
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21

./rag-omni/llm/nginx_balance/build.sh

#!/bin/bash

docker build -t 'nginx_balance_image' .
docker run -itd --name nginx_balance -h nginx_balance -p 5000:5000 nginx_balance_image
docker update nginx_balance --restart=always
1
2
3
4
5

给 build.sh 添加可执行权限,执行该脚本即可部署负载均衡。

# 4. 部署检索服务

# 4.1 源码结构概述

构建检索服务分为两步,先使用知识文件构建索引,再使用索引构建检索服务。

retrieval
├── bge
│   ├── bge-large-zh-v1.5                 // bge模型文件
│   ├── bge_download_model.py             // 下载bge模型文件的脚本    
│   ├── bge_index.py                      // 构建bge索引
│   ├── bge_retrieval.py                  // 使用bge索引进行检索
│   └── index                             // bge索引文件
├── bm25
│   ├── bm25_index.py                     // 构建bm25索引
│   ├── bm25_retrieval.py                 // 使用bm25索引进行检索
│   ├── index                             // bm25索引文件
│   └── stop_words.txt                    // 停用词
├── openai_embedding
│   ├── openai_index.py                   // 构建openai_embedding索引
│   ├── openai_retrieval.py               // 使用openai_embedding索引进行检索
│   └── index                             // openai_embedding索引文件
├── code.py
├── log.py
├── response.py
├── retrieval_index.py                    // 构建索引文件脚本
├── retrieval_server.py                   // 部署检索服务
└── test
    └── retrieval_test.py                 // 测试检索服务
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23

# 4.2 BM25检索算法

BM25算法较为简单,这里就直接实现了。除此之外,BM25为ES默认的相关性排序算法,也可以借助ES去实现。

# 4.2.1 构建BM25索引

支持增量构建BM25索引,因此 main 里的构建索引拆成了两步作为示例。索引路径和索引名可以自行指定,如果不指定索引名则自动使用uuid生成。

./rag-omni/retrieval/bm25/bm25_index.py

# -*- coding: utf-8 -*-

import math
import os
import jieba
import logging
import json
import uuid

jieba.setLogLevel(log_level=logging.INFO)


class BM25Param(object):
    def __init__(self, f, df, idf, length, avg_length, docs_list, line_length_list, k1=1.5, k2=1.0, b=0.75):
        self.f = f
        self.df = df
        self.k1 = k1
        self.k2 = k2
        self.b = b
        self.idf = idf
        self.length = length
        self.avg_length = avg_length
        self.docs_list = docs_list
        self.line_length_list = line_length_list

    def __str__(self):
        return f"k1:{self.k1}, k2:{self.k2}, b:{self.b}"


class BM25Indexer(object):
    def __init__(self, file_paths, old_index_path=None):
        self.file_paths = file_paths
        self.old_index_path = old_index_path
        self._stop_words = self._load_stop_words()

    def _load_stop_words(self):
        current_dir = os.path.dirname(os.path.abspath(__file__))
        stop_words_path = os.path.join(current_dir, 'stop_words.txt')
        if not os.path.exists(stop_words_path):
            raise Exception(f"system stop words: {stop_words_path} not found")
        stop_words = []
        with open(stop_words_path, 'r', encoding='utf8') as reader:
            for line in reader:
                line = line.strip()
                stop_words.append(line)
        return stop_words

    def _load_old_index(self):
        if not self.old_index_path or not os.path.exists(self.old_index_path):
            return None
        with open(self.old_index_path, 'r', encoding='utf8') as f:
            old_index_data = json.load(f)
        return BM25Param(**old_index_data)

    def _merge_indexes(self, old_param, new_param):
        if not old_param:
            return new_param

        combined_length = old_param.length + new_param.length
        combined_avg_length = (
            (old_param.avg_length * old_param.length) + (new_param.avg_length * new_param.length)
        ) / combined_length

        for word, freq in new_param.df.items():
            if word in old_param.df:
                old_param.df[word] += freq
            else:
                old_param.df[word] = freq

        for word, score in new_param.idf.items():
            if word in old_param.idf:
                old_param.idf[word] = (old_param.idf[word] * old_param.length + score * new_param.length) / combined_length
            else:
                old_param.idf[word] = score

        old_param.f.extend(new_param.f)
        old_param.docs_list.extend(new_param.docs_list)
        old_param.line_length_list.extend(new_param.line_length_list)

        old_param.length = combined_length
        old_param.avg_length = combined_avg_length

        return old_param

    def _build_param(self):
        def _cal_param(docs_data):
            f = []
            df = {}
            idf = {}
            length = len(docs_data)
            words_count = 0
            docs_list = []
            line_length_list = []
            for doc in docs_data:
                content = doc.get("part_content", "").strip()
                if not content:
                    continue
                words = [word for word in jieba.lcut(content) if word and word not in self._stop_words]
                line_length_list.append(len(words))
                docs_list.append(doc)
                words_count += len(words)
                tmp_dict = {}
                for word in words:
                    tmp_dict[word] = tmp_dict.get(word, 0) + 1
                f.append(tmp_dict)
                for word in tmp_dict.keys():
                    df[word] = df.get(word, 0) + 1
            for word, num in df.items():
                idf[word] = math.log((length - num + 0.5) / (num + 0.5) + 1)
            param = BM25Param(f, df, idf, length, words_count / length, docs_list, line_length_list)
            return param

        docs_data = []
        for file_path in self.file_paths:
            if not os.path.exists(file_path):
                raise Exception(f"input docs {file_path} not found")
            with open(file_path, 'r', encoding='utf8') as reader:
                docs = json.load(reader)
                for doc in docs:
                    doc["file_path"] = file_path
                docs_data.extend(docs)

        new_param = _cal_param(docs_data)
        old_param = self._load_old_index()

        return self._merge_indexes(old_param, new_param)

    def build_index(self, output_path, index_name=None):
        param = self._build_param()
        if not os.path.exists(output_path):
            os.makedirs(output_path)
        if not index_name:
            index_name = str(uuid.uuid4())
        index_file = os.path.join(output_path, f'{index_name}.json')
        with open(index_file, 'w', encoding='utf8') as f:
            json.dump(param.__dict__, f, ensure_ascii=False, indent=4)
        print(f"Index saved to {index_file}")


if __name__ == '__main__':

    index_name = "bm25_index"      # 定义索引名(如果不指定则自动使用uuid生成)
    output_path = "./index"        # 定义索引的存储路径

    # 用一个文件构建初始索引
    file_paths = [
        "../../data/preprocess_data/国务院关于加强地方政府性债务管理的意见.json"
    ]
    indexer = BM25Indexer(file_paths)
    indexer.build_index(output_path, index_name=index_name)

    # 用另一个文件和旧索引增量构建新索引
    file_paths = [
        "../../data/preprocess_data/中共中央办公厅国务院办公厅印发《关于做好地方政府专项债券发行及项目配套融资工作的通知》.json"
    ]
    old_index_path = "{}/{}.json".format(output_path, index_name)
    indexer = BM25Indexer(file_paths, old_index_path)
    indexer.build_index(output_path, index_name=index_name)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158

# 4.2.2 实现BM25检索

./rag-omni/retrieval/bm25/bm25_retrieval.py

# -*- coding: utf-8 -*-

import os
import jieba
import logging
import json

jieba.setLogLevel(log_level=logging.INFO)


class BM25Param(object):
    def __init__(self, f, df, idf, length, avg_length, docs_list, line_length_list, k1=1.5, k2=1.0, b=0.75):
        self.f = f
        self.df = df
        self.k1 = k1
        self.k2 = k2
        self.b = b
        self.idf = idf
        self.length = length
        self.avg_length = avg_length
        self.docs_list = docs_list
        self.line_length_list = line_length_list

    def __str__(self):
        return f"k1:{self.k1}, k2:{self.k2}, b:{self.b}"


class BM25Retrieval(object):
    def __init__(self, index_path):
        self.index_path = index_path
        self.param: BM25Param = self._load_param()
        self._stop_words = self._load_stop_words()

    def _load_stop_words(self):
        current_dir = os.path.dirname(os.path.abspath(__file__))
        stop_words_path = os.path.join(current_dir, 'stop_words.txt')
        if not os.path.exists(stop_words_path):
            raise Exception(f"system stop words: {stop_words_path} not found")
        stop_words = []
        with open(stop_words_path, 'r', encoding='utf8') as reader:
            for line in reader:
                line = line.strip()
                stop_words.append(line)
        return stop_words

    def _load_param(self):
        if not os.path.exists(self.index_path):
            raise Exception(f"Index file {self.index_path} not found")
        with open(self.index_path, 'r', encoding='utf8') as f:
            data = json.load(f)
            param = BM25Param(**data)
        param.length = len(param.f)
        return param

    def _cal_similarity(self, words, index):
        score = 0
        for word in words:
            if word not in self.param.f[index]:
                continue
            molecular = self.param.idf[word] * self.param.f[index][word] * (self.param.k1 + 1)
            denominator = self.param.f[index][word] + self.param.k1 * (1 - self.param.b +
                                                                       self.param.b * self.param.line_length_list[
                                                                           index] /
                                                                       self.param.avg_length)
            score += molecular / denominator
        return score

    def search(self, query: str, top_k: int = -1):
        if top_k != -1 and top_k <= 0:
            raise ValueError("top_k should be -1 or a positive integer")

        words = [word for word in jieba.lcut(query) if word and word not in self._stop_words]
        score_list = []
        for index in range(len(self.param.f)):
            if index >= len(self.param.f):
                raise IndexError(f"Index {index} is out of range for parameter f")
            score = self._cal_similarity(words, index)
            score_list.append((self.param.docs_list[index], score))

        score_list.sort(key=lambda x: -x[1])
        if top_k != -1:
            score_list = score_list[:top_k]

        result = [
            {
                "file_name": doc["file_name"],
                "part_content": doc["part_content"],
                "score": score
            }
            for doc, score in score_list
        ]
        return result


if __name__ == '__main__':
    index_path = "./index/bm25_index.json"
    bm25 = BM25Retrieval(index_path)
    query_content = "国务院对于地方政府性债务管理的意见"
    top_k = 5  # 可以设置为任意正整数,或者-1表示不限制
    result = bm25.search(query_content, top_k)
    print(json.dumps(result, ensure_ascii=False, indent=4))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

注:代码中会用到 stop_words.txt 文件,已经放到项目里了,这里就不展示了。

# 4.3 BGE检索算法

BGE向量检索需要下载 BAAI/bge-large-zh-v1.5 模型文件,项目里提供了模型下载脚本。

./rag-omni/retrieval/bge/download_bge_model.py

# -*- coding: utf-8 -*-

import os
from transformers import AutoTokenizer, AutoModel
from transformers.utils import logging

# 设置代理
# os.environ['http_proxy'] = 'http://127.0.0.1:7890'
# os.environ['https_proxy'] = 'http://127.0.0.1:7890'

# 配置 transformers 日志
logging.set_verbosity_info()


def download_and_save_model(model_name, save_directory):
    # 下载模型和分词器
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)
    # 保存模型和分词器
    tokenizer.save_pretrained(save_directory)
    model.save_pretrained(save_directory)
    print(f"模型和分词器已保存到 {save_directory}")


if __name__ == '__main__':
    model_name = 'BAAI/bge-large-zh-v1.5'
    save_directory = './bge-large-zh-v1.5'
    download_and_save_model(model_name, save_directory)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

# 4.3.1 构建BGE索引

支持增量构建BGE索引,因此 main 里的构建索引拆成了两步作为示例。索引路径和索引名可以自行指定,如果不指定索引名则自动使用uuid生成。

./rag-omni/retrieval/bge/bge_index.py

# -*- coding: utf-8 -*-

import os
import json
import numpy as np
from tqdm import trange
from transformers import AutoTokenizer, AutoModel
import torch
import uuid


class BGEIndexer:
    def __init__(self, file_paths, old_index_path=None):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        current_dir = os.path.dirname(os.path.abspath(__file__))
        self.model_path = os.path.join(current_dir, 'bge-large-zh-v1.5')
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        self.model = AutoModel.from_pretrained(self.model_path).to(self.device)
        self.old_index_path = old_index_path
        self.data_list = self.load_data(file_paths)
        self.embeddings_list = self.generate_embeddings()
        self.cuda_oom_flag = False

    def load_data(self, file_paths):
        data_list = []
        for file_path in file_paths:
            with open(file_path, "r", encoding="utf-8") as f:
                data = json.load(f)
            data_list.extend(data)
        return data_list

    def generate_embeddings(self):
        embeddings_list = []
        batch_size = 4
        has_switched_to_cpu = False
        for i in trange(0, len(self.data_list), batch_size):
            batch_texts = [item['part_content'] for item in self.data_list[i:i + batch_size]]
            inputs = self.tokenizer(batch_texts, return_tensors='pt', padding=True, truncation=True, max_length=512).to(
                self.device)
            try:
                with torch.no_grad():
                    outputs = self.model(**inputs)
                embeddings = outputs.last_hidden_state.mean(dim=1).to('cpu').numpy()
            except RuntimeError as e:
                if 'CUDA out of memory' in str(e):
                    if not has_switched_to_cpu:
                        print("CUDA out of memory. Switching to CPU for this batch.")
                        has_switched_to_cpu = True
                    torch.cuda.empty_cache()
                    inputs = inputs.to('cpu')
                    self.model.to('cpu')
                    with torch.no_grad():
                        outputs = self.model(**inputs)
                    embeddings = outputs.last_hidden_state.mean(dim=1).numpy()
                    self.model.to(self.device)
                else:
                    raise e
            embeddings_list.extend(embeddings)
            torch.cuda.empty_cache()
        return np.array(embeddings_list)

    def _load_old_index(self):
        if not self.old_index_path or not os.path.exists(self.old_index_path):
            return None, None
        data = np.load(self.old_index_path, allow_pickle=True)
        old_embeddings_list = data['embeddings_list']
        old_data_list_json = data['data_list'].item()
        old_data_list = json.loads(old_data_list_json)
        return old_data_list, old_embeddings_list

    def _merge_indexes(self, old_data_list, old_embeddings_list):
        if old_data_list is None or old_embeddings_list is None:
            return self.data_list, self.embeddings_list
        new_data_list = old_data_list + self.data_list
        new_embeddings_list = np.vstack((old_embeddings_list, self.embeddings_list))
        return new_data_list, new_embeddings_list

    def build_index(self, output_path, index_name=None):
        if not os.path.exists(output_path):
            os.makedirs(output_path)
        if not index_name:
            index_name = str(uuid.uuid4())
        index_file = os.path.join(output_path, f'{index_name}.npz')

        old_data_list, old_embeddings_list = self._load_old_index()
        merged_data_list, merged_embeddings_list = self._merge_indexes(old_data_list, old_embeddings_list)

        data_list_json = json.dumps(merged_data_list, ensure_ascii=False, indent=4)
        np.savez(index_file, embeddings_list=merged_embeddings_list, data_list=data_list_json)
        print(f"Index saved to {index_file}")


if __name__ == '__main__':
    index_name = "bge_index"  # 定义索引名(如果不指定则自动使用uuid生成)
    output_path = "./index"   # 定义索引的存储路径

    # 用一个文件构建初始索引
    file_paths = [
        "../../data/preprocess_data/国务院关于加强地方政府性债务管理的意见.json"
    ]
    indexer = BGEIndexer(file_paths)
    indexer.build_index(output_path, index_name=index_name)

    # 用另一个文件和旧索引增量构建新索引
    file_paths = [
        "../../data/preprocess_data/中共中央办公厅国务院办公厅印发《关于做好地方政府专项债券发行及项目配套融资工作的通知》.json"
    ]
    old_index_path = os.path.join(output_path, f'{index_name}.npz')
    indexer = BGEIndexer(file_paths, old_index_path)
    indexer.build_index(output_path, index_name=index_name)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110

# 4.3.2 实现BGE检索

./rag-omni/retrieval/bge/bge_retrieval.py

# -*- coding: utf-8 -*-

import os
import json
import numpy as np
from transformers import AutoTokenizer, AutoModel
import torch
import faiss


class BGERetrieval:
    def __init__(self, index_file):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        current_dir = os.path.dirname(os.path.abspath(__file__))
        self.model_path = os.path.join(current_dir, 'bge-large-zh-v1.5')
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        self.model = AutoModel.from_pretrained(self.model_path)
        self.data_list, self.embeddings_list = self.load_index(index_file)
        self.faiss_index = self.build_faiss_index()

    def load_index(self, index_file):
        data = np.load(index_file, allow_pickle=True)
        embeddings_list = data['embeddings_list']
        data_list_json = data['data_list'].item()
        data_list = json.loads(data_list_json)
        return data_list, embeddings_list

    def build_faiss_index(self):
        faiss_index = faiss.IndexFlatIP(self.embeddings_list.shape[1])
        faiss_index.add(self.embeddings_list)
        return faiss_index

    def search(self, query, top_k=-1):
        inputs = self.tokenizer(query, return_tensors='pt', padding=True, truncation=True, max_length=512).to(self.device)
        with torch.no_grad():
            outputs = self.model(**inputs)
        query_emb = outputs.last_hidden_state.mean(dim=1).to('cpu').numpy()
        if top_k == -1:
            top_k = len(self.data_list)
        score, rank = self.faiss_index.search(query_emb, top_k)
        rank = rank[0]
        score = score[0]
        results = [
            {
                "file_name": self.data_list[rank[i]]["file_name"],
                "part_content": self.data_list[rank[i]]["part_content"],
                "score": float(score[i])
            }
            for i in range(top_k)
        ]
        return results


if __name__ == '__main__':
    index_file = "./index/bge_index.npz"
    query_text = "国务院对于地方政府性债务管理的意见"
    top_k = -1  # 可以设置为任意正整数,或者-1表示不限制
    retriever = BGERetrieval(index_file)
    results = retriever.search(query_text, top_k)
    print(json.dumps(results, ensure_ascii=False, indent=4))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60

# 4.4 OpenAI Embedding检索算法

# 4.4.1 构建OpenAI Embedding索引

./rag-omni/retrieval/openai_embedding/openai_index.py

# -*- coding: utf-8 -*-

import os
import json
import pickle
import uuid

from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings


class OpenAIIndexer:
    def __init__(self, file_paths, openai_api_base, openai_api_key, embedding_model="text-embedding-3-large", old_index_path=None):
        self.embedding_model = self.create_embedding_model(openai_api_base, openai_api_key, embedding_model)
        self.old_index_path = old_index_path
        self.new_data_list, self.new_content, self.new_metadata = self.load_data(file_paths)

        if self.old_index_path and os.path.exists(self.old_index_path):
            self.old_data_list, self.old_content, self.old_metadata = self.load_existing_data()
            self.data_list = self.old_data_list + self.new_data_list
            self.content = self.old_content + self.new_content
            self.metadata = self.old_metadata + self.new_metadata
        else:
            self.data_list, self.content, self.metadata = self.new_data_list, self.new_content, self.new_metadata

        self.faiss_vectorstore = self.generate_embedding()

    def create_embedding_model(self, openai_api_base, openai_api_key, embedding_model):
        return {
            "openai_api_key": openai_api_key,
            "openai_api_base": openai_api_base,
            "model": embedding_model
        }

    def load_data(self, file_paths):
        """读取数据文件"""
        data_list = []
        content_list = []
        metadata_list = []
        for file_path in file_paths:
            with open(file_path, "r", encoding="utf-8") as f:
                data = json.load(f)
            for item in data:
                item['file_name'] = os.path.basename(file_path)
                data_list.append(item)
                content_list.append(item['part_content'])
                metadata_list.append({'file_name': os.path.basename(file_path)})
        return data_list, content_list, metadata_list

    def generate_embedding(self):
        embedding_model_instance = OpenAIEmbeddings(**self.embedding_model)
        return FAISS.from_texts(self.content, embedding_model_instance, metadatas=self.metadata)

    def get_index_folder(self, output_path, index_name):
        return os.path.join(output_path, index_name)

    def load_existing_data(self):
        index_folder = self.old_index_path
        embedding_path = os.path.join(index_folder, 'embeddings')

        with open(os.path.join(embedding_path, 'index.faiss'), 'rb') as f:
            index = pickle.load(f)

        with open(os.path.join(embedding_path, 'docstore.pkl'), 'rb') as f:
            docstore = pickle.load(f)

        with open(os.path.join(embedding_path, 'index_to_docstore_id.pkl'), 'rb') as f:
            index_to_docstore_id = pickle.load(f)

        embedding_model_instance = OpenAIEmbeddings(**self.embedding_model)
        faiss_vectorstore = FAISS(
            index=index,
            embedding_function=embedding_model_instance,
            docstore=docstore,
            index_to_docstore_id=index_to_docstore_id
        )

        with open(os.path.join(index_folder, 'content.pkl'), 'rb') as file:
            content = pickle.load(file)

        with open(os.path.join(index_folder, 'metadata.pkl'), 'rb') as file:
            metadata = pickle.load(file)

        # 重新构造数据列表
        data_list = [{'part_content': c, 'metadata': m} for c, m in zip(content, metadata)]
        return data_list, content, metadata

    def build_index(self, output_path, index_name=None):
        if not index_name:
            index_name = str(uuid.uuid4())
        index_folder = self.get_index_folder(output_path, index_name)
        os.makedirs(index_folder, exist_ok=True)
        embedding_path = os.path.join(index_folder, 'embeddings')
        os.makedirs(embedding_path, exist_ok=True)

        with open(os.path.join(embedding_path, 'index.faiss'), 'wb') as f:
            pickle.dump(self.faiss_vectorstore.index, f)

        with open(os.path.join(embedding_path, 'docstore.pkl'), 'wb') as f:
            pickle.dump(self.faiss_vectorstore.docstore, f)

        with open(os.path.join(embedding_path, 'index_to_docstore_id.pkl'), 'wb') as f:
            pickle.dump(self.faiss_vectorstore.index_to_docstore_id, f)

        with open(os.path.join(index_folder, 'embedding_model_params.pkl'), 'wb') as file:
            pickle.dump(self.embedding_model, file)

        with open(os.path.join(index_folder, 'content.pkl'), 'wb') as file:
            pickle.dump(self.content, file)

        with open(os.path.join(index_folder, 'metadata.pkl'), 'wb') as file:
            pickle.dump(self.metadata, file)

        print(f"Index saved to {index_folder}")


if __name__ == '__main__':
    index_name = "openai_index"  # 定义索引名(如果不指定则自动使用uuid生成)
    output_path = "./index"  # 定义索引的存储路径
    openai_api_base = "https://api.openai.com/v1"
    openai_api_key = "sk-xxx"

    # 用一个文件构建初始索引
    file_paths = [
        "../../data/preprocess_data/国务院关于加强地方政府性债务管理的意见.json"
    ]
    indexer = OpenAIIndexer(file_paths, openai_api_base, openai_api_key)
    indexer.build_index(output_path, index_name)

    # 用另一个文件和旧索引增量构建新索引
    new_file_paths = [
        "../../data/preprocess_data/中共中央办公厅国务院办公厅印发《关于做好地方政府专项债券发行及项目配套融资工作的通知》.json"
    ]
    old_index_path = os.path.join(output_path, index_name)
    indexer = OpenAIIndexer(new_file_paths, openai_api_base, openai_api_key, old_index_path=old_index_path)
    indexer.build_index(output_path, index_name)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

# 4.4.2 实现OpenAI Embedding检索

./rag-omni/retrieval/openai_embedding/openai_retrieval.py

# -*- coding: utf-8 -*-

import os
import pickle
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
import json


class OpenAIRetrieval:
    def __init__(self, index_folder):
        self.embedding_model_params = self.load_embedding_model_params(index_folder)
        self.embedding_model = OpenAIEmbeddings(**self.embedding_model_params)
        self.faiss_vectorstore, self.content, self.metadata = self.load_data(index_folder)

    def load_embedding_model_params(self, index_folder):
        embedding_model_path = os.path.join(index_folder, 'embedding_model_params.pkl')
        if not os.path.exists(embedding_model_path):
            raise FileNotFoundError(f"Embedding model params file not found: {embedding_model_path}")
        with open(embedding_model_path, 'rb') as file:
            embedding_model_params = pickle.load(file)
        return embedding_model_params

    def load_data(self, index_folder):
        embedding_path = os.path.join(index_folder, 'embeddings')
        index_file = os.path.join(embedding_path, 'index.faiss')
        docstore_file = os.path.join(embedding_path, 'docstore.pkl')
        index_to_docstore_id_file = os.path.join(embedding_path, 'index_to_docstore_id.pkl')

        if not os.path.exists(index_file):
            raise FileNotFoundError(f"FAISS index file not found: {index_file}")
        if not os.path.exists(docstore_file):
            raise FileNotFoundError(f"Docstore file not found: {docstore_file}")
        if not os.path.exists(index_to_docstore_id_file):
            raise FileNotFoundError(f"Index to docstore ID file not found: {index_to_docstore_id_file}")

        with open(index_file, 'rb') as f:
            index = pickle.load(f)
        with open(docstore_file, 'rb') as f:
            docstore = pickle.load(f)
        with open(index_to_docstore_id_file, 'rb') as f:
            index_to_docstore_id = pickle.load(f)

        embedding_model_instance = OpenAIEmbeddings(**self.embedding_model_params)
        faiss_vectorstore = FAISS(
            index=index,
            embedding_function=embedding_model_instance,
            docstore=docstore,
            index_to_docstore_id=index_to_docstore_id
        )

        content_file = os.path.join(index_folder, 'content.pkl')
        metadata_file = os.path.join(index_folder, 'metadata.pkl')

        if not os.path.exists(content_file):
            raise FileNotFoundError(f"Content file not found: {content_file}")
        if not os.path.exists(metadata_file):
            raise FileNotFoundError(f"Metadata file not found: {metadata_file}")

        with open(content_file, 'rb') as file:
            content = pickle.load(file)
        with open(metadata_file, 'rb') as file:
            metadata = pickle.load(file)
        return faiss_vectorstore, content, metadata

    def search(self, query, top_k=5):
        results = self.faiss_vectorstore.similarity_search(query, k=len(self.content))
        # 如果 top_k 为 -1,则返回所有结果
        if top_k == -1:
            top_k = len(results)
        results = results[:top_k]
        search_results = []
        for item in results:
            result = {
                "file_name": item.metadata['file_name'],
                "part_content": item.page_content
            }
            search_results.append(result)
        return search_results


if __name__ == '__main__':
    index_folder = "./index/openai_index"
    query_text = "国务院对于地方政府性债务管理的意见"
    top_k = 5  # 可以设置为任意正整数,或者-1表示不限制
    openai_retriever = OpenAIRetrieval(index_folder)
    results = openai_retriever.search(query_text, top_k=top_k)
    print(json.dumps(results, ensure_ascii=False, indent=4))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

# 4.5 构建索引文件

# 4.5.1 封装索引构建

./rag-omni/retrieval/retrieval_index.py

# -*- coding: utf-8 -*-

import argparse
import logging
from bge.bge_index import BGEIndexer
from bm25.bm25_index import BM25Indexer
from openai_embedding.openai_index import OpenAIIndexer


# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="构建索引的参数")
    parser.add_argument('--file_paths', type=str, nargs='+', required=True, help="JSON知识文件路径(支持一个或多个文件路径)")
    parser.add_argument('--algorithm', type=str, choices=['BM25', 'BGE', 'OpenAI'], required=True, help="索引算法:目前仅支持BM25、BGE、OpenAI")
    parser.add_argument('--api_base_url', type=str, default="https://api.openai.com/v1", help="LLM API Base URL")
    parser.add_argument('--api_key', type=str, help="LLM API Key")
    parser.add_argument('--output_path', type=str, required=True, help="索引存储路径")
    parser.add_argument('--index_name', type=str, required=False, help="索引名(可选,如果不指定则自动使用UUID生成)")
    parser.add_argument('--old_index_path', type=str, required=False, help="旧索引路径(可选,传递旧索引则增量构建)")
    args = parser.parse_args()

    file_paths = args.file_paths
    algorithm = args.algorithm
    api_base_url = args.api_base_url
    api_key = args.api_key
    output_path = args.output_path
    index_name = args.index_name
    old_index_path = args.old_index_path

    try:
        if algorithm == 'BGE':
            logging.info("开始构建BGE索引...")
            indexer = BGEIndexer(file_paths, old_index_path)
            indexer.build_index(output_path, index_name)
            logging.info("BGE索引构建成功")
        elif algorithm == 'BM25':
            logging.info("开始构建BM25索引...")
            indexer = BM25Indexer(file_paths, old_index_path)
            indexer.build_index(output_path, index_name)
            logging.info("BM25索引构建成功")
        elif algorithm == 'OpenAI':
            logging.info("开始构建OpenAI索引...")
            indexer = OpenAIIndexer(file_paths, api_base_url, api_key, old_index_path=old_index_path)
            indexer.build_index(output_path, index_name)
            logging.info("OpenAI索引构建成功")
        else:
            raise ValueError("Unsupported algorithm. Please choose either 'BM25' 'BGE' or 'OpenAI'.")
    except Exception as e:
        logging.error(f"索引构建失败: {e}")
        raise

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53

# 4.5.2 生成索引文件并测试

以下示例命令里为了演示增量构建索引的流程,将构建索引文件分成两步了,实际使用时可以一步进行构建。file_paths 参数传递知识文件,多个使用逗号进行分隔,旧索引路径是可选项,如果传递进去则会增量构建,不传递则使用知识文件从零构建。

// 构建BM25索引
$ python3 ./retrieval/retrieval_index.py --file_paths "./data/preprocess_data/国务院关于加强地方政府性债务管理的意见.json" --algorithm BM25 --output_path "./retrieval/bm25/index" --index_name "bm25_index"
$ python3 ./retrieval/retrieval_index.py --file_paths "./data/preprocess_data/中共中央办公厅国务院办公厅印发《关于做好地方政府专项债券发行及项目配套融资工作的通知》.json" --algorithm BM25 --output_path "./retrieval/bm25/index" --index_name "bm25_index" --old_index_path "./retrieval/bm25/index/bm25_index.json"

// 构建BGE索引
$ python3 ./retrieval/retrieval_index.py --file_paths "./data/preprocess_data/国务院关于加强地方政府性债务管理的意见.json" --algorithm BGE --output_path "./retrieval/bge/index" --index_name "bge_index"
$ python3 ./retrieval/retrieval_index.py --file_paths "./data/preprocess_data/中共中央办公厅国务院办公厅印发《关于做好地方政府专项债券发行及项目配套融资工作的通知》.json" --algorithm BGE --output_path "./retrieval/bge/index" --index_name "bge_index" --old_index_path "./retrieval/bge/index/bge_index.npz"

// 构建OpenAI Embedding索引
$ python3 ./retrieval/retrieval_index.py --file_paths "./data/preprocess_data/国务院关于加强地方政府性债务管理的意见.json" --algorithm OpenAI --output_path "./retrieval/openai_embedding/index" --index_name "openai_index" --api_base_url "https://api.openai.com/v1" --api_key "sk-xxx"
$ python3 ./retrieval/retrieval_index.py --file_paths "./data/preprocess_data/中共中央办公厅国务院办公厅印发《关于做好地方政府专项债券发行及项目配套融资工作的通知》.json" --algorithm OpenAI --output_path "./retrieval/openai_embedding/index" --index_name "openai_index" --old_index_path "./retrieval/openai_embedding/index/openai_index" --api_base_url "https://api.openai.com/v1" --api_key "sk-xxx"
1
2
3
4
5
6
7
8
9
10
11

注:构建完之后,拿对应检索程序里的 main 测试是否能够检索即可。

# 4.6 部署检索服务

# 4.6.1 封装检索服务

这里使用 Flask 框架将检索算法封装成一个服务(log.py、response.py、code.py此处省略)。启动时需要传入知识库文件路径(json_files)、检索算法(algorithm)、服务端口号(port),/api/rag/retrieval 接口入参接受输入问题(question)和检索条数(top_k)。

./rag-omni/retrieval/retrieval_server.py

# -*- coding: utf-8 -*-

import argparse
import json
from flask import Flask, jsonify
from flask_cors import CORS
from pre_request import pre, Rule

from log import logger
from response import ResponseCode, ResponseMessage
from bm25.bm25_retrieval import BM25Retrieval
from bge.bge_retrieval import BGERetrieval
from openai_embedding.openai_retrieval import OpenAIRetrieval

# 解析启动参数
parser = argparse.ArgumentParser(description="启动参数")
parser.add_argument('--index_path', type=str, required=True, help="索引路径")
parser.add_argument('--algorithm', type=str, choices=['BM25', 'BGE', 'OpenAI'], required=True, help="检索算法:目前仅支持BM25、BGE、OpenAI")
parser.add_argument('--port', type=int, default=5001, help="启动的端口号,默认5001")
args = parser.parse_args()

index_path = args.index_path
retrieval_algorithm = args.algorithm
port = args.port

# 创建一个服务
app = Flask(__name__)
CORS(app, supports_credentials=True)

# 初始化检索算法
if retrieval_algorithm == 'BM25':
    search_engine = BM25Retrieval(index_path)
elif retrieval_algorithm == 'BGE':
    search_engine = BGERetrieval(index_path)
elif retrieval_algorithm == 'OpenAI':
    search_engine = OpenAIRetrieval(index_path)
else:
    raise ValueError("Unsupported retrieval algorithm")

"""
# 检索算法服务
"""
@app.route(rule='/api/rag/retrieval', methods=['GET'])
def retrieval():

    # 参数校验
    rule = {
        "question": Rule(type=str, required=True),
        "top_k": Rule(type=int, required=True, gte=-1, custom=lambda x: x == -1 or x > 0)
    }
    try:
        params = pre.parse(rule=rule)
    except Exception as e:
        logger.error(e)
        fail_response = dict(code=ResponseCode.PARAM_FAIL, msg=ResponseMessage.PARAM_FAIL, data=None)
        logger.error(fail_response)
        response = jsonify(fail_response)
        response.data = json.dumps(fail_response, ensure_ascii=False, indent=4)
        return response

    # 获取参数
    question = params.get("question")
    top_k = params.get("top_k")

    # 业务处理模块
    try:
        results = search_engine.search(question, top_k)
    except Exception as e:
        logger.error(e)
        fail_response = dict(code=ResponseCode.BUSINESS_FAIL, msg=ResponseMessage.BUSINESS_FAIL, data=None)
        logger.error(fail_response)
        response = jsonify(fail_response)
        response.data = json.dumps(fail_response, ensure_ascii=False, indent=4)
        return response

    # 成功的结果返回,格式化JSON
    success_response = dict(code=ResponseCode.SUCCESS, msg=ResponseMessage.SUCCESS, data=results)
    logger.info(success_response)
    response = jsonify(success_response)
    response.data = json.dumps(success_response, ensure_ascii=False, indent=4)
    return response


if __name__ == '__main__':
    # 解决中文乱码问题
    app.config['JSON_AS_ASCII'] = False
    # 启动服务,指定主机和端口
    app.run(host='0.0.0.0', port=port, debug=False, threaded=True)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

# 4.6.2 启动检索服务并测试

选择索引文件启动检索服务,以下两种检索服务选择一个进行启动即可。

$ cd ./retrieval
$ python3 retrieval_server.py --index_path "./bm25/index/bm25_index.json" --algorithm BM25 --port 5001  // 启动BM25检索服务
$ python3 retrieval_server.py --index_path "./bge/index/bge_index.npz" --algorithm BGE --port 5001      // 启动BGE检索服务
$ python3 retrieval_server.py --index_path "./openai_embedding/index/openai_index" --algorithm OpenAI --port 5001 // 启动OpenAI Embedding检索服务
1
2
3
4

./rag-omni/retrieval/test/retrieval_test.py 可用来测试检索服务

# -*- coding: utf-8 -*-

import requests


def retrieval_test(url, params):
    r = requests.get(url, params=params)
    print(r.text)


if __name__ == '__main__':
    url = 'http://{0}:{1}/api/rag/retrieval'.format("127.0.0.1", "5001")
    params = {'question': "国务院对于地方政府性债务管理的意见", 'top_k': 3}
    retrieval_test(url, params)
1
2
3
4
5
6
7
8
9
10
11
12
13
14

输出top_k条数据记录,得分降序排列,返回值格式如下:

BM25检索算法的返回值

# 5. 部署RAG服务

# 5.1 RAG服务的实现

# 5.1.1 RAG服务核心逻辑

核心逻辑:用户输入的问题——>问题重构(根据历史对话补全信息得到新的问题)——>文档检索(用重构后的问题从检索库里搜索相关文档)——>给出大模型总结的答案(如果检索出来的文档与问题相关,则使用大模型根据相关文档进行总结;如果检索出来的文档与问题无关,则直接使用大模型进行回复并给出提示)

./rag-omni/rag/rag_solve.py

# -*- coding: utf-8 -*-

import requests
import json
import os
import logging
from time import sleep

# 全局参数
RETRIEVAL_TOP_K = 5
LLM_HISTORY_LEN = 30
UNRELATED_RESPONSE = "很抱歉,检索库内不存在与问题相关的参考材料,以下是大模型直接生成的结果:"

logging.basicConfig(level=logging.INFO)


class LLMService:
    def __init__(self, url, api_key, model):
        self.url = url
        self.headers = {
            "Authorization": f"Bearer {api_key}",
            "Content-Type": "application/json"
        }
        self.model = model

    def __call__(self, messages: list) -> str:
        data = {
            "model": self.model,
            "messages": messages
        }
        response = requests.post(self.url, headers=self.headers, json=data)
        try:
            response.raise_for_status()
            return response.json()["choices"][0]["message"]["content"]
        except requests.exceptions.JSONDecodeError as e:
            logging.error(f"Error decoding JSON: {e}")
            logging.error(f"Response content: {response.text}")
            raise
        except requests.exceptions.RequestException as e:
            logging.error(f"Request error: {e}")
            raise


class History:
    def __init__(self, session_id):
        self.session_id = session_id
        self.history = []


def get_docs(question: str, url: str, top_k=RETRIEVAL_TOP_K, retries=3):
    params = {"question": question, "top_k": top_k}
    for attempt in range(retries):
        try:
            response = requests.get(url, params=params)
            response.raise_for_status()
            try:
                docs_response = response.json()
                docs = [doc["part_content"] for doc in docs_response["data"]]
                return docs
            except requests.exceptions.JSONDecodeError as e:
                logging.error(f"Error decoding JSON: {e}")
                logging.error(f"Response content: {response.text}")
                if attempt < retries - 1:
                    sleep(2 ** attempt)
                else:
                    raise
        except Exception as e:
            logging.error(f"Error in get_docs: {e}")
            if attempt < retries - 1:
                sleep(2 ** attempt)
            else:
                raise


def get_knowledge_based_answer(query, history_obj, url_retrieval, llm):
    global RETRIEVAL_TOP_K, UNRELATED_RESPONSE

    if len(history_obj.history) > LLM_HISTORY_LEN:
        history_obj.history = history_obj.history[-LLM_HISTORY_LEN:]

    # 重构问题
    if len(history_obj.history) > 0:
        rewrite_question_input = history_obj.history.copy()
        rewrite_question_input.append(
            {
                "role": "user",
                "content": f"""请基于对话历史,对后续问题进行补全重构。如果后续问题与历史相关,你必须结合语境将代词替换为相应的指代内容,让它的提问更加明确;否则直接返回原始的后续问题。
                注意:请不要对后续问题做任何回答和解释。

                历史对话:{json.dumps(history_obj.history, ensure_ascii=False)}
                后续问题:{query}

                修改后的后续问题:"""
            }
        )
        new_query = llm(rewrite_question_input).strip()
        if "请不要对后续问题做任何回答和解释" in new_query:
            new_query = query
    else:
        new_query = query

    # 获取相关文档
    docs = get_docs(new_query, url_retrieval, RETRIEVAL_TOP_K)
    doc_string = "\n".join([json.dumps(doc, ensure_ascii=False) for doc in docs])

    # 判断文档与重构后的问题是否相关
    relevance_check_input = [
        {"role": "system", "content": "你是一个帮助判断内容是否相关的助手。"},
        {"role": "user", "content": f"问题:{new_query}\n文档:{doc_string}\n请判断这些文档是否与问题相关,如果相关,请返回'相关',否则返回'无关'。"}
    ]
    relevance_response = llm(relevance_check_input).strip()

    if "无关" in relevance_response:
        # 使用重构的问题调用大模型
        direct_response_input = [{"role": "user", "content": new_query}]
        direct_response = llm(direct_response_input)
        response = f"{UNRELATED_RESPONSE}\n\n{direct_response}"
    else:
        history_obj.history.append(
            {
                "role": "user",
                "content": f"请基于参考,回答问题,并给出参考依据:\n问题:\n{query}\n参考:\n{doc_string}\n答案:"
            }
        )
        response = llm(history_obj.history)
        history_obj.history[-1] = {"role": "user", "content": query}
        history_obj.history.append({"role": "assistant", "content": response})

    # 保存history
    current_dir = os.path.dirname(os.path.abspath(__file__))
    history_dir = os.path.join(current_dir, 'history')
    os.makedirs(history_dir, exist_ok=True)
    history_file_path = os.path.join(history_dir, f'history_{history_obj.session_id}.json')

    if not os.path.exists(history_file_path):
        with open(history_file_path, "w", encoding="utf-8") as file:
            json.dump([], file, ensure_ascii=False, indent=2)

    with open(history_file_path, "r", encoding="utf-8") as file:
        data = json.load(file)
    data.append({"query": query, "new_query": new_query, "docs": docs, "response": response})
    with open(history_file_path, "w", encoding="utf-8") as file:
        json.dump(data, file, ensure_ascii=False, indent=2)

    return {"response": response, "docs": docs}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145

# 5.1.2 封装RAG服务

使用 Flask 框架将RAG封装成一个服务(log.py、response.py、code.py此处省略)。启动时需要传入大模型服务地址(api_url)、大模型服务密钥(api_key)、大模型型号(model)、服务端口号(port)、检索服务地址(retrieval_url),/api/rag/summary 接口入参接受输入问题(content)和检索标识(id)。

./rag-omni/rag/rag_server.py

# -*- coding: utf-8 -*-

import argparse
import json
import time
import os
from flask import Flask, jsonify
from flask_cors import CORS
from pre_request import Rule, pre

from rag_solve import LLMService, History, get_knowledge_based_answer
from response import ResponseCode, ResponseMessage
from log import logger

# 解析启动参数
parser = argparse.ArgumentParser(description="启动参数")
parser.add_argument('--api_url', type=str, default="https://api.openai.com/v1/chat/completions", help="LLM API URL")
parser.add_argument('--api_key', type=str, help="LLM API Key")
parser.add_argument('--model', type=str, help="LLM模型名称")
parser.add_argument('--port', type=int, default=5002, help="启动的端口号,默认5002")
parser.add_argument('--retrieval_url', type=str, default="http://127.0.0.1:5001/api/rag/retrieval", help="检索服务的URL")
args = parser.parse_args()

# 初始化参数
api_url = args.api_url
api_key = args.api_key
model = args.model
port = args.port
retrieval_url = args.retrieval_url

# 初始化LLM服务
llm = LLMService(url=api_url, api_key=api_key, model=model)

# 初始化历史消息
session_histories = {}

# 创建一个服务
app = Flask(__name__)
CORS(app, supports_credentials=True)

"""
# 基于RAG的LLM对话服务
"""
@app.route("/api/rag/summary", methods=["POST"])
def get_bot_response():
    global session_histories, llm

    # 获取请求数据
    rule = {
        "user_prompt": Rule(type=str, required=True),
        "session_id": Rule(type=str, required=True)
    }
    try:
        params = pre.parse(rule=rule)
    except Exception as e:
        logger.error(e)
        fail_response = dict(code=ResponseCode.PARAM_FAIL, msg=ResponseMessage.PARAM_FAIL, data=None)
        logger.error(fail_response)
        response = jsonify(fail_response)
        response.data = json.dumps(fail_response, ensure_ascii=False, indent=4)
        return response

    user_prompt = params["user_prompt"]
    session_id = params["session_id"]

    # 获取对话历史,如果有的话
    if session_id in session_histories:
        history_obj = session_histories[session_id]["history"]
        session_histories[session_id]["last_access_time"] = time.time()
    else:
        history_obj = History(session_id)
        session_histories[session_id] = {
            "history": history_obj,
            "last_access_time": time.time(),
        }

    # 如果用户超过一个小时没有交互,则删除该用户的对话历史
    max_idle_time = 60 * 60
    for sid, session_data in session_histories.copy().items():
        idle_time = time.time() - session_data["last_access_time"]
        if idle_time > max_idle_time:
            del session_histories[sid]

    # 清空对话历史
    if user_prompt == "$清空对话历史":
        history_obj.history = []
        history_file_path = f'./history/history_{session_id}.json'
        if os.path.exists(history_file_path):
            os.remove(history_file_path)
        success_response = dict(code=ResponseCode.SUCCESS, msg=ResponseMessage.SUCCESS, data="已清空对话历史")
        logger.info(success_response)
        response = jsonify(success_response)
        response.data = json.dumps(success_response, ensure_ascii=False, indent=4)
        if session_id in session_histories:
            del session_histories[session_id]
        return response

    # 获取知识库回答
    try:
        answer = get_knowledge_based_answer(
            query=user_prompt, history_obj=history_obj, url_retrieval=retrieval_url, llm=llm
        )
        success_response = dict(code=ResponseCode.SUCCESS, msg=ResponseMessage.SUCCESS, data=answer)
        logger.info(success_response)
        response = jsonify(success_response)
        response.data = json.dumps(success_response, ensure_ascii=False, indent=4)
        return response
    except Exception as e:
        logger.error(e)
        fail_response = dict(code=ResponseCode.BUSINESS_FAIL, msg=ResponseMessage.BUSINESS_FAIL, data=None)
        logger.error(fail_response)
        response = jsonify(fail_response)
        response.data = json.dumps(fail_response, ensure_ascii=False, indent=4)
        return response


if __name__ == '__main__':
    # 解决中文乱码问题
    app.config['JSON_AS_ASCII'] = False
    # 启动服务,指定主机和端口
    app.run(host='0.0.0.0', port=port, debug=False, threaded=True)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121

# 5.2 启动RAG服务并测试

# 5.2.1 两种方式启动RAG服务

方式一:使用外部OpenAI服务启动

$ cd ./rag
$ python3 rag_server.py --api_url "https://xxx.xxx.xxx/v1/chat/completions" --api_key "sk-xxx" --model "gpt-3.5-turbo-1106" --port 5002 --retrieval_url "http://127.0.0.1:5001/api/rag/retrieval"
1
2

方式二:使用本地部署的Qwen服务启动

$ cd ./rag
$ python3 rag_server.py --api_url "http://127.0.0.1:5000/v1/chat/completions" --model "qwen-1.5-0.5b" --port 5002 --retrieval_url "http://127.0.0.1:5001/api/rag/retrieval"
1
2

注:如果是使用本地部署的大模型服务,因为没有权限验证,因此这里就不需要传 api_key 参数了。

# 5.2.2 测试RAG服务

./rag-omni/rag/test/rag_test.py 可用来测试RAG服务

# -*- coding: utf-8 -*-

import requests
import json


def get_summary(url, user_prompt, session_id):
    headers = {
        "Content-Type": "application/json"
    }
    data = {
        "user_prompt": user_prompt,
        "session_id": session_id
    }

    response = requests.post(url, headers=headers, data=json.dumps(data))
    return response.json()


if __name__ == "__main__":
    url = "http://127.0.0.1:5002/api/rag/summary"
    session_id = "session_id_001"

    user_prompt_1 = "简要总结一下国家对于地方政府性债务管理的意见"
    response_1 = get_summary(url, user_prompt_1, session_id)
    print("第一个问题的回复:")
    print(response_1)

    user_prompt_2 = "再详细一些"
    response_2 = get_summary(url, user_prompt_2, session_id)
    print("第二个问题的回复:")
    print(response_2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

输出结果里response为大模型总结的答案,docs是检索出来的相关文档,返回值格式如下:

{
    "code": 200,
    "msg": "请求成功",
    "data": {
        "response": "国家对地方政府性债务管理的意见包括:实行地方政府债务规模限额管理,严格限定地方政府举债程序和资金用途,将地方政府债务纳入全口径预算管理,加强组织领导,妥善处理存量债务和在建项目后续融资,控制和化解地方政府性债务风险,加快建立规范的地方政府举债融资机制。\n\n参考依据:参考文本中列举了国家对地方政府性债务管理的具体措施和要求,包括规模控制和预算管理、加强组织领导、妥善处理存量债务和在建项目后续融资、控制和化解地方政府性债务风险、加快建立规范的地方政府举债融资机制等方面的要求和措施。",
        "docs": [
            "三、对地方政府债务实行规模控制和预算管理\n(一)对地方政府债务实行规模控制。地方政府债务规模实行限额管理,地方政府举债不得突破批准的限额。地方政府一般债务和专项债务规模纳入限额管理,由国务院确定并报全国人大或其常委会批准,分地区限额由财政部在全国人大或其常委会批准的地方政府债务规模内根据各地区债务风险、财力状况等因素测算并报国务院批准。(二)严格限定地方政府举债程序和资金用途。地方政府在国务院批准的分地区限额内举借债务,必须报本级人大或其常委会批准。地方政府不得通过企事业单位等举借债务。地方政府举借债务要遵循市场化原则。建立地方政府信用评级制度,逐步完善地方政府债券市场。地方政府举借的债务,只能用于公益性资本支出和适度归还存量债务,不得用于经常性支出。(三)把地方政府债务分门别类纳入全口径预算管理。地方政府要将一般债务收支纳入一般公共预算管理,将专项债务收支纳入政府性基金预算管理,将政府与社会资本合作项目中的财政补贴等支出按性质纳入相应政府预算管理。地方政府各部门、各单位要将债务收支纳入部门和单位预算管理。或有债务确需地方政府或其部门、单位依法承担偿债责任的,偿债资金要纳入相应预算管理。",
            "七、加强组织领导\n各地区、各部门要高度重视,把思想和行动统一到党中央、国务院决策部署上来。地方政府要切实担负起加强地方政府性债务管理、防范化解财政金融风险的责任,结合实际制定具体方案,政府主要负责人要作为第一责任人,认真抓好政策落实。要建立地方政府性债务协调机制,统筹加强地方政府性债务管理。财政部门作为地方政府性债务归口管理部门,要完善债务管理制度,充实债务管理力量,做好债务规模控制、债券发行、预算管理、统计分析和风险监控等工作;发展改革部门要加强政府投资计划管理和项目审批,从严审批债务风险较高地区的新开工项目;金融监管部门要加强监管、正确引导,制止金融机构等违法违规提供融资;审计部门要依法加强对地方政府性债务的审计监督,促进完善债务管理制度,防范风险,规范管理,提高资金使用效益。各地区、各部门要切实履行职责,加强协调配合,全面做好加强地方政府性债务管理各项工作,确保政策贯彻落实到位。",
            "六、妥善处理存量债务和在建项目后续融资\n(一)抓紧将存量债务纳入预算管理。以2013年政府性债务审计结果为基础,结合审计后债务增减变化情况,经债权人与债务人共同协商确认,对地方政府性债务存量进行甄别。对地方政府及其部门举借的债务,相应纳入一般债务和专项债务。对企事业单位举借的债务,凡属于政府应当偿还的债务,相应纳入一般债务和专项债务。地方政府将甄别后的政府存量债务逐级汇总上报国务院批准后,分类纳入预算管理。纳入预算管理的债务原有债权债务关系不变,偿债资金要按照预算管理要求规范管理。(二)积极降低存量债务利息负担。对甄别后纳入预算管理的地方政府存量债务,各地区可申请发行地方政府债券置换,以降低利息负担,优化期限结构,腾出更多资金用于重点项目建设。(三)妥善偿还存量债务。处置到期存量债务要遵循市场规则,减少行政干预。对项目自身运营收入能够按时还本付息的债务,应继续通过项目收入偿还。对项目自身运营收入不足以还本付息的债务,可以通过依法注入优质资产、加强经营管理、加大改革力度等措施,提高项目盈利能力,增强偿债能力。地方政府应指导和督促有关债务举借单位加强财务管理、拓宽偿债资金渠道、统筹安排偿债资金。对确需地方政府偿还的债务,地方政府要切实履行偿债责任,必要时可以处置政府资产偿还债务。对确需地方政府履行担保或救助责任的债务,地方政府要切实依法履行协议约定,作出妥善安排。有关债务举借单位和连带责任人要按照协议认真落实偿债责任,明确偿债时限,按时还本付息,不得单方面改变原有债权债务关系,不得转嫁偿债责任和逃废债务。对确已形成损失的存量债务,债权人应按照商业化原则承担相应责任和损失。(四)确保在建项目后续融资。地方政府要统筹各类资金,优先保障在建项目续建和收尾。对使用债务资金的在建项目,原贷款银行等要重新进行审核,凡符合国家有关规定的项目,要继续按协议提供贷款,推进项目建设;对在建项目确实没有其他建设资金来源的,应主要通过政府与社会资本合作模式和地方政府债券解决后续融资。",
            "四、控制和化解地方政府性债务风险\n(一)建立地方政府性债务风险预警机制。财政部根据各地区一般债务、专项债务、或有债务等情况,测算债务率、新增债务率、偿债率、逾期债务率等指标,评估各地区债务风险状况,对债务高风险地区进行风险预警。列入风险预警范围的债务高风险地区,要积极采取措施,逐步降低风险。债务风险相对较低的地区,要合理控制债务余额的规模和增长速度。(二)建立债务风险应急处置机制。要硬化预算约束,防范道德风险,地方政府对其举借的债务负有偿还责任,中央政府实行不救助原则。各级政府要制定应急处置预案,建立责任追究机制。地方政府出现偿债困难时,要通过控制项目规模、压缩公用经费、处置存量资产等方式,多渠道筹集资金偿还债务。地方政府难以自行偿还债务时,要及时上报,本级和上级政府要启动债务风险应急处置预案和责任追究机制,切实化解债务风险,并追究相关人员责任。(三)严肃财经纪律。建立对违法违规融资和违规使用政府性债务资金的惩罚机制,加大对地方政府性债务管理的监督检查力度。地方政府及其所属部门不得在预算之外违法违规举借债务,不得以支持公益性事业发展名义举借债务用于经常性支出或楼堂馆所建设,不得挪用债务资金或改变既定资金用途;对企业的注资、财政补贴等行为必须依法合规,不得违法为任何单位和个人的债务以任何方式提供担保;不得违规干预金融机构等正常经营活动,不得强制金融机构等提供政府性融资。地方政府要进一步规范土地出让管理,坚决制止违法违规出让土地及融资行为。",
            "二、加快建立规范的地方政府举债融资机制\n(一)赋予地方政府依法适度举债权限。经国务院批准,省、自治区、直辖市政府可以适度举借债务,市县级政府确需举借债务的由省、自治区、直辖市政府代为举借。明确划清政府与企业界限,政府债务只能通过政府及其部门举借,不得通过企事业单位等举借。(二)建立规范的地方政府举债融资机制。地方政府举债采取政府债券方式。没有收益的公益性事业发展确需政府举借一般债务的,由地方政府发行一般债券融资,主要以一般公共预算收入偿还。有一定收益的公益性事业发展确需政府举借专项债务的,由地方政府通过发行专项债券融资,以对应的政府性基金或专项收入偿还。(三)推广使用政府与社会资本合作模式。鼓励社会资本通过特许经营等方式,参与城市基础设施等有一定收益的公益性事业投资和运营。政府通过特许经营权、合理定价、财政补贴等事先公开的收益约定规则,使投资者有长期稳定收益。投资者按照市场化原则出资,按约定规则独自或与政府共同成立特别目的公司建设和运营合作项目。投资者或特别目的公司可以通过银行贷款、企业债、项目收益债券、资产证券化等市场化方式举债并承担偿债责任。政府对投资者或特别目的公司按约定规则依法承担特许经营权、合理定价、财政补贴等相关责任,不承担投资者或特别目的公司的偿债责任。(四)加强政府或有债务监管。剥离融资平台公司政府融资职能,融资平台公司不得新增政府债务。地方政府新发生或有债务,要严格限定在依法担保的范围内,并根据担保合同依法承担相关责任。地方政府要加强对或有债务的统计分析和风险防控,做好相关监管工作。"
        ]
    }
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14

对应的history文件记录请求历史,里面包含了重构后的问题:

RAG请求历史记录-含问题重构

# 5.3 RAG整体效果评测与调优

可以准备个规范化的测试用例数据集,里面包含若干类评测问题,然后写个脚本调用RAG服务,将结果写入进去,然后对比标准答案评估效果。

{
  "conversations": [
    {
      "question_type": "测试问题类型",
      "rounds": [
        {
          "question": "测试问题用例1",
          "standard_answer": "测试问题用例1的标准答案",
          "rag_answer": "测试问题用例1的RAG回复",
          "rag_answer_result": "准确"
        },
        {
          "question": "测试问题用例1的后续问题",
          "standard_answer": "测试问题用例1的后续问题的标准答案",
          "rag_answer": "测试问题用例1的后续问题的RAG回复",
          "rag_answer_result": "不准确"
        }
      ]
    }
  ]
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21

实际RAG问答的部分优化方向:上下文记录历史连续对话、对于无关问题的处理、输入信息不完备时给出推理或者让用户补充、判断是否使用检索出的数据、检索相近语义数据、开源模型计算和推理的能力差、组合问题需要分解步骤去操作、检索结果涉及多跳问题等。

如果需要实际使用效果的话,可参考借鉴 https://github.com/infiniflow/ragflow (opens new window) 项目,如何搭建及使用详见我的另一篇博客:基于RAG的知识库问答平台使用指南 (opens new window)

# 5.4 将RAG服务接入场景页面

场景页面是使用 Vue 开发的,服务地址已经在 ./rag-omni/chat/src/App.vue 上进行对接了,如果要更换的话,直接在上面修改即可。

$ cd ./chat
$ npm run serve
1
2

使用Chrome浏览器访问 http://127.0.0.1:5003 页面。

将RAG服务接入场景页面

# 6. 参考资料

[1] 检索增强生成 (RAG) from Prompt Engineering Guide (opens new window)

[2] 用检索增强生成让大模型更强大,这里有个手把手的Python实现 from 机器之心 (opens new window)

[3] 检索、提示:检索增强的(Retrieval Augmented)自然语言处理 from 知乎 (opens new window)

[4] RAG: Streamlining the creation of intelligent natural language processing models from Meta (opens new window)

[5] 一文详解检索增强语言模型新范式REPLUG from CSDN (opens new window)

[6] 大模型+检索增强(RAG、Atlas 和 REPLUG)from CSDN (opens new window)

[7] 检索增强生成RAG的技术趋势调查仓库RAG-Survey from Github (opens new window)

[8] 一文读懂RAG的来源、发展和前沿 from 微信公众号 (opens new window)

[9] 多模态RAG综述 from 知乎 (opens new window)

[10] RAG 分块Chunk技术优劣、技巧、方法汇总(五)from 知乎 (opens new window)

[11] 向量模型BGE与M3E from 知乎 (opens new window)

[12] RAG提效利器——BM25检索算法原理和Python实现 from 知乎 (opens new window)

[13] 稠密检索和稀疏检索分别指的是什么(以向量检索为例)from 知乎 (opens new window)

[14] 基于python的BM25文本匹配算法实现 from Github (opens new window)

[15] 非监督文本匹配算法——BM25 from CSDN (opens new window)

[16] LLM+Embedding构建问答系统的局限性及优化方案 from 知乎 (opens new window)

[17] OpenAI分享他们在RAG技术的最佳实践 from 微信公众号 (opens new window)

[18] 大语言模型何时需要检索?UCLA提出全新自监督选择性检索策略 from 微信公众号 (opens new window)

[19] Agentic RAG 与图任务编排 from AIGC开放社区 (opens new window)

[20] 基于 RAPTOR 实现长上下文 RAG from RAGFlow官方文档 (opens new window)

[21] 一文详谈20多种RAG优化方法 from 微信公众号 (opens new window)

Last Updated: 7/9/2024, 11:12:06 AM