区间查询的问题,比较直观的思路是维护一个sorted的不相交的interval的list。每一次插入的时候,我们merge所有可能的interval,保证插入后仍然是不相交的。删除的话,一组和删除区间相交的interval,除了头尾可能还残余一部分没有被覆盖的interval,剩下的都可以抹去,我们重新插入残余的interval即可。我们可以用bst来维护这个集合,key就是interval start的值;value就是end对应的值,具体对应每一个操作:
- 插入i = [start, end]:找到大于end最小的key,对应的就是紧邻i的右边的interval。之后开始向左边扫找到所有和i相交的interval并且从bst中删除。插入merge之后的interval
- 删除i = [start, end]: 找到大于end最小的key,对应的就是紧邻i的右边的interval。之后开始向左边扫找到所有和i相交的interval并且从bst中删除。注意头尾的两个interval可能有残余的部分,要重新插入进bst中
- 查询i = [start, end]:找到大于end最小的key,对应的就是紧邻i的右边的interval。其左边相邻的interval i2必定是i2.start <= end的区间。判断其是否覆盖[start, end]即可
时间复杂度,如果bst中有k个不相交的区间。查询的时间复杂度为O(log k),插入和删除的复杂度为O(k)。代码如下:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class RangeModule { | |
public: | |
RangeModule() { | |
} | |
void addRange(int left, int right) { | |
auto iter = m_intervals.upper_bound(right); | |
if(iter == m_intervals.begin()) | |
{ | |
m_intervals[left] = right; | |
return; | |
} | |
--iter; | |
while(true) | |
{ | |
if(iter->second < left)break; | |
left = min(left, iter->first); | |
right = max(right, iter->second); | |
iter = m_intervals.erase(iter); | |
if(iter == m_intervals.begin())break; | |
else --iter; | |
} | |
m_intervals[left] = right; | |
} | |
bool queryRange(int left, int right) { | |
auto iter = m_intervals.upper_bound(left); | |
if(iter == m_intervals.begin())return false; | |
--iter; | |
return iter->first <= left && iter->second >= right; | |
} | |
void removeRange(int left, int right) { | |
auto iter = m_intervals.upper_bound(right); | |
if(iter == m_intervals.begin())return; | |
vector<pair<int, int>> intervals; | |
--iter; | |
if(iter->second > right)intervals.push_back(make_pair(right, iter->second)); | |
while(true) | |
{ | |
if(iter->first < left) | |
{ | |
if(iter->second > left) | |
{ | |
intervals.push_back(make_pair(iter->first, left)); | |
m_intervals.erase(iter); | |
} | |
break; | |
} | |
iter = m_intervals.erase(iter); | |
if(iter == m_intervals.begin())break; | |
else --iter; | |
} | |
for(auto&& p : intervals)m_intervals[p.first] = p.second; | |
} | |
private: | |
map<int, int> m_intervals; | |
}; | |
/** | |
* Your RangeModule object will be instantiated and called as such: | |
* RangeModule obj = new RangeModule(); | |
* obj.addRange(left,right); | |
* bool param_2 = obj.queryRange(left,right); | |
* obj.removeRange(left,right); | |
*/ |
既然是区间查询的题目,segment tree是标准的处理区间问题的数据结构。关于其介绍可以参考这篇文章。要存的东西也十分直观,对于每一个代表区间[s, e]的节点,我们只需要存[s, e]是否被完全覆盖了即可。并且这道题都是区间更新,我们用lazy propagation可以降低时间复杂度。
值得一提的是,这道题性质我们没有办法做区间的压缩,也即是把输入区间[minVal, maxVal]映射到[0, N]的区间上。所以我们只能按照题目给的[0, 1000,000,000]范围来处理。那么在实际实现的时候,会在后面的一些test case上MLE。为了处理这种情况,我们可以牺牲一点时间复杂度,在每次更新的时候,与其做lazy propagation,我们把完全包含于更新区间[s, e]的节点的所有子节点删除。当我们再需要的时候查询/更新这些子节点的时候,我们在重新根据父节点生成他们,这样虽然牺牲了一点时间,但是我们大大节省了空间。时间复杂度的话,如果用的是lazy propagation,更新和查询的时间复杂度都为O(log MaxVal),这里MaxVal = 1000,000,000。代码如下:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Node | |
{ | |
public: | |
int start, end; | |
bool covered; | |
Node* left, *right; | |
Node(int s, int e, bool c) | |
{ | |
start = s; | |
end = e; | |
covered = c; | |
left = nullptr; | |
right = nullptr; | |
} | |
Node* getLeft() | |
{ | |
int mid = start + (end - start) / 2; | |
if (!left)left = new Node(start, mid, covered); | |
return left; | |
} | |
Node* getRight() | |
{ | |
int mid = start + (end - start) / 2; | |
if (!right)right = new Node(mid + 1, end, covered); | |
return right; | |
} | |
//update [s, e + 1) to be true | |
void update(int s, int e, bool val) | |
{ | |
if (e < start || s > end)return; | |
else if (end <= e && start >= s) | |
{ | |
covered = val; | |
if (left)left->clear(); | |
if (right)right->clear(); | |
left = nullptr; | |
right = nullptr; | |
} | |
else | |
{ | |
getLeft()->update(s, e, val); | |
getRight()->update(s, e, val); | |
covered = getLeft()->covered && getRight()->covered; | |
} | |
} | |
//if [s, e + 1) is tracked | |
bool query(int s, int e) | |
{ | |
if (e < start || s > end)return true; | |
else if (end <= e && start >= s)return covered; | |
else return getLeft()->query(s, e) && getRight()->query(s, e); | |
} | |
void clear() | |
{ | |
if(left)left->clear(); | |
if(right)right->clear(); | |
delete this; | |
} | |
}; | |
class RangeModule { | |
public: | |
RangeModule() { | |
root = new Node(0, 1000000000, false); | |
} | |
~RangeModule() | |
{ | |
root->clear(); | |
} | |
void addRange(int left, int right) { | |
root->update(left, right - 1, 1); | |
} | |
bool queryRange(int left, int right) { | |
return root->query(left, right - 1); | |
} | |
void removeRange(int left, int right) { | |
root->update(left, right - 1, 0); | |
} | |
private: | |
Node* root; | |
}; | |
/** | |
* Your RangeModule object will be instantiated and called as such: | |
* RangeModule obj = new RangeModule(); | |
* obj.addRange(left,right); | |
* bool param_2 = obj.queryRange(left,right); | |
* obj.removeRange(left,right); | |
*/ |
No comments:
Post a Comment