这里给出三种思路:「二分」「翻转对」「线段树」
nums1[i] - nums1[j] <= nums2[i] - nums2[j] + diff
上述不等式变形可得:nums1[i] - nums2[i] <= nums1[j] - nums2[j] + diff
将nums1[i] - nums2[i]
视为一个整体,用df[i]
表示
对于一个df[i]
,需要在[0...i-1]
中寻找一个j
,使得df[j] - diff <= df[i]
public long numberOfPairs(int[] nums1, int[] nums2, int diff) {
int n = nums1.length;
int[] df = new int[n];
for (int i = 0; i < n; i++) df[i] = nums1[i] - nums2[i];
long ans = 0;
// list 存储 df[i] - diff
List<Integer> list = new ArrayList<>();
for (int i = 0; i < n; i++) {
int k = df[i];
// 在 [lo, hi] 中二分寻找 <= df[i] 的最右下标
int lo = 0, hi = list.size() - 1;
while (lo <= hi) {
int mid = lo + (hi - lo) / 2;
if (list.get(mid) <= k) lo = mid + 1;
else hi = mid - 1;
}
ans += lo;
int t = df[i] - diff;
// 为了保持有序,在 [lo, hi] 中二分寻找 t 的的存储位置
lo = 0; hi = list.size() - 1;
while (lo <= hi) {
int mid = lo + (hi - lo) / 2;
if (list.get(mid) <= t) lo = mid + 1;
else hi = mid - 1;
}
list.add(lo, t);
}
return ans;
}
关于「翻转对」的详细介绍可见 详解归并排序及其应用
xprivate int[] temp;
private long ans;
private int diff;
public long numberOfPairs(int[] nums1, int[] nums2, int diff) {
int n = nums1.length;
int[] df = new int[n];
for (int i = 0; i < n; i++) df[i] = nums1[i] - nums2[i];
temp = new int[n]; this.diff = diff; ans = 0;
sort(df, 0, n - 1);
return ans;
}
private void sort(int[] nums, int lo, int hi) {
if (lo >= hi) return ;
int mid = lo + (hi - lo) / 2;
sort(nums, lo, mid);
sort(nums, mid + 1, hi);
merge(nums, lo, mid, hi);
}
private void merge(int[] nums, int lo, int mid, int hi) {
for (int i = lo; i <= hi; i++) temp[i] = nums[i];
// 寻找部分
int end = lo;
for (int i = mid + 1; i <= hi; i++) {
while (end <= mid && temp[end] - diff <= temp[i]) end++;
ans += end - lo;
}
int i = lo, j = mid + 1, idx = lo;
while (i <= mid || j <= hi) {
if (i > mid) nums[idx++] = temp[j++];
else if (j > hi) nums[idx++] = temp[i++];
else if (temp[i] <= temp[j]) nums[idx++] = temp[i++];
else nums[idx++] = temp[j++];
}
}
关于「线段树」的详细介绍和模版可见 线段树详解
根据范围:
-10^4 <= nums1[i], nums2[i] <= 10^4
-10^4 <= diff <= 10^4
所以nums1[i] - nums2[i] + diff
的范围为-3 * 10^4 <= x <= 3 * 10^4
为了方便处理,我们将范围右移到正数区间,即右移-3 * 10^4
xxxxxxxxxx
class Solution {
public long numberOfPairs(int[] nums1, int[] nums2, int diff) {
int n = nums1.length;
// 偏移量
int move = 3 * (int) 1e4;
long ans = 0;
for (int i = 0; i < n; i++) {
// d 为更新量,target 为查询量
int d = nums1[i] - nums2[i];
int target = d + diff;
// 查询区间 [0, target + move] 上的数量
ans += query(root, 0, N, 0, target + move);
// 更新区间 [d + move, d + move]
update(root, 0, N, d + move, d + move, 1);
}
return ans;
}
// *************** 下面是模版 ***************
class Node {
Node left, right;
int val, add;
}
private int N = (int) 1e9;
private Node root = new Node();
public void update(Node node, int start, int end, int l, int r, int val) {
if (l <= start && end <= r) {
node.val += (end - start + 1) * val;
node.add += val;
return ;
}
int mid = (start + end) >> 1;
pushDown(node, mid - start + 1, end - mid);
if (l <= mid) update(node.left, start, mid, l, r, val);
if (r > mid) update(node.right, mid + 1, end, l, r, val);
pushUp(node);
}
public int query(Node node, int start, int end, int l, int r) {
if (l <= start && end <= r) return node.val;
int mid = (start + end) >> 1, ans = 0;
pushDown(node, mid - start + 1, end - mid);
if (l <= mid) ans += query(node.left, start, mid, l, r);
if (r > mid) ans += query(node.right, mid + 1, end, l, r);
return ans;
}
private void pushUp(Node node) {
node.val = node.left.val + node.right.val;
}
private void pushDown(Node node, int leftNum, int rightNum) {
if (node.left == null) node.left = new Node();
if (node.right == null) node.right = new Node();
if (node.add == 0) return ;
node.left.val += node.add * leftNum;
node.right.val += node.add * rightNum;
// 对区间进行「加减」的更新操作,下推懒惰标记时需要累加起来,不能直接覆盖
node.left.add += node.add;
node.right.add += node.add;
node.add = 0;
}
}