C++AI大模型接入SDK—模型、会话、数据管理


项目地址: 橘子师兄/ai-model-acess-tech - Gitee.com

博客专栏:C++AI大模型接入SDK_橘子师兄的博客-CSDN博客

博主首页:橘子师兄-CSDN博客

1、模型管理

现在deepseek-chat、gpt-4o-mini、gemini-2.0-flash模型已经接入成功,每个模型都对应单独的Provider类,了后续使用简单,再封装一个LLMManager类将模型管理起来,后续通过多态的方式实现模型路由。

//////////////////////////////// LLMManager.h /////////////////////////////////
#pragma once
#include <unordered_map>
#include <memory>
#include <string>
#include <functional>
#include "ILLMProvider.h"

namespace ai_chat_sdk {
    // LLM 管理
    class LLMManager {
    public:
        // 注册 LLM 提供者
        bool registerProvider(const std::string& name,
                              std::unique_ptr<ILLMProvider> provider);
        
        // 初始化指定模型
        bool initModel(const std::string& modelName,
                       const std::map<std::string, std::string>& modelParams);
        
        // 获取可用模型列表
        std::vector<ModelInfo> getAvailableModel() const;
        
        // 检查模型是否可用
        bool isModelAvailable(const std::string& modelName) const;
        
        // 发送消息到指定模型
        std::string sendMessage(const std::string& modelName,
                                const std::vector<Message>& messages,
                                const std::map<std::string, std::string>& requestParams);
        
        // 发送消息到指定模型,流式响应
        std::string sendMessageStream(const std::string& modelName,
                                      const std::vector<Message>& messages,
                                      const std::map<std::string, std::string>& requestParam,
                                      std::function<void(const std::string&, bool)> callback);
        
    private:
        std::unordered_map<std::string, std::unique_ptr<ILLMProvider>> _providers;
        std::unordered_map<std::string, ModelInfo> _models;
    };
} // end ai_chat_sdk

//////////////////////////////// LLMManager.cpp /////////////////////////////////
#include "LLMManager.h"
#include "../../util/my_logger.h"

namespace ai_chat_sdk {
    // 注册 LLM 提供者
    bool LLMManager::registerProvider(const std::string& name,
                                      std::unique_ptr<ILLMProvider> provider)
    {
        // 参数检测
        if (!provider) {
            ERR("Cannot register null provider!!!");
            return false;
        }
        
        // 注意,unique_ptr是防拷贝的,此处只能通过move的方式将资源转移给当前对象
        // 因此,provider在设置的时候,设置成一个临时变量
        _providers[name] = std::move(provider);
        
        // 添加模型信息
        _models[name] = ModelInfo(name);
        
        // 模型注册成功
        INFO("Register LLM Provider : {}", name);
        return true;
    }
    
    // 初始化指定模型
    bool LLMManager::initModel(const std::string& modelName,
                               const std::map<std::string, std::string>& modelParams)
    {
        // 检测模型是否注册
        auto it = _providers.find(modelName);
        if (it == _providers.end()) {
            ERR("Model {} is not registered!!!", modelName);
            return false;
        }
        
        // 模型已经注册过了,可以进行初始化
        bool isSuccess = it->second->initModel(modelParams);
        if (isSuccess) {
            INFO("Model {} init success!!!", modelName);
            _models[modelName]._desc = it->second->getModelDesc();
            _models[modelName]._isAvailable = true;
        } else {
            INFO("Model {} init Failed!!!", modelName);
            _models[modelName]._isAvailable = false;
        }
        return isSuccess;
    }
    
    // 获取可用模型列表
    std::vector<ModelInfo> LLMManager::getAvailableModel() const
    {
        // 从注册的模型列表中筛选出所有可用的模型
        std::vector<ModelInfo> availableModels;
        for (const auto& pair : _models) {
            if (pair.second._isAvailable) {
                availableModels.push_back(pair.second);
            }
        }
        return availableModels;
    }
    
    // 检查模型是否可用
    bool LLMManager::isModelAvailable(const std::string& modelName) const
    {
        auto it = _models.find(modelName);
        return it != _models.end() && it->second._isAvailable;
    }
    
    // 发送消息到指定模型
    std::string LLMManager::sendMessage(const std::string& modelName,
                                        const std::vector<Message>& messages,
                                        const std::map<std::string, std::string>& requestParams)
    {
        // 检测模型是否注册
        auto it = _providers.find(modelName);
        if (it == _providers.end()) {
            ERR("Model {} is not registered!!!", modelName);
            return "";
        }
        
        // 检测模型是否可用
        if (!it->second->isAvailable()) {
            ERR("Model {} is not available!!!", modelName);
            return "";
        }
        
        // 模型已注册,并且是可用的
        return it->second->sendMessage(messages, requestParams);
    }
    
    // 发送消息到指定模型,流式响应
    std::string LLMManager::sendMessageStream(const std::string& modelName,
                                              const std::vector<Message>& messages,
                                              const std::map<std::string, std::string>& requestParam,
                                              std::function<void(const std::string&, bool)> callback)
    {
        // 检测模型是否注册
        auto it = _providers.find(modelName);
        if (it == _providers.end()) {
            ERR("Model {} is not registered!!!", modelName);
            return "";
        }
        
        // 检测模型是否可用
        if (!it->second->isAvailable()) {
            ERR("Model {} is not available!!!", modelName);
            return "";
        }
        
        // 模型已注册,并且是可用的
        return it->second->sendMessageStream(messages, requestParam, callback);
    }
} // end ai_chat_sdk

2、会话管理

假设现在借助LLMManager搭建一个大模型后端服务,用户和模型进行了多轮会话,每个会话中都包含了好多条消息,在某个会话中,和模型聊天的多轮消息该如何管理?多个会话该如何管理?

在这里插入图片描述

解决该问题的一种方式是引入Session。

2.1 会话介绍

会话是用户与大语言模型之间的一系列连续交互,它通过维护上下文和状态信息,确保对话的连贯性和一致性。由于大模型不会为用户管理会话信息,因此需要程序员手动完成会话管理。会话管理涉及以下内容:

[!NOTE]

  1. 保存会话数据。由于存在多组会话,每组会话可能有多条聊天消息,在保存会话数据时可
    以将会话id和具体的会话简历映射关系,方便后续查询。<session_id, session>
  2. 创建会话。在和大模型建立连接前,需要先创建好会话,将后续每次和大模型聊天的消息
    信息要存在该会话中。
  3. 通过会话id获取指定会话。
  4. 向某会话中添加消息。在和大模型聊天时,需要将用户提问、模型恢复的消息保存到会话
    中。
  5. 获某会话中所有消息。当用户点击某个会话时,需要将该会话中的历史消息显示到界面。
  6. 索取所有会话列表。获取所有的会话列表,显示在界面中。
  7. 删除会话。用户可以在页面中删除具体的会话。
  8. 更新会话时间戳。继上次聊完之后,用户再次打开该会话继续和模型聊天时,需更新会话
    时间戳。
  9. 清空所有会话
  10. 获取会话总数
  11. 生成会话id。需要生成唯一的会话id来标记具体某条会话。
  12. 生成消息id。需要使用唯一的消息id来标记具体某条消息。

注意:会话管理模块会保存所有的会话,在同一时刻,可能会对多个会话进行操作,因此创建会话、更新会话、删除会话等时需要考虑线程安全问题。

2.2 会话管理数据结构设计

//////////////////////////// common.h ///////////////////////////////////
// ...
namespace ai_chat_sdk {
    // ...
    // 会话结构
    struct Session {
        std::string session_id;         // 会话id
        std::string model_name;         // 模型名称
        std::vector<Message> messages;  // 消息列表
        std::time_t create_time;        // 创建时间
        std::time_t update_time;        // 更新时间
        
        Session(const std::string& model_name)
            : model_name(model_name)
            , create_time(std::time(nullptr))
            , update_time(std::time(nullptr))
        {}
    };
} // end ai_chat_sdk

2.3 会话实现

SessionMaganer类

///////////////////////////// SessionManager.h
////////////////////////////////////
#include "chat_sdk.h"
#include <vector>
#include <memory>
#include <unordered_map>
#include <mutex>
#include <atomic>
namespace ai_chat_sdk {
class SessionManager{
public:
// 创建新会话. model_name : 模型名称 返回会话id
std::string createSession(const std::string& model_name);
// 获取具体会话,通过会话id获取. session_id : 会话id 返回会话指针
std::shared_ptr<Session> getSession(const std::string& session_id);
// 添加消息到会话. session_id : 会话id, message : 消息 返回是否添加成功
bool addMessage(const std::string& session_id, const Message& message);
// 获取会话历史,即获取具体某次会话的所有消息. session_id : 会话id 返回消息列表
    std::vector<Message> getSessionHistory(const std::string& session_id);
// 获取会话列表 返回会话列表
std::vector<std::string> getSessionList()const;
// 删除会话. session_id : 会话id 返回是否删除成功
bool deleteSession(const std::string& session_id);
// 更新会话时间戳. session_id : 会话id
void updateSessionTimestamp(const std::string& session_id);
// 清空所有会话
void clearAllSessions();
// 获取会话总数 返回会话总数
size_t getSessionCount()const;
private:
// 生成唯一会话id 返回会话id
std::string generateSessionId();
// 生成唯一消息id 返回消息id
std::string generateMessageId();
private:
// 管理所有会话,key为session_id, value为会话指针
std::unordered_map<std::string, std::shared_ptr<Session>> _sessions;
mutable std::mutex _mutex; // 在const成员函数中,可能会修改所的状态,因此需要
mutable修饰
static std::atomic<int64_t> _message_counter; // 记录所有会话中消息总数
static std::atomic<int64_t> _session_counter; // 记录会话总数
};
} // end ai_chat_sdk
//////////////////////////// SessionManager.cpp
/////////////////////////////////////
#include "../include/session_manager.h"
#include <sstream>
#include "../../util/myLogger.h"
#include <iomanip>

namespace ai_chat_sdk {
    std::atomic<int64_t> SessionManager::_message_counter{0};
    std::atomic<int64_t> SessionManager::_session_counter{0};
    
    // 生成消息id
    std::string SessionManager::generateMessageId() {
        // 消息计数自增
        _message_counter.fetch_add(1);
        std::time_t time = std::time(nullptr);
        // 消息id格式:msg_时间戳_消息计数
        std::ostringstream os;
        os << "msg_" << time << "_" << std::setfill('0') << std::setw(8) << _message_counter;
        return os.str();
    }
    
    // 生成会话id
    std::string SessionManager::generateSessionId() {
        // 会话计数自增
        _session_counter.fetch_add(1);
        std::time_t time = std::time(nullptr);
        // 会话id格式:session_时间戳_会话计数
        std::ostringstream os;
        os << "session_" << time << "_" << std::setfill('0') << std::setw(8) << _session_counter;
        return os.str();
    }
    
    // 创建会话
    std::string SessionManager::createSession(const std::string& model_name) {
        std::lock_guard<std::mutex> lock(_mutex);
        // 生成会话id
        std::string session_id = generateSessionId();
        // 创建会话, 设置sessionId
        auto session = std::make_shared<Session>(model_name);
        session->session_id = session_id;
        // 加入会话列表
        _sessions[session_id] = session;
        INFO("create session, session_id: {}, model_name: {}", session_id, model_name);
        return session_id;
    }
    
    // 获取会话
    std::shared_ptr<Session> SessionManager::getSession(const std::string& session_id) {
        std::lock_guard<std::mutex> lock(_mutex);
        auto it = _sessions.find(session_id);
        if (it == _sessions.end()) {
            return nullptr;
        }
        return it->second;
    }
    
    // 添加消息
    bool SessionManager::addMessage(const std::string& session_id, const Message& message) {
        std::lock_guard<std::mutex> lock(_mutex);
        // 获取sessionId对应的会话
        auto it = _sessions.find(session_id);
        if (it == _sessions.end()) {
            return false;
        }
        // 创建消息, 设置messageId
        Message msg(message.role, message.content);
        msg.message_id = generateMessageId();
        // 添加消息到会话
        it->second->messages.push_back(msg);
        it->second->update_time = std::time(nullptr);
        INFO("add message to session {}:{}", session_id, msg.content);
        return true;
    }
    
    // 获取会话历史
    std::vector<Message> SessionManager::getSessionHistory(const std::string& session_id) {
        std::lock_guard<std::mutex> lock(_mutex);
        auto it = _sessions.find(session_id);
        if (it == _sessions.end()) {
            return {};
        }
        return it->second->messages;
    }
    
    // 获取会话列表, 包含sessionId和modelName
    std::vector<std::string> SessionManager::getSessionList() const {
        std::lock_guard<std::mutex> lock(_mutex);
        std::vector<std::pair<std::time_t, std::shared_ptr<const Session>>> temp;
        temp.reserve(_sessions.size());
        // 填充临时会话
        for (const auto& pair : _sessions) {
            const auto& session_ptr = pair.second;
            temp.emplace_back(session_ptr->update_time, session_ptr);
        }
        // 按更新时间排序, > 排降序
        std::sort(temp.begin(), temp.end(), [](const auto& a, const auto& b) {
            return a.first > b.first;
        });
        // 构造返回列表
        std::vector<std::string> session_ids;
        session_ids.reserve(temp.size());
        for (const auto& item : temp) {
            session_ids.push_back(item.second->session_id);
        }
        return session_ids;
    }
    
    // 删除会话
    bool SessionManager::deleteSession(const std::string& session_id) {
        std::lock_guard<std::mutex> lock(_mutex);
        auto it = _sessions.find(session_id);
        if (it == _sessions.end()) {
            INFO("session {} not found for delete", session_id);
            return false;
        }
        INFO("delete session {}", session_id);
        _sessions.erase(it);
        return true;
    }
    
    // 更新会话时间戳
    void SessionManager::updateSessionTimestamp(const std::string& session_id) {
        std::lock_guard<std::mutex> lock(_mutex);
        auto it = _sessions.find(session_id);
        if (it != _sessions.end()) {
            it->second->update_time = std::time(nullptr);
        }
    }
    
    // 清空所有会话
    void SessionManager::clearAllSessions() {
        std::lock_guard<std::mutex> lock(_mutex);
        INFO("clear {} sessions", _sessions.size());
        _sessions.clear();
    }
    
    // 获取会话总数
    size_t SessionManager::getSessionCount() const {
        std::lock_guard<std::mutex> lock(_mutex);
        return _sessions.size();
    }
} // end ai_chat_sdk

3、数据管理

3.1 持久化存储

​ 目前在程序中创建的会话以及和模型交互的历史数据都存储在内存中,而内存是一种带电存储设备,一旦断电或者程序重启之后数据就丢失了,想要拿到之前的会话数据就回天乏术了。Deepseek、ChatGPT提供的网页聊天机器人都是支持会话存储的,方便用户查看以往的会话记录。因此,ChatSDK中也需要添加数据持久化存储的功能,即将数据永久的保存下来,不会因为断电或关闭程序而引起数据丢失。

​ 持久化存储最常见的方案就是将数据存储到数据库,由于ChatSDK实现好之后,需要共享给其他需要
的用户使用,为了减少用户使用成本(使用时不需要安装过多依赖或三方工具),本项目采用SQLite实现
数据存储。

3.2 数据管理模块设计

接下来将会话信息和与模型聊天的数据持久化存储到SQLite数据库,存储时需要注意线程安全问题。

#pragma once
#include <string>
#include <vector>
#include <memory>
#include <mutex>
#include <sqlite3.h>
#include "common.h"

namespace ai_chat_sdk {
    class DataManager {
    public:
        DataManager(const std::string& dbName);
        ~DataManager();
        
        // Session相关操作
        // 插入新会话
        bool insertSession(const Session& session);
        // 获取指定会话
        std::shared_ptr<Session> getSession(const std::string& sessionId) const;
        // 更新指定会话的时间戳
        void updateSessionTimestamp(const std::string& sessionId, std::time_t timestamp);
        // 删除指定会话,注意:会话删除后,该会话下管理的历史消息也要同步删除
        bool deleteSession(const std::string& sessionId);
        // 获取所有会话id
        std::vector<std::string> getAllSessionIds() const;
        // 获取所有会话信息
        std::vector<std::shared_ptr<Session>> getAllSessions() const;
        // 删除所有会话
        bool clearAllSessions();
        // 获取所有会话个数
        size_t getSessionCount() const;
        
        // Message相关操作
        bool insertMessage(const std::string& sessionId, const Message& message);
        // 获取指定会话的历史消息
        std::vector<Message> getMessagesBySessionId(const std::string& sessionId) const;
        // 删除指定会话的历史消息
        bool deleteMessagesBySessionId(const std::string& sessionId);
        
    private:
        // 初始化数据库(创建表)
        bool initDatabase();
        // 执行SQL语句的工具函数
        bool executeSQL(const std::string& sql);
        
    private:
        sqlite3* _db;
        std::string _dbName;
        mutable std::mutex _mutex; // mutable关键字允许在const成员函数中修改这个互斥锁,
                                   // 用于线程安全
    };
} // end ai_chat_sdk

3.3 数据管理模块实现

数据库连接:

#include "../include/dataManager.h"
#include "../include/util/my_logger.h"

namespace ai_chat_sdk {
    DataManager::DataManager(const std::string& dbName)
        : _db(nullptr)
        , _dbName(dbName) {
        // 创建或打开数据库连接
        if (sqlite3_open(_dbName.c_str(), &_db) != SQLITE_OK) {
            ERR("Failed to open database: {}", sqlite3_errmsg(_db));
        }
        INFO("Database opened successfully: {}", _dbName);
        
        // 初始化数据库(创建表)
        if (!initDatabase()) {
            sqlite3_close(_db);
            _db = nullptr;
            ERR("Failed to initialize database");
        }
    }
    
    DataManager::~DataManager() {
        if (_db) {
            sqlite3_close(_db);
            _db = nullptr;
            INFO("Database closed: {}", _dbName);
        }
    }
    
    bool DataManager::initDatabase() {
        std::lock_guard<std::mutex> lock(_mutex);
        
        // 创建sessions表
        const std::string createSessionsTable =
            "CREATE TABLE IF NOT EXISTS sessions ("
            "session_id TEXT PRIMARY KEY, "
            "model_name TEXT NOT NULL, "
            "create_time INTEGER NOT NULL, "
            "update_time INTEGER NOT NULL"
            ");";
        if (!executeSQL(createSessionsTable)) {
            return false;
        }
        
        // 创建messages表
        const std::string createMessagesTable =
            "CREATE TABLE IF NOT EXISTS messages ("
            "message_id TEXT PRIMARY KEY, "
            "session_id TEXT NOT NULL, "
            "role TEXT NOT NULL, "
            "content TEXT NOT NULL, "
            "timestamp INTEGER NOT NULL, "
            "FOREIGN KEY (session_id) REFERENCES sessions (session_id) ON DELETE CASCADE"
            ");";
        if (!executeSQL(createMessagesTable)) {
            return false;
        }
        
        // 创建索引以加速查询
        const std::string createMessageIndex =
            "CREATE INDEX IF NOT EXISTS idx_messages_session_id ON messages (session_id);";
        if (!executeSQL(createMessageIndex)) {
            return false;
        }
        
        INFO("Database initialized successfully");
        return true;
    }
    
    // 辅助方法
    bool DataManager::executeSQL(const std::string& sql) {
        char* errMsg = nullptr;
        if (sqlite3_exec(_db, sql.c_str(), nullptr, nullptr, &errMsg) != SQLITE_OK) {
            ERR("Failed to execute SQL: {}", errMsg);
            sqlite3_free(errMsg);
            return false;
        }
        return true;
    }
} // end ai_chat_sdk

插入会话

bool DataManager::insertSession(const Session& session) {
    std::lock_guard<std::mutex> lock(_mutex);
    
    const std::string sql =
        "INSERT INTO sessions (session_id, model_name, create_time, update_time) "
        "VALUES (?, ?, ?, ?);";
    
    // 准备SQL语句
    sqlite3_stmt* stmt;
    int rc = sqlite3_prepare_v2(_db, sql.c_str(), -1, &stmt, nullptr);
    if (rc != SQLITE_OK) {
        ERR("Failed to prepare SQL: {}", sqlite3_errmsg(_db));
        return false;
    }
    
    // 绑定参数
    sqlite3_bind_text(stmt, 1, session.session_id.c_str(), -1, SQLITE_TRANSIENT);
    sqlite3_bind_text(stmt, 2, session.model_name.c_str(), -1, SQLITE_TRANSIENT);
    sqlite3_bind_int64(stmt, 3, static_cast<int64_t>(session.create_time));
    sqlite3_bind_int64(stmt, 4, static_cast<int64_t>(session.update_time));
    
    // 执行SQL语句, 检查是否插入成功
    rc = sqlite3_step(stmt);
    if (rc != SQLITE_DONE) {
        ERR("Failed to insert session: {}", sqlite3_errmsg(_db));
        sqlite3_finalize(stmt);
        return false;
    }
    
    sqlite3_finalize(stmt); // 执行完成后, 释放stmt
    INFO("Inserted session: {}", session.session_id);
    return true;
}

获取指定会话

std::shared_ptr<Session> DataManager::getSession(const std::string& sessionId) const {
    std::lock_guard<std::mutex> lock(_mutex);
    
    // 准备SQL语句
    const std::string sql = "SELECT model_name, create_time, update_time FROM sessions WHERE session_id = ?;";
    sqlite3_stmt* stmt;
    int rc = sqlite3_prepare_v2(_db, sql.c_str(), -1, &stmt, nullptr);
    if (rc != SQLITE_OK) {
        ERR("Failed to prepare SQL: {}", sqlite3_errmsg(_db));
        return nullptr;
    }
    
    // 绑定参数
    sqlite3_bind_text(stmt, 1, sessionId.c_str(), -1, SQLITE_TRANSIENT);
    
    // 执行SQL语句, 检查是否查询到会话
    rc = sqlite3_step(stmt);
    if (rc != SQLITE_ROW) {
        sqlite3_finalize(stmt);
        return nullptr;
    }
    
    // 创建Session对象
    std::string modelName(reinterpret_cast<const char*>(sqlite3_column_text(stmt, 0)));
    auto session = std::make_shared<Session>(modelName);
    session->session_id = sessionId;
    session->create_time = static_cast<std::time_t>(sqlite3_column_int64(stmt, 1));
    session->update_time = static_cast<std::time_t>(sqlite3_column_int64(stmt, 2));
    
    // 释放stmt
    sqlite3_finalize(stmt);
    
    // 获取该会话的所有消息
    session->messages = getMessagesBySessionId(sessionId);
    return session;
}

更新指定会话时间戳

void DataManager::updateSessionTimestamp(const std::string& sessionId,
                                        std::time_t timestamp)
{
    std::lock_guard<std::mutex> lock(_mutex);
    
    // 准备SQL语句
    const std::string sql = "UPDATE sessions SET update_time = ? WHERE session_id = ?;";
    
    // 绑定参数
    sqlite3_stmt* stmt;
    int rc = sqlite3_prepare_v2(_db, sql.c_str(), -1, &stmt, nullptr);
    if (rc != SQLITE_OK) {
        ERR("Failed to prepare SQL: {}", sqlite3_errmsg(_db));
        return;
    }
    
    // 绑定参数
    sqlite3_bind_int64(stmt, 1, static_cast<int64_t>(timestamp));
    sqlite3_bind_text(stmt, 2, sessionId.c_str(), -1, SQLITE_TRANSIENT);
    
    // 执行SQL语句, 检查是否更新成功
    rc = sqlite3_step(stmt);
    if (rc != SQLITE_DONE) {
        ERR("Failed to update session: {}", sqlite3_errmsg(_db));
        sqlite3_finalize(stmt);
        return;
    }
    
    sqlite3_finalize(stmt);
    INFO("Updated session timestamp: {}", sessionId);
}

删除指定会话

bool DataManager::deleteSession(const std::string& sessionId) {
    // 从数据库中删除会话的所有消息,在删除会话列表
    // 否则,该函数中已经加锁了,在未退出前调用deleteMessagesBySessionId,会重新加锁
    // 就导致死锁了
    deleteMessagesBySessionId(sessionId);
    
    std::lock_guard<std::mutex> lock(_mutex);
    const std::string sql = "DELETE FROM sessions WHERE session_id = ?;";
    sqlite3_stmt* stmt;
    int rc = sqlite3_prepare_v2(_db, sql.c_str(), -1, &stmt, nullptr);
    if (rc != SQLITE_OK) {
        ERR("Failed to prepare SQL: {}", sqlite3_errmsg(_db));
        return false;
    }
    
    sqlite3_bind_text(stmt, 1, sessionId.c_str(), -1, SQLITE_TRANSIENT);
    rc = sqlite3_step(stmt);
    if (rc != SQLITE_DONE) {
        ERR("Failed to delete session: {}", sqlite3_errmsg(_db));
        sqlite3_finalize(stmt);
        return false;
    }
    
    sqlite3_finalize(stmt);
    INFO("Deleted session: {}", sessionId);
    return true;
}

获取所有会话id

std::vector<std::string> DataManager::getAllSessionIds() const {
    std::lock_guard<std::mutex> lock(_mutex);
    std::vector<std::string> sessionIds;
    
    const std::string sql = "SELECT session_id FROM sessions ORDER BY update_time DESC;";
    sqlite3_stmt* stmt;
    
    // 准备SQL语句
    int rc = sqlite3_prepare_v2(_db, sql.c_str(), -1, &stmt, nullptr);
    if (rc != SQLITE_OK) {
        ERR("Failed to prepare SQL: {}", sqlite3_errmsg(_db));
        return sessionIds;
    }
    
    while ((rc = sqlite3_step(stmt)) == SQLITE_ROW) {
        sessionIds.push_back(reinterpret_cast<const char*>(sqlite3_column_text(stmt, 0)));
    }
    
    sqlite3_finalize(stmt);
    return sessionIds;
}

获取所有Session信息,并按照更新时间降序排列:

std::vector<std::shared_ptr<Session>> DataManager::getAllSessions() const {
    std::lock_guard<std::mutex> lock(_mutex);
    std::vector<std::shared_ptr<Session>> sessions;
    
    const std::string sql = "SELECT session_id, model_name, create_time, update_time FROM sessions ORDER BY update_time DESC;";
    sqlite3_stmt* stmt;
    
    // 准备SQL语句
    int rc = sqlite3_prepare_v2(_db, sql.c_str(), -1, &stmt, nullptr);
    if (rc != SQLITE_OK) {
        ERR("Failed to prepare SQL: {}", sqlite3_errmsg(_db));
        return sessions;
    }
    
    while ((rc = sqlite3_step(stmt)) == SQLITE_ROW) {
        std::string sessionId(reinterpret_cast<const char*>(sqlite3_column_text(stmt, 0)));
        std::string modelName(reinterpret_cast<const char*>(sqlite3_column_text(stmt, 1)));
        
        auto session = std::make_shared<Session>(modelName);
        session->session_id = sessionId;
        session->create_time = static_cast<std::time_t>(sqlite3_column_int64(stmt, 2));
        session->update_time = static_cast<std::time_t>(sqlite3_column_int64(stmt, 3));
        
        // 此处暂不加载历史消息,需要历史消息可以通过会话id单独获取
        sessions.push_back(session);
    }
    
    sqlite3_finalize(stmt);
    return sessions;
}

获取会话总个数

size_t DataManager::getSessionCount() const {
    std::lock_guard<std::mutex> lock(_mutex);
    
    const std::string sql = "SELECT COUNT(*) FROM sessions;";
    sqlite3_stmt* stmt;
    
    // 准备SQL语句
    int rc = sqlite3_prepare_v2(_db, sql.c_str(), -1, &stmt, nullptr);
    if (rc != SQLITE_OK) {
        ERR("Failed to prepare SQL: {}", sqlite3_errmsg(_db));
        return 0;
    }
    
    // 执行SQL语句
    rc = sqlite3_step(stmt);
    size_t count = 0;
    if (rc == SQLITE_ROW) {
        count = static_cast<size_t>(sqlite3_column_int64(stmt, 0));
    }
    
    sqlite3_finalize(stmt);
    INFO("Session count: {}", count);
    return count;
}

清空所有会话

// 删除所有会话
bool DataManager::clearAllSessions() {
    std::lock_guard<std::mutex> lock(_mutex);
    
    // 构建SQL语句
    std::string deleteSQL = R"(
        DELETE FROM sessions;
    )";
    
    // 准备SQL语句
    sqlite3_stmt* stmt;
    int rc = sqlite3_prepare_v2(_db, deleteSQL.c_str(), -1, &stmt, nullptr);
    if (rc != SQLITE_OK) {
        ERR("clearAllSessions - 准备语句失败:{}", sqlite3_errmsg(_db));
        return false;
    }
    
    // 执行SQL语句
    rc = sqlite3_step(stmt);
    if (rc != SQLITE_DONE) {
        ERR("clearAllSessions - 执行语句失败:{}", sqlite3_errmsg(_db));
        sqlite3_finalize(stmt);
        return false;
    }
    
    // 释放语句
    sqlite3_finalize(stmt);
    INFO("clearAllSessions - 删除所有会话成功");
    return true;
}

在指定会话中插入一条消息

bool DataManager::insertMessage(const std::string& sessionId, const Message& message) {
    std::lock_guard<std::mutex> lock(_mutex);
    
    const std::string sql =
        "INSERT INTO messages (message_id, session_id, role, content, timestamp) "
        "VALUES (?, ?, ?, ?, ?);";
    
    // 准备SQL语句
    sqlite3_stmt* stmt;
    int rc = sqlite3_prepare_v2(_db, sql.c_str(), -1, &stmt, nullptr);
    if (rc != SQLITE_OK) {
        ERR("Failed to prepare SQL: {}", sqlite3_errmsg(_db));
        return false;
    }
    
    // 绑定参数
    sqlite3_bind_text(stmt, 1, message.message_id.c_str(), -1, SQLITE_TRANSIENT);
    sqlite3_bind_text(stmt, 2, sessionId.c_str(), -1, SQLITE_TRANSIENT);
    sqlite3_bind_text(stmt, 3, message.role.c_str(), -1, SQLITE_TRANSIENT);
    sqlite3_bind_text(stmt, 4, message.content.c_str(), -1, SQLITE_TRANSIENT);
    sqlite3_bind_int64(stmt, 5, static_cast<int64_t>(message.timestamp));
    
    // 执行SQL语句, 检查是否更新成功
    rc = sqlite3_step(stmt);
    if (rc != SQLITE_DONE) {
        ERR("Failed to insert message: {}", sqlite3_errmsg(_db));
        sqlite3_finalize(stmt);
        return false;
    }
    sqlite3_finalize(stmt);
    
    // 同时更新session的update_time
    const std::string updateSessionSql =
        "UPDATE sessions SET update_time = ? WHERE session_id = ?;";
    sqlite3_stmt* updateStmt;
    rc = sqlite3_prepare_v2(_db, updateSessionSql.c_str(), -1, &updateStmt, nullptr);
    if (rc == SQLITE_OK) {
        sqlite3_bind_int64(updateStmt, 1, static_cast<int64_t>(std::time(nullptr)));
        sqlite3_bind_text(updateStmt, 2, sessionId.c_str(), -1, SQLITE_TRANSIENT);
        sqlite3_step(updateStmt);
        sqlite3_finalize(updateStmt);
    } else {
        ERR("Failed to prepare SQL: {}", sqlite3_errmsg(_db));
    }
    
    DBG("Inserted message: {} into session: {}", message.message_id, sessionId);
    return true;
}

获取指定会话历史消息

std::vector<Message> DataManager::getMessagesBySessionId(const std::string& sessionId) const {
    std::lock_guard<std::mutex> lock(_mutex);
    std::vector<Message> messages;
    
    const std::string sql =
        "SELECT message_id, role, content, timestamp FROM messages WHERE "
        "session_id = ? ORDER BY timestamp ASC;";
    
    // 准备SQL语句
    sqlite3_stmt* stmt;
    int rc = sqlite3_prepare_v2(_db, sql.c_str(), -1, &stmt, nullptr);
    if (rc != SQLITE_OK) {
        ERR("Failed to prepare SQL: {}", sqlite3_errmsg(_db));
        return messages;
    }
    
    // 绑定参数
    sqlite3_bind_text(stmt, 1, sessionId.c_str(), -1, SQLITE_TRANSIENT);
    
    while ((rc = sqlite3_step(stmt)) == SQLITE_ROW) {
        std::string messageId(reinterpret_cast<const char*>(sqlite3_column_text(stmt, 0)));
        std::string role(reinterpret_cast<const char*>(sqlite3_column_text(stmt, 1)));
        std::string content(reinterpret_cast<const char*>(sqlite3_column_text(stmt, 2)));
        std::time_t timestamp = static_cast<std::time_t>(sqlite3_column_int64(stmt, 3));
        
        Message message(role, content);
        message.message_id = messageId;
        message.timestamp = timestamp;
        messages.push_back(message);
    }
    
    sqlite3_finalize(stmt);
    return messages;
}

删除指定会话的历史消息

bool DataManager::deleteMessagesBySessionId(const std::string& sessionId) {
    std::lock_guard<std::mutex> lock(_mutex);
    
    const std::string sql = "DELETE FROM messages WHERE session_id = ?;";
    
    // 准备SQL语句
    sqlite3_stmt* stmt;
    int rc = sqlite3_prepare_v2(_db, sql.c_str(), -1, &stmt, nullptr);
    if (rc != SQLITE_OK) {
        ERR("Failed to prepare SQL: {}", sqlite3_errmsg(_db));
        return false;
    }
    
    // 绑定参数
    sqlite3_bind_text(stmt, 1, sessionId.c_str(), -1, SQLITE_TRANSIENT);
    
    // 执行SQL语句, 检查是否更新成功
    rc = sqlite3_step(stmt);
    if (rc != SQLITE_DONE) {
        ERR("Failed to delete messages: {}", sqlite3_errmsg(_db));
        sqlite3_finalize(stmt);
        return false;
    }
    
    sqlite3_finalize(stmt);
    INFO("Deleted all messages for session: {}", sessionId);
    return true;
}

3.4 会话数据同步到数据库

每次创建新会话、删除会话、更新会话、以及生成消息记录等都需要同步到SQLite中,因此让SessionManager类持有一个DataManager的对象,当发生上述操作时,通过DataManager的对象持久化会话数据。

////////////////////////////////// SessionManager.h
///////////////////////////////
// ...
class SessionManager {
    // ...
private:
    // ...
    // 通过会话数据到数据库
    DataManager _dataManager;
};

////////////////////////////////// SessionManager.cpp
///////////////////////////////
SessionManager::SessionManager()
    : _dataManager("chatDB.db")
{
    // 获取所有会话
    auto sessions = _dataManager.getAllSessions();
    for (auto& session : sessions) {
        _sessions[session->session_id] = session;
    }
    INFO("SessionManager init, session count: {}", _sessions.size());
}

// ...
// 创建会话
std::string SessionManager::createSession(const std::string& model_name) {
    std::string session_id;
    std::shared_ptr<Session> session;
    
    _mutex.lock();
    // 生成会话id
    // ...
    _mutex.unlock();
    
    // 插入会话到数据库
    _dataManager.insertSession(*session);
    return session_id;
}

// 获取会话
std::shared_ptr<Session> SessionManager::getSession(const std::string& session_id) {
    // 先快速检查内存中是否存在
    _mutex.lock();
    auto it = _sessions.find(session_id);
    if (it != _sessions.end()) {
        _mutex.unlock();
        // 获取该会话的消息列表
        it->second->messages = _dataManager.getMessagesBySessionId(session_id);
        return it->second;
    }
    _mutex.unlock();
    
    // 内存中没有, 从数据库中查,顺便内容中也存储一份
    auto sessionPtr = _dataManager.getSession(session_id);
    if (sessionPtr) {
        // 在更新内存前再次加锁,防止其他线程已经添加
        std::lock_guard<std::mutex> lock(_mutex);
        // 再次检查,因为可能在获取锁的过程中,其他线程已经添加了这个session
        auto it = _sessions.find(session_id);
        if (it == _sessions.end()) {
            _sessions[session_id] = sessionPtr;
        }
        // 获取该会话的消息列表
        sessionPtr->messages = _dataManager.getMessagesBySessionId(session_id);
        return sessionPtr;
    }
    
    WARN("session {} not found", session_id);
    return nullptr;
}

// 添加消息
bool SessionManager::addMessage(const std::string& session_id, const Message& message) {
    _mutex.lock();
    // 获取sessionId对应的会话
    // ...
    _mutex.unlock();
    
    // 插入消息到数据库
    _dataManager.insertMessage(session_id, msg);
    return true;
}

// 获取会话历史
std::vector<Message> SessionManager::getSessionHistory(const std::string& session_id) {
    _mutex.lock();
    // 先从内存中获取,如果内存中获取不到,再到数据库获取
    // ...
    _mutex.unlock();
    
    // 从数据库中获取消息
    return _dataManager.getMessagesBySessionId(session_id);
}

// 获取会话列表, 包含sessionId和modelName
std::vector<std::string> SessionManager::getSessionList() const {
    // 先从数据库获取
    auto sessions = _dataManager.getAllSessions();
    std::lock_guard<std::mutex> lock(_mutex);
    
    std::vector<std::pair<std::time_t, std::shared_ptr<const Session>>> temp;
    temp.reserve(_sessions.size());
    
    // 填充临时会话
    for (const auto& pair : _sessions) {
        const auto& session_ptr = pair.second;
        temp.emplace_back(session_ptr->update_time, session_ptr);
    }
    
    // 合并数据库会话和内存会话
    for (const auto& session : sessions) {
        if (_sessions.find(session->session_id) == _sessions.end()) {
            temp.emplace_back(session->update_time, session);
        }
    }
    
    // 按更新时间排序, > 排降序
    // ...
}

// 删除会话
bool SessionManager::deleteSession(const std::string& session_id) {
    _mutex.lock();
    // 先删除内存中会话信息,然后删除数据库
    // ...
    _mutex.unlock();
    
    // 从数据库中删除会话以及该会话对应的消息列表
    _dataManager.deleteSession(session_id);
    return true;
}

void SessionManager::updateSessionTimestamp(const std::string& session_id) {
    _mutex.lock();
    // 先更新内容中指定会话的时间戳,然后同步到数据库
    // ...
    _mutex.unlock();
    
    // 更新数据库中的会话时间戳
    _dataManager.updateSessionTimestamp(session_id, it->second->update_time);
}

// 获取会话总数
size_t SessionManager::getSessionCount() const {
    return _dataManager.getSessionCount();
}
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐