这里给出三种思路:「二分」「翻转对」「线段树」
nums1[i] - nums1[j] <= nums2[i] - nums2[j] + diff
上述不等式变形可得:nums1[i] - nums2[i] <= nums1[j] - nums2[j] + diff
将nums1[i] - nums2[i]视为一个整体,用df[i]表示
对于一个df[i],需要在[0...i-1]中寻找一个j,使得df[j] - diff <= df[i]
public long numberOfPairs(int[] nums1, int[] nums2, int diff) { int n = nums1.length; int[] df = new int[n]; for (int i = 0; i < n; i++) df[i] = nums1[i] - nums2[i]; long ans = 0; // list 存储 df[i] - diff List<Integer> list = new ArrayList<>(); for (int i = 0; i < n; i++) { int k = df[i]; // 在 [lo, hi] 中二分寻找 <= df[i] 的最右下标 int lo = 0, hi = list.size() - 1; while (lo <= hi) { int mid = lo + (hi - lo) / 2; if (list.get(mid) <= k) lo = mid + 1; else hi = mid - 1; } ans += lo; int t = df[i] - diff; // 为了保持有序,在 [lo, hi] 中二分寻找 t 的的存储位置 lo = 0; hi = list.size() - 1; while (lo <= hi) { int mid = lo + (hi - lo) / 2; if (list.get(mid) <= t) lo = mid + 1; else hi = mid - 1; } list.add(lo, t); } return ans;}关于「翻转对」的详细介绍可见 详解归并排序及其应用
xprivate int[] temp;private long ans;private int diff;public long numberOfPairs(int[] nums1, int[] nums2, int diff) { int n = nums1.length; int[] df = new int[n]; for (int i = 0; i < n; i++) df[i] = nums1[i] - nums2[i]; temp = new int[n]; this.diff = diff; ans = 0; sort(df, 0, n - 1); return ans;}private void sort(int[] nums, int lo, int hi) { if (lo >= hi) return ; int mid = lo + (hi - lo) / 2; sort(nums, lo, mid); sort(nums, mid + 1, hi); merge(nums, lo, mid, hi);}private void merge(int[] nums, int lo, int mid, int hi) { for (int i = lo; i <= hi; i++) temp[i] = nums[i]; // 寻找部分 int end = lo; for (int i = mid + 1; i <= hi; i++) { while (end <= mid && temp[end] - diff <= temp[i]) end++; ans += end - lo; }
int i = lo, j = mid + 1, idx = lo; while (i <= mid || j <= hi) { if (i > mid) nums[idx++] = temp[j++]; else if (j > hi) nums[idx++] = temp[i++]; else if (temp[i] <= temp[j]) nums[idx++] = temp[i++]; else nums[idx++] = temp[j++]; }}关于「线段树」的详细介绍和模版可见 线段树详解
根据范围:
-10^4 <= nums1[i], nums2[i] <= 10^4-10^4 <= diff <= 10^4所以nums1[i] - nums2[i] + diff的范围为-3 * 10^4 <= x <= 3 * 10^4
为了方便处理,我们将范围右移到正数区间,即右移-3 * 10^4
xxxxxxxxxxclass Solution { public long numberOfPairs(int[] nums1, int[] nums2, int diff) { int n = nums1.length; // 偏移量 int move = 3 * (int) 1e4; long ans = 0; for (int i = 0; i < n; i++) { // d 为更新量,target 为查询量 int d = nums1[i] - nums2[i]; int target = d + diff; // 查询区间 [0, target + move] 上的数量 ans += query(root, 0, N, 0, target + move); // 更新区间 [d + move, d + move] update(root, 0, N, d + move, d + move, 1); } return ans; } // *************** 下面是模版 *************** class Node { Node left, right; int val, add; } private int N = (int) 1e9; private Node root = new Node(); public void update(Node node, int start, int end, int l, int r, int val) { if (l <= start && end <= r) { node.val += (end - start + 1) * val; node.add += val; return ; } int mid = (start + end) >> 1; pushDown(node, mid - start + 1, end - mid); if (l <= mid) update(node.left, start, mid, l, r, val); if (r > mid) update(node.right, mid + 1, end, l, r, val); pushUp(node); } public int query(Node node, int start, int end, int l, int r) { if (l <= start && end <= r) return node.val; int mid = (start + end) >> 1, ans = 0; pushDown(node, mid - start + 1, end - mid); if (l <= mid) ans += query(node.left, start, mid, l, r); if (r > mid) ans += query(node.right, mid + 1, end, l, r); return ans; } private void pushUp(Node node) { node.val = node.left.val + node.right.val; } private void pushDown(Node node, int leftNum, int rightNum) { if (node.left == null) node.left = new Node(); if (node.right == null) node.right = new Node(); if (node.add == 0) return ; node.left.val += node.add * leftNum; node.right.val += node.add * rightNum; // 对区间进行「加减」的更新操作,下推懒惰标记时需要累加起来,不能直接覆盖 node.left.add += node.add; node.right.add += node.add; node.add = 0; }}