Thursday, September 14, 2017

[LeetCode]Number of Longest Increasing Subsequence


DP类型的题目,我们用DP[0][i]表示在i处结尾的LIS的长度,DP[1][i]表示在i处结尾的LIS的数量。DP方程为:

  • DP[0][i] = max(DP[0][j]) + 1 (0 <= j < i && nums[j] < nums[i])
  • DP[1][i] = sum(DP[1][j]) where DP[0][j] = DP[0][i] - 1
由此我们可以写出代码,时间复杂度O(n^2),空间复杂度O(n):

class Solution {
public:
int findNumberOfLIS(vector<int>& nums) {
int len = nums.size(), maxLen = 0, maxNum = 0;
//dp[0][i] lenght of LIS ending at i
//dp[1][i] numbers of LISs ending at i
vector<vector<int>> dp(2, vector<int>(len, 0));
for(int i = 0; i < len; ++i)
{
int currLen = 1, currNum = 1;
for(int j = 0; j < i; ++j)
{
if(nums[j] < nums[i])
{
if(dp[0][j] + 1 > currLen)
{
currLen = dp[0][j] + 1;
currNum = dp[1][j];
}
else if(dp[0][j] + 1 == currLen)
currNum += dp[1][j];
}
}
dp[0][i] = currLen;
dp[1][i] = currNum;
if(currLen > maxLen)
{
maxLen = currLen;
maxNum = currNum;
}
else if(currLen == maxLen)
maxNum += currNum;
}
return maxNum;
}
};

另一种方法,这道题可以用segment tree来做。segment tree存的区间是值域上的区间而不是以index来划分。我们要存的东西比较有意思,对于当前节点和其所表示的区间[s, e],从所有以处于[s, e]区间范围中的元素结尾的子序列当中,找出最长的子序列长度和数量,存入node中。
插入的话,我们左向右扫数组,对于每一个数nums[i],我们查询以[minVal, nums[i - 1]]范围中的数结尾的最长的子序列的长度和数量。那么我们可以根据查询的结果更新segment tree,比如我们得知以小于nums[i]结尾的子序列的长度为len, 数量为cnt。那么我们要在segment tree中把以nums[i]结尾的最长子序列的长度和数量更新为len + 1和cnt。
我们每次更新的时候,只需要去左右两个子区间:

  • 如果左子区间最大长度比右子区间长,我们取左子区间的长度和数量
  • 如果左子区间最大长度比右子区间短,我们取右子区间的长度和数量
  • 如果左子区间最大长度和右子区间相等,我们取任意区间的长度,数量为两个子区间之和
实现的时候,对于长度为N的数组,我们要把值域map到[0, N - 1]的区间,这样可以省下很多空间。时间复杂度O(N *log N),代码如下:

struct Value
{
int len, cnt;
Value()
{
len = 0;
cnt = 0;
}
Value(int l, int c)
{
len = l;
cnt = c;
}
};
class Node
{
public:
int start, end;
Node* left, *right;
Value val;
Node(int s, int e)
{
start = s;
end = e;
left = nullptr;
right = nullptr;
}
Node* getLeft()
{
int mid = start + (end - start) / 2;
if (!left)left = new Node(start, mid);
return left;
}
Node* getRight()
{
int mid = start + (end - start) / 2;
if (!right)right = new Node(mid + 1, end);
return right;
}
void insert(int key, Value updateVal)
{
if (start == end)
{
val = merge(val, updateVal);
return;
}
int mid = start + (end - start) / 2;
if (start > key || end < key)return;
else if (key <= mid)getLeft()->insert(key, updateVal);
else getRight()->insert(key, updateVal);
val = merge(getLeft()->val, getRight()->val);
}
//get value for all possible sequences which end between [s, e]
Value query(int s, int e)
{
if (e < start || s > end)return Value(0, 1);
else if (start >= s && end <= e)return val;
else return merge(getLeft()->query(s, e), getRight()->query(s, e));
}
private:
Value merge(const Value& lhs, const Value& rhs)
{
if (lhs.len == rhs.len)
{
if (!lhs.len)return Value(0, 1);
return Value(lhs.len, lhs.cnt + rhs.cnt);
}
return lhs.len > rhs.len ? lhs : rhs;
}
};
class Solution {
public:
int findNumberOfLIS(vector<int>& nums) {
int len = nums.size(), minNum = INT_MAX, maxNum = INT_MIN;
for (const auto& num : nums)
{
minNum = min(minNum, num);
maxNum = max(maxNum, num);
}
Node* root = new Node(minNum, maxNum);
for (const auto& num : nums)
{
Value v = root->query(minNum, num - 1);
Value updateVal(v.len + 1, v.cnt);
root->insert(num, updateVal);
}
return root->val.cnt;
}
};


区间查询的话Binary Indexed Tree当然也可以做,虽然和我们在链接文章中分析的,BIT实现Range Min/Max Query的话会比普通的BIT复杂一些,但是这一道题,对于每一个nums[i]我们只插入不更新,所以简单的BIT即可实现。时间复杂度O(N *log N),代码如下:
struct Value
{
int len, cnt;
Value()
{
len = 0;
cnt = 0;
}
Value(int l, int c)
{
len = l;
cnt = c;
}
};
class BIT
{
public:
BIT(int n)
{
N = n + 1;
tree = vector<Value>(N);
}
Value queryUntil(int key)
{
int i = key + 1;
Value val(0, 1);
while(i)
{
val = merge(val, tree[i]);
i -= i & -i;
}
return val;
}
void update(int key, Value val)
{
int i = key + 1;
while(i < N)
{
tree[i] = merge(val, tree[i]);
i += i & -i;
}
}
private:
int N;
vector<Value> tree;
Value merge(const Value& lhs, const Value& rhs)
{
if (lhs.len == rhs.len)
{
if (!lhs.len)return Value(0, 1);
return Value(lhs.len, lhs.cnt + rhs.cnt);
}
return lhs.len > rhs.len ? lhs : rhs;
}
};
class Solution {
public:
int findNumberOfLIS(vector<int>& nums) {
int len = nums.size();
if(!len)return 0;
vector<int> aux = nums;
sort(aux.begin(), aux.end());
unordered_map<int, int> xMap;
for(int i = 0; i < len; ++i)xMap[aux[i]] = i;
BIT bit(len);
for(int i = 0; i < len; ++i)
{
Value val = bit.queryUntil(xMap[nums[i]] - 1);
Value updateVal(val.len + 1, val.cnt);
bit.update(xMap[nums[i]], updateVal);
}
return bit.queryUntil(len - 1).cnt;
}
};

No comments:

Post a Comment