手搓向量数据库:实现百万向量相似度搜索的工程实践
本文详细介绍了从零构建百万级向量数据库的工程实践。主要内容包括:1. 向量数据库架构设计,包含数据接入、索引管理、查询处理等核心模块;2. 相似度计算基础,实现欧几里得距离、余弦相似度等多种度量方法;3. 高效索引结构实现,涵盖KD-Tree、LSH和HNSW等算法;4. 工程优化策略,如向量量化、并行计算和持久化存储;5. 完整系统实现与性能测试框架。通过实践展示了文本相似度搜索和图像检索等应用
手搓向量数据库:实现百万向量相似度搜索的工程实践
引言:为什么需要自研向量数据库?
在人工智能快速发展的今天,向量数据已成为现代应用的核心组成部分。从推荐系统的用户画像、自然语言处理的词嵌入,到计算机视觉的特征提取,向量无处不在。传统的基于关键词的搜索技术已无法满足相似性搜索的需求,这催生了专门处理高维向量数据的向量数据库。
商业向量数据库如Pinecone、Weaviate、Qdrant等虽功能完善,但往往存在成本高昂、定制化困难、数据隐私等问题。理解向量数据库的内部原理并实现自己的轻量级版本,不仅有助于深入理解这一核心技术,还能根据特定需求进行优化定制。
本文将详细讲解如何从零开始构建一个支持百万级别向量相似度搜索的向量数据库,涵盖理论原理、系统设计、核心算法实现和优化策略,最终形成一个完整的可运行系统。
第一章:向量数据库基础架构设计
1.1 核心组件概览
一个完整的向量数据库应包含以下核心模块:
-
数据接入层:负责接收和预处理向量数据
-
索引管理层:构建和维护高效的向量索引结构
-
查询处理层:处理相似度搜索请求并返回结果
-
存储引擎层:持久化存储向量数据和索引结构
-
系统管理层:监控、优化和维持系统运行
1.2 系统架构设计
python
# 向量数据库核心架构示意代码
class VectorDatabase:
def __init__(self, dimension, distance_metric='cosine'):
self.dimension = dimension
self.distance_metric = distance_metric
self.index = None
self.vectors = []
self.metadata = []
self.index_built = False
def add_vector(self, vector, metadata=None):
"""添加向量到数据库"""
if len(vector) != self.dimension:
raise ValueError(f"向量维度必须是 {self.dimension}")
self.vectors.append(vector)
self.metadata.append(metadata or {})
self.index_built = False # 标记索引需要重建
def build_index(self, index_type='hnsw', **kwargs):
"""构建向量索引"""
# 索引构建逻辑将在后续章节详细实现
pass
def search(self, query_vector, k=10, **kwargs):
"""相似度搜索"""
if not self.index_built:
self.build_index()
# 搜索逻辑将在后续章节详细实现
pass
第二章:向量相似度计算基础
2.1 常用距离度量方法
向量相似度搜索的核心是距离计算。不同场景需要不同的距离度量:
2.1.1 欧几里得距离 (L2距离)
适用于需要计算实际距离的场景,如空间位置。
python
import numpy as np
def euclidean_distance(vec1, vec2):
"""计算欧几里得距离"""
return np.sqrt(np.sum((np.array(vec1) - np.array(vec2)) ** 2))
def euclidean_distance_batch(vectors, query):
"""批量计算欧几里得距离(优化版)"""
vectors = np.array(vectors)
query = np.array(query)
# 利用numpy广播和向量化运算
differences = vectors - query
return np.sqrt(np.sum(differences ** 2, axis=1))
2.1.2 余弦相似度
适用于文本、推荐等方向比大小更重要的场景。
python
def cosine_similarity(vec1, vec2):
"""计算余弦相似度"""
vec1 = np.array(vec1)
vec2 = np.array(vec2)
dot_product = np.dot(vec1, vec2)
norm1 = np.linalg.norm(vec1)
norm2 = np.linalg.norm(vec2)
return dot_product / (norm1 * norm2)
def cosine_distance(vec1, vec2):
"""余弦距离 = 1 - 余弦相似度"""
return 1 - cosine_similarity(vec1, vec2)
2.1.3 内积相似度
适用于已归一化的向量。
python
def inner_product(vec1, vec2):
"""计算内积(点积)"""
return np.dot(np.array(vec1), np.array(vec2))
2.1.4 曼哈顿距离 (L1距离)
适用于网格状路径或稀疏向量。
python
def manhattan_distance(vec1, vec2):
"""计算曼哈顿距离"""
return np.sum(np.abs(np.array(vec1) - np.array(vec2)))
2.2 距离计算的优化
对于百万级向量的搜索,距离计算是性能瓶颈。我们需要进行多种优化:
python
import numba
from numba import jit, prange
@jit(nopython=True, parallel=True, fastmath=True)
def euclidean_distance_batch_optimized(vectors, query):
"""使用numba优化的批量欧氏距离计算"""
n = vectors.shape[0]
d = vectors.shape[1]
distances = np.empty(n, dtype=np.float32)
for i in prange(n):
dist = 0.0
for j in range(d):
diff = vectors[i, j] - query[j]
dist += diff * diff
distances[i] = np.sqrt(dist)
return distances
第三章:高效索引结构设计与实现
3.1 朴素线性扫描及其局限性
最简单的搜索方法是线性扫描:计算查询向量与数据库中所有向量的距离,然后排序。这种方法时间复杂度为O(Nd),其中N是向量数量,d是维度。
python
class LinearScanIndex:
"""线性扫描索引(基准方法)"""
def __init__(self, distance_fn=euclidean_distance):
self.distance_fn = distance_fn
self.vectors = []
self.metadata = []
def add_batch(self, vectors, metadata_list=None):
self.vectors.extend(vectors)
if metadata_list:
self.metadata.extend(metadata_list)
else:
self.metadata.extend([{}] * len(vectors))
def search(self, query, k=10):
distances = []
for i, vec in enumerate(self.vectors):
dist = self.distance_fn(vec, query)
distances.append((dist, i))
# 获取前k个最近邻
distances.sort(key=lambda x: x[0])
results = []
for dist, idx in distances[:k]:
results.append({
'vector': self.vectors[idx],
'metadata': self.metadata[idx],
'distance': dist
})
return results
对于百万级向量,线性扫描不可行。我们需要更高效的索引结构。
3.2 树形索引结构
3.2.1 KD-Tree实现
KD-Tree是一种经典的空间划分数据结构,适用于低维空间(d < 20)。
python
class KDTreeNode:
def __init__(self, point, left=None, right=None, axis=None):
self.point = point
self.left = left
self.right = right
self.axis = axis
self.index = None # 原始数据索引
class KDTree:
def __init__(self, points, indices=None):
self.dimension = len(points[0]) if points else 0
if indices is None:
indices = list(range(len(points)))
self.root = self._build_tree(points, indices)
def _build_tree(self, points, indices, depth=0):
if not points:
return None
axis = depth % self.dimension
# 按当前轴排序
combined = list(zip(points, indices))
combined.sort(key=lambda x: x[0][axis])
sorted_points = [p for p, _ in combined]
sorted_indices = [i for _, i in combined]
median = len(sorted_points) // 2
node = KDTreeNode(
point=sorted_points[median],
axis=axis,
index=sorted_indices[median]
)
node.left = self._build_tree(
sorted_points[:median],
sorted_indices[:median],
depth + 1
)
node.right = self._build_tree(
sorted_points[median+1:],
sorted_indices[median+1:],
depth + 1
)
return node
def _search_nn(self, node, point, depth, best):
if node is None:
return best
axis = depth % self.dimension
# 计算当前节点距离
dist = euclidean_distance(node.point, point)
# 更新最佳结果
if dist < best['distance']:
best['distance'] = dist
best['point'] = node.point
best['index'] = node.index
# 决定搜索方向
next_branch = None
opposite_branch = None
if point[axis] < node.point[axis]:
next_branch = node.left
opposite_branch = node.right
else:
next_branch = node.right
opposite_branch = node.left
# 搜索更近的分支
best = self._search_nn(next_branch, point, depth + 1, best)
# 如果超球体与分裂平面相交,搜索另一分支
if abs(point[axis] - node.point[axis]) < best['distance']:
best = self._search_nn(opposite_branch, point, depth + 1, best)
return best
def search_knn(self, point, k=10):
"""寻找k个最近邻"""
# 使用最大堆维护前k个最近邻
import heapq
heap = []
def search_node(node, depth):
if node is None:
return
axis = depth % self.dimension
dist = euclidean_distance(node.point, point)
# 将节点加入堆(使用负距离实现最大堆)
if len(heap) < k:
heapq.heappush(heap, (-dist, node))
elif dist < -heap[0][0]:
heapq.heappushpop(heap, (-dist, node))
# 决定搜索顺序
if point[axis] < node.point[axis]:
first, second = node.left, node.right
else:
first, second = node.right, node.left
search_node(first, depth + 1)
# 检查是否需要搜索另一分支
if len(heap) < k or abs(point[axis] - node.point[axis]) < -heap[0][0]:
search_node(second, depth + 1)
search_node(self.root, 0)
# 提取并排序结果
results = [(-dist, node.point, node.index) for dist, node in heap]
results.sort(key=lambda x: x[0])
return results[:k]
3.3 近似最近邻搜索算法
对于高维数据,精确最近邻搜索往往不切实际。近似最近邻(ANN)算法在精度和速度之间取得平衡。
3.3.1 局部敏感哈希(LSH)
LSH的核心思想是:相似的向量经过哈希后,有更高的概率落入相同的桶中。
python
import hashlib
import struct
class LSHIndex:
def __init__(self, dimension, num_tables=10, hash_size=10):
self.dimension = dimension
self.num_tables = num_tables
self.hash_size = hash_size
self.tables = [{} for _ in range(num_tables)]
self.random_vectors = []
self._init_random_vectors()
def _init_random_vectors(self):
"""初始化随机向量用于投影"""
np.random.seed(42)
for _ in range(self.num_tables):
table_vectors = []
for _ in range(self.hash_size):
# 从高斯分布中采样随机向量
random_vec = np.random.randn(self.dimension)
random_vec = random_vec / np.linalg.norm(random_vec)
table_vectors.append(random_vec)
self.random_vectors.append(table_vectors)
def _hash_vector(self, vector, table_idx):
"""计算向量在指定哈希表中的哈希值"""
hash_bits = []
for i in range(self.hash_size):
projection = np.dot(vector, self.random_vectors[table_idx][i])
# 将投影值二值化
bit = 1 if projection >= 0 else 0
hash_bits.append(bit)
# 将位数组转换为整数哈希值
hash_value = 0
for bit in hash_bits:
hash_value = (hash_value << 1) | bit
return hash_value
def add_vector(self, vector, idx):
"""添加向量到所有哈希表"""
for table_idx in range(self.num_tables):
hash_value = self._hash_vector(vector, table_idx)
if hash_value not in self.tables[table_idx]:
self.tables[table_idx][hash_value] = []
self.tables[table_idx][hash_value].append((vector, idx))
def search(self, query, k=10):
"""搜索近似最近邻"""
candidates = set()
# 从所有哈希表中收集候选向量
for table_idx in range(self.num_tables):
hash_value = self._hash_vector(query, table_idx)
if hash_value in self.tables[table_idx]:
for vector, idx in self.tables[table_idx][hash_value]:
candidates.add((tuple(vector), idx))
# 计算距离并排序
results = []
for vector_tuple, idx in candidates:
vector = np.array(vector_tuple)
dist = euclidean_distance(vector, query)
results.append((dist, vector, idx))
results.sort(key=lambda x: x[0])
return results[:k]
3.3.2 HNSW (Hierarchical Navigable Small World) 图算法
HNSW是目前最先进的ANN算法之一,结合了可导航小世界图和多层结构。
python
import heapq
import random
import math
from collections import defaultdict
class HNSWNode:
def __init__(self, id, vector, level):
self.id = id
self.vector = vector
self.level = level
self.neighbors = defaultdict(list) # 每层的邻居列表
def add_neighbor(self, node, level):
"""在指定层添加邻居"""
if node not in self.neighbors[level]:
self.neighbors[level].append(node)
def get_neighbors(self, level):
"""获取指定层的邻居"""
return self.neighbors.get(level, [])
class HNSWIndex:
def __init__(self, dimension, M=16, ef_construction=200, ef_search=50):
"""
参数说明:
- M: 每个节点在每层的最大连接数
- ef_construction: 构建时的动态候选列表大小
- ef_search: 搜索时的动态候选列表大小
"""
self.dimension = dimension
self.M = M
self.ef_construction = ef_construction
self.ef_search = ef_search
self.max_level = 0
self.entry_point = None
self.nodes = []
self.level_multiplier = 1 / math.log(M)
def _random_level(self):
"""随机生成节点的层数(指数分布)"""
level = 0
while random.random() < self.level_multiplier and level < 100:
level += 1
return level
def _search_layer(self, query, entry_point, ef, level):
"""在指定层搜索最近邻"""
visited = set([entry_point.id])
candidates = [(euclidean_distance(query, entry_point.vector), entry_point)]
heapq.heapify(candidates)
results = []
while candidates:
dist, node = heapq.heappop(candidates)
# 检查是否可以作为结果
if len(results) < ef:
heapq.heappush(results, (-dist, node))
if len(results) > ef:
heapq.heappop(results)
else:
# 如果当前节点比结果中最差的距离还远,停止搜索
farthest_dist = -results[0][0] if results else float('inf')
if dist > farthest_dist:
break
# 探索邻居
for neighbor in node.get_neighbors(level):
if neighbor.id not in visited:
visited.add(neighbor.id)
dist_to_neighbor = euclidean_distance(query, neighbor.vector)
heapq.heappush(candidates, (dist_to_neighbor, neighbor))
# 返回结果(按距离升序)
sorted_results = sorted([(-dist, node) for dist, node in results], key=lambda x: x[0])
return [node for _, node in sorted_results]
def _select_neighbors_simple(self, query, candidates, M, level):
"""从候选集中选择M个最近邻(简化版)"""
distances = [(euclidean_distance(query, node.vector), node) for node in candidates]
distances.sort(key=lambda x: x[0])
return [node for _, node in distances[:M]]
def insert(self, vector, id=None):
"""插入新向量"""
if id is None:
id = len(self.nodes)
level = self._random_level()
node = HNSWNode(id, vector, level)
self.nodes.append(node)
if self.entry_point is None:
self.entry_point = node
self.max_level = level
return node
# 从最高层开始搜索入口点
current_node = self.entry_point
current_max_level = self.max_level
# 记录每层找到的最近邻
nearest_neighbors = []
# 自上而下搜索
for l in range(current_max_level, -1, -1):
if l > level:
# 对于新节点不存在的层,继续搜索但不记录
results = self._search_layer(vector, current_node, 1, l)
if results:
current_node = results[0]
else:
# 对于新节点存在的层,搜索并记录最近邻
results = self._search_layer(vector, current_node, self.ef_construction, l)
nearest_neighbors.append(results)
if results:
current_node = results[0]
# 自下而上建立连接
for l in range(min(level, len(nearest_neighbors))):
neighbors = self._select_neighbors_simple(
vector,
nearest_neighbors[l],
self.M,
l
)
# 建立双向连接
for neighbor in neighbors:
node.add_neighbor(neighbor, l)
neighbor.add_neighbor(node, l)
# 保持邻居数不超过M
if len(neighbor.get_neighbors(l)) > self.M:
# 简化处理:随机移除一个多余的连接
neighbor.neighbors[l] = neighbor.neighbors[l][:self.M]
# 更新入口点
if level > self.max_level:
self.entry_point = node
self.max_level = level
return node
def search(self, query, k=10):
"""搜索k个最近邻"""
if self.entry_point is None:
return []
# 从最高层开始搜索入口点
current_node = self.entry_point
for l in range(self.max_level, 0, -1):
results = self._search_layer(query, current_node, 1, l)
if results:
current_node = results[0]
# 在最底层进行详细搜索
results = self._search_layer(query, current_node, self.ef_search, 0)
# 计算精确距离并返回前k个
distances = [(euclidean_distance(query, node.vector), node) for node in results]
distances.sort(key=lambda x: x[0])
return [(dist, node.vector, node.id) for dist, node in distances[:k]]
第四章:工程化实现与优化
4.1 内存优化策略
4.1.1 向量量化技术
向量量化通过将连续向量空间离散化来减少内存占用。
python
import pickle
from sklearn.cluster import KMeans
class ProductQuantization:
def __init__(self, num_subvectors=8, num_clusters=256):
self.num_subvectors = num_subvectors
self.num_clusters = num_clusters
self.codebooks = [] # 码本列表
self.codes = [] # 向量编码
def fit(self, vectors):
"""训练量化器"""
n_vectors = len(vectors)
dimension = len(vectors[0])
subvector_dim = dimension // self.num_subvectors
# 将每个向量划分为子向量
subvectors = []
for i in range(self.num_subvectors):
start_idx = i * subvector_dim
end_idx = start_idx + subvector_dim if i < self.num_subvectors - 1 else dimension
subvecs = [vec[start_idx:end_idx] for vec in vectors]
subvectors.append(subvecs)
# 为每个子空间训练码本
self.codebooks = []
for i in range(self.num_subvectors):
print(f"训练子空间 {i+1}/{self.num_subvectors}")
kmeans = KMeans(n_clusters=self.num_clusters, random_state=42)
kmeans.fit(subvectors[i])
self.codebooks.append(kmeans.cluster_centers_)
# 编码所有向量
self.codes = self.encode(vectors)
def encode(self, vectors):
"""编码向量"""
codes = []
for vec in vectors:
code = []
subvector_dim = len(vec) // self.num_subvectors
for i in range(self.num_subvectors):
start_idx = i * subvector_dim
end_idx = start_idx + subvector_dim if i < self.num_subvectors - 1 else len(vec)
subvec = vec[start_idx:end_idx]
# 找到最近的聚类中心
distances = np.linalg.norm(self.codebooks[i] - subvec, axis=1)
closest_idx = np.argmin(distances)
code.append(closest_idx)
codes.append(code)
return np.array(codes, dtype=np.uint8)
def decode(self, codes):
"""解码为近似向量"""
vectors = []
subvector_dim = self.codebooks[0].shape[1]
for code in codes:
vec_parts = []
for i, idx in enumerate(code):
vec_parts.append(self.codebooks[i][idx])
vectors.append(np.concatenate(vec_parts))
return np.array(vectors)
def asymmetric_distance(self, query, code):
"""计算查询向量与编码向量的不对称距离"""
distance = 0.0
subvector_dim = len(query) // self.num_subvectors
for i in range(self.num_subvectors):
start_idx = i * subvector_dim
end_idx = start_idx + subvector_dim if i < self.num_subvectors - 1 else len(query)
subquery = query[start_idx:end_idx]
centroid = self.codebooks[i][code[i]]
distance += np.sum((subquery - centroid) ** 2)
return distance
4.2 并行计算与批处理
4.2.1 多线程搜索
python
import concurrent.futures
from threading import Lock
class ParallelHNSWIndex(HNSWIndex):
def __init__(self, dimension, num_threads=4, **kwargs):
super().__init__(dimension, **kwargs)
self.num_threads = num_threads
self.insert_lock = Lock()
def parallel_search(self, queries, k=10):
"""并行处理多个查询"""
with concurrent.futures.ThreadPoolExecutor(max_workers=self.num_threads) as executor:
futures = [executor.submit(self.search, query, k) for query in queries]
results = [future.result() for future in concurrent.futures.as_completed(futures)]
return results
def batch_insert(self, vectors, ids=None):
"""批量插入向量"""
if ids is None:
ids = list(range(len(self.nodes), len(self.nodes) + len(vectors)))
with concurrent.futures.ThreadPoolExecutor(max_workers=self.num_threads) as executor:
futures = []
for i, vector in enumerate(vectors):
future = executor.submit(self._safe_insert, vector, ids[i] if i < len(ids) else None)
futures.append(future)
# 等待所有插入完成
concurrent.futures.wait(futures)
def _safe_insert(self, vector, id=None):
"""线程安全的插入操作"""
with self.insert_lock:
return self.insert(vector, id)
4.3 持久化存储
4.3.1 索引序列化
python
import json
import msgpack
class PersistentHNSWIndex(HNSWIndex):
def __init__(self, dimension, storage_path='./vector_db', **kwargs):
super().__init__(dimension, **kwargs)
self.storage_path = storage_path
os.makedirs(storage_path, exist_ok=True)
def save(self, filename='index.bin'):
"""保存索引到文件"""
filepath = os.path.join(self.storage_path, filename)
# 序列化数据
data = {
'dimension': self.dimension,
'M': self.M,
'ef_construction': self.ef_construction,
'max_level': self.max_level,
'nodes': [],
'entry_point_id': self.entry_point.id if self.entry_point else None
}
# 序列化节点
for node in self.nodes:
node_data = {
'id': node.id,
'vector': node.vector.tolist() if hasattr(node.vector, 'tolist') else node.vector,
'level': node.level,
'neighbors': {str(k): [n.id for n in v] for k, v in node.neighbors.items()}
}
data['nodes'].append(node_data)
# 使用msgpack进行高效二进制序列化
with open(filepath, 'wb') as f:
packed = msgpack.packb(data, use_bin_type=True)
f.write(packed)
print(f"索引已保存到 {filepath}")
def load(self, filename='index.bin'):
"""从文件加载索引"""
filepath = os.path.join(self.storage_path, filename)
if not os.path.exists(filepath):
raise FileNotFoundError(f"索引文件 {filepath} 不存在")
with open(filepath, 'rb') as f:
data = msgpack.unpackb(f.read(), raw=False)
# 恢复基本参数
self.dimension = data['dimension']
self.M = data['M']
self.ef_construction = data['ef_construction']
self.max_level = data['max_level']
# 重建节点
self.nodes = []
node_map = {}
# 第一遍:创建节点对象
for node_data in data['nodes']:
node = HNSWNode(
id=node_data['id'],
vector=np.array(node_data['vector']),
level=node_data['level']
)
node_map[node.id] = node
self.nodes.append(node)
# 第二遍:重建邻居关系
for node_data in data['nodes']:
node = node_map[node_data['id']]
for level_str, neighbor_ids in node_data['neighbors'].items():
level = int(level_str)
for neighbor_id in neighbor_ids:
if neighbor_id in node_map:
node.add_neighbor(node_map[neighbor_id], level)
# 设置入口点
if data['entry_point_id'] is not None:
self.entry_point = node_map[data['entry_point_id']]
print(f"已从 {filepath} 加载索引,包含 {len(self.nodes)} 个节点")
第五章:系统集成与性能测试
5.1 完整向量数据库实现
python
import time
import logging
from typing import List, Dict, Any, Optional, Union
from dataclasses import dataclass, asdict
import numpy as np
import os
@dataclass
class SearchResult:
"""搜索结果数据结构"""
id: int
vector: List[float]
metadata: Dict[str, Any]
distance: float
score: float # 相似度得分,0-1之间
class VectorDatabase:
"""完整的向量数据库实现"""
def __init__(self,
dimension: int,
index_type: str = 'hnsw',
distance_metric: str = 'cosine',
persist_path: str = './vector_db_data',
max_connections: int = 16,
ef_construction: int = 200):
self.dimension = dimension
self.index_type = index_type
self.distance_metric = distance_metric
self.persist_path = persist_path
self.max_connections = max_connections
self.ef_construction = ef_construction
# 初始化数据结构
self.vectors = []
self.metadata = []
self.ids = []
self.next_id = 0
# 初始化索引
self.index = None
self._init_index()
# 配置日志
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
# 创建存储目录
os.makedirs(persist_path, exist_ok=True)
def _init_index(self):
"""初始化索引结构"""
if self.index_type == 'hnsw':
self.index = HNSWIndex(
dimension=self.dimension,
M=self.max_connections,
ef_construction=self.ef_construction
)
elif self.index_type == 'lsh':
self.index = LSHIndex(
dimension=self.dimension,
num_tables=10,
hash_size=12
)
elif self.index_type == 'flat':
self.index = LinearScanIndex()
else:
raise ValueError(f"不支持的索引类型: {self.index_type}")
def _get_distance_fn(self):
"""获取距离计算函数"""
if self.distance_metric == 'euclidean':
return euclidean_distance
elif self.distance_metric == 'cosine':
return cosine_distance
elif self.distance_metric == 'inner_product':
return lambda x, y: -inner_product(x, y) # 内积越大距离越小
else:
raise ValueError(f"不支持的距离度量: {self.distance_metric}")
def add(self,
vector: Union[List[float], np.ndarray],
metadata: Optional[Dict] = None) -> int:
"""添加单个向量"""
vector = np.array(vector, dtype=np.float32)
if len(vector) != self.dimension:
raise ValueError(f"向量维度必须是 {self.dimension}")
vector_id = self.next_id
self.next_id += 1
# 添加到存储
self.vectors.append(vector)
self.metadata.append(metadata or {})
self.ids.append(vector_id)
# 添加到索引
if hasattr(self.index, 'insert'):
self.index.insert(vector, vector_id)
elif hasattr(self.index, 'add_vector'):
self.index.add_vector(vector, vector_id)
self.logger.info(f"已添加向量 ID: {vector_id}")
return vector_id
def add_batch(self,
vectors: List[Union[List[float], np.ndarray]],
metadata_list: Optional[List[Dict]] = None) -> List[int]:
"""批量添加向量"""
ids = []
for i, vector in enumerate(vectors):
metadata = metadata_list[i] if metadata_list and i < len(metadata_list) else None
vector_id = self.add(vector, metadata)
ids.append(vector_id)
self.logger.info(f"已批量添加 {len(vectors)} 个向量")
return ids
def search(self,
query: Union[List[float], np.ndarray],
k: int = 10,
filter_fn: Optional[callable] = None) -> List[SearchResult]:
"""相似度搜索"""
query = np.array(query, dtype=np.float32)
if len(query) != self.dimension:
raise ValueError(f"查询向量维度必须是 {self.dimension}")
start_time = time.time()
# 执行索引搜索
if hasattr(self.index, 'search'):
index_results = self.index.search(query, k * 2) # 多取一些结果用于过滤
else:
raise RuntimeError("索引不支持搜索操作")
# 处理搜索结果
results = []
distance_fn = self._get_distance_fn()
for dist, vector, idx in index_results[:k * 2]:
if idx < len(self.ids):
vector_id = self.ids[idx]
# 应用过滤器
if filter_fn and not filter_fn(self.metadata[idx]):
continue
# 计算相似度得分
if self.distance_metric == 'cosine':
score = 1 - dist
elif self.distance_metric == 'euclidean':
# 将欧氏距离转换为相似度得分(需要归一化)
max_dist = np.sqrt(self.dimension * 4) # 假设向量值在[-1, 1]之间
score = max(0, 1 - dist / max_dist)
else:
score = 1 / (1 + dist)
result = SearchResult(
id=vector_id,
vector=vector.tolist() if hasattr(vector, 'tolist') else vector,
metadata=self.metadata[idx],
distance=float(dist),
score=float(score)
)
results.append(result)
if len(results) >= k:
break
elapsed = time.time() - start_time
self.logger.info(f"搜索完成,返回 {len(results)} 个结果,耗时 {elapsed:.4f} 秒")
return results
def save(self):
"""保存数据库到磁盘"""
# 保存向量数据
vectors_path = os.path.join(self.persist_path, 'vectors.npy')
np.save(vectors_path, np.array(self.vectors))
# 保存元数据
metadata_path = os.path.join(self.persist_path, 'metadata.msgpack')
with open(metadata_path, 'wb') as f:
f.write(msgpack.packb(self.metadata, use_bin_type=True))
# 保存ID
ids_path = os.path.join(self.persist_path, 'ids.npy')
np.save(ids_path, np.array(self.ids))
# 保存配置
config = {
'dimension': self.dimension,
'index_type': self.index_type,
'distance_metric': self.distance_metric,
'next_id': self.next_id
}
config_path = os.path.join(self.persist_path, 'config.json')
with open(config_path, 'w') as f:
json.dump(config, f)
# 保存索引
if hasattr(self.index, 'save'):
self.index.save()
self.logger.info(f"数据库已保存到 {self.persist_path}")
def load(self):
"""从磁盘加载数据库"""
# 加载配置
config_path = os.path.join(self.persist_path, 'config.json')
if not os.path.exists(config_path):
self.logger.warning("找不到数据库文件,创建新数据库")
return
with open(config_path, 'r') as f:
config = json.load(f)
# 验证配置
if config['dimension'] != self.dimension:
raise ValueError(f"维度不匹配: 数据库维度为 {config['dimension']}, 当前维度为 {self.dimension}")
# 加载向量数据
vectors_path = os.path.join(self.persist_path, 'vectors.npy')
self.vectors = np.load(vectors_path).tolist()
# 加载元数据
metadata_path = os.path.join(self.persist_path, 'metadata.msgpack')
with open(metadata_path, 'rb') as f:
self.metadata = msgpack.unpackb(f.read(), raw=False)
# 加载ID
ids_path = os.path.join(self.persist_path, 'ids.npy')
self.ids = np.load(ids_path).tolist()
# 恢复ID计数器
self.next_id = config['next_id']
# 重建索引
if hasattr(self.index, 'load'):
self.index.load()
else:
# 批量重新插入
self._rebuild_index()
self.logger.info(f"数据库已加载,包含 {len(self.vectors)} 个向量")
def _rebuild_index(self):
"""重建索引"""
self._init_index()
for i, vector in enumerate(self.vectors):
if hasattr(self.index, 'insert'):
self.index.insert(vector, self.ids[i])
elif hasattr(self.index, 'add_vector'):
self.index.add_vector(vector, self.ids[i])
def get_stats(self) -> Dict[str, Any]:
"""获取数据库统计信息"""
return {
'total_vectors': len(self.vectors),
'dimension': self.dimension,
'index_type': self.index_type,
'distance_metric': self.distance_metric,
'next_id': self.next_id
}
5.2 性能测试与基准对比
python
import matplotlib.pyplot as plt
from tqdm import tqdm
class VectorDBBenchmark:
"""向量数据库性能测试工具"""
@staticmethod
def generate_test_data(num_vectors, dimension, seed=42):
"""生成测试数据"""
np.random.seed(seed)
vectors = np.random.randn(num_vectors, dimension).astype(np.float32)
# 归一化向量(对于余弦相似度很重要)
norms = np.linalg.norm(vectors, axis=1, keepdims=True)
vectors = vectors / norms
return vectors
@staticmethod
def benchmark_search(db, test_queries, k=10, num_runs=10):
"""基准测试搜索性能"""
search_times = []
for _ in tqdm(range(num_runs), desc="运行搜索测试"):
start_time = time.time()
for query in test_queries:
db.search(query, k=k)
end_time = time.time()
search_times.append((end_time - start_time) / len(test_queries))
avg_time = np.mean(search_times)
std_time = np.std(search_times)
return {
'avg_search_time_ms': avg_time * 1000,
'std_search_time_ms': std_time * 1000,
'qps': 1 / avg_time # 每秒查询数
}
@staticmethod
def benchmark_insertion(db, vectors, batch_size=1000):
"""基准测试插入性能"""
insertion_times = []
for i in tqdm(range(0, len(vectors), batch_size), desc="运行插入测试"):
batch = vectors[i:i+batch_size]
start_time = time.time()
db.add_batch(batch.tolist())
end_time = time.time()
insertion_times.append((end_time - start_time) / len(batch))
avg_time = np.mean(insertion_times)
return {
'avg_insert_time_ms': avg_time * 1000,
'vectors_per_second': 1 / avg_time
}
@staticmethod
def benchmark_accuracy(db, test_queries, ground_truth, k=10):
"""基准测试搜索准确率"""
accuracy_scores = []
for i, query in enumerate(tqdm(test_queries, desc="测试准确率")):
results = db.search(query, k=k)
result_ids = [r.id for r in results]
# 计算召回率
true_positives = len(set(result_ids) & set(ground_truth[i][:k]))
recall = true_positives / min(k, len(ground_truth[i]))
accuracy_scores.append(recall)
return {
'avg_recall': np.mean(accuracy_scores),
'min_recall': np.min(accuracy_scores),
'max_recall': np.max(accuracy_scores)
}
@staticmethod
def run_comprehensive_benchmark():
"""运行全面基准测试"""
print("=" * 60)
print("向量数据库全面性能测试")
print("=" * 60)
# 测试配置
dimension = 128
num_vectors = 100000
num_queries = 1000
k = 10
# 生成测试数据
print("\n1. 生成测试数据...")
vectors = VectorDBBenchmark.generate_test_data(num_vectors, dimension)
test_queries = VectorDBBenchmark.generate_test_data(num_queries, dimension, seed=43)
# 计算真实最近邻(用于准确率测试)
print("2. 计算真实最近邻(基准线)...")
ground_truth = []
for query in tqdm(test_queries):
# 使用线性扫描计算真实最近邻
distances = np.linalg.norm(vectors - query, axis=1)
top_k_indices = np.argsort(distances)[:k]
ground_truth.append(top_k_indices.tolist())
# 测试不同索引类型
index_types = ['flat', 'lsh', 'hnsw']
results = {}
for index_type in index_types:
print(f"\n3. 测试 {index_type.upper()} 索引...")
# 创建数据库实例
db = VectorDatabase(
dimension=dimension,
index_type=index_type,
distance_metric='euclidean'
)
# 测试插入性能
print(" 测试插入性能...")
insert_results = VectorDBBenchmark.benchmark_insertion(db, vectors[:50000])
# 测试搜索性能
print(" 测试搜索性能...")
search_results = VectorDBBenchmark.benchmark_search(
db, test_queries[:100], k=k, num_runs=3
)
# 测试准确率
print(" 测试搜索准确率...")
accuracy_results = VectorDBBenchmark.benchmark_accuracy(
db, test_queries[:50], ground_truth[:50], k=k
)
results[index_type] = {
'insertion': insert_results,
'search': search_results,
'accuracy': accuracy_results
}
print(f" {index_type.upper()} 索引测试完成")
# 输出结果
print("\n" + "=" * 60)
print("测试结果汇总")
print("=" * 60)
for index_type in index_types:
print(f"\n{index_type.upper()} 索引:")
print(f" 插入性能: {results[index_type]['insertion']['vectors_per_second']:.1f} 向量/秒")
print(f" 搜索性能: {results[index_type]['search']['qps']:.1f} QPS")
print(f" 搜索延迟: {results[index_type]['search']['avg_search_time_ms']:.2f} ms")
print(f" 召回率: {results[index_type]['accuracy']['avg_recall']:.3f}")
return results
# 运行基准测试
if __name__ == "__main__":
benchmark_results = VectorDBBenchmark.run_comprehensive_benchmark()
5.3 可视化分析
python
class VisualizeResults:
"""可视化测试结果"""
@staticmethod
def plot_performance_comparison(results):
"""绘制性能对比图"""
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
index_types = list(results.keys())
# 1. 插入性能对比
insertion_speeds = [results[it]['insertion']['vectors_per_second'] for it in index_types]
axes[0, 0].bar(index_types, insertion_speeds, color='skyblue')
axes[0, 0].set_title('插入性能(向量/秒)')
axes[0, 0].set_ylabel('向量/秒')
axes[0, 0].grid(True, alpha=0.3)
# 2. 搜索性能对比
search_speeds = [results[it]['search']['qps'] for it in index_types]
axes[0, 1].bar(index_types, search_speeds, color='lightgreen')
axes[0, 1].set_title('搜索性能(QPS)')
axes[0, 1].set_ylabel('查询/秒')
axes[0, 1].grid(True, alpha=0.3)
# 3. 搜索延迟对比
search_latencies = [results[it]['search']['avg_search_time_ms'] for it in index_types]
axes[1, 0].bar(index_types, search_latencies, color='salmon')
axes[1, 0].set_title('搜索延迟(毫秒)')
axes[1, 0].set_ylabel('毫秒')
axes[1, 0].grid(True, alpha=0.3)
# 4. 准确率对比
recall_rates = [results[it]['accuracy']['avg_recall'] for it in index_types]
axes[1, 1].bar(index_types, recall_rates, color='gold')
axes[1, 1].set_title('召回率')
axes[1, 1].set_ylabel('召回率')
axes[1, 1].set_ylim([0, 1])
axes[1, 1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
@staticmethod
def plot_scalability_analysis():
"""绘制可扩展性分析"""
dimensions = [64, 128, 256, 512]
num_vectors_list = [1000, 10000, 100000, 500000]
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
# 维度对性能的影响
dimension_results = []
for dim in dimensions:
# 模拟性能数据(实际应通过测试获得)
perf = 1000 / (dim / 64) # 假设性能与维度成反比
dimension_results.append(perf)
axes[0].plot(dimensions, dimension_results, 'o-', linewidth=2)
axes[0].set_xlabel('向量维度')
axes[0].set_ylabel('相对性能')
axes[0].set_title('维度对搜索性能的影响')
axes[0].grid(True, alpha=0.3)
# 数据量对性能的影响
scale_results = []
for num_vec in num_vectors_list:
# 模拟性能数据
if num_vec <= 10000:
perf = 1000
elif num_vec <= 100000:
perf = 1000 * (10000 / num_vec) ** 0.5
else:
perf = 1000 * (10000 / num_vec) ** 0.7
scale_results.append(perf)
axes[1].plot(num_vectors_list, scale_results, 's-', linewidth=2)
axes[1].set_xlabel('向量数量')
axes[1].set_ylabel('相对性能')
axes[1].set_title('数据量对搜索性能的影响')
axes[1].set_xscale('log')
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# 可视化结果
if __name__ == "__main__":
# 假设已经有测试结果
# VisualizeResults.plot_performance_comparison(benchmark_results)
VisualizeResults.plot_scalability_analysis()
第六章:生产环境部署与优化建议
6.1 部署架构建议
对于生产环境,建议采用以下架构:
text
客户端应用
|
API网关(负载均衡、认证)
|
向量数据库集群
├── 主节点(写操作、索引管理)
├── 从节点1(读操作、搜索)
├── 从节点2(读操作、搜索)
└── 从节点N(读操作、搜索)
|
分布式存储
├── 对象存储(向量数据)
├── 缓存层(Redis/Memcached)
└── 元数据数据库(PostgreSQL)
6.2 性能优化建议
-
索引参数调优:
-
根据数据特征调整HNSW的M、ef_construction参数
-
对于高维数据,考虑使用PCA降维
-
-
内存优化:
-
使用向量量化减少内存占用
-
考虑使用磁盘辅助索引处理超大规模数据
-
-
并发优化:
-
实现读写锁支持并发访问
-
批量操作减少锁竞争
-
-
缓存策略:
-
实现查询结果缓存
-
热点数据预加载
-
6.3 监控与维护
python
class MonitoringSystem:
"""简单的监控系统"""
def __init__(self):
self.metrics = {
'search_latency': [],
'insert_latency': [],
'memory_usage': [],
'cache_hit_rate': []
}
def record_metric(self, metric_name, value):
"""记录指标"""
if metric_name in self.metrics:
self.metrics[metric_name].append((time.time(), value))
# 只保留最近1000个记录
if len(self.metrics[metric_name]) > 1000:
self.metrics[metric_name] = self.metrics[metric_name][-1000:]
def get_stats(self, metric_name, window_seconds=300):
"""获取统计信息"""
if metric_name not in self.metrics:
return None
current_time = time.time()
window_start = current_time - window_seconds
# 过滤时间窗口内的数据
recent_data = [v for t, v in self.metrics[metric_name] if t >= window_start]
if not recent_data:
return None
return {
'avg': np.mean(recent_data),
'min': np.min(recent_data),
'max': np.max(recent_data),
'p95': np.percentile(recent_data, 95),
'count': len(recent_data)
}
第七章:实际应用案例
7.1 文本相似度搜索
python
class TextVectorDatabase:
"""面向文本的向量数据库封装"""
def __init__(self, embedding_model=None):
self.embedding_model = embedding_model or self._default_embedding_model()
self.vector_db = VectorDatabase(
dimension=768, # 假设使用BERT类模型
index_type='hnsw',
distance_metric='cosine'
)
def _default_embedding_model(self):
"""默认的文本嵌入模型(简化版)"""
# 实际应用中可以使用Sentence-BERT、OpenAI embeddings等
class SimpleEmbedder:
def embed(self, text):
# 这里应该是一个真实的嵌入模型
# 为示例,返回随机向量
np.random.seed(hash(text) % 10000)
vector = np.random.randn(768)
return vector / np.linalg.norm(vector)
return SimpleEmbedder()
def add_document(self, text, metadata=None):
"""添加文档"""
vector = self.embedding_model.embed(text)
doc_id = self.vector_db.add(vector, metadata or {})
return doc_id
def search_similar_texts(self, query_text, k=10):
"""搜索相似文本"""
query_vector = self.embedding_model.embed(query_text)
results = self.vector_db.search(query_vector, k=k)
# 格式化结果
formatted_results = []
for result in results:
formatted_results.append({
'text': result.metadata.get('text', ''),
'similarity_score': result.score,
'metadata': {k: v for k, v in result.metadata.items() if k != 'text'}
})
return formatted_results
7.2 图像检索系统
python
class ImageRetrievalSystem:
"""基于内容的图像检索系统"""
def __init__(self, feature_extractor=None):
self.feature_extractor = feature_extractor or self._default_feature_extractor()
self.vector_db = VectorDatabase(
dimension=2048, # ResNet-50特征维度
index_type='hnsw',
distance_metric='cosine'
)
self.image_paths = {}
def _default_feature_extractor(self):
"""默认特征提取器(简化版)"""
class SimpleExtractor:
def extract(self, image_path):
# 实际应用中可以使用ResNet、VGG等CNN模型
# 为示例,返回随机向量
np.random.seed(hash(image_path) % 10000)
vector = np.random.randn(2048)
return vector / np.linalg.norm(vector)
return SimpleExtractor()
def index_image(self, image_path, metadata=None):
"""索引图像"""
features = self.feature_extractor.extract(image_path)
image_id = self.vector_db.add(features, metadata or {})
self.image_paths[image_id] = image_path
return image_id
def search_similar_images(self, query_image_path, k=10):
"""搜索相似图像"""
query_features = self.feature_extractor.extract(query_image_path)
results = self.vector_db.search(query_features, k=k)
# 格式化结果
formatted_results = []
for result in results:
image_id = result.id
formatted_results.append({
'image_path': self.image_paths.get(image_id, ''),
'similarity_score': result.score,
'metadata': result.metadata
})
return formatted_results
第八章:总结与展望
8.1 技术总结
通过本文的详细讲解,我们实现了:
-
完整的向量数据库核心架构:包括数据存储、索引管理、查询处理等模块
-
多种索引算法实现:从简单的线性扫描到先进的HNSW图算法
-
工程化优化:内存优化、并行计算、持久化存储等
-
性能测试框架:全面的基准测试和性能分析工具
-
实际应用案例:文本搜索和图像检索的完整实现
8.2 性能表现
根据我们的测试,在不同场景下的性能表现:
-
小规模数据(<10K向量):所有索引都表现良好,线性扫描简单可靠
-
中等规模数据(10K-1M向量):HNSW在精度和速度之间达到最佳平衡
-
大规模数据(>1M向量):需要结合量化、分区等技术进一步优化
8.3 未来发展方向
-
硬件加速:利用GPU、TPU或专用AI芯片加速向量计算
-
分布式架构:支持水平扩展,处理十亿级向量
-
混合搜索:结合向量搜索与传统关键词搜索
-
自动调优:基于数据特征自动选择最佳索引参数
-
云原生支持:容器化部署、Kubernetes编排等
8.4 开源贡献建议
读者可以在此基础上继续开发:
-
实现更多算法(IVF-PQ、SCANN等)
-
添加GraphQL/REST API接口
-
开发客户端SDK(Python、JavaScript、Go等)
-
创建预构建的Docker镜像
-
编写详细文档和教程
结语
手搓向量数据库不仅是技术挑战,更是深入理解现代AI基础设施的绝佳途径。通过本文的实践,我们不仅构建了一个功能完整的向量数据库,更重要的是理解了其背后的设计思想和权衡取舍。
在AI时代,向量数据库正成为连接AI模型与实际应用的关键桥梁。掌握这项技术,将为你在大模型应用、推荐系统、内容理解等领域的探索提供坚实基础。
希望本文能够激发更多开发者深入探索这一领域,共同推动向量数据库技术的发展与创新。
附录A:环境配置与依赖
bash
# 环境配置 python>=3.8 numpy>=1.20.0 numba>=0.55.0 scikit-learn>=1.0.0 msgpack>=1.0.0 matplotlib>=3.5.0 # 用于可视化 tqdm>=4.60.0 # 进度条 # 安装命令 pip install numpy numba scikit-learn msgpack matplotlib tqdm
附录B:完整代码仓库结构
text
vector-db-from-scratch/
├── README.md
├── requirements.txt
├── src/
│ ├── __init__.py
│ ├── distances.py # 距离计算函数
│ ├── indexes/ # 索引实现
│ │ ├── __init__.py
│ │ ├── flat.py # 线性扫描
│ │ ├── hnsw.py # HNSW索引
│ │ ├── lsh.py # LSH索引
│ │ └── kdtree.py # KD-Tree
│ ├── quantization.py # 向量量化
│ ├── storage.py # 持久化存储
│ ├── database.py # 主数据库类
│ └── utils.py # 工具函数
├── benchmarks/ # 性能测试
│ ├── __init__.py
│ ├── benchmark.py
│ └── visualization.py
├── examples/ # 示例代码
│ ├── text_search.py
│ ├── image_retrieval.py
│ └── basic_usage.py
└── tests/ # 单元测试
├── __init__.py
├── test_distances.py
├── test_indexes.py
└── test_database.py
附录C:进一步学习资源
-
论文:
-
"Efficient and robust approximate nearest neighbor search using Hierarchical Navigable Small World graphs" (Y. Malkov, 2016)
-
"Product Quantization for Nearest Neighbor Search" (J. Jégou et al., 2011)
-
"Locality-Sensitive Hashing Scheme Based on p-Stable Distributions" (M. Datar et al., 2004)
-
-
开源项目:
-
FAISS (Facebook AI Similarity Search)
-
Annoy (Spotify)
-
ScaNN (Google)
-
-
在线课程:
-
Coursera: "Similarity Based Retrieval"
-
Stanford CS276: Information Retrieval and Web Search
-
版权声明:本文代码仅供学习参考,可用于个人项目和非商业用途。商业使用请遵守相关开源协议或联系作者。
更新日志:
-
v1.0 (2024): 初始版本,包含核心实现和基础测试
-
计划:添加分布式支持、GPU加速、Web界面等
反馈与贡献:欢迎提出问题、建议和改进,共同完善这个项目。
更多推荐




所有评论(0)