Q:我不太理解代码中异步async的用法,请问从学习项目的角度出发,我现在应该忽视这个问题,还是应该停下脚步去学习async

从学习项目的角度来看,我建议你先了解 async/await 的基础知识,因为这是 LightRAG 项目的一个核心设计特性。让我解释一下为什么:

1. 为什么 LightRAG 使用异步?

  1. 性能考虑:
async def upsert(self, data: dict[str, dict]):
    # 批量处理文本向量化
    embedding_tasks = [wrapped_task(batch) for batch in batches]
    # 并发执行所有任务
    embeddings_list = await asyncio.gather(*embedding_tasks)

主要场景:

  • 文档向量化(需要调用 API)
  • 数据库操作(I/O 密集)
  • LLM 调用(网络请求)
  1. 并发处理:
async def _get_edge_data(
    keywords,
    knowledge_graph_inst: BaseGraphStorage,
    relationships_vdb: BaseVectorStorage,
    text_chunks_db: BaseKVStorage[TextChunkSchema],
    query_param: QueryParam,
):
    # 并发获取边的信息
    edge_datas = await asyncio.gather(
        *[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
    )

2. async/await 基础知识

  1. 基本概念:
# 同步代码
def normal_function():
    result = slow_operation()  # 会阻塞
    return result

# 异步代码
async def async_function():
    result = await slow_operation()  # 不会阻塞
    return result
  1. 核心用法:
# 1. 定义异步函数
async def my_function():
    await some_async_operation()

# 2. 运行异步函数
import asyncio

# 方法1:使用事件循环
asyncio.run(my_function())

# 方法2:在其他异步函数中调用
async def another_function():
    await my_function()
  1. 并发执行:
async def process_many():
    # 并发执行多个任务
    tasks = [
        async_operation(item)
        for item in items
    ]
    results = await asyncio.gather(*tasks)

3. 建议学习路径

  1. 第一阶段:基础概念(花1-2天)
# 理解基本用法
async def hello():
    print("开始")
    await asyncio.sleep(1)  # 模拟异步操作
    print("结束")

# 运行
asyncio.run(hello())
  1. 第二阶段:LightRAG 中的应用(2-3天)
# 理解项目中的异步操作
async def process_document():
    # 1. 文档分块
    chunks = chunking_by_token_size(content)
    
    # 2. 并发向量化
    tasks = [
        embedding_func(chunk) 
        for chunk in chunks
    ]
    vectors = await asyncio.gather(*tasks)
  1. 第三阶段:进阶使用(根据需要)
  • 错误处理
  • 超时控制
  • 并发限制

4. 实践建议

  1. 从简单示例开始:
from lightrag import LightRAG

# 基础用法
async def simple_demo():
    rag = LightRAG(...)
    
    # 插入文档
    await rag.ainsert("测试文档")
    
    # 查询
    result = await rag.aquery("问题")
    print(result)

# 运行
asyncio.run(simple_demo())
  1. 理解常见模式:
# 并发处理模式
async def batch_process(items):
    tasks = []
    for item in items:
        task = process_item(item)
        tasks.append(task)
    
    results = await asyncio.gather(*tasks)
    return results
  1. 注意事项:
  • 异步函数必须用 await 调用
  • 异步函数内部可以调用同步函数
  • 同步函数内部不能直接调用异步函数

5. 结论

我的建议是:

  1. 不要完全忽视 async
  • 它是项目的重要特性
  • 影响代码的使用方式
  1. 但也不需要深入太多
  • 先掌握基础用法
  • 了解项目中的主要应用场景
  • 能看懂和使用项目代码即可
  1. 循序渐进
  • 先理解基本概念
  • 在使用过程中逐步深入
  • 遇到问题再专门学习相关知识

建议你花 1-2 天时间学习 async 的基础知识,然后继续研究 LightRAG 项目。这样可以:

  • 不会阻碍项目学习进度
  • 又能理解代码运行机制
  • 为后续深入学习打好基础

你觉得这个建议如何?需要我详细解释某个部分吗?

Q:那么,请你为我提供更多所需要的async的知识让我学习,以为了让我能继续研究LightRAG项目为目的

好的,让我为你介绍在研究 LightRAG 项目所需的关键 async 知识点。

1. 基本概念

  1. 异步函数定义和调用
# 定义异步函数
async def my_function():
    result = await some_async_operation()
    return result

# 调用异步函数
await my_function()  # 在其他异步函数中调用
# 或
asyncio.run(my_function())  # 在同步代码中调用

在 LightRAG 中的应用示例:

    async def aquery(self, query: str, param: QueryParam = QueryParam()):
        if param.mode in ["local", "global", "hybrid"]:
            response = await kg_query(
                query,
                self.chunk_entity_relation_graph,
                self.entities_vdb,
                self.relationships_vdb,
                self.text_chunks,
                param,
                asdict(self),
                hashing_kv=self.llm_response_cache
                if self.llm_response_cache
                and hasattr(self.llm_response_cache, "global_config")
                else self.key_string_value_json_storage_cls(
                    namespace="llm_response_cache",
                    global_config=asdict(self),
                    embedding_func=None,
                ),
            )
        elif param.mode == "naive":
            response = await naive_query(
                query,
                self.chunks_vdb,
                self.text_chunks,
                param,
                asdict(self),
                hashing_kv=self.llm_response_cache
                if self.llm_response_cache
                and hasattr(self.llm_response_cache, "global_config")
                else self.key_string_value_json_storage_cls(
                    namespace="llm_response_cache",
                    global_config=asdict(self),
                    embedding_func=None,
                ),

        elif param.mode == "mix":
            response = await mix_kg_vector_query(
                query,
                self.chunk_entity_relation_graph,
                self.entities_vdb,
                self.relationships_vdb,
                self.chunks_vdb,
                self.text_chunks,
                param,
                asdict(self),
                hashing_kv=self.llm_response_cache
                if self.llm_response_cache
                and hasattr(self.llm_response_cache, "global_config")
                else self.key_string_value_json_storage_cls(
                    namespace="llm_response_cache",
                    global_config=asdict(self),
                    embedding_func=None,
                ),
            )
        else:
            raise ValueError(f"Unknown mode {param.mode}")
        await self._query_done()
        return response

2. 并发执行

  1. 使用 asyncio.gather
async def process_multiple():
    # 并发执行多个异步任务
    results = await asyncio.gather(
        task1(),
        task2(),
        task3()
    )

在 LightRAG 中的应用示例:

    kg_context, vector_context = await asyncio.gather(
        get_kg_context(), get_vector_context()
    )
  1. 任务列表并发
async def process_list():
    tasks = [process_item(item) for item in items]
    results = await asyncio.gather(*tasks)

3. 异步上下文管理

async with async_session() as session:
    await session.execute(query)

4. 异步迭代

async for item in async_iterator:
    await process_item(item)

5. LightRAG 中的关键异步模式

  1. 异步查询处理
async def naive_query(
    query,
    chunks_vdb: BaseVectorStorage,
    text_chunks_db: BaseKVStorage[TextChunkSchema],
    query_param: QueryParam,
    global_config: dict,
    hashing_kv: BaseKVStorage = None,
):
    # Handle cache
    use_model_func = global_config["llm_model_func"]
    args_hash = compute_args_hash(query_param.mode, query)
    cached_response, quantized, min_val, max_val = await handle_cache(
        hashing_kv, args_hash, query, query_param.mode
    )
    if cached_response is not None:
        return cached_response

    results = await chunks_vdb.query(query, top_k=query_param.top_k)
    if not len(results):
        return PROMPTS["fail_response"]

    chunks_ids = [r["id"] for r in results]
    chunks = await text_chunks_db.get_by_ids(chunks_ids)

    # Filter out invalid chunks
    valid_chunks = [
        chunk for chunk in chunks if chunk is not None and "content" in chunk
    ]

    if not valid_chunks:
        logger.warning("No valid chunks found after filtering")
        return PROMPTS["fail_response"]

    maybe_trun_chunks = truncate_list_by_token_size(
        valid_chunks,
        key=lambda x: x["content"],
        max_token_size=query_param.max_token_for_text_unit,
    )

    if not maybe_trun_chunks:
        logger.warning("No chunks left after truncation")
        return PROMPTS["fail_response"]

    logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks")
    section = "\n--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])

    if query_param.only_need_context:
        return section

    sys_prompt_temp = PROMPTS["naive_rag_response"]
    sys_prompt = sys_prompt_temp.format(
        content_data=section, response_type=query_param.response_type
    )

    if query_param.only_need_prompt:
        return sys_prompt

    response = await use_model_func(
        query,
        system_prompt=sys_prompt,
    )
  1. 异步向量存储
import asyncio
import os
from dataclasses import dataclass
from typing import Union

import numpy as np
from sqlalchemy import create_engine, text
from tqdm import tqdm

from lightrag.base import BaseVectorStorage, BaseKVStorage, BaseGraphStorage
from lightrag.utils import logger


class TiDB(object):
    def __init__(self, config, **kwargs):
        self.host = config.get("host", None)
        self.port = config.get("port", None)
        self.user = config.get("user", None)
        self.password = config.get("password", None)
        self.database = config.get("database", None)
        self.workspace = config.get("workspace", None)
        connection_string = (
            f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
            f"?ssl_verify_cert=true&ssl_verify_identity=true"
        )

        try:
            self.engine = create_engine(connection_string)
            logger.info(f"Connected to TiDB database at {self.database}")
        except Exception as e:
            logger.error(f"Failed to connect to TiDB database at {self.database}")
            logger.error(f"TiDB database error: {e}")
            raise

    async def check_tables(self):
        for k, v in TABLES.items():
            try:
                await self.query(f"SELECT 1 FROM {k}".format(k=k))
            except Exception as e:
                logger.error(f"Failed to check table {k} in TiDB database")
                logger.error(f"TiDB database error: {e}")
                try:
                    # print(v["ddl"])
                    await self.execute(v["ddl"])
                    logger.info(f"Created table {k} in TiDB database")
                except Exception as e:
                    logger.error(f"Failed to create table {k} in TiDB database")


    async def query(
        self, sql: str, params: dict = None, multirows: bool = False
    ) -> Union[dict, None]:
        if params is None:
            params = {"workspace": self.workspace}
        else:
            params.update({"workspace": self.workspace})
        with self.engine.connect() as conn, conn.begin():
            try:
                result = conn.execute(text(sql), params)
            except Exception as e:
                logger.error(f"Tidb database error: {e}")
                print(sql)
                print(params)
                raise
            if multirows:
                rows = result.all()
                if rows:
                    data = [dict(zip(result.keys(), row)) for row in rows]
                else:
                    data = []
            else:
                row = result.first()
                if row:
                    data = dict(zip(result.keys(), row))
                else:
                    data = None
            return data

    async def execute(self, sql: str, data: list | dict = None):
        # logger.info("go into TiDBDB execute method")
        try:
            with self.engine.connect() as conn, conn.begin():
                if data is None:
                    conn.execute(text(sql))
                else:
                    conn.execute(text(sql), parameters=data)
        except Exception as e:
            logger.error(f"TiDB database error: {e}")
            print(sql)

6. 实践练习

让我们通过一个简单的例子来学习:

import asyncio
from lightrag import LightRAG, QueryParam

async def process_document():
    # 1. 初始化
    rag = LightRAG(
        working_dir="./test_rag",
        chunk_token_size=512
    )
    
    # 2. 插入文档
    with open("test.txt", "r") as f:
        await rag.ainsert(f.read())
    
    # 3. 并发查询不同模式
    modes = ["naive", "local", "global", "hybrid"]
    queries = [
        rag.aquery("问题1", param=QueryParam(mode=mode))
        for mode in modes
    ]
    
    # 4. 等待所有查询完成
    results = await asyncio.gather(*queries)
    
    return results

# 运行
asyncio.run(process_document())

7. 异步错误处理

async def safe_operation():
    try:
        await risky_operation()
    except Exception as e:
        logger.error(f"操作失败: {e}")
        # 错误恢复逻辑
    finally:
        await cleanup()

8. 性能优化

  1. 限制并发数
def limit_async_func_call(max_size: int):
    def wrapper(func):
        semaphore = asyncio.Semaphore(max_size)
        
        async def wrapped(*args, **kwargs):
            async with semaphore:
                return await func(*args, **kwargs)
        
        return wrapped
    return wrapper
  1. 批量处理
async def batch_process(items, batch_size=100):
    for i in range(0, len(items), batch_size):
        batch = items[i:i + batch_size]
        await process_batch(batch)

9. 建议学习路径

  1. 第一步:掌握基础语法
  • async/await 的使用
  • asyncio.gather 的并发处理
  1. 第二步:理解 LightRAG 的异步流程
  • 文档处理流程
  • 查询处理流程
  • 向量存储操作
  1. 第三步:学习高级特性
  • 错误处理
  • 性能优化
  • 并发控制

10. 注意事项

  1. 异步函数调用规则:
  • 异步函数内可以调用同步函数
  • 同步函数内不能直接调用异步函数
  • 异步函数必须用 await 调用
  1. 常见陷阱:
  • 忘记 await
  • 阻塞操作影响性能
  • 并发数过高导致资源耗尽

掌握这些知识点后,你就能更好地理解和使用 LightRAG 项目了。需要我详细解释某个部分吗?

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐