LLM上下文窗口有限,不能容纳所有信息,所以有必要对LLM处理信息进行分块。

这里以Semantic Chunks的StatisticalChunker为例,探索大模型语义分块策略,参考链接如下。

https://github.com/aurelio-labs/semantic-chunkers/blob/main/semantic_chunkers/chunkers/statistical.py

这个语义分块过程不是一步完成,而是先使用标点符号、长度等基本信息进行初步语义分块。

然后,基于分块之间相似度,对语义接近的多个相邻分块进行融合,确保分块语义的完整性。

1 初步分块

1.1 RegexSplitter

StatisticChunker默认采用RegexSplitter,即通过多种正则规则进行分块,也是最常用分块方案。

正则分块规则和示例代码如下所示。

from typing import List, Union

import regex

from semantic_chunkers.splitters.base import BaseSplitter


class RegexSplitter(BaseSplitter):
    """
    Enhanced regex pattern to split a given text into sentences more accurately.
    """

    regex_pattern = r"""
        # Negative lookbehind for word boundary, word char, dot, word char
        (?<!\b\w\.\w.)
        # Negative lookbehind for single uppercase initials like "A."
        (?<!\b[A-Z][a-z]\.)
        # Negative lookbehind for abbreviations like "U.S."
        (?<!\b[A-Z]\.)
        # Negative lookbehind for abbreviations with uppercase letters and dots
        (?<!\b\p{Lu}\.\p{Lu}.)
        # Negative lookbehind for numbers, to avoid splitting decimals
        (?<!\b\p{N}\.)
        # Positive lookbehind for punctuation followed by whitespace
        (?<=\.|\?|!|:|\.\.\.)\s+
        # Positive lookahead for uppercase letter or opening quote at word boundary
        (?="?(?=[A-Z])|"\b)
        # OR
        |
        # Splits after punctuation that follows closing punctuation, followed by
        # whitespace
        (?<=[\"\'\]\)\}][\.!?])\s+(?=[\"\'\(A-Z])
        # OR
        |
        # Splits after punctuation if not preceded by a period
        (?<=[^\.][\.!?])\s+(?=[A-Z])
        # OR
        |
        # Handles splitting after ellipses
        (?<=\.\.\.)\s+(?=[A-Z])
        # OR
        |
        # Matches and removes control characters and format characters
        [\p{Cc}\p{Cf}]+
        # OR
        |
        # Splits after punctuation marks followed by another punctuation mark
        (?<=[\.!?])(?=[\.!?])
        # OR
        |
        # Splits after exclamation or question marks followed by whitespace or end of string
        (?<=[!?])(?=\s|$)
    """

    def __call__(
        self, doc: str, delimiters: List[Union[str, regex.Pattern]] = []
    ) -> List[str]:
        if not delimiters:
            compiled_pattern = regex.compile(self.regex_pattern)
            delimiters.append(compiled_pattern)
        sentences = [doc]
        for delimiter in delimiters:
            sentences_for_next_delimiter = []
            for sentence in sentences:
                if isinstance(delimiter, regex.Pattern):
                    sub_sentences = regex.split(
                        self.regex_pattern, doc, flags=regex.VERBOSE
                    )
                    split_char = ""  # No single character to append for regex pattern
                else:
                    sub_sentences = sentence.split(delimiter)
                    split_char = delimiter
                for i, sub_sentence in enumerate(sub_sentences):
                    if i < len(sub_sentences) - 1:
                        sub_sentence += split_char  # Append delimiter to sub_sentence
                    if sub_sentence.strip():
                        sentences_for_next_delimiter.append(sub_sentence.strip())
            sentences = sentences_for_next_delimiter
        return sentences

https://github.com/aurelio-labs/semantic-chunkers/blob/main/semantic_chunkers/splitters/regex.py

1.2 其他初步分块

Semantic Chunker只提供了RegexSpliiter方案。

当然除正则分块外,也可以采用其他初步分块方案。

可以参考RegexSplitter,根据业务规则,实现其他定制化的初步分块方案。

2 分块融合

初步分块后,然后就是分块融合,这是StatisticalChunker模块的核心。

StatisticalChunker采用批处理的方式,每次性尝试融合一个batch内的多个分块,依次迭代。

2.1 相似度计算

因为,这里相似度用于判断分块i是否能与其之前的多个分块融合。

所以,相似度需要表示为分块i向量v_i,与其之前多个分块的平均向量v_prev的相似程度。

即,对于batch内的分块i,

首先,计算0, 1, i-1分块向量的平均向量v_prev

然后,计算分块i向量v_i与平均向量v_prev的相似度s

基于分块i的相似度计算结果,合理的分块融合逻辑应该设计如下。

如果分块i向量v_i与平均向量v_prev足够相似,表明分块i应该融合到之前分块中

如果分块i向量v_i与平均向量v_prev相似度不够,表明分块i不应该融合到之前分块中

以下是针对batch内即encoded_docs多个分块文档相似度计算的代码示例。

    def _calculate_similarity_scores(self, encoded_docs: np.ndarray) -> List[float]:
        raw_similarities = []
        for idx in range(1, len(encoded_docs)):
            window_start = max(0, idx - self.window_size)
            cumulative_context = np.mean(encoded_docs[window_start:idx], axis=0)
            curr_sim_score = np.dot(cumulative_context, encoded_docs[idx]) / (
                np.linalg.norm(cumulative_context) * np.linalg.norm(encoded_docs[idx])
                + 1e-10
            )
            raw_similarities.append(curr_sim_score)
        return raw_similarities

2.2 相似度阀值确定

1)阀值上下限

在获取batch内分块的相似度后,需要一个合理的相似度阀值,判断分块i是否能与之前分块融合。

参考相似度计算,已知分块i与batch内i之前分块平均相似度。

结合统计学知识,可以采用如下方式计算batch内相似度值的上下限,阀值应该落在上下限内。

# Analyze the distribution of similarity scores to set initial bounds
median_score = np.median(similarity_scores)
std_dev = np.std(similarity_scores)

# Set initial bounds based on median and standard deviation
low = max(0.0, float(median_score - std_dev))
high = min(1.0, float(median_score + std_dev))

2)模拟重新分块

如何确定阀值是否合适,这里模拟基于该阀值去重新划分,并判重新划分合后分块是否满足要求。

重新划分逻辑如下

对于每个分块

score_i表示分块i与之前分块的平均相似度

如果score_i小于阀值,表示分块i不能融合到之前分块,需要从分块i开始重新构建新分块。

示例代码如下,构建新分块的方式为记录分块下标。

    def _find_split_indices(
        self, similarities: List[float], calculated_threshold: float
    ) -> List[int]:
        split_indices = []
        for idx, score in enumerate(similarities):
            logger.debug(f"Similarity score at index {idx}: {score}")
            if score < calculated_threshold:
                logger.debug(
                    f"Adding to split_indices due to score < threshold: "
                    f"{score} < {calculated_threshold}"
                )
                # Chunk after the document at idx
                split_indices.append(idx + 1)
        return split_indices

3)确定相似度阀值

现在已知相似度阀值的上下限,以及对于每个阀值[low, high],如何确定分块重新划分的方式。

搜索合适阀值的方式就是二分查找。

对于每对low和high,假设阀值为calculated_threshold = (low + high) / 2

然后基于calculated_threshold去计算划分,并计算基于这个划分的分块tokens的中位数。

示例代码如下

            # Calculate the token counts for each split using the cumulative sums
            split_token_counts = [
                cumulative_token_counts[end] - cumulative_token_counts[start]
                for start, end in zip(
                    [0] + split_indices, split_indices + [len(token_counts)]
                )
            ]

            # Calculate the median token count for the chunks
            median_tokens = np.median(split_token_counts)

如果中位数载分块允许区间内

[min_tokens - tokens_tolerance,    max_tokens + tokens_tolerance ]

说明划分成功,退出二分查找。

否则依据中位数值重新调整相似度的上下限值high和low,进一步查找合适的相似度阀值。

示例代码如下所示

    def _find_optimal_threshold(self, docs: List[str], similarity_scores: List[float]):
        token_counts = [tiktoken_length(doc) for doc in docs]
        cumulative_token_counts = np.cumsum([0] + token_counts)

        # Analyze the distribution of similarity scores to set initial bounds
        median_score = np.median(similarity_scores)
        std_dev = np.std(similarity_scores)

        # Set initial bounds based on median and standard deviation
        low = max(0.0, float(median_score - std_dev))
        high = min(1.0, float(median_score + std_dev))

        iteration = 0
        median_tokens = 0
        calculated_threshold = 0.0
        while low <= high:
            calculated_threshold = (low + high) / 2
            split_indices = self._find_split_indices(
                similarity_scores, calculated_threshold
            )
            logger.debug(
                f"Iteration {iteration}: Trying threshold: {calculated_threshold}"
            )

            # Calculate the token counts for each split using the cumulative sums
            split_token_counts = [
                cumulative_token_counts[end] - cumulative_token_counts[start]
                for start, end in zip(
                    [0] + split_indices, split_indices + [len(token_counts)]
                )
            ]

            # Calculate the median token count for the chunks
            median_tokens = np.median(split_token_counts)
            logger.debug(
                f"Iteration {iteration}: Median tokens per split: {median_tokens}"
            )
            if (
                self.min_split_tokens - self.split_tokens_tolerance
                <= median_tokens
                <= self.max_split_tokens + self.split_tokens_tolerance
            ):
                logger.debug("Median tokens in target range. Stopping iteration.")
                break
            elif median_tokens < self.min_split_tokens:
                high = calculated_threshold - self.threshold_adjustment
                logger.debug(f"Iteration {iteration}: Adjusting high to {high}")
            else:
                low = calculated_threshold + self.threshold_adjustment
                logger.debug(f"Iteration {iteration}: Adjusting low to {low}")
            iteration += 1

        logger.debug(
            f"Optimal threshold {calculated_threshold} found "
            f"with median tokens ({median_tokens}) in target range "
            f"({self.min_split_tokens}-{self.max_split_tokens})."
        )

        return calculated_threshold

2.3 重新进行分块

在确定好合适的相似度阀值后,就是基于该相似度阀值进行实际分块。

对每批batch分块后,都会将该批最后的分块即last_chunk拿出,尝试与新batch分块进行融合。

如此尽量避免batch边界分块的语义割裂问题。

示例代码如下所示。

    @time_it
    def _chunk(
        self, splits: List[Any], batch_size: int = 64, enforce_max_tokens: bool = False
    ) -> List[Chunk]:
        """Merge splits into chunks using semantic similarity, with optional enforcement
        of maximum token limits per chunk.

        :param splits: Splits to be merged into chunks.
        :param batch_size: Number of splits to process in one batch.
        :param enforce_max_tokens: If True, further split chunks that exceed the maximum
        token limit.

        :return: List of chunks.
        """
        # Split the docs that already exceed max_split_tokens to smaller chunks
        if enforce_max_tokens:
            new_splits = []
            for split in splits:
                token_count = tiktoken_length(split)
                if token_count > self.max_split_tokens:
                    logger.info(
                        f"Single document exceeds the maximum token limit "
                        f"of {self.max_split_tokens}. "
                        "Splitting to sentences before semantically merging."
                    )
                    _splits = self._split(split)
                    new_splits.extend(_splits)
                else:
                    new_splits.append(split)

            splits = [split for split in new_splits if split and split.strip()]

        chunks = []
        last_chunk: Optional[Chunk] = None
        for i in tqdm(range(0, len(splits), batch_size)):
            batch_splits = splits[i : i + batch_size]
            if last_chunk is not None:
                batch_splits = last_chunk.splits + batch_splits

            encoded_splits = self._encode_documents(batch_splits)
            similarities = self._calculate_similarity_scores(encoded_splits)

            if self.dynamic_threshold:
                calculated_threshold = self._find_optimal_threshold(
                    batch_splits, similarities
                )
            else:
                calculated_threshold = (
                    self.encoder.score_threshold
                    if self.encoder.score_threshold
                    else self.DEFAULT_THRESHOLD
                )
            split_indices = self._find_split_indices(
                similarities=similarities, calculated_threshold=calculated_threshold
            )

            doc_chunks = self._split_documents(
                docs=batch_splits,
                split_indices=split_indices,
                similarities=similarities,
            )

            if len(doc_chunks) > 1:
                chunks.extend(doc_chunks[:-1])
                last_chunk = doc_chunks[-1]
            else:
                last_chunk = doc_chunks[0]

            if self.plot_chunks:
                self.plot_similarity_scores(
                    similarities=similarities,
                    split_indices=split_indices,
                    chunks=doc_chunks,
                    calculated_threshold=calculated_threshold,
                )

            if self.enable_statistics:
                print(self.statistics)

        if last_chunk:
            chunks.append(last_chunk)

        return chunks

3 完整代码

以下是StatisticalChunker完整的分块代码示例,调用过程参考如下链接中的应用示例。

https://blog.csdn.net/liliang199/article/details/156982300

分块模块完整代码如下所示

import asyncio
from dataclasses import dataclass
from typing import Any, List, Optional

import numpy as np
from semantic_router.encoders.base import DenseEncoder
from tqdm.auto import tqdm

from semantic_chunkers.chunkers.base import BaseChunker
from semantic_chunkers.schema import Chunk
from semantic_chunkers.splitters.base import BaseSplitter
from semantic_chunkers.splitters.regex import RegexSplitter
from semantic_chunkers.utils.logger import logger
from semantic_chunkers.utils.text import (
    async_retry_with_timeout,
    tiktoken_length,
    time_it,
)


@dataclass
class ChunkStatistics:
    total_documents: int
    total_chunks: int
    chunks_by_threshold: int
    chunks_by_max_chunk_size: int
    chunks_by_last_split: int
    min_token_size: int
    max_token_size: int
    chunks_by_similarity_ratio: float

    def __str__(self):
        return (
            f"Chunking Statistics:\n"
            f"  - Total Documents: {self.total_documents}\n"
            f"  - Total Chunks: {self.total_chunks}\n"
            f"  - Chunks by Threshold: {self.chunks_by_threshold}\n"
            f"  - Chunks by Max Chunk Size: {self.chunks_by_max_chunk_size}\n"
            f"  - Last Chunk: {self.chunks_by_last_split}\n"
            f"  - Minimum Token Size of Chunk: {self.min_token_size}\n"
            f"  - Maximum Token Size of Chunk: {self.max_token_size}\n"
            f"  - Similarity Chunk Ratio: {self.chunks_by_similarity_ratio:.2f}"
        )


class StatisticalChunker(BaseChunker):
    encoder: DenseEncoder

    def __init__(
        self,
        encoder: DenseEncoder,
        splitter: BaseSplitter = RegexSplitter(),
        name="statistical_chunker",
        threshold_adjustment=0.01,
        dynamic_threshold: bool = True,
        window_size=5,
        min_split_tokens=100,
        max_split_tokens=300,
        split_tokens_tolerance=10,
        plot_chunks=False,
        enable_statistics=False,
    ):
        super().__init__(name=name, encoder=encoder, splitter=splitter)
        self.encoder = encoder
        self.threshold_adjustment = threshold_adjustment
        self.dynamic_threshold = dynamic_threshold
        self.window_size = window_size
        self.plot_chunks = plot_chunks
        self.min_split_tokens = min_split_tokens
        self.max_split_tokens = max_split_tokens
        self.split_tokens_tolerance = split_tokens_tolerance
        self.enable_statistics = enable_statistics
        self.statistics: ChunkStatistics
        self.DEFAULT_THRESHOLD = 0.5

    @time_it
    def _chunk(
        self, splits: List[Any], batch_size: int = 64, enforce_max_tokens: bool = False
    ) -> List[Chunk]:
        """Merge splits into chunks using semantic similarity, with optional enforcement
        of maximum token limits per chunk.

        :param splits: Splits to be merged into chunks.
        :param batch_size: Number of splits to process in one batch.
        :param enforce_max_tokens: If True, further split chunks that exceed the maximum
        token limit.

        :return: List of chunks.
        """
        # Split the docs that already exceed max_split_tokens to smaller chunks
        if enforce_max_tokens:
            new_splits = []
            for split in splits:
                token_count = tiktoken_length(split)
                if token_count > self.max_split_tokens:
                    logger.info(
                        f"Single document exceeds the maximum token limit "
                        f"of {self.max_split_tokens}. "
                        "Splitting to sentences before semantically merging."
                    )
                    _splits = self._split(split)
                    new_splits.extend(_splits)
                else:
                    new_splits.append(split)

            splits = [split for split in new_splits if split and split.strip()]

        chunks = []
        last_chunk: Optional[Chunk] = None
        for i in tqdm(range(0, len(splits), batch_size)):
            batch_splits = splits[i : i + batch_size]
            if last_chunk is not None:
                batch_splits = last_chunk.splits + batch_splits

            encoded_splits = self._encode_documents(batch_splits)
            similarities = self._calculate_similarity_scores(encoded_splits)

            if self.dynamic_threshold:
                calculated_threshold = self._find_optimal_threshold(
                    batch_splits, similarities
                )
            else:
                calculated_threshold = (
                    self.encoder.score_threshold
                    if self.encoder.score_threshold
                    else self.DEFAULT_THRESHOLD
                )
            split_indices = self._find_split_indices(
                similarities=similarities, calculated_threshold=calculated_threshold
            )

            doc_chunks = self._split_documents(
                docs=batch_splits,
                split_indices=split_indices,
                similarities=similarities,
            )

            if len(doc_chunks) > 1:
                chunks.extend(doc_chunks[:-1])
                last_chunk = doc_chunks[-1]
            else:
                last_chunk = doc_chunks[0]

            if self.plot_chunks:
                self.plot_similarity_scores(
                    similarities=similarities,
                    split_indices=split_indices,
                    chunks=doc_chunks,
                    calculated_threshold=calculated_threshold,
                )

            if self.enable_statistics:
                print(self.statistics)

        if last_chunk:
            chunks.append(last_chunk)

        return chunks

    @time_it
    async def _async_chunk(
        self, splits: List[Any], batch_size: int = 64, enforce_max_tokens: bool = False
    ) -> List[Chunk]:
        """Merge splits into chunks using semantic similarity, with optional enforcement
        of maximum token limits per chunk.

        :param splits: Splits to be merged into chunks.
        :param batch_size: Number of splits to process in one batch.
        :param enforce_max_tokens: If True, further split chunks that exceed the maximum
        token limit.

        :return: List of chunks.
        """
        # Split the docs that already exceed max_split_tokens to smaller chunks
        if enforce_max_tokens:
            new_splits = []
            for split in splits:
                token_count = tiktoken_length(split)
                if token_count > self.max_split_tokens:
                    logger.info(
                        f"Single document exceeds the maximum token limit "
                        f"of {self.max_split_tokens}. "
                        "Splitting to sentences before semantically merging."
                    )
                    _splits = self._split(split)
                    new_splits.extend(_splits)
                else:
                    new_splits.append(split)

            splits = [split for split in new_splits if split and split.strip()]

        chunks: list[Chunk] = []

        # Step 1: Define process_batch as a separate coroutine function for parallel
        async def _process_batch(batch_splits: List[str]):
            encoded_splits = await self._async_encode_documents(batch_splits)
            return batch_splits, encoded_splits

        # Step 2: Create tasks for parallel execution
        tasks = []
        for i in range(0, len(splits), batch_size):
            batch_splits = splits[i : i + batch_size]
            tasks.append(_process_batch(batch_splits))

        # Step 3: Await tasks and collect results
        encoded_split_results = await asyncio.gather(*tasks)

        # Step 4: Sequentially process results
        for batch_splits, encoded_splits in encoded_split_results:
            similarities = self._calculate_similarity_scores(encoded_splits)
            if self.dynamic_threshold:
                calculated_threshold = self._find_optimal_threshold(
                    batch_splits, similarities
                )
            else:
                calculated_threshold = (
                    self.encoder.score_threshold
                    if self.encoder.score_threshold
                    else self.DEFAULT_THRESHOLD
                )

            split_indices = self._find_split_indices(
                similarities=similarities, calculated_threshold=calculated_threshold
            )

            doc_chunks: list[Chunk] = self._split_documents(
                docs=batch_splits,
                split_indices=split_indices,
                similarities=similarities,
            )

            chunks.extend(doc_chunks)
        return chunks

    @time_it
    def __call__(self, docs: List[str], batch_size: int = 64) -> List[List[Chunk]]:
        """Split documents into smaller chunks based on semantic similarity.

        :param docs: list of text documents to be split, if only wanted to
            split a single document, pass it as a list with a single element.

        :return: list of Chunk objects containing the split documents.
        """
        if not docs:
            raise ValueError("At least one document is required for splitting.")

        all_chunks = []
        for doc in docs:
            token_count = tiktoken_length(doc)
            if token_count > self.max_split_tokens:
                logger.info(
                    f"Single document exceeds the maximum token limit "
                    f"of {self.max_split_tokens}. "
                    "Splitting to sentences before semantically merging."
                )
            if isinstance(doc, str):
                splits = self._split(doc)
                doc_chunks = self._chunk(splits, batch_size=batch_size)
                all_chunks.append(doc_chunks)
            else:
                raise ValueError("The document must be a string.")
        return all_chunks

    @time_it
    async def acall(self, docs: List[str], batch_size: int = 64) -> List[List[Chunk]]:
        """Split documents into smaller chunks based on semantic similarity.

        :param docs: list of text documents to be split, if only wanted to
            split a single document, pass it as a list with a single element.

        :return: list of Chunk objects containing the split documents.
        """
        if not docs:
            raise ValueError("At least one document is required for splitting.")

        all_chunks = []
        for doc in docs:
            token_count = tiktoken_length(doc)
            if token_count > self.max_split_tokens:
                logger.info(
                    f"Single document exceeds the maximum token limit "
                    f"of {self.max_split_tokens}. "
                    "Splitting to sentences before semantically merging."
                )
            if isinstance(doc, str):
                splits = self._split(doc)
                doc_chunks = await self._async_chunk(splits, batch_size=batch_size)
                all_chunks.append(doc_chunks)
            else:
                raise ValueError("The document must be a string.")
        return all_chunks

    @time_it
    def _encode_documents(self, docs: List[str]) -> np.ndarray:
        """
        Encodes a list of documents into embeddings. If the number of documents
        exceeds 2000, the documents are split into batches to avoid overloading
        the encoder. OpenAI has a limit of len(array) < 2048.

        :param docs: List of text documents to be encoded.
        :return: A numpy array of embeddings for the given documents.
        """
        max_docs_per_batch = 2000
        embeddings = []

        for i in range(0, len(docs), max_docs_per_batch):
            batch_docs = docs[i : i + max_docs_per_batch]
            try:
                batch_embeddings = self.encoder(batch_docs)
                embeddings.extend(batch_embeddings)
            except Exception as e:
                logger.error(f"Error encoding documents {batch_docs}: {e}")
                raise

        return np.array(embeddings)

    @async_retry_with_timeout(retries=3, timeout=5)
    @time_it
    async def _async_encode_documents(self, docs: List[str]) -> np.ndarray:
        """
        Encodes a list of documents into embeddings. If the number of documents
        exceeds 2000, the documents are split into batches to avoid overloading
        the encoder. OpenAI has a limit of len(array) < 2048.

        :param docs: List of text documents to be encoded.
        :return: A numpy array of embeddings for the given documents.
        """
        max_docs_per_batch = 2000
        embeddings = []

        for i in range(0, len(docs), max_docs_per_batch):
            batch_docs = docs[i : i + max_docs_per_batch]
            try:
                batch_embeddings = await self.encoder.acall(batch_docs)
                embeddings.extend(batch_embeddings)
            except Exception as e:
                logger.error(f"Error encoding documents {batch_docs}: {e}")
                raise

        return np.array(embeddings)

    def _calculate_similarity_scores(self, encoded_docs: np.ndarray) -> List[float]:
        raw_similarities = []
        for idx in range(1, len(encoded_docs)):
            window_start = max(0, idx - self.window_size)
            cumulative_context = np.mean(encoded_docs[window_start:idx], axis=0)
            curr_sim_score = np.dot(cumulative_context, encoded_docs[idx]) / (
                np.linalg.norm(cumulative_context) * np.linalg.norm(encoded_docs[idx])
                + 1e-10
            )
            raw_similarities.append(curr_sim_score)
        return raw_similarities

    def _find_split_indices(
        self, similarities: List[float], calculated_threshold: float
    ) -> List[int]:
        split_indices = []
        for idx, score in enumerate(similarities):
            logger.debug(f"Similarity score at index {idx}: {score}")
            if score < calculated_threshold:
                logger.debug(
                    f"Adding to split_indices due to score < threshold: "
                    f"{score} < {calculated_threshold}"
                )
                # Chunk after the document at idx
                split_indices.append(idx + 1)
        return split_indices

    def _find_optimal_threshold(self, docs: List[str], similarity_scores: List[float]):
        token_counts = [tiktoken_length(doc) for doc in docs]
        cumulative_token_counts = np.cumsum([0] + token_counts)

        # Analyze the distribution of similarity scores to set initial bounds
        median_score = np.median(similarity_scores)
        std_dev = np.std(similarity_scores)

        # Set initial bounds based on median and standard deviation
        low = max(0.0, float(median_score - std_dev))
        high = min(1.0, float(median_score + std_dev))

        iteration = 0
        median_tokens = 0
        calculated_threshold = 0.0
        while low <= high:
            calculated_threshold = (low + high) / 2
            split_indices = self._find_split_indices(
                similarity_scores, calculated_threshold
            )
            logger.debug(
                f"Iteration {iteration}: Trying threshold: {calculated_threshold}"
            )

            # Calculate the token counts for each split using the cumulative sums
            split_token_counts = [
                cumulative_token_counts[end] - cumulative_token_counts[start]
                for start, end in zip(
                    [0] + split_indices, split_indices + [len(token_counts)]
                )
            ]

            # Calculate the median token count for the chunks
            median_tokens = np.median(split_token_counts)
            logger.debug(
                f"Iteration {iteration}: Median tokens per split: {median_tokens}"
            )
            if (
                self.min_split_tokens - self.split_tokens_tolerance
                <= median_tokens
                <= self.max_split_tokens + self.split_tokens_tolerance
            ):
                logger.debug("Median tokens in target range. Stopping iteration.")
                break
            elif median_tokens < self.min_split_tokens:
                high = calculated_threshold - self.threshold_adjustment
                logger.debug(f"Iteration {iteration}: Adjusting high to {high}")
            else:
                low = calculated_threshold + self.threshold_adjustment
                logger.debug(f"Iteration {iteration}: Adjusting low to {low}")
            iteration += 1

        logger.debug(
            f"Optimal threshold {calculated_threshold} found "
            f"with median tokens ({median_tokens}) in target range "
            f"({self.min_split_tokens}-{self.max_split_tokens})."
        )

        return calculated_threshold

    def _split_documents(
        self, docs: List[str], split_indices: List[int], similarities: List[float]
    ) -> List[Chunk]:
        """
        This method iterates through each document, appending it to the current split
        until it either reaches a split point (determined by split_indices) or exceeds
        the maximum token limit for a split (self.max_split_tokens).
        When a document causes the current token count to exceed this limit,
        or when a split point is reached and the minimum token requirement is met,
        the current split is finalized and added to the List of chunks.
        """
        token_counts = [tiktoken_length(doc) for doc in docs]
        chunks, current_split = [], []
        current_tokens_count = 0

        # Statistics
        chunks_by_threshold = 0
        chunks_by_max_chunk_size = 0
        chunks_by_last_split = 0

        for doc_idx, doc in enumerate(docs):
            doc_token_count = token_counts[doc_idx]
            logger.debug(f"Accumulative token count: {current_tokens_count} tokens")
            logger.debug(f"Document token count: {doc_token_count} tokens")
            # Check if current index is a split point based on similarity
            if doc_idx + 1 in split_indices:
                if (
                    self.min_split_tokens
                    <= current_tokens_count + doc_token_count
                    < self.max_split_tokens
                ):
                    # Include the current document before splitting
                    # if it doesn't exceed the max limit
                    current_split.append(doc)
                    current_tokens_count += doc_token_count

                    triggered_score = (
                        similarities[doc_idx] if doc_idx < len(similarities) else None
                    )
                    chunks.append(
                        Chunk(
                            splits=current_split.copy(),
                            is_triggered=True,
                            triggered_score=triggered_score,
                            token_count=current_tokens_count,
                        )
                    )
                    logger.debug(
                        f"Chunk finalized with {current_tokens_count} tokens due to "
                        f"threshold {triggered_score}."
                    )
                    current_split, current_tokens_count = [], 0
                    chunks_by_threshold += 1
                    continue  # Move to the next document after splitting

            # Check if adding the current document exceeds the max token limit
            if current_tokens_count + doc_token_count > self.max_split_tokens:
                if current_tokens_count >= self.min_split_tokens:
                    chunks.append(
                        Chunk(
                            splits=current_split.copy(),
                            is_triggered=False,
                            triggered_score=None,
                            token_count=current_tokens_count,
                        )
                    )
                    chunks_by_max_chunk_size += 1
                    logger.debug(
                        f"Chink finalized with {current_tokens_count} tokens due to "
                        f"exceeding token limit of {self.max_split_tokens}."
                    )
                    current_split, current_tokens_count = [], 0

            current_split.append(doc)
            current_tokens_count += doc_token_count

        # Handle the last split
        if current_split:
            chunks.append(
                Chunk(
                    splits=current_split.copy(),
                    is_triggered=False,
                    triggered_score=None,
                    token_count=current_tokens_count,
                )
            )
            chunks_by_last_split += 1
            logger.debug(
                f"Final split added with {current_tokens_count} "
                "tokens due to remaining documents."
            )

        # Validation to ensure no tokens are lost during the split
        original_token_count = sum(token_counts)
        split_token_count = sum(
            [tiktoken_length(doc) for split in chunks for doc in split.splits]
        )
        if original_token_count != split_token_count:
            logger.error(
                f"Token count mismatch: {original_token_count} != {split_token_count}"
            )
            raise ValueError(
                f"Token count mismatch: {original_token_count} != {split_token_count}"
            )

        # Statistics
        total_chunks = len(chunks)
        chunks_by_similarity_ratio = (
            chunks_by_threshold / total_chunks if total_chunks else 0
        )
        min_token_size = max_token_size = 0
        if chunks:
            token_counts = [
                split.token_count for split in chunks if split.token_count is not None
            ]
            min_token_size, max_token_size = (
                min(token_counts, default=0),
                max(token_counts, default=0),
            )

        self.statistics = ChunkStatistics(
            total_documents=len(docs),
            total_chunks=total_chunks,
            chunks_by_threshold=chunks_by_threshold,
            chunks_by_max_chunk_size=chunks_by_max_chunk_size,
            chunks_by_last_split=chunks_by_last_split,
            min_token_size=min_token_size,
            max_token_size=max_token_size,
            chunks_by_similarity_ratio=chunks_by_similarity_ratio,
        )

        return chunks

    def plot_similarity_scores(
        self,
        similarities: List[float],
        split_indices: List[int],
        chunks: list[Chunk],
        calculated_threshold: float,
    ):
        try:
            from matplotlib import pyplot as plt
        except ImportError:
            logger.warning(
                "Plotting is disabled. Please `pip install "
                "semantic-router[processing]`."
            )
            return

        _, axs = plt.subplots(2, 1, figsize=(12, 12))  # Adjust for two plots

        # Plot 1: Similarity Scores
        axs[0].plot(similarities, label="Similarity Scores", marker="o")
        for split_index in split_indices:
            axs[0].axvline(
                x=split_index - 1,
                color="r",
                linestyle="--",
                label="Chunk" if split_index == split_indices[0] else "",
            )
        axs[0].axhline(
            y=calculated_threshold,
            color="g",
            linestyle="-.",
            label="Threshold Similarity Score",
        )

        # Annotating each similarity score
        for i, score in enumerate(similarities):
            axs[0].annotate(
                f"{score:.2f}",  # Formatting to two decimal places
                (i, score),
                textcoords="offset points",
                xytext=(0, 10),  # Positioning the text above the point
                ha="center",
            )  # Center-align the text

        axs[0].set_xlabel("Document Segment Index")
        axs[0].set_ylabel("Similarity Score")
        axs[0].set_title(
            f"Threshold: {calculated_threshold} | Window Size: {self.window_size}",
            loc="right",
            fontsize=10,
        )
        axs[0].legend()

        # Plot 2: Chunk Token Size Distribution
        token_counts = [split.token_count for split in chunks]
        axs[1].bar(range(len(token_counts)), token_counts, color="lightblue")
        axs[1].set_title("Chunk Token Sizes")
        axs[1].set_xlabel("Chunk Index")
        axs[1].set_ylabel("Token Count")
        axs[1].set_xticks(range(len(token_counts)))
        axs[1].set_xticklabels([str(i) for i in range(len(token_counts))])
        axs[1].grid(True)

        # Annotate each bar with the token size
        for idx, token_count in enumerate(token_counts):
            if not token_count:
                continue
            axs[1].text(
                idx, token_count + 0.01, str(token_count), ha="center", va="bottom"
            )

        plt.tight_layout()
        plt.show()

    def plot_sentence_similarity_scores(
        self, docs: List[str], threshold: float, window_size: int
    ):
        try:
            from matplotlib import pyplot as plt
        except ImportError:
            logger.warning("Plotting is disabled. Please `pip install matplotlib`.")
            return
        """
        Computes similarity scores between the average of the last
        'window_size' sentences and the next one,
        plots a graph of these similarity scores, and prints the first
        sentence after a similarity score below
        a specified threshold.
        """
        sentences = [sentence for doc in docs for sentence in self._split(doc)]
        encoded_sentences = self._encode_documents(sentences)
        similarity_scores = []

        for i in range(window_size, len(encoded_sentences)):
            window_avg_encoding = np.mean(
                encoded_sentences[i - window_size : i], axis=0
            )
            sim_score = np.dot(window_avg_encoding, encoded_sentences[i]) / (
                np.linalg.norm(window_avg_encoding)
                * np.linalg.norm(encoded_sentences[i])
                + 1e-10
            )
            similarity_scores.append(sim_score)

        plt.figure(figsize=(10, 8))
        plt.plot(similarity_scores, marker="o", linestyle="-", color="b")
        plt.title("Sliding Window Sentence Similarity Scores")
        plt.xlabel("Sentence Index")
        plt.ylabel("Similarity Score")
        plt.grid(True)
        plt.axhline(y=threshold, color="r", linestyle="--", label="Threshold")
        plt.show()

        for i, score in enumerate(similarity_scores):
            if score < threshold:
                print(
                    f"First sentence after similarity score "
                    f"below {threshold}: {sentences[i + window_size]}"
                )

链接如下https://github.com/aurelio-labs/semantic-chunkers/blob/main/semantic_chunkers/chunkers/statistical.py

reference

---

Semantic Chunkers

https://github.com/aurelio-labs/semantic-chunkers

文本视频音频分块工具 - Semantic Chunkers

https://blog.csdn.net/liliang199/article/details/156982300

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐