线段树是一种用来解决区间查询和更新的数据结构。对于[m ,n]的区间我们可以把区间递归地二分直到区间的长度变为1。比如储存任意区间和的线段树可以表示为(image from here):
可以看出,segment tree的每个节点要不有两个子节点要不没有子节点,所以segment tree是一个full binary tree,同时也是一个balance binary tree。对于有n个叶节点的binary tree,有n - 1个非叶节点,所以空间复杂度是O(n), 我们实现的时候用array因为更加方便,但是值得注意的是,开2*n的空间是不够的,因为和heap不同,segment tree不是一个complete binary tree所以其中有一些slot我们是用不上的,一般开3*n差不多够用。这里我们以求区间和为例,建树的时候我们只需bottom up不断用子节点的值更新当前节点即可,时间复杂度O(n),代码如下,先不用在意mark,我们之后会讲:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
struct Node | |
{ | |
int val; | |
int mark; | |
Node() : val(0), mark(0) | |
{ | |
} | |
}; | |
SegmentTree(vector<int> nums) { | |
n = nums.size(); | |
if (!n)return; | |
st = vector<Node>(3 * n); | |
createTree(nums, 0, n - 1, 0); | |
} | |
void createTree(vector<int>& nums, int lo, int hi, int i) | |
{ | |
if (lo == hi) | |
{ | |
st[i].val = nums[lo]; | |
return; | |
} | |
int mid = lo + (hi - lo) / 2; | |
createTree(nums, lo, mid, 2 * i + 1); | |
createTree(nums, mid + 1, hi, 2 * i + 2); | |
st[i].val = st[2 * i + 1].val + st[2 * i + 2].val; | |
} |
更新的时候我们只需binary search找到要更新的节点然后更新路径上的所有节点,时间复杂度O(log(n)),代码如下:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
void update(int start, int end, int add) | |
{ | |
update(start, end, 0, 0, n - 1, add); | |
} | |
void update(int index, int val, int node, int lo, int hi) | |
{ | |
if (lo == hi) | |
{ | |
st[node].val = val; | |
return; | |
} | |
int mid = lo + (hi - lo) / 2; | |
if (index <= mid)update(index, val, 2 * node + 1, lo, mid); | |
else update(index, val, 2 * node + 2, mid + 1, hi); | |
st[node].val = st[2 * node + 1].val + st[2 * node + 2].val; | |
} |
查询的时候我们要考虑三种情况:
- 查询区间和当前区间没有交集,return 0
- 查询区间完全包含当前区间,return当前node的值即可
- 查询区间和当前区间相交或者当前区间完全包含查询区间,递归地查询两个子区间
那么我们有的时候要查询两个子区间,那么查询的时间复杂度还是O(log(n))吗?设想一下两种情况:
- 查询区间没有通过当前区间的中点,那么我们最终只递归查询了一个子区间
- 查询区间通过了当前区间的中点,考虑两个子区间c1,c2,我们有以下的结论:
- 查询区间要不覆盖c1的整个右半边和左半边的一部分,要不只覆盖c1右半边的一部分
- 查询区间要不覆盖c2的整个左半边和右半边的一部分,要不只覆盖c2左半边的一部分
对于2的情况,如果查询区间覆盖了整个区间的右半部分或者左半部分L,对于L我们只需要只需要直接return当前节点的值即可,不会产生额外的递归,所以对于2的情况,我们每一层最多查询4个节点,两个继续递归的节点,和两个完全覆盖直接return的节点。所以查询的时间复杂度仍然是O(log(n))。代码如下:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
int sumRange(int i, int j) { | |
return sumQuery(i, j, 0, n - 1, 0); | |
} | |
int sumQuery(int qlo, int qhi, int clo, int chi, int i) | |
{ | |
if(qlo > chi || qhi < clo)return 0; | |
if(qlo <= clo && qhi >= chi)return st[i].val; | |
int mid = clo + (chi - clo) / 2; | |
return sumQuery(qlo, qhi, clo, mid, 2 * i + 1) + sumQuery(qlo, qhi, mid + 1, chi, 2 * i + 2); | |
} |
对于区间更新,这是segment tree的精华和优势所在,我们采用lazy propagation来把更新的操作分摊到后续的每一次经过该节点的更新和查询上。具体来讲,每个node我们会加一个lazy用来存从父节点得到的更新信息,但是我们不急于apply到当前节点和所有子节点。对于一个完全被包含于更新区间[s, e]的节点node,一般来讲我们需要递归地更新其所有子节点。但是如果应用了lazy propagation,我们只需要更新node对应的值,然后更新两个子节点对应的lazy的值,当下一次我们的某个查询或者更新经过了某个子节点,我们再更具lazy的值更新对应的节点值并且继续传递到子节点。这样的话,我们只有当要用(查询/更新)这个节点的话我们才去更新它,也就是所谓的lazy propagation。此外,我们的更新信息也不会丢失因为每当applylazy的信息的时候我们都会传递给子节点知道叶节点。区间更新的时间复杂度也是O(log(n)),分析和查询类似,查询的时候我们只多了每一个节点向子节点传播mark的时间,所以仍然是O(log(n))。代码如下:
完整代码如下:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
void update(int start, int end, int add) | |
{ | |
update(start, end, 0, 0, n - 1, add); | |
} | |
void update(int s, int e, int node, int lo, int hi, int add) | |
{ | |
if (st[node].mark)propagate(node, hi - lo + 1); | |
if (lo > e || hi < s)return; | |
if (s <= lo && e >= hi) | |
{ | |
st[node].val += add; | |
if (lo != hi) | |
{ | |
st[2 * node + 1].mark += add; | |
st[2 * node + 2].mark += add; | |
} | |
return; | |
} | |
int mid = lo + (hi - lo) / 2; | |
update(s, e, 2 * node + 1, lo, mid, add); | |
update(s, e, 2 * node + 2, mid + 1, hi, add); | |
st[node].val = st[2 * node + 1].val + st[2 * node + 2].val; | |
} | |
void propagate(int node, int len) | |
{ | |
st[node].val += len * st[node].mark; | |
if (len != 1) | |
{ | |
st[2 * node + 1].mark += st[node].mark; | |
st[2 * node + 2].mark += st[node].mark; | |
} | |
st[node].mark = 0; | |
} | |
int sumQuery(int qlo, int qhi, int clo, int chi, int i) | |
{ | |
if(st[i].mark)propagate(i, chi - clo + 1); | |
if (qlo > chi || qhi < clo)return 0; | |
if (qlo <= clo && qhi >= chi)return st[i].val; | |
int mid = clo + (chi - clo) / 2; | |
int res = sumQuery(qlo, qhi, clo, mid, 2 * i + 1) + sumQuery(qlo, qhi, mid + 1, chi, 2 * i + 2); | |
} |
完整代码如下:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
struct Node | |
{ | |
int val; | |
int mark; | |
Node() : val(0), mark(0) | |
{ | |
} | |
}; | |
class NumArray { | |
public: | |
NumArray(vector<int> nums) { | |
n = nums.size(); | |
if (!n)return; | |
st = vector<Node>(3 * n); | |
createTree(nums, 0, n - 1, 0); | |
} | |
void update(int i, int val) { | |
update(i, val, 0, 0, n - 1); | |
} | |
void update(int start, int end, int add) | |
{ | |
update(start, end, 0, 0, n - 1, add); | |
} | |
int sumRange(int i, int j) { | |
return sumQuery(i, j, 0, n - 1, 0); | |
} | |
private: | |
int n; | |
vector<Node> st; | |
void createTree(vector<int>& nums, int lo, int hi, int i) | |
{ | |
if (lo == hi) | |
{ | |
st[i].val = nums[lo]; | |
return; | |
} | |
int mid = lo + (hi - lo) / 2; | |
createTree(nums, lo, mid, 2 * i + 1); | |
createTree(nums, mid + 1, hi, 2 * i + 2); | |
st[i].val = st[2 * i + 1].val + st[2 * i + 2].val; | |
} | |
void update(int index, int val, int node, int lo, int hi) | |
{ | |
if (lo == hi) | |
{ | |
st[node].val = val; | |
return; | |
} | |
int mid = lo + (hi - lo) / 2; | |
if (index <= mid)update(index, val, 2 * node + 1, lo, mid); | |
else update(index, val, 2 * node + 2, mid + 1, hi); | |
st[node].val = st[2 * node + 1].val + st[2 * node + 2].val; | |
} | |
void update(int s, int e, int node, int lo, int hi, int add) | |
{ | |
if (st[node].mark)propagate(node, hi - lo + 1); | |
if (lo > e || hi < s)return; | |
if (s <= lo && e >= hi) | |
{ | |
st[node].val += add; | |
if (lo != hi) | |
{ | |
st[2 * node + 1].mark += add; | |
st[2 * node + 2].mark += add; | |
} | |
return; | |
} | |
int mid = lo + (hi - lo) / 2; | |
update(s, e, 2 * node + 1, lo, mid, add); | |
update(s, e, 2 * node + 2, mid + 1, hi, add); | |
st[node].val = st[2 * node + 1].val + st[2 * node + 2].val; | |
} | |
void propagate(int node, int len) | |
{ | |
st[node].val += len * st[node].mark; | |
if (len != 1) | |
{ | |
st[2 * node + 1].mark += st[node].mark; | |
st[2 * node + 2].mark += st[node].mark; | |
} | |
st[node].mark = 0; | |
} | |
int sumQuery(int qlo, int qhi, int clo, int chi, int i) | |
{ | |
if(st[i].mark)propagate(i, chi - clo + 1); | |
if (qlo > chi || qhi < clo)return 0; | |
if (qlo <= clo && qhi >= chi)return st[i].val; | |
int mid = clo + (chi - clo) / 2; | |
int res = sumQuery(qlo, qhi, clo, mid, 2 * i + 1) + sumQuery(qlo, qhi, mid + 1, chi, 2 * i + 2); | |
return res; | |
} | |
}; |
2D扩展
2D Segment Tree可以用来查询2D的range query。值得注意的是,2D Segment Tree并不是QuadTree。2D Segment Tree的每一个节点都是一个1D的Segment Tree,我们以Region Sum为例:对应的4,5,6,7节点分别对应矩阵第1,2,3,4列的1D Segment Tree。2,3节点分别对应1,2列的和与3,4列的和形成的1D Segment Tree,1号节点就为所有列的和所形成的segment tree。2D的construct,query和update和1D是十分类似的。假设输入矩阵为O(m * n),以上操作对应时间复杂度分别为O(m * n), O(log(m) * log(n)) 和O(log(m) * log(n))。空间复杂度O(m * n),代码如下:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SegmentTree | |
{ | |
public: | |
SegmentTree(vector<int>& nums) | |
{ | |
n = nums.size(); | |
m_st = vector<int>(3 * n, 0); | |
createTree(0, n - 1, 0, nums); | |
} | |
void update(int i, int add) | |
{ | |
updateTree(0, n - 1, i, 0, add); | |
} | |
int query(int lo, int hi) | |
{ | |
return query(lo, hi, 0, n - 1, 0); | |
} | |
private: | |
vector<int> m_st; | |
int n; | |
void createTree(int lo, int hi, int i, vector<int>& nums) | |
{ | |
if(lo == hi) | |
{ | |
m_st[i] = nums[lo]; | |
return; | |
} | |
int mid = lo + (hi - lo) / 2; | |
createTree(lo, mid, 2 * i + 1, nums); | |
createTree(mid + 1, hi, 2 * i + 2, nums); | |
m_st[i] = m_st[2 * i + 1] + m_st[2 * i + 2]; | |
} | |
void updateTree(int lo, int hi, int idx, int i, int add) | |
{ | |
if(lo == hi) | |
{ | |
m_st[i] += add; | |
return; | |
} | |
int mid = lo + (hi - lo) / 2; | |
if(idx <= mid) | |
updateTree(lo, mid, idx, 2 * i + 1, add); | |
else | |
updateTree(mid + 1, hi, idx, 2 * i + 2, add); | |
m_st[i] = m_st[2 * i + 1] + m_st[2 * i + 2]; | |
} | |
int query(int qlo, int qhi, int lo, int hi, int i) | |
{ | |
int mid = lo + (hi - lo) / 2; | |
if(qlo > hi || qhi < lo) | |
return 0; | |
else if(qlo <= lo && qhi >= hi) | |
return m_st[i]; | |
else | |
return query(qlo, qhi, lo, mid, 2 * i + 1) + query(qlo, qhi, mid + 1, hi, 2 * i + 2); | |
} | |
}; | |
class NumMatrix { | |
public: | |
NumMatrix(vector<vector<int>> matrix) { | |
nums = matrix; | |
m = matrix.size(); | |
int n = m? matrix[0].size(): 0; | |
if(!m || !n)return; | |
m_st = vector<SegmentTree*>(3 * m, nullptr); | |
createTree(0, m - 1, 0, matrix); | |
} | |
~NumMatrix() | |
{ | |
for(auto ptr : m_st) | |
{ | |
if(ptr) | |
delete ptr; | |
} | |
} | |
void update(int row, int col, int val) { | |
int diff = val - nums[row][col]; | |
update(0, m - 1, row, col, 0, diff); | |
nums[row][col] = val; | |
} | |
int sumRegion(int row1, int col1, int row2, int col2) { | |
return query(row1, row2, col1, col2, 0, m - 1, 0); | |
} | |
private: | |
vector<SegmentTree*> m_st; | |
vector<vector<int>> nums; | |
int m; | |
vector<int> createTree(int lo, int hi, int i, vector<vector<int>>& matrix) | |
{ | |
if(lo == hi) | |
{ | |
m_st[i] = new SegmentTree(matrix[lo]); | |
return matrix[lo]; | |
} | |
int mid = lo + (hi - lo) / 2; | |
auto v1 = createTree(lo, mid, 2 * i + 1, matrix); | |
auto v2 = createTree(mid + 1, hi, 2 * i + 2, matrix); | |
for(int i = 0; i < v1.size(); ++i) | |
v1[i] += v2[i]; | |
m_st[i] = new SegmentTree(v1); | |
return v1; | |
} | |
void update(int lo, int hi, int i, int j, int idx, int add) | |
{ | |
if(lo == hi) | |
{ | |
m_st[idx]->update(j, add); | |
return; | |
} | |
int mid = lo + (hi - lo) / 2; | |
if(i <= mid) | |
update(lo, mid, i, j, 2 * idx + 1, add); | |
else | |
update(mid + 1, hi, i, j, 2 * idx + 2, add); | |
m_st[idx]->update(j, add); | |
} | |
int query(int xlo, int xhi, int ylo, int yhi, int lo, int hi, int i) | |
{ | |
int mid = lo + (hi - lo) / 2; | |
if(xhi < lo || xlo > hi) | |
return 0; | |
else if(xlo <= lo && xhi >= hi) | |
return m_st[i]->query(ylo, yhi); | |
else | |
return query(xlo, xhi, ylo, yhi, lo, mid, 2 * i + 1) + query(xlo, xhi, ylo, yhi, mid + 1, hi, 2 * i + 2); | |
} | |
}; | |
/** | |
* Your NumMatrix object will be instantiated and called as such: | |
* NumMatrix obj = new NumMatrix(matrix); | |
* obj.update(row,col,val); | |
* int param_2 = obj.sumRegion(row1,col1,row2,col2); | |
*/ |
No comments:
Post a Comment