给定两个大小为n的数组nums1和nums2以及整数diff,统计满足以下条件的数对(i,j)的个数:0<=i<j<=n-1
,并且nums1[i]-nums1[j]<=nums2[i]-nums2[j]+diff
。
2<=n<=1e5; -1e4<=nums1[i],nums2[i]<=1e4; -1e4<=diff<=1e4
先对条件做下变形,将下标相同的移到同一侧,得到:nums1[i]-nums2[i]<=nums1[j]-num2[j]+diff
,如果记A[i]=nums1[i]-nums2[i],就转化成一个类似逆序对的统计问题,这里用平衡树来统计。
template <typename TYPE>
struct Treap {struct Node {TYPE data, sum;int rnd, siz, dup, son[2];void init(const TYPE & d) {data = sum = d;rnd = rand();siz = dup = 1;son[0] = son[1] = 0;}};Treap(size_t sz, bool multi):multiple(multi) {node.resize(sz);reset();}int newnode(const TYPE & d) {total += 1;node[total].init(d);return total;}void reset() { root = total = 0; }void maintain(int x) {node[x].siz = node[x].dup;node[x].sum = node[x].data * node[x].dup;if (node[x].son[0]) {node[x].siz += node[node[x].son[0]].siz;node[x].sum += node[node[x].son[0]].sum;}if (node[x].son[1]) {node[x].siz += node[node[x].son[1]].siz;node[x].sum += node[node[x].son[1]].sum;}}void rotate(int d, int &r) {int k = node[r].son[d^1];node[r].son[d^1] = node[k].son[d];node[k].son[d] = r;maintain(r);maintain(k);r = k;}void insert(const TYPE &data, int &r, bool &ans) {if (r) {if (!(data < node[r].data) && !(node[r].data < data)) {ans = false;if (multiple) {node[r].dup += 1;maintain(r);}} else {int d = data < node[r].data ? 0 : 1;insert(data, node[r].son[d], ans);if (node[node[r].son[d]].rnd > node[r].rnd) {rotate(d^1, r);} else {maintain(r);}}} else {r = newnode(data);}}void getkth(int k, int r, TYPE& data) {int x = node[r].son[0] ? node[node[r].son[0]].siz : 0;int y = node[r].dup;if (k <= x) {getkth(k, node[r].son[0], data);} else if (k <= x + y) {data = node[r].data;} else {getkth(k-x-y, node[r].son[1], data);}}TYPE getksum(int k, int r) {if (k <= 0 || r == 0) return 0;int x = node[r].son[0] ? node[node[r].son[0]].siz : 0;int y = node[r].dup;if (k <= x) return getksum(k, node[r].son[0]);if (k <= x+y) return node[node[r].son[0]].sum + node[r].data * (k-x);return node[node[r].son[0]].sum + node[r].data * y + getksum(k-x-y,node[r].son[1]);}void erase(const TYPE& data, int & r) {if (r == 0) return;int d = -1;if (data < node[r].data) {d = 0;} else if (node[r].data < data) {d = 1;}if (d == -1) {node[r].dup -= 1;if (node[r].dup > 0) {maintain(r);} else {if (node[r].son[0] == 0) {r = node[r].son[1];} else if (node[r].son[1] == 0) {r = node[r].son[0];} else {int dd = node[node[r].son[0]].rnd > node[node[r].son[1]].rnd ? 1 : 0;rotate(dd, r);erase(data, node[r].son[dd]);}}} else {erase(data, node[r].son[d]);}if (r) maintain(r);}int ltcnt(const TYPE& data, int r) {if (r == 0) return 0;int x = node[r].son[0] ? node[node[r].son[0]].siz : 0;if (data < node[r].data) {return ltcnt(data, node[r].son[0]);}if (!(data < node[r].data) && !(node[r].data < data)) {return x;}return x + node[r].dup + ltcnt(data, node[r].son[1]);}int gtcnt(const TYPE& data, int r) {if (r == 0) return 0;int x = node[r].son[1] ? node[node[r].son[1]].siz : 0;if (data > node[r].data) {return gtcnt(data, node[r].son[1]);}if (!(data < node[r].data) && !(node[r].data < data)) {return x;}return x + node[r].dup + gtcnt(data, node[r].son[0]);}int count(const TYPE& data, int r) {if (r == 0) return 0;if (data < node[r].data) return count(data, node[r].son[0]);if (node[r].data < data) return count(data, node[r].son[1]);return node[r].dup;}void prev(const TYPE& data, int r, TYPE& result, bool& ret) {if (r) {if (node[r].data < data) {if (ret) {result = max(result, node[r].data);} else {result = node[r].data;ret = true;}prev(data, node[r].son[1], result, ret);} else {prev(data, node[r].son[0], result, ret);}}}void next(const TYPE& data, int r, TYPE& result, bool& ret) {if (r) {if (data < node[r].data) {if (ret) {result = min(result, node[r].data);} else {result = node[r].data;ret = true;}next(data, node[r].son[0], result, ret);} else {next(data, node[r].son[1], result, ret);}}}vector<Node> node;int root, total;bool multiple;bool insert(const TYPE& data) {bool ret = true;insert(data, root, ret);return ret;}bool kth(int k, TYPE &data) {if (!root || k <= 0 || k > node[root].siz)return false;getkth(k, root, data);return true;}TYPE ksum(int k) {assert(root && k>0 && k<=node[root].siz);return getksum(k, root);}int count(const TYPE &data) {return count(data, root);}int size() const {return root ? node[root].siz : 0;}void erase(const TYPE& data) {return erase(data, root);}int ltcnt(const TYPE& data) {return ltcnt(data, root);}int gtcnt(const TYPE& data) {return gtcnt(data, root);}int lecnt(const TYPE& data) {return size() - gtcnt(data, root);}int gecnt(const TYPE& data) {return size() - ltcnt(data, root);}bool prev(const TYPE& data, TYPE& result) {bool ret = false;prev(data, root, result, ret);return ret;}bool next(const TYPE& data, TYPE& result) {bool ret = false;next(data, root, result, ret);return ret;}
};class Solution {
public:long long numberOfPairs(vector<int>& nums1, vector<int>& nums2, int diff) {int n = nums1.size();Treap<int> tp(100005, true);long long ans = 0;for (int i = 0; i < n; i++) {ans += tp.lecnt(nums1[i]-nums2[i]+diff);tp.insert(nums1[i] - nums2[i]);}return ans;}
};