PIKE-RAG 部署实践
PIKE-RAG是微软开发的检索增强生成框架,专注于专业化知识与推理增强,通过模块化架构提升工业任务处理能力。环境配置包括克隆仓库、Python环境设置和依赖安装。本地部署支持VLLM启动模型API服务(如Qwen2.5-14B-Instruct和bge-m3),需修改客户端代码以兼容OpenAI接口。该框架适用于工业制造、采矿等领域,采用多指标评估系统性能。
一、项目简介
PIKE-RAG(sPecIalized KnowledgE and Rationale Augmented Generation)是微软开发的一款先进的检索增强生成(RAG)框架,旨在提升大型语言模型(LLMs)处理复杂工业任务的能力。该框架专注于专业化知识与推理增强生成,通过模块化架构适配多样化的现实场景需求。
核心特点
- 模块化框架:包含文档解析、知识提取、存储、检索、组织、推理及任务分解等核心模块,可灵活调整以满足不同场景需求。
- 行业适应性:在工业制造、采矿、制药等领域经过测试,能显著提高问答准确性。
- 多级流水线:支持从事实信息检索到基于事实的创新生成等不同复杂度的任务。
- 稳健评估:采用精确匹配(EM)、F1分数、精确率、召回率及LLM驱动的准确性等指标。
二、环境配置
1. 克隆仓库
git clone https://github.com/microsoft/PIKE-RAG.git
cd PIKE-RAG
2. 配置 Python 环境
- 创建并激活一个新的 Python 虚拟环境(可选)
conda create -n PIKE-RAG python=3.12 -y
conda activate PIKE-RAG
- 安装依赖
pip install -r requirements.txt
pip install -r examples/requirements.txt
- 设置 PYTHONPATH(每次激活环境前)
- Linux
- 当前就在仓库根目录
export PYTHONPATH=$PWD
三、VLLM本地部署模型以OpenAI格式提供API服务
开源代码库不支持 openai 的标准接口,只支持通过微软的智能化平台 Azure 部署的 API 接口,其配置和调用的方式都和openai的接口不同,但是接口参数都是相同的。于是可以通过修改代码使其支持 openai 的标准接口,再用VLLM以 openai 兼容的接口格式利用本地模型部署 API 服务。
1. 安装VLLM
- 创建并激活一个新的 Python 虚拟环境(可选)。
conda create -n vllm python=3.12 -y
conda activate vllm
- 安装 vllm
pip3 install vllm -i https://pypi.tuna.tsinghua.edu.cn/simple
- 启动服务器(completion接口,模型:Qwen2.5-14B-Instruct)
vllm serve /data/LLMS/qwen/Qwen2.5-14B-Instruct/ --port 8888 --gpu-memory-utilization 0.3 --enable-chunked-prefill --max-num-batched-tokens 1024 --tensor-parallel-size 8 --host 0.0.0.0
- /data/LLMS/qwen/Qwen2.5-14B-Instruct/:指定本地模型权重目录
- –port 8888:服务监听端口
- –gpu-memory-utilization 0.3:最多使用 30% 的 GPU 显存
- –enable-chunked-prefill:开启分块预填充,可减少长输入时的显存峰值
- –max-num-batched-tokens 1024:单次推理最多同时处理 1024 个 token
- –tensor-parallel-size 8:使用 8 张 GPU 做张量并行(需机器有 8 张 GPU)
- –host 0.0.0.0:监听所有网卡地址,允许局域网/公网访问。
注意
- 若出现
cuda out of memory
,可启用--tensor-parallel-size 8
- 若启用
--tensor-parallel-size 8
,需要先执行ray start --head
初始化ray集群
- 启动服务器(Embedding接口,模型:bge-m3)
vllm serve /data/LLMS/bge-m3 \
--gpu-memory-utilization 0.34 \
--port 10010 \
--host 0.0.0.0 \
--max-model-len 4096 &
成功启动服务器的界面如下
2. 修改客户端代码
假设位于项目根目录,打开文件pikerag/llm_client/azure_open_ai_client.py
,修改使其成为一个新的客户端open_ai_client.py
。
完整代码如下:
import json
import os
import re
import time
from typing import List, Literal, Optional, Union
import openai
from langchain_core.embeddings import Embeddings
from openai.types import CreateEmbeddingResponse
from openai.types.chat.chat_completion import ChatCompletion
from pickledb import PickleDB
from pikerag.llm_client.base import BaseLLMClient
from pikerag.utils.logger import Logger
def parse_wait_time_from_error(error: openai.RateLimitError) -> Optional[int]:
try:
info_str: str = error.args[0]
info_dict_str: str = info_str[info_str.find("{"):]
error_info: dict = json.loads(re.compile('(?<!\\\\)\'').sub('\"', info_dict_str))
error_message = error_info["error"]["message"]
matches = re.search(r"Try again in (\d+) seconds", error_message)
wait_time = int(matches.group(1)) + 3 # NOTE: wait 3 more seconds here.
return wait_time
except Exception as e:
return None
class OpenAIClient(BaseLLMClient):
### 修改
NAME = "OpenAIClient"
def __init__(
self, location: str = None, auto_dump: bool = True, logger: Logger = None,
max_attempt: int = 5, exponential_backoff_factor: int = None, unit_wait_time: int = 60, **kwargs,
) -> None:
"""LLM Communication Client for Azure OpenAI endpoints.
Args:
location (str): the file location of the LLM client communication cache. No cache would be created if set to
None. Defaults to None.
auto_dump (bool): automatically save the Client's communication cache or not. Defaults to True.
logger (Logger): client logger. Defaults to None.
max_attempt (int): Maximum attempt time for LLM requesting. Request would be skipped if max_attempt reached.
Defaults to 5.
exponential_backoff_factor (int): Set to enable exponential backoff retry manner. Every time the wait time
would be `exponential_backoff_factor ^ num_attempt`. Set to None to disable and use the `unit_wait_time`
manner. Defaults to None.
unit_wait_time (int): `unit_wait_time` would be used only if the exponential backoff mode is disabled. Every
time the wait time would be `unit_wait_time * num_attempt`, with seconds (s) as the time unit. Defaults
to 60.
"""
super().__init__(location, auto_dump, logger, max_attempt, exponential_backoff_factor, unit_wait_time, **kwargs)
### 修改
from openai import OpenAI
api_key = os.environ.get("OPENAI_API_KEY", None)
base_url = os.environ.get("OPENAI_BASE_URL", None)
self._client = OpenAI(api_key=api_key, base_url=base_url)
def _get_response_with_messages(self, messages: List[dict], **llm_config) -> ChatCompletion:
response: ChatCompletion = None
num_attempt: int = 0
while num_attempt < self._max_attempt:
try:
# TODO: handling the kwargs not passed issue for other Clients
response = self._client.chat.completions.create(messages=messages, **llm_config)
break
except openai.RateLimitError as e:
self.warning(" Failed due to RateLimitError...")
# NOTE: mask the line below to keep trying if failed due to RateLimitError.
# num_attempt += 1
wait_time = parse_wait_time_from_error(e)
self._wait(num_attempt, wait_time=wait_time)
self.warning(f" Retrying...")
except openai.BadRequestError as e:
self.warning(f" Failed due to Exception: {e}")
self.warning(f" Skip this request...")
break
except Exception as e:
self.warning(f" Failed due to Exception: {e}")
num_attempt += 1
self._wait(num_attempt)
self.warning(f" Retrying...")
return response
def _get_content_from_response(self, response: ChatCompletion, messages: List[dict] = None) -> str:
try:
content = response.choices[0].message.content
if content is None:
finish_reason = response.choices[0].finish_reason
warning_message = f"Non-Content returned due to {finish_reason}"
if "content_filter" in finish_reason:
for reason, res_dict in response.choices[0].content_filter_results.items():
if res_dict["filtered"] is True or res_dict["severity"] != "safe":
warning_message += f", '{reason}': {res_dict}"
self.warning(warning_message)
self.debug(f" -- Complete response: {response}")
if messages is not None and len(messages) >= 1:
self.debug(f" -- Last message: {messages[-1]}")
content = ""
except Exception as e:
self.warning(f"Try to get content from response but get exception:\n {e}")
self.debug(
f" Response: {response}\n"
f" Last message: {messages}"
)
content = ""
return content
def close(self):
super().close()
self._client.close()
### 修改
class OpenAIEmbedding(Embeddings):
def __init__(self, **kwargs) -> None:
client_configs = kwargs.get("client_config", {})
### 修改,为base_url指定embedding对应VLLM服务启动时设置的端口号
from openai import OpenAI
api_key = os.environ.get("OPENAI_API_KEY", None)
#base_url = os.environ.get("OPENAI_BASE_URL", None)
base_url = "http://localhost:10010/v1/"
self._client = OpenAI(api_key=api_key, base_url=base_url)
self._model = kwargs.get("model", "/data/LLMS/bge-m3")
cache_config = kwargs.get("cache_config", {})
cache_location = cache_config.get("location", None)
auto_dump = cache_config.get("auto_dump", True)
if cache_location is not None:
self._cache: PickleDB = PickleDB(location=cache_location)
else:
self._cache = None
def _save_cache(self, query: str, embedding: List[float]) -> None:
if self._cache is None:
return
self._cache.set(query, embedding)
return
def _get_cache(self, query: str) -> Union[List[float], Literal[False]]:
if self._cache is None:
return False
return self._cache.get(query)
def _get_response(self, texts: Union[str, List[str]]) -> CreateEmbeddingResponse:
while True:
try:
response = self._client.embeddings.create(input=texts, model=self._model)
break
except openai.RateLimitError as e:
expected_wait = parse_wait_time_from_error(e)
if e is not None:
print(f"Embedding failed due to RateLimitError, wait for {expected_wait} seconds")
time.sleep(expected_wait)
else:
print(f"Embedding failed due to RateLimitError, but failed parsing expected waiting time, wait for 30 seconds")
time.sleep(30)
except Exception as e:
print(f"Embedding failed due to exception {e}")
exit(0)
return response
def embed_documents(self, texts: List[str], batch_call: bool=False) -> List[List[float]]:
# NOTE: call self._get_response(texts) would cause RateLimitError, it may due to large batch size.
if batch_call is True:
response = self._get_response(texts)
embeddings = [res.embedding for res in response.data]
else:
embeddings = [self.embed_query(text) for text in texts]
return embeddings
def embed_query(self, text: str) -> List[float]:
embedding = self._get_cache(text)
if embedding is False:
response = self._get_response(text)
embedding = response.data[0].embedding
self._save_cache(text, embedding)
return embedding
3. 新建.env文件
假设位于项目根目录,在文件夹env_configs下新建.env文件
OPENAI_BASE_URL="http://localhost:8888/v1/"
OPENAI_API_KEY="{OPENAI_API_KEY}"
为OPENAI_BASE_URL
指定completion对应VLLM服务启动时设置的端口号
4. 修改.yml文件
跑例子时,修改对应.yml文件的LLM Setting
部分
# LLM Setting
################################################################################
llm_client:
module_path: pikerag.llm_client
# available class_name: AzureMetaLlamaClient, AzureOpenAIClient, HFMetaLlamaClient
class_name: OpenAIClient ### 修改
args: {}
llm_config:
model: /data/LLMS/qwen/Qwen2.5-14B-Instruct/ ### 修改成使用的本地模型的路径
temperature: 0
cache_config:
# location_prefix: will be joined with log_dir to generate the full path;
# if set to null, the experiment_name would be used
location_prefix: null
auto_dump: True
四、MuSiQue数据集
以MuSiQue数据集的原子分解任务为例:
1. 准备测试数据
假设位于项目根目录下
# 安装预处理所需的库
pip install -r data_process/open_benchmarks/requirements.txt
# 运行脚本下载 MuSiQue、采样子集、转换格式
python data_process/main.py data_process/open_benchmarks/config/musique.yml
脚本完成后,可以在 data/musique/ 目录下找到预处理后的数据。具体来说,采样的数据集是 data/musique/dev_500.jsonl
。
注意:这步国内用户需要配置代理
远程服务器通过本地代理加速访问网站的方法可参考链接
2. 运行原子问题标记
假设位于项目根目录下
# 运行脚本从 QA 数据中提取上下文段落
python examples/tagging.py examples/musique/configs/tagging.yml
脚本完成后,可以在 data/musique/ 目录下找到文件 dev_500_retrieval_contexts_as_chunks.jsonl
。
3. 原子问题标记
假设位于项目根目录下
# 为 MuSiQue 样本集标记原子问题
python examples/tagging.py examples/musique/configs/tagging.yml
运行完成后,可以在 data/musique/ 目录下找到文件 dev_500_retrieval_contexts_as_chunks_with_atom_questions.jsonl
。
4. 执行问答任务
假设位于项目根目录下,
# 在 MuSiQue 上运行基于标记的原子问题的检索
python examples/qa.py examples/musique/configs/atomic_decompose.yml
VLLM服务器端界面
运行完成后,可以在文件 logs/musique/atomic_decompose/atomic_decompose.jsonl
中找到答案数据,其中每行对应一个 QA dict 数据,带有新的 answer(类型:str) 字段和 answer_metadata(类型:dict) 字段。
5. 评估
修改examples/evaluate.yml
里的路径如下
log_dir: logs/musique/atomic_decompose
result_path: logs/musique/atomic_decompose/atomic_decompose.jsonl
output_path: logs/musique/atomic_decompose/atomic_decompose_updated.jsonl
假设位于项目根目录下,
# 执行评估
python examples/evaluate.py examples/evaluate.yml
终端显示评估结果如下
更多细节请参考项目官方文档
更多推荐
所有评论(0)