AI数学基础(一):线性代数核心 —— 向量/矩阵运算的 Java 实现

——别让“黑盒调用”掩盖了你对智能的理解

大家好,我是那个总在模型输出异常时翻线性代数笔记、又在 Java 里手写矩阵乘法的老架构。今天不聊 Transformer,也不谈反向传播——我们回到 AI 的地基:

当你从电科金仓 KingbaseES(KES)读出一个 768 维用户 embedding,如何理解它?又如何用 Java 对它做有意义的操作?

很多人直接调 dl4j.nn.output(),看到一个 float 数组就完事。
但如果你不知道这个数组代表什么、为什么能做点积、为什么 L2 归一化能让相似度更准——那你只是在“使用 AI”,而不是“构建 AI”。

线性代数不是数学课的作业,而是AI 工程的语言。今天我们就用 Java,把向量、矩阵、内积、范数这些概念,变成可运行、可调试、可优化的代码。


一、为什么 Java 工程师要懂线性代数?

因为 AI 的本质,是在高维空间中寻找结构

  • 用户行为 → 向量
  • 特征组合 → 矩阵
  • 相似度计算 → 点积
  • 模型参数 → 权重矩阵

而 KES,正是这些向量的“家”。你在 KES 表 中存的 BYTEA 字段,背后就是一个 float 数组——一个向量。

如果你连向量加法都写不利索,怎么判断 embedding 是否漂移?怎么设计缓存淘汰策略?怎么 debug 模型输出异常?

所以,今天我们不依赖 DL4J,从零实现核心运算,只为建立直觉。


二、向量:不只是数组,而是有方向的量

定义一个向量类

public class Vector {
    private final float[] data;

    public Vector(float[] data) {
        this.data = Objects.requireNonNull(data).clone(); // 防篡改
    }

    public int dim() { return data.length; }
    public float get(int i) { return data[i]; }
    public float[] toArray() { return data.clone(); }
}

✅ 原则:不可变对象,避免副作用。


核心运算 1:点积(Dot Product)—— 相似度的基石

点积衡量两个向量的“方向一致性”。值越大,越相似。

public static float dot(Vector a, Vector b) {
    if (a.dim() != b.dim()) {
        throw new IllegalArgumentException("维度不匹配");
    }
    float sum = 0.0f;
    for (int i = 0; i < a.dim(); i++) {
        sum += a.get(i) * b.get(i);
    }
    return sum;
}

📌 应用:计算用户 A 与商品 B 的匹配分

float score = Vector.dot(userEmbedding, itemEmbedding);

核心运算 2:L2 范数(Euclidean Norm)—— 向量的“长度”

public static float norm(Vector v) {
    float sum = 0.0f;
    for (float x : v.toArray()) {
        sum += x * x;
    }
    return (float) Math.sqrt(sum);
}

💡 为什么重要?

  • 归一化后,点积 = 余弦相似度(不受向量长度影响)
  • 可用于检测异常 embedding(范数突增)

核心运算 3:余弦相似度(Cosine Similarity)

public static float cosineSimilarity(Vector a, Vector b) {
    float dot = dot(a, b);
    float normA = norm(a);
    float normB = norm(b);
    if (normA == 0 || normB == 0) return 0.0f;
    return dot / (normA * normB);
}

✅ 这才是推荐系统、语义搜索的真实度量标准。


三、矩阵:线性变换的载体

在 AI 中,矩阵常用于:

  • 全连接层权重(W × x + b)
  • 特征投影(降维/升维)
  • 批量 embedding 处理

简单矩阵类

public class Matrix {
    private final float[][] data;
    private final int rows, cols;

    public Matrix(int rows, int cols) {
        this.rows = rows;
        this.cols = cols;
        this.data = new float[rows][cols];
    }

    public float get(int i, int j) { return data[i][j]; }
    public void set(int i, int j, float value) { data[i][j] = value; }
    public int rows() { return rows; }
    public int cols() { return cols; }
}

核心运算:矩阵 × 向量

这是神经网络前向传播的核心。

public static Vector multiply(Matrix W, Vector x) {
    if (W.cols() != x.dim()) {
        throw new IllegalArgumentException("矩阵列数 ≠ 向量维度");
    }
    float[] result = new float[W.rows()];
    for (int i = 0; i < W.rows(); i++) {
        float sum = 0.0f;
        for (int j = 0; j < W.cols(); j++) {
            sum += W.get(i, j) * x.get(j);
        }
        result[i] = sum;
    }
    return new Vector(result);
}

⚠️ 注意:这是 O(n²) 操作。生产环境应使用 ND4J(DL4J 底层)Eclipse Collections 优化。


四、与 KES 协同:从数据库到向量运算

现在,把理论落地。

假设你在 KES 中有如下表:

CREATE TABLE ai_features.user_embedding (
    user_id VARCHAR(64) PRIMARY KEY,
    embedding BYTEA  -- 存储 float[768]
);

步骤 1:从 KES 读取 embedding

public Vector loadUserEmbedding(String userId) {
    String sql = "SELECT embedding FROM ai_features.user_embedding WHERE user_id = ?";
    try (Connection conn = KESDataSource.getConnection();
         PreparedStatement ps = conn.prepareStatement(sql)) {
        
        ps.setString(1, userId);
        ResultSet rs = ps.executeQuery();
        if (rs.next()) {
            byte[] bytes = rs.getBytes("embedding");
            float[] array = deserializeFloatArray(bytes); // 自定义反序列化
            return new Vector(array);
        }
    } catch (SQLException e) {
        throw new RuntimeException("加载 embedding 失败", e);
    }
    return null;
}

🔗 驱动请从 电科金仓官网下载,确保兼容 BYTEA 类型。


步骤 2:计算用户相似度

Vector userA = loadUserEmbedding("U123");
Vector userB = loadUserEmbedding("U456");

if (userA != null && userB != null) {
    float sim = Vector.cosineSimilarity(userA, userB);
    System.out.printf("用户 U123 与 U456 相似度: %.4f%n", sim);
    
    // 若相似度 > 0.8,可触发“好友推荐”
    if (sim > 0.8f) {
        triggerRecommendation("U123", "U456");
    }
}

五、性能提醒:别在生产环境手写矩阵乘法

我让你手写,是为了建立直觉,不是让你上线。

真实场景中,请用专业库:

  • DL4J / ND4J:支持 CPU/GPU 加速、自动内存管理;
  • EJML:轻量级,适合嵌入式;
  • Apache Commons Math:通用数学库。

但无论用哪个库,你必须理解背后的线性代数——否则,当模型输出 NaN 时,你只能祈祷。


结语:数学是 AI 的母语

AI 不是魔法,而是数学在数据上的投影。
而 Java,作为企业级系统的主力语言,必须能承载这份理性。

当你能从 KES 中读出一个向量,理解它的几何意义,并用几行代码计算出业务价值——你就不再是“调包侠”,而是一名真正的 AI 工程师。

下一期,我们会讲:AI数学基础(二):概率与统计 —— 从贝叶斯到 A/B 测试的 Java 实践
敬请期待。

—— 一位相信“不懂数学的 AI 工程师,就像不会看地图的司机”的架构师

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐