一、项目简介

PIKE-RAG(sPecIalized KnowledgE and Rationale Augmented Generation)是微软开发的一款先进的检索增强生成(RAG)框架,旨在提升大型语言模型(LLMs)处理复杂工业任务的能力。该框架专注于专业化知识与推理增强生成,通过模块化架构适配多样化的现实场景需求。

核心特点

  • 模块化框架:包含文档解析、知识提取、存储、检索、组织、推理及任务分解等核心模块,可灵活调整以满足不同场景需求。
  • 行业适应性:在工业制造、采矿、制药等领域经过测试,能显著提高问答准确性。
  • 多级流水线:支持从事实信息检索到基于事实的创新生成等不同复杂度的任务。
  • 稳健评估:采用精确匹配(EM)、F1分数、精确率、召回率及LLM驱动的准确性等指标。

二、环境配置

1. 克隆仓库

git clone https://github.com/microsoft/PIKE-RAG.git
cd PIKE-RAG

2. 配置 Python 环境

  1. 创建并激活一个新的 Python 虚拟环境(可选)
conda create -n PIKE-RAG python=3.12 -y
conda activate PIKE-RAG
  1. 安装依赖
pip install -r requirements.txt
pip install -r examples/requirements.txt
  1. 设置 PYTHONPATH(每次激活环境前)
  • Linux
  • 当前就在仓库根目录
export PYTHONPATH=$PWD

三、VLLM本地部署模型以OpenAI格式提供API服务

开源代码库不支持 openai 的标准接口,只支持通过微软的智能化平台 Azure 部署的 API 接口,其配置和调用的方式都和openai的接口不同,但是接口参数都是相同的。于是可以通过修改代码使其支持 openai 的标准接口,再用VLLM以 openai 兼容的接口格式利用本地模型部署 API 服务。

1. 安装VLLM

  1. 创建并激活一个新的 Python 虚拟环境(可选)。
conda create -n vllm python=3.12 -y
conda activate vllm
  1. 安装 vllm
pip3 install vllm -i  https://pypi.tuna.tsinghua.edu.cn/simple
  1. 启动服务器(completion接口,模型:Qwen2.5-14B-Instruct)
vllm serve /data/LLMS/qwen/Qwen2.5-14B-Instruct/ --port 8888 --gpu-memory-utilization 0.3 --enable-chunked-prefill --max-num-batched-tokens 1024 --tensor-parallel-size 8 --host 0.0.0.0
  • /data/LLMS/qwen/Qwen2.5-14B-Instruct/:指定本地模型权重目录
  • –port 8888:服务监听端口
  • –gpu-memory-utilization 0.3:最多使用 30% 的 GPU 显存
  • –enable-chunked-prefill:开启分块预填充,可减少长输入时的显存峰值
  • –max-num-batched-tokens 1024:单次推理最多同时处理 1024 个 token
  • –tensor-parallel-size 8:使用 8 张 GPU 做张量并行(需机器有 8 张 GPU)
  • –host 0.0.0.0:监听所有网卡地址,允许局域网/公网访问。

注意

  • 若出现 cuda out of memory,可启用--tensor-parallel-size 8
  • 若启用--tensor-parallel-size 8,需要先执行ray start --head初始化ray集群
  1. 启动服务器(Embedding接口,模型:bge-m3)
vllm serve /data/LLMS/bge-m3 \
  --gpu-memory-utilization 0.34 \
  --port 10010 \
  --host 0.0.0.0 \
  --max-model-len 4096 &

成功启动服务器的界面如下

界面

2. 修改客户端代码

假设位于项目根目录,打开文件pikerag/llm_client/azure_open_ai_client.py,修改使其成为一个新的客户端open_ai_client.py

完整代码如下:

import json
import os
import re
import time
from typing import List, Literal, Optional, Union

import openai
from langchain_core.embeddings import Embeddings
from openai.types import CreateEmbeddingResponse
from openai.types.chat.chat_completion import ChatCompletion
from pickledb import PickleDB

from pikerag.llm_client.base import BaseLLMClient
from pikerag.utils.logger import Logger


def parse_wait_time_from_error(error: openai.RateLimitError) -> Optional[int]:
    try:
        info_str: str = error.args[0]
        info_dict_str: str = info_str[info_str.find("{"):]
        error_info: dict = json.loads(re.compile('(?<!\\\\)\'').sub('\"', info_dict_str))
        error_message = error_info["error"]["message"]
        matches = re.search(r"Try again in (\d+) seconds", error_message)
        wait_time = int(matches.group(1)) + 3  # NOTE: wait 3 more seconds here.
        return wait_time
    except Exception as e:
        return None


class OpenAIClient(BaseLLMClient):
    ### 修改
    NAME = "OpenAIClient"

    def __init__(
        self, location: str = None, auto_dump: bool = True, logger: Logger = None,
        max_attempt: int = 5, exponential_backoff_factor: int = None, unit_wait_time: int = 60, **kwargs,
    ) -> None:
        """LLM Communication Client for Azure OpenAI endpoints.

        Args:
            location (str): the file location of the LLM client communication cache. No cache would be created if set to
                None. Defaults to None.
            auto_dump (bool): automatically save the Client's communication cache or not. Defaults to True.
            logger (Logger): client logger. Defaults to None.
            max_attempt (int): Maximum attempt time for LLM requesting. Request would be skipped if max_attempt reached.
                Defaults to 5.
            exponential_backoff_factor (int): Set to enable exponential backoff retry manner. Every time the wait time
                would be `exponential_backoff_factor ^ num_attempt`. Set to None to disable and use the `unit_wait_time`
                manner. Defaults to None.
            unit_wait_time (int): `unit_wait_time` would be used only if the exponential backoff mode is disabled. Every
                time the wait time would be `unit_wait_time * num_attempt`, with seconds (s) as the time unit. Defaults
                to 60.
        """
        super().__init__(location, auto_dump, logger, max_attempt, exponential_backoff_factor, unit_wait_time, **kwargs)

        ### 修改
        from openai import OpenAI
        api_key = os.environ.get("OPENAI_API_KEY", None)
        base_url = os.environ.get("OPENAI_BASE_URL", None)
        self._client = OpenAI(api_key=api_key, base_url=base_url)

    def _get_response_with_messages(self, messages: List[dict], **llm_config) -> ChatCompletion:
        response: ChatCompletion = None
        num_attempt: int = 0
        while num_attempt < self._max_attempt:
            try:
                # TODO: handling the kwargs not passed issue for other Clients
                response = self._client.chat.completions.create(messages=messages, **llm_config)
                break

            except openai.RateLimitError as e:
                self.warning("  Failed due to RateLimitError...")
                # NOTE: mask the line below to keep trying if failed due to RateLimitError.
                # num_attempt += 1
                wait_time = parse_wait_time_from_error(e)
                self._wait(num_attempt, wait_time=wait_time)
                self.warning(f"  Retrying...")

            except openai.BadRequestError as e:
                self.warning(f"  Failed due to Exception: {e}")
                self.warning(f"  Skip this request...")
                break

            except Exception as e:
                self.warning(f"  Failed due to Exception: {e}")
                num_attempt += 1
                self._wait(num_attempt)
                self.warning(f"  Retrying...")

        return response

    def _get_content_from_response(self, response: ChatCompletion, messages: List[dict] = None) -> str:
        try:
            content = response.choices[0].message.content
            if content is None:
                finish_reason = response.choices[0].finish_reason
                warning_message = f"Non-Content returned due to {finish_reason}"

                if "content_filter" in finish_reason:
                    for reason, res_dict in response.choices[0].content_filter_results.items():
                        if res_dict["filtered"] is True or res_dict["severity"] != "safe":
                            warning_message += f", '{reason}': {res_dict}"

                self.warning(warning_message)
                self.debug(f"  -- Complete response: {response}")
                if messages is not None and len(messages) >= 1:
                    self.debug(f"  -- Last message: {messages[-1]}")

                content = ""
        except Exception as e:
            self.warning(f"Try to get content from response but get exception:\n  {e}")
            self.debug(
                f"  Response: {response}\n"
                f"  Last message: {messages}"
            )
            content = ""

        return content

    def close(self):
        super().close()
        self._client.close()

### 修改
class OpenAIEmbedding(Embeddings):
    def __init__(self, **kwargs) -> None:
        client_configs = kwargs.get("client_config", {})

        ### 修改,为base_url指定embedding对应VLLM服务启动时设置的端口号
        from openai import OpenAI
        api_key = os.environ.get("OPENAI_API_KEY", None)
        #base_url = os.environ.get("OPENAI_BASE_URL", None)
        base_url = "http://localhost:10010/v1/"
        self._client = OpenAI(api_key=api_key, base_url=base_url)

        self._model = kwargs.get("model", "/data/LLMS/bge-m3")

        cache_config = kwargs.get("cache_config", {})
        cache_location = cache_config.get("location", None)
        auto_dump = cache_config.get("auto_dump", True)
        if cache_location is not None:
            self._cache: PickleDB = PickleDB(location=cache_location)
        else:
            self._cache = None

    def _save_cache(self, query: str, embedding: List[float]) -> None:
        if self._cache is None:
            return

        self._cache.set(query, embedding)
        return

    def _get_cache(self, query: str) -> Union[List[float], Literal[False]]:
        if self._cache is None:
            return False

        return self._cache.get(query)

    def _get_response(self, texts: Union[str, List[str]]) -> CreateEmbeddingResponse:
        while True:
            try:
                response = self._client.embeddings.create(input=texts, model=self._model)
                break

            except openai.RateLimitError as e:
                expected_wait = parse_wait_time_from_error(e)
                if e is not None:
                    print(f"Embedding failed due to RateLimitError, wait for {expected_wait} seconds")
                    time.sleep(expected_wait)
                else:
                    print(f"Embedding failed due to RateLimitError, but failed parsing expected waiting time, wait for 30 seconds")
                    time.sleep(30)

            except Exception as e:
                print(f"Embedding failed due to exception {e}")
                exit(0)

        return response

    def embed_documents(self, texts: List[str], batch_call: bool=False) -> List[List[float]]:
        # NOTE: call self._get_response(texts) would cause RateLimitError, it may due to large batch size.
        if batch_call is True:
            response = self._get_response(texts)
            embeddings = [res.embedding for res in response.data]
        else:
            embeddings = [self.embed_query(text) for text in texts]
        return embeddings

    def embed_query(self, text: str) -> List[float]:
        embedding =  self._get_cache(text)
        if embedding is False:
            response = self._get_response(text)
            embedding = response.data[0].embedding
            self._save_cache(text, embedding)
        return embedding

3. 新建.env文件

假设位于项目根目录,在文件夹env_configs下新建.env文件

OPENAI_BASE_URL="http://localhost:8888/v1/"
OPENAI_API_KEY="{OPENAI_API_KEY}"

OPENAI_BASE_URL指定completion对应VLLM服务启动时设置的端口号

4. 修改.yml文件

跑例子时,修改对应.yml文件的LLM Setting部分

# LLM Setting
################################################################################
llm_client:
  module_path: pikerag.llm_client
  # available class_name: AzureMetaLlamaClient, AzureOpenAIClient, HFMetaLlamaClient
  class_name: OpenAIClient ### 修改
  args: {}

  llm_config:
    model: /data/LLMS/qwen/Qwen2.5-14B-Instruct/ ### 修改成使用的本地模型的路径
    temperature: 0

  cache_config:
    # location_prefix: will be joined with log_dir to generate the full path;
    #   if set to null, the experiment_name would be used
    location_prefix: null
    auto_dump: True

四、MuSiQue数据集

以MuSiQue数据集的原子分解任务为例:

1. 准备测试数据

假设位于项目根目录下

# 安装预处理所需的库
pip install -r data_process/open_benchmarks/requirements.txt
# 运行脚本下载 MuSiQue、采样子集、转换格式
python data_process/main.py data_process/open_benchmarks/config/musique.yml

脚本完成后,可以在 data/musique/ 目录下找到预处理后的数据。具体来说,采样的数据集是 data/musique/dev_500.jsonl

dev_500.jsonl
注意:这步国内用户需要配置代理
远程服务器通过本地代理加速访问网站的方法可参考链接

2. 运行原子问题标记

假设位于项目根目录下

# 运行脚本从 QA 数据中提取上下文段落
python examples/tagging.py examples/musique/configs/tagging.yml

脚本完成后,可以在 data/musique/ 目录下找到文件 dev_500_retrieval_contexts_as_chunks.jsonl

dev_500_retrieval_contexts_as_chunks.jsonl

3. 原子问题标记

假设位于项目根目录下

# 为 MuSiQue 样本集标记原子问题
python examples/tagging.py examples/musique/configs/tagging.yml

运行完成后,可以在 data/musique/ 目录下找到文件 dev_500_retrieval_contexts_as_chunks_with_atom_questions.jsonl

dev_500_retrieval_contexts_as_chunks_with_atom_questions.jsonl

4. 执行问答任务

假设位于项目根目录下,

# 在 MuSiQue 上运行基于标记的原子问题的检索
python examples/qa.py examples/musique/configs/atomic_decompose.yml

VLLM服务器端界面

completion
embedding

运行完成后,可以在文件 logs/musique/atomic_decompose/atomic_decompose.jsonl 中找到答案数据,其中每行对应一个 QA dict 数据,带有新的 answer(类型:str) 字段和 answer_metadata(类型:dict) 字段。

atomic_decompose.jsonl

5. 评估

修改examples/evaluate.yml里的路径如下

log_dir: logs/musique/atomic_decompose
result_path: logs/musique/atomic_decompose/atomic_decompose.jsonl
output_path: logs/musique/atomic_decompose/atomic_decompose_updated.jsonl

假设位于项目根目录下,

# 执行评估
python examples/evaluate.py examples/evaluate.yml

终端显示评估结果如下

终端截图

更多细节请参考项目官方文档

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐