题目链接:LeetCode 373. Find K Pairs with Smallest Sums。leetcode
给定两个非降序排序的整数数组 nums1、nums2 和整数 k,定义一对 (u, v) 由 nums1 中的一个元素和 nums2 中的一个元素组成,返回和最小的 k 个数对。leetcode

题目分析:暴力为什么不行

最直接的思路是:
枚举所有 pair (nums1[i], nums2[j])。
计算和,按和排序,取前 k 个。
问题在于:
nums1.length、nums2.length 最多都是 10^5,pair 总数是 10^10 级别,根本枚举不完。leetcode
时间和空间都会爆炸。
所以关键是:避免生成全部 pair,只生成前 k 小那一部分。leetcode

正解核心思想:有序 + 小根堆 + 按需扩展

由于两个数组都是非降序的,有两个重要性质:leetcode
全局最小和的 pair 是 (nums1[0], nums2[0])。
对于某个 (i, j),它的"邻居" (i+1, j)、(i, j+1) 的和只会更大或相等。
可以把下标 (i, j) 看成二维网格上的点,值是 nums1[i] + nums2[j],此时问题变成:
在一个"行、列都有序"的二维网格中,找和最小的 k 个格子。
做法类似多路归并 / Dijkstra:
用 小根堆 按 sum 升序存 (sum, i, j)。
初始把 (0,0) 入堆。
每次:
弹出堆顶 (i, j),这是当前全局最小的一个 pair,记入结果。
把它的右邻 (i, j+1) 和下邻 (i+1, j) 入堆。
用 visited 去重,避免同一个 (i,j) 反复入堆。
重复 pop,直到拿到 k 个结果或堆空。
注意:
这里小根堆 不需要限制容量为 k,只要从中 pop k 次即可。
和"Kth Largest Element in an Array"那题用"大根堆或固定容量小根堆"维护前 k 大不同,这里堆是用来 依次生成最小的 k 个。

第一个坑:visited 开成 nums1Size × nums2Size 导致 MLE

一开始的实现是这样的:

char** create_visited(int nums1Size, int nums2Size) {
    char **visited = (char **)malloc(nums1Size * sizeof(char *));
    for (int i = 0; i < nums1Size; i++) { 
        visited[i] = (char *)calloc(nums2Size, sizeof(char));
    }
    return visited;
}

最坏情况:
nums1Size = nums2Size = 10^5,visited 大小是 10^10 个 char ≈ 10GB,直接 MLE。leetcode
虽然 BFS 实际上不会访问这么多格子,但 visited 被提前开满了整个矩阵,这是浪费。
要点:visited 的规模必须和实际访问的节点数同阶(大致 O(k)),而不是 O(n1 × n2)。

第二个坑:堆容量设计不当

堆的实现使用自定义结构:

typedef struct pair {
    int index1;
    int index2;
    int sum;
} pair_t;

typedef struct heap {
    pair_t *data;
    int cnt;
    int capacity;
} heap_t;

原始做法中,堆容量设为:

heap_capacity = (nums1Size + nums2Size > k) ? (nums1Size + nums2Size) : k;
heap_t *heap = heap_create(heap_capacity);

在目前代码逻辑下,大致安全,但如果后续修改逻辑,不小心让入堆次数爆掉,heap_push 静默 return,会引入隐藏 bug。更合理的是:
设计堆可扩容,或
明确利用算法性质,让堆大小保持在 O(k) 级别。

AC 版本思路:限制访问区域 + 仍然保留 (i+1, j)/(i, j+1) 扩展

在保持原算法框架的前提下,只改两点:
visited 仅为前 min(k, nums1Size) 行 × min(k, nums2Size) 列开数组,因为要找的最多只有 k 个 pair,没有必要跨出这么大区域。
扩展 (i+1, j) / (i, j+1) 时,必须保证在这个缩小后的 visited 范围内。
这样:
visited 最大是 k × k(k 最多 1e4),内存约 10^8 个 char ≈ 100MB,虽然偏大,但在平台限制内可以通过。leetcode
实际访问的节点数受 k 限制,不会接近 nums1Size * nums2Size。

完整 AC C 代码(自定义小根堆)

这一版保留了原来的 heap 写法,只在核心地方做了“区域剪枝”:

#include <stdlib.h>
#include <string.h>
#include <stdbool.h>

typedef struct pair {
    int index1;
    int index2;
    int sum;
} pair_t;

typedef struct heap {
    pair_t *data;
    int cnt;
    int capacity;
} heap_t;

char** create_visited(int nums1Size, int nums2Size) {
    char **visited = NULL;
    int i;

    visited = (char **)malloc(nums1Size * sizeof(char *));
    for (i = 0; i < nums1Size; i++) {
        visited[i] = (char *)calloc(nums2Size, sizeof(char));
    }

    return visited;
}

void destroy_visited(char** visited, int nums1Size) {
    int i;
    for (i = 0; i < nums1Size; i++) {
        free(visited[i]);
    }
    free(visited);
}

heap_t *heap_create(int capacity) {
    heap_t *heap = (heap_t *)malloc(sizeof(heap_t));
    heap->data = (pair_t *)malloc(sizeof(pair_t) * capacity);
    heap->cnt = 0;
    heap->capacity = capacity;
    return heap;
}

void swap_pair(pair_t *a, pair_t *b) {
    pair_t temp = *a;
    *a = *b;
    *b = temp;
}

void heap_destroy(heap_t *heap) {
    free(heap->data);
    free(heap);
}

void sift_up(heap_t *heap, int index) {
    int parent;
    while (index > 0) {
        parent = (index - 1) / 2;
        if (heap->data[parent].sum <= heap->data[index].sum)
            break;
        swap_pair(&heap->data[parent], &heap->data[index]);
        index = parent;
    }
}

void sift_down(heap_t *heap, int index) {
    int left, right, smallest;

    while (true) {
        left = 2 * index + 1;
        right = 2 * index + 2;
        smallest = index;

        if (left < heap->cnt && heap->data[left].sum < heap->data[smallest].sum)
            smallest = left;
        if (right < heap->cnt && heap->data[right].sum < heap->data[smallest].sum)
            smallest = right;

        if (smallest == index)
            break;

        swap_pair(&heap->data[index], &heap->data[smallest]);
        index = smallest;
    }
}

void heap_push(heap_t *heap, int value, int index1, int index2) {
    if (heap->cnt >= heap->capacity)
        return;

    heap->data[heap->cnt].sum = value;
    heap->data[heap->cnt].index1 = index1;
    heap->data[heap->cnt].index2 = index2;
    heap->cnt++;
    sift_up(heap, heap->cnt - 1);
}

pair_t heap_pop(heap_t *heap) {
    pair_t value;

    value.sum = heap->data[0].sum;
    value.index1 = heap->data[0].index1;
    value.index2 = heap->data[0].index2;

    heap->data[0] = heap->data[heap->cnt - 1];
    heap->cnt--;
    sift_down(heap, 0);

    return value;
}

pair_t heap_top(heap_t *heap) {
    return heap->data[0];
}

/**
 * Return an array of arrays of size *returnSize.
 * The sizes of the arrays are returned as *returnColumnSizes array.
 * Note: Both returned array and *columnSizes array must be malloced, assume caller calls free().
 */
int** kSmallestPairs(int* nums1, int nums1Size, int* nums2, int nums2Size,
                     int k, int* returnSize, int** returnColumnSizes) {

    int index1, index2, resultIndex, i, heap_capacity, visited_rows, visited_cols;
    char **visited = NULL;
    int **result = NULL;
    pair_t value;

    if (nums1Size == 0 || nums2Size == 0 || k == 0) {
        *returnSize = 0;
        *returnColumnSizes = NULL;
        return NULL;
    }

    result = (int **)malloc(k * sizeof(int *));
    for (i = 0; i < k; i++) {
        result[i] = (int *)malloc(2 * sizeof(int));
    }

    // 只在前 min(nums1Size, k) 行和前 min(nums2Size, k) 列内开 visited,避免 MLE
    visited_rows = (nums1Size < k) ? nums1Size : k;
    visited_cols = (nums2Size < k) ? nums2Size : k;
    visited = create_visited(visited_rows, visited_cols);

    // 堆容量:略大于 k 即可
    heap_capacity = (nums1Size + nums2Size > k) ? (nums1Size + nums2Size) : k;
    heap_t *heap = heap_create(heap_capacity);

    index1 = 0;
    index2 = 0;
    resultIndex = 0;

    if (index1 < visited_rows && index2 < visited_cols) {
        heap_push(heap, nums1[index1] + nums2[index2], index1, index2);
        visited[index1][index2] = 1;
    }

    while (resultIndex < k && heap->cnt > 0) {
        value = heap_pop(heap);
        result[resultIndex][0] = nums1[value.index1];
        result[resultIndex][1] = nums2[value.index2];
        resultIndex++;

        index1 = value.index1;
        index2 = value.index2;

        // 向下扩展 (i+1, j),要在 visited 范围内
        if (index1 + 1 < visited_rows && index1 + 1 < nums1Size &&
            index2 < visited_cols && visited[index1 + 1][index2] == 0) {
            heap_push(heap, nums1[index1 + 1] + nums2[index2], index1 + 1, index2);
            visited[index1 + 1][index2] = 1;
        }

        // 向右扩展 (i, j+1),要在 visited 范围内
        if (index2 + 1 < visited_cols && index2 + 1 < nums2Size &&
            index1 < visited_rows && visited[index1][index2 + 1] == 0) {
            heap_push(heap, nums1[index1] + nums2[index2 + 1], index1, index2 + 1);
            visited[index1][index2 + 1] = 1;
        }
    }

    destroy_visited(visited, visited_rows);
    heap_destroy(heap);

    *returnSize = resultIndex;
    *returnColumnSizes = (int *)malloc(resultIndex * sizeof(int));
    for (i = 0; i < resultIndex; i++) {
        (*returnColumnSizes)[i] = 2;
    }

    return result;
}

总结

这题的本质是:在二维有序网格中,用小根堆做“按需扩展”的最小 k 个搜索。
小根堆里不仅要存 sum,还要存 (i, j),否则无法扩展 (i+1, j) 与 (i, j+1)。
最大的坑在于 visited 的内存:不能盲目开 nums1Size × nums2Size,要利用 k 的限制收缩搜索区域。
理解了这题之后,对“堆 + 有序结构 + 剪枝”这类题的感觉会好很多,比如:
Kth Smallest Number in Sorted Matrix
合并 K 个有序数组 / 链表
各种“从最小开始扩散”的最短路 / BFS 变体
这些题的核心思想其实是一脉相承的。leetcode
https://leetcode.com/problems/find-k-pairs-with-smallest-sums/?envType=study-plan-v2&envId=top-interview-150

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐