「多路归并」的核心思想与合并n
个有序链表极其相似,题目详情可见 合并两个有序链表 合并K个升序链表
将n
个有序链表的头节点加入小根堆的优先队列中,由于n
个链表都是递增的顺序,所以第一个最小的元素肯定在这n
个头节点中产生
选择最小元素后,将该最小元素所在的链表后移一位,并将后一位元素加入队列中,以此类推...
而我们今天要介绍的「多路归并」本质就是上面的思想,唯一不同的就是,n
个有序链表需要我们根据题目意思抽象出来而已,下面我们根据题目来依次分析!!
题目详情可见 丑数 II
对于一个丑数n
,均可以衍生出三个与之对应的丑数:n * 2, n * 3, n * 5
这个题目的有序链表需要通过求得的丑数来动态获取,所以我们利用三个指针P1, P2, P3
分别指向正在被处理的丑数
注意:需要去除相同元素。如上图第四小部分,第一条链表和第二条链表的当前值均为6
,如果选择把第一条链表中的6
加入ans
中,那么第二条链条需要向前移动一格,否则6
就加入了两次
下面给出具体具体代码:
public int nthUglyNumber(int n) {
// 从下标 1 开始
int[] ans = new int[n + 1];
// 初始化
ans[1] = 1;
// p2 p3 p5 分别表示 3 个质因数的指针
// idx 表示 ans 下一个存储的下标
for (int p2 = 1, p3 = 1, p5 = 1, idx = 2; idx <= n; idx++) {
// a b c 表示当前的三个元素,如上图橙色标识出的元素
int a = ans[p2] * 2, b = ans[p3] * 3, c = ans[p5] * 5;
// 求出三者的最小值
int min = Math.min(a, Math.min(b, c));
// 存储到 ans 中
ans[idx] = min;
// 指针后移 (同时具有去重的效果)
if (min == a) p2++;
if (min == b) p3++;
if (min == c) p5++;
}
// 返回第 n 个丑数
return ans[n];
}
题目详情可见 超级丑数
这个题目其实和上面的题目大同小异,无非就是质因数从 3 个衍生成更多个而已!
下面给出具体具体代码:
xxxxxxxxxx
public int nthSuperUglyNumber(int n, int[] primes) {
// 指针:每个质因数对应一个指针
int[] p = new int[primes.length];
// 由于数据加强了,这里需要使用 long
long[] ans = new long[n + 1];
ans[1] = 1;
// 初始化每个指针指向第一个丑数
Arrays.fill(p, 1);
for (int i = 2; i <= n; i++) {
// 求最小值
long min = Integer.MAX_VALUE;
for (int j = 0; j < primes.length; j++) {
min = Math.min(min, ans[p[j]] * primes[j]);
}
// 指针后移 (同时具有去重的效果)
for (int j = 0; j < primes.length; j++) {
if (min == ans[p[j]] * primes[j]) {
p[j]++;
}
}
// 存储到 ans 中
ans[i] = min;
}
// 返回第 n 个丑数
return (int) ans[n];
}
题目详情可见 查找和最小的 K 对数字
对于例子:nums1 = [1,7,11], nums2 = [2,4,6]
,我们可以构造出三条递增的有序链表,如下图所示:
我们再来分析一下时间复杂度,假设n = nums1.length, m = num2.length
按照上图的方法构造有序链表的话,每次需要从n
个元素中找出最小的元素,需要找k
次,所以时间复杂度为O(klogn)
所以为了更优的时间复杂度,尽量让nums1
长度更短;如果nums1
长度更长,我们就交换两个数组的位置
下面给出具体具体代码:
xxxxxxxxxx
// 标志是否交换了位置 true : 未交换;false : 交换了
private boolean flag = true;
public List<List<Integer>> kSmallestPairs(int[] nums1, int[] nums2, int k) {
int n = nums1.length, m = nums2.length;
// 判断是否需要交换顺序
if (n > m && !(flag = false)) return kSmallestPairs(nums2, nums1, k);
// 注意:队列中存储的只是下标
// 按照「两数和」递增排列
Queue<int[]> q = new PriorityQueue<>((a, b) -> nums1[a[0]] + nums2[a[1]] - nums1[b[0]] - nums2[b[1]]);
// 加入头节点
// 这里有一个技巧:如果 k < n,那么一开始只需要往队列中添加前 k 个元素即可
// 后面的 n - k 个元素肯定比前面 k 个元素大,所以加入没有意义
for (int i = 0; i < Math.min(n, k); i++) q.offer(new int[]{i, 0});
List<List<Integer>> ans = new ArrayList<>();
while (ans.size() < k && !q.isEmpty()) {
// 弹出队顶元素,即最小元素
int[] cur = q.poll();
int a = cur[0], b = cur[1];
ans.add(new ArrayList<Integer>(){{
add(flag ? nums1[a] : nums2[b]);
add(flag ? nums2[b] : nums1[a]);
}});
// 如果 b + 1 < m 表示该条链条后面还有元素,可以继续加入队列中
if (b + 1 < m) q.offer(new int[]{a, b + 1});
}
return ans;
}
题目详情可见 第 K 个最小的素数分数
对于例子:arr = [1,2,3,5]
,我们可以构造出三条递增的有序链表
如下图所示:
这里有一个小技巧:如果直接比较除完之后的结果,可能会存在误差,所以可以来一波曲线救国
对于
下面给出具体具体代码:
xxxxxxxxxx
public int[] kthSmallestPrimeFraction(int[] arr, int k) {
// 定义比较规则
Queue<int[]> q = new PriorityQueue<>((a, b) -> arr[a[0]] * arr[b[1]] - arr[b[0]] * arr[a[1]]);
// 加入头节点
for (int j = 1; j < arr.length; j++) q.offer(new int[]{0, j});
for (int cnt = 1; cnt <= k; cnt++) {
// 弹出队顶元素,即最小元素
int[] cur = q.poll();
int i = cur[0], j = cur[1];
if (cnt == k) return new int[]{arr[i], arr[j]};
// 分子下标不能超过分母
if (i + 1 < j) q.offer(new int[]{i + 1, j});
}
return new int[]{-1, -1};
}
题目详情可见 子数组和排序后的区间和
对于例子:nums = [1,2,3,4]
,我们可以构造出四条递增的有序链表
如下图所示:
下面给出具体具体代码:
xxxxxxxxxx
public int rangeSum(int[] nums, int n, int left, int right) {
int mod = (int) 1e9 + 7;
// 定义比较规则
Queue<int[]> q = new PriorityQueue<>((a, b) -> a[0] - b[0]);
// 加入头节点
// 头节点包含两个值:子数组和,子数组的右边界
for (int i = 0; i < n; i++) q.offer(new int[]{nums[i], i});
int ans = 0;
for (int i = 1; i <= right; i++) {
int[] cur = q.poll();
int sum = cur[0], j = cur[1];
// 开始记录结果
if (i >= left) ans = (ans + sum) % mod;
// 如果 j + 1 < n 表示该条链条后面还有元素,可以继续加入队列中
if (j + 1 < n) q.offer(new int[]{sum + nums[j + 1], j + 1});
}
return ans;
}
题目详情可见 找出第 K 小的数对距离
对于例子:nums = [1,3,1]
,我们可以构造出两条递增的有序链表 (不过需要提前对数组排序,排序后的数组为[1,1,3]
)
下面给出具体具体代码:
xxxxxxxxxx
public int smallestDistancePair(int[] nums, int k) {
// 排序
Arrays.sort(nums);
// 定义比较规则
Queue<int[]> q = new PriorityQueue<>((a, b) -> nums[a[1]] - nums[a[0]] - (nums[b[1]] - nums[b[0]]));
// 加入头节点
for (int i = 0; i < nums.length - 1; i++) q.offer(new int[]{i, i + 1});
while (k-- > 0) {
int[] cur = q.poll();
int i = cur[0], j = cur[1];
if (k == 0) return nums[j] - nums[i];
if (j + 1 < nums.length) q.offer(new int[]{i, j + 1});
}
return -1;
}
不过很遗憾,这个题目用「多路归并」直接超时,这里给出这种方法只是为了加深对「多路归并」的理解!!
这个题目最佳的方法是「二分➕双指针」,详情可见 二分搜索:第 K 小问题
题目详情可见 有序矩阵中的第 k 个最小数组和
这个题目稍微有一丢丢的不一样,对于每次弹出队顶元素后,需要把n
个元素入队,如下图所示:
下面给出具体具体代码:
xxxxxxxxxx
public int kthSmallest(int[][] mat, int k) {
int m = mat.length, n = mat[0].length;
// 去重的作用
Set<String> set = new HashSet<>();
Queue<int[]> q = new PriorityQueue<>((a, b) -> a[m] - b[m]);
// init[0..m-1] 存储矩阵的一列元素的下标
// inti[m] 存储该列的和
int[] init = new int[m + 1];
for (int i = 0; i < m; i++) {
init[m] += mat[i][0];
init[i] = 0;
}
q.offer(init);
set.add(Arrays.toString(init));
while (k-- > 0) {
int[] cur = q.poll();
if (k == 0) return cur[m];
// 构造出需要加入队列的元素,如上图所示
for (int i = 0; i < m; i++) {
int[] temp = (int[])Arrays.copyOf(cur, m + 1);
if (temp[i] + 1 >= n) continue;
temp[m] += mat[i][temp[i] + 1] - mat[i][temp[i]];
temp[i] += 1;
String s = Arrays.toString(temp);
if (set.contains(s)) continue;
q.offer(temp);
set.add(s);
}
}
return -1;
}
对于这个题目,这种方法效率并不是很高,这里给出这种方法只是为了加深对「多路归并」的理解!!
这个题目最佳的方法是「二分➕DFS」,详情可见 二分搜索:第 K 小问题