首先如果这道题不要求实现leave()的话是很好做的,我们只需要用priority_queue维护最大的那个interval即可,然后每次seat()的时候就在最大的interval的二分位置插入,0和N - 1是特殊情况。实现leave()的话会有所不同,因为两个较小的interval会合成一个大的interval,然而我们没有太好的方法去有效地查找这些关联的interval。
第一种方法就是每次seat()的时候去查找集合中的最大interval,这样一来我们每次需要O(n)的时间,leave()仍然是O(1)。代码如下:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ExamRoom { | |
public: | |
ExamRoom(int N) { | |
n = N; | |
} | |
int seat() { | |
if(seats.empty()) | |
{ | |
seats.insert(0); | |
return 0; | |
} | |
auto iter = seats.begin(); | |
pair<int, int> interval = {0, *iter - 1}; | |
int prev = *iter + 1; | |
++iter; | |
while(iter != seats.end()) | |
{ | |
if(dist(prev, *iter - 1) > dist(interval.first, interval.second)) | |
{ | |
interval = {prev, *iter - 1}; | |
} | |
prev = *iter + 1; | |
++iter; | |
} | |
//check the last interval | |
if(dist(prev, n - 1) > dist(interval.first, interval.second)) | |
interval = {prev, n - 1}; | |
int idx = -1; | |
if(interval.first == 0)idx = 0; | |
else if(interval.second == n - 1)idx = n - 1; | |
else idx = interval.first + (interval.second - interval.first) / 2; | |
seats.insert(idx); | |
return idx; | |
} | |
void leave(int p) { | |
seats.erase(p); | |
} | |
private: | |
set<int> seats; | |
int n; | |
int dist(int s, int e) | |
{ | |
if(s == 0)return e; | |
if(e == n - 1)return e - s; | |
if(s > e)return -1; | |
return (e - s) / 2; | |
} | |
}; | |
/** | |
* Your ExamRoom object will be instantiated and called as such: | |
* ExamRoom obj = new ExamRoom(N); | |
* int param_1 = obj.seat(); | |
* obj.leave(p); | |
*/ |
第二种方法,既然是interval的问题,我们就可以尝试用segment tree解决。我们要query的是[0, N)的最大的available的interval,我们称作maxInterval,那么我们每个节点要存的就是这个节点代表的interval的最大的available的interval。每次更新的时候:
- 首先我们可以查询左右两个区间的maxInterval
- 当然除了这两个interval之外,还可能存在横跨左右两个子区间的interval,所以我们也要查最长的跨区间的interval
- 1和2中最长的interval就是当前节点对应区间的最长available interval
那么我们如何查询跨左右两个子区间的最长interval呢?显然在每个node我们还要存一些其他的信息,我们要把当前节点对应区间的从左端点和右端点开始的available interval也存下来,这样在处理横跨左右两个子区间的interval的时候我们就看左子区间的右端点对应的区间和右子区间左端点对应的区间能否merge起来,这样我们就可以得到2中对应的interval。
举例而言比如[0, 4]中0, 3被占据了,那么左端点起没有区间是available的,右端点结束最长的available的区间为[4, 4]。如果[5, 8]中7被占据了,我们知道从5起始的最长的区间为[5, 6]。那么在处理[0, 8]的时候我们知道还有一个[4, 6] = [4, 4] + [5, 6]的区间需要考虑。
这样的话seat()和leave()都可以在O(log n)的时间复杂度里完成,空间复杂度O(n),代码如下:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
struct Node | |
{ | |
int start, end, shi, elo;//[start, hi] and [lo, end] | |
pair<int, int> maxInterval; | |
Node* left, *right; | |
Node(int s, int e, bool root) | |
{ | |
start = s; | |
end = e; | |
shi = e; | |
elo = s; | |
maxInterval = { s, e }; | |
left = nullptr; | |
right = nullptr; | |
} | |
Node* getLeft() | |
{ | |
if (start == end)return nullptr; | |
int mid = start + (end - start) / 2; | |
if (!left)left = new Node(start, mid, false); | |
return left; | |
} | |
Node* getRight() | |
{ | |
if (start == end)return nullptr; | |
int mid = start + (end - start) / 2; | |
if (!right)right = new Node(mid + 1, end, false); | |
return right; | |
} | |
void update(int key, bool sit) | |
{ | |
if (start == end) | |
{ | |
if (sit) | |
{ | |
maxInterval = { -1, -3 }; | |
shi = -1; | |
elo = -1; | |
} | |
else | |
{ | |
maxInterval = { start, end }; | |
shi = end; | |
elo = start; | |
} | |
return; | |
} | |
int mid = start + (end - start) / 2; | |
if (key <= mid)getLeft()->update(key, sit); | |
else getRight()->update(key, sit); | |
//update shi and elo | |
shi = getLeft()->shi; | |
elo = getRight()->elo; | |
//update maxInterval in current node | |
auto maxL = getLeft()->maxInterval; | |
auto maxR = getRight()->maxInterval; | |
maxInterval = maxL; | |
if (getLeft()->elo != -1 && getRight()->shi != -1) | |
{ | |
pair<int, int> merged = { getLeft()->elo, getRight()->shi }; | |
if (dist(merged) > dist(maxInterval) || merged.first == maxInterval.first) | |
maxInterval = merged; | |
//update shi and elo | |
if (shi != -1 && shi >= getLeft()->elo)shi = merged.second; | |
if (elo != -1 && elo <= getRight()->shi)elo = merged.first; | |
} | |
if (dist(maxR) > dist(maxInterval))maxInterval = maxR; | |
} | |
private: | |
int dist(const pair<int, int>& interval) | |
{ | |
return (interval.second - interval.first) / 2; | |
} | |
}; | |
class ExamRoom { | |
public: | |
ExamRoom(int N) { | |
root = new Node(0, N - 1, true); | |
n = N; | |
} | |
int seat() { | |
int idx = -1; | |
pair<int, int> maxInterval = root->maxInterval; | |
if (maxInterval.first == 0)idx = root->elo != -1 && n - root->elo > maxInterval.second + 1 ? n - 1 : 0; | |
else if (maxInterval.second == n - 1)idx = n - 1; | |
else idx = maxInterval.first + (maxInterval.second - maxInterval.first) / 2; | |
root->update(idx, true); | |
return idx; | |
} | |
void leave(int p) { | |
root->update(p, false); | |
} | |
private: | |
Node* root; | |
int n; | |
}; | |
/** | |
* Your ExamRoom object will be instantiated and called as such: | |
* ExamRoom obj = new ExamRoom(N); | |
* int param_1 = obj.seat(); | |
* obj.leave(p); | |
*/ |
No comments:
Post a Comment