Saturday, November 25, 2017

[LeetCode] Count of Smaller Numbers After Self

这道题有两种做法。首先我们知道merge sort可以解counting inversions类的题目,具体做法请参考上面的链接。我们可以用类似的做法来解这一题,上面的题是求解总的inversions,但是这道题需要求解每个元素和其后面的元素形成的inversions,显然我们需要在sort的时候追踪当前数在原来array num中对应的index是多少。我们要实现这一点,可以间接地sort array,也就是说,我们新建一个index array,存num中每个元素对于的下标,那么index起始的时候就应该是{1,2,3,4,5...}的样子,我们sort index array,而不是原来的num。例如num = {5,1,4,2}, index = num{0,1,2,3},我们只sort index array, num array是用来比较的时候用index里的下标来access真正的值。那么sort之后,num不变,index = {1,3,2,0},转换成num对应的数的话就是{1,2,4,5}。这样当我们counting当前数对应的reverse paris的时候,我们可以直接知道其在num中对应的原坐标是多少。时间复杂度就是merge sort的复杂度O(n * log n),代码如下:


class Solution {
public:
vector<int> countSmaller(vector<int>& nums) {
int len = nums.size();
vector<int> idx(len, 0), res(len, 0);
for(int i = 0; i < len; ++i)
idx[i] = i;
mergeSort(nums, idx, res, 0, len - 1);
return res;
}
private:
void mergeSort(vector<int>& nums, vector<int>& idx, vector<int>& res, int lo, int hi)
{
if(lo >= hi)return;
int mid = lo + (hi - lo) / 2;
mergeSort(nums, idx, res, lo, mid);
mergeSort(nums, idx, res, mid + 1, hi);
int p1 = lo, p2 = mid + 1;
vector<int> aux(hi - lo + 1, 0);
int curr = 0;
while(p1 <= mid && p2 <= hi)
{
int idx1 = idx[p1], idx2 = idx[p2], num1 = nums[idx1], num2 = nums[idx2];
if(num1 <= num2)
{
res[idx1] += p2 - mid - 1;
aux[curr++] = idx[p1++];
}
else
{
aux[curr++] = idx[p2++];
}
}
while(p1 <= mid)
{
res[idx[p1]] += p2 - mid - 1;
aux[curr++] = idx[p1++];
}
while(p2 <= hi)
aux[curr++] = idx[p2++];
for(int i = lo; i <= hi; ++i)
idx[i] = aux[i - lo];
}
};
另外一种方法我们可以用BST,因为在这边文章中讲过,node中多存一个size的节点,就可以在log n的时间里查询有多少个小于它的节点,我们只需要从右向左,一遍构建BST一遍查询就可以。时间复杂度也是O(n * log n),代码如下:

struct Node
{
int key, size, dup;
Node* left, *right;
Node(int num) : key(num), size(1), dup(1), left(nullptr), right(nullptr)
{
}
};
class Solution {
public:
vector<int> countSmaller(vector<int>& nums) {
int len = nums.size();
vector<int> res(len, 0);
for(int i = len - 1; i >= 0; --i)
{
res[i] = search(m_root, nums[i]);
insert(nums[i]);
}
return res;
}
private:
Node* m_root;
int size(Node* root)
{
return root? root->size: 0;
}
int search(Node* curr, int key)
{
if(!curr)return 0;
if(curr->key < key)return curr->dup + size(curr->left) + search(curr->right, key);
else if(curr->key >= key)return search(curr->left, key);
}
void insert(int key)
{
if(m_root == nullptr)
{
m_root = new Node(key);
return;
}
Node* curr = m_root;
while(true)
{
++curr->size;
if(curr->key < key)
{
if(curr->right)curr = curr->right;
else
{
curr->right = new Node(key);
break;
}
}
else if(curr->key > key)
{
if(curr->left)curr = curr->left;
else
{
curr->left = new Node(key);
break;
}
}
else
{
++curr->dup;
break;
}
}
}
};
除了以上两个方法之外,这同时也是一道range query的题目。因为对于每一个数num,我们想要知道其右边有多少个数小于它,就相当于对所有在其右边的数进行一个[minVal, num)的range query来统计出现了多少个。而我们可以用Binary Indexed Tree或者Segment Tree实现这些区间查询,我们要存的就是当前节点表示的范围里面有多少个数已经出现了(重复的当然也要计算)。具体的原理可以参考上面给出的文章链接,这里因为数字的范围很大,我们将其压缩到[0, N)的区间即可,N为输入数组的长度。值得一提的一点,Reverse Pairs这道很相似的题我们没有办法进行输入数据的压缩,因为压缩过后很有可能对于一些数,原先满足题目中给出关系的会不再满足,也就是说原先是两倍的关系,压缩之后可能就不是了,所以我们没法用BIT或者Segment Tree实现。BIT和Segment Tree在区间查询的时候都可以达到O(log N)的时间复杂度,所以总的时间复杂度也是O(log N),空间复杂度O(N)。代码如下:

Segment Tree:

class Node
{
public:
int start, end;
int cnt;//# of nums appear in [start, end]
Node* left, *right;
Node(int s, int e)
{
start = s;
end = e;
cnt = 0;
left = nullptr;
right = nullptr;
}
Node* getLeft()
{
int mid = start + (end - start) / 2;
if (!left)left = new Node(start, mid);
return left;
}
Node* getRight()
{
int mid = start + (end - start) / 2;
if (!right)right = new Node(mid + 1, end);
return right;
}
void insert(int key)
{
if (start == end)
{
++cnt;
return;
}
int mid = start + (end - start) / 2;
if (key <= mid)getLeft()->insert(key);
else getRight()->insert(key);
cnt = getLeft()->cnt + getRight()->cnt;
}
//how many nums we have in [s, e]
int query(int s, int e)
{
if (s > end || e < start)return 0;
else if (start >= s && end <= e)return cnt;
else return getLeft()->query(s, e) + getRight()->query(s, e);
}
void clear()
{
if(left)left->clear();
if(right)right->clear();
delete this;
}
};
class Solution {
public:
vector<int> countSmaller(vector<int>& nums) {
int len = nums.size();
//mapping to index
vector<int> aux = nums;
sort(aux.begin(), aux.end());
unordered_map<int, int> map;
for (int i = 0; i < len; ++i)map[aux[i]] = i;
Node* root = new Node(0, len - 1);
vector<int> res;
for (int i = len - 1; i >= 0; --i)
{
int ans = root->query(0, map[nums[i]] - 1);
root->insert(map[nums[i]]);
res.push_back(ans);
}
reverse(res.begin(), res.end());
return res;
}
};
BIT:


class BIT
{
public:
BIT(int n)
{
N = n + 1;
tree = vector<int>(N, 0);
}
void insert(int key)
{
int i = key + 1;
while(i < N)
{
++tree[i];
i += i & -i;
}
}
int queryUntil(int key)
{
int i = key + 1, res = 0;
while(i)
{
res += tree[i];
i -= i & -i;
}
return res;
}
private:
int N;
vector<int> tree;
};
class Solution {
public:
vector<int> countSmaller(vector<int>& nums) {
int len = nums.size();
//mapping to index
vector<int> aux = nums;
sort(aux.begin(), aux.end());
unordered_map<int, int> map;
for (int i = 0; i < len; ++i)map[aux[i]] = i;
BIT bit(len);
vector<int> res;
for(int i = len - 1; i >= 0; --i)
{
int ans = bit.queryUntil(map[nums[i]] - 1);
res.push_back(ans);
bit.insert(map[nums[i]]);
}
reverse(res.begin(), res.end());
return res;
}
};

No comments:

Post a Comment