2D Range Query问题,可以用2D Segment Tree,Quad Tree或者2D Binary Index Tree来解决,具体数据结构讲解和时间复杂度分析请参考对应链接。
2D Segment Tree解法:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SegmentTree | |
{ | |
public: | |
SegmentTree(vector<int>& nums) | |
{ | |
n = nums.size(); | |
m_st = vector<int>(3 * n, 0); | |
createTree(0, n - 1, 0, nums); | |
} | |
void update(int i, int add) | |
{ | |
updateTree(0, n - 1, i, 0, add); | |
} | |
int query(int lo, int hi) | |
{ | |
return query(lo, hi, 0, n - 1, 0); | |
} | |
private: | |
vector<int> m_st; | |
int n; | |
void createTree(int lo, int hi, int i, vector<int>& nums) | |
{ | |
if(lo == hi) | |
{ | |
m_st[i] = nums[lo]; | |
return; | |
} | |
int mid = lo + (hi - lo) / 2; | |
createTree(lo, mid, 2 * i + 1, nums); | |
createTree(mid + 1, hi, 2 * i + 2, nums); | |
m_st[i] = m_st[2 * i + 1] + m_st[2 * i + 2]; | |
} | |
void updateTree(int lo, int hi, int idx, int i, int add) | |
{ | |
if(lo == hi) | |
{ | |
m_st[i] += add; | |
return; | |
} | |
int mid = lo + (hi - lo) / 2; | |
if(idx <= mid) | |
updateTree(lo, mid, idx, 2 * i + 1, add); | |
else | |
updateTree(mid + 1, hi, idx, 2 * i + 2, add); | |
m_st[i] = m_st[2 * i + 1] + m_st[2 * i + 2]; | |
} | |
int query(int qlo, int qhi, int lo, int hi, int i) | |
{ | |
int mid = lo + (hi - lo) / 2; | |
if(qlo > hi || qhi < lo) | |
return 0; | |
else if(qlo <= lo && qhi >= hi) | |
return m_st[i]; | |
else | |
return query(qlo, qhi, lo, mid, 2 * i + 1) + query(qlo, qhi, mid + 1, hi, 2 * i + 2); | |
} | |
}; | |
class NumMatrix { | |
public: | |
NumMatrix(vector<vector<int>> matrix) { | |
nums = matrix; | |
m = matrix.size(); | |
int n = m? matrix[0].size(): 0; | |
if(!m || !n)return; | |
m_st = vector<SegmentTree*>(3 * m, nullptr); | |
createTree(0, m - 1, 0, matrix); | |
} | |
~NumMatrix() | |
{ | |
for(auto ptr : m_st) | |
{ | |
if(ptr) | |
delete ptr; | |
} | |
} | |
void update(int row, int col, int val) { | |
int diff = val - nums[row][col]; | |
update(0, m - 1, row, col, 0, diff); | |
nums[row][col] = val; | |
} | |
int sumRegion(int row1, int col1, int row2, int col2) { | |
return query(row1, row2, col1, col2, 0, m - 1, 0); | |
} | |
private: | |
vector<SegmentTree*> m_st; | |
vector<vector<int>> nums; | |
int m; | |
vector<int> createTree(int lo, int hi, int i, vector<vector<int>>& matrix) | |
{ | |
if(lo == hi) | |
{ | |
m_st[i] = new SegmentTree(matrix[lo]); | |
return matrix[lo]; | |
} | |
int mid = lo + (hi - lo) / 2; | |
auto v1 = createTree(lo, mid, 2 * i + 1, matrix); | |
auto v2 = createTree(mid + 1, hi, 2 * i + 2, matrix); | |
for(int i = 0; i < v1.size(); ++i) | |
v1[i] += v2[i]; | |
m_st[i] = new SegmentTree(v1); | |
return v1; | |
} | |
void update(int lo, int hi, int i, int j, int idx, int add) | |
{ | |
if(lo == hi) | |
{ | |
m_st[idx]->update(j, add); | |
return; | |
} | |
int mid = lo + (hi - lo) / 2; | |
if(i <= mid) | |
update(lo, mid, i, j, 2 * idx + 1, add); | |
else | |
update(mid + 1, hi, i, j, 2 * idx + 2, add); | |
m_st[idx]->update(j, add); | |
} | |
int query(int xlo, int xhi, int ylo, int yhi, int lo, int hi, int i) | |
{ | |
int mid = lo + (hi - lo) / 2; | |
if(xhi < lo || xlo > hi) | |
return 0; | |
else if(xlo <= lo && xhi >= hi) | |
return m_st[i]->query(ylo, yhi); | |
else | |
return query(xlo, xhi, ylo, yhi, lo, mid, 2 * i + 1) + query(xlo, xhi, ylo, yhi, mid + 1, hi, 2 * i + 2); | |
} | |
}; | |
/** | |
* Your NumMatrix object will be instantiated and called as such: | |
* NumMatrix obj = new NumMatrix(matrix); | |
* obj.update(row,col,val); | |
* int param_2 = obj.sumRegion(row1,col1,row2,col2); | |
*/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class NumMatrix | |
{ | |
struct Node | |
{ | |
int sum = 0; | |
Node* children[4] = {nullptr, nullptr, nullptr, nullptr}; | |
Node(int val) : sum(val) | |
{ | |
} | |
}; | |
public: | |
NumMatrix(vector<vector<int>> matrix) | |
{ | |
m_m = matrix.size(), m_n = m_m? matrix[0].size(): 0; | |
if(!m_m || !m_n)return; | |
root = createTree(matrix, 0, m_m - 1, 0, m_n - 1); | |
} | |
~NumMatrix() | |
{ | |
del(root); | |
} | |
int sumRegion(int row1, int col1, int row2, int col2) | |
{ | |
if(row1 < 0 || row2 >= m_m || col1 < 0 || col2 >= m_n)return 0; | |
return query(root, 0, m_m - 1, 0, m_n - 1, row1, row2, col1, col2); | |
} | |
void update(int row, int col, int val) | |
{ | |
if(row < 0 || row >= m_m || col < 0 || col >= m_n)return; | |
update(root, 0, m_m - 1, 0, m_n - 1, row, col, val); | |
} | |
private: | |
int m_m, m_n; | |
Node* root = nullptr; | |
Node* createTree(vector<vector<int>>& matrix, int row1, int row2, int col1, int col2) | |
{ | |
if(row1 > row2 || col1 > col2)return nullptr; | |
if(row1 == row2 && col1 == col2)return new Node(matrix[row1][col1]); | |
int rMid = row1 + (row2 - row1) / 2, cMid = col1 + (col2 - col1) / 2; | |
auto topL = createTree(matrix, row1, rMid, col1, cMid); | |
auto topR = createTree(matrix, row1, rMid, cMid + 1, col2); | |
auto botL = createTree(matrix, rMid + 1, row2, col1, cMid); | |
auto botR = createTree(matrix, rMid + 1, row2, cMid + 1, col2); | |
Node* node = new Node(val(topL) + val(topR) + val(botL) + val(botR)); | |
node->children[0] = topL; | |
node->children[1] = topR; | |
node->children[2] = botL; | |
node->children[3] = botR; | |
return node; | |
} | |
int query(Node* curr, int row1, int row2, int col1, int col2, int qRow1, int qRow2, int qCol1, int qCol2) | |
{ | |
if (row1 > row2 || col1 > col2)return 0; | |
if(qRow1 > row2 || qRow2 < row1 || qCol1 > col2 || qCol2 < col1)return 0; | |
if(qRow1 <= row1 && qCol1 <= col1 && qRow2 >= row2 && qCol2 >= col2)return curr->sum; | |
int rMid = row1 + (row2 - row1) / 2, cMid = col1 + (col2 - col1) / 2; | |
int res = query(curr->children[0], row1, rMid, col1, cMid, qRow1, qRow2, qCol1, qCol2); | |
res += query(curr->children[1], row1, rMid, cMid + 1, col2, qRow1, qRow2, qCol1, qCol2); | |
res += query(curr->children[2], rMid + 1, row2, col1, cMid, qRow1, qRow2, qCol1, qCol2); | |
res += query(curr->children[3], rMid + 1, row2, cMid + 1, col2, qRow1, qRow2, qCol1, qCol2); | |
return res; | |
} | |
int update(Node* curr, int row1, int row2, int col1, int col2, int row, int col, int val) | |
{ | |
if(row1 == row2 && col1 == col2) | |
{ | |
int add = val - curr->sum; | |
curr->sum += add; | |
return add; | |
} | |
int add = 0, rMid = row1 + (row2 - row1) / 2, cMid = col1 + (col2 - col1) / 2; | |
if(row <= rMid && col <= cMid)add = update(curr->children[0], row1, rMid, col1, cMid, row, col, val); | |
else if(row <= rMid && col > cMid)add = update(curr->children[1], row1, rMid, cMid + 1, col2, row, col, val); | |
else if(row > rMid && col <= cMid)add = update(curr->children[2], rMid + 1, row2, col1, cMid, row, col, val); | |
else add = update(curr->children[3], rMid + 1, row2, cMid + 1, col2, row, col, val); | |
curr->sum += add; | |
return add; | |
} | |
void del(Node* node) | |
{ | |
if(!node)return; | |
for(int i = 0; i < 4; ++i)del(node->children[i]); | |
delete node; | |
} | |
int val(Node* n) | |
{ | |
return n? n->sum: 0; | |
} | |
}; | |
/** | |
* Your NumMatrix object will be instantiated and called as such: | |
* NumMatrix obj = new NumMatrix(matrix); | |
* obj.update(row,col,val); | |
* int param_2 = obj.sumRegion(row1,col1,row2,col2); | |
*/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class NumMatrix { | |
public: | |
NumMatrix(vector<vector<int>> matrix) { | |
m = matrix.size(); | |
n = m? matrix[0].size(): 0; | |
if(!m || !n)return; | |
m_BIT = vector<vector<int>>(m + 1, vector<int>(n + 1, 0)); | |
m_matrix = matrix; | |
for(int i = 0; i < m; ++i) | |
{ | |
for(int j = 0; j < n; ++j) | |
add(i, j, matrix[i][j]); | |
} | |
} | |
void update(int row, int col, int val) { | |
if(!m || !n)return; | |
int diff = val - m_matrix[row][col]; | |
add(row, col, diff); | |
m_matrix[row][col] = val; | |
} | |
int sumRegion(int row1, int col1, int row2, int col2) { | |
if(!m || !n)return 0; | |
return sumTo(row2, col2) - sumTo(row2, col1 - 1) - sumTo(row1 -1, col2) + sumTo(row1 - 1, col1 - 1); | |
} | |
private: | |
vector<vector<int>> m_BIT; | |
vector<vector<int>> m_matrix; | |
int m, n; | |
void add(int row, int col, int diff) | |
{ | |
if(!m || !n)return; | |
for(int i = row + 1; i <= m; i += i & -i) | |
for(int j = col + 1; j <= n; j += j & -j) | |
m_BIT[i][j] += diff; | |
} | |
int sumTo(int row, int col) | |
{ | |
int sum = 0; | |
for(int i = row + 1; i > 0; i -= i & -i) | |
for(int j = col + 1; j > 0; j -= j & -j) | |
sum += m_BIT[i][j]; | |
return sum; | |
} | |
}; | |
/** | |
* Your NumMatrix object will be instantiated and called as such: | |
* NumMatrix obj = new NumMatrix(matrix); | |
* obj.update(row,col,val); | |
* int param_2 = obj.sumRegion(row1,col1,row2,col2); | |
*/ |
No comments:
Post a Comment