Monday, October 16, 2017

[LeetCode]Range Sum Query 2D - Mutable


2D Range Query问题,可以用2D Segment TreeQuad Tree或者2D Binary Index Tree来解决,具体数据结构讲解和时间复杂度分析请参考对应链接。

2D Segment Tree解法:

class SegmentTree
{
public:
SegmentTree(vector<int>& nums)
{
n = nums.size();
m_st = vector<int>(3 * n, 0);
createTree(0, n - 1, 0, nums);
}
void update(int i, int add)
{
updateTree(0, n - 1, i, 0, add);
}
int query(int lo, int hi)
{
return query(lo, hi, 0, n - 1, 0);
}
private:
vector<int> m_st;
int n;
void createTree(int lo, int hi, int i, vector<int>& nums)
{
if(lo == hi)
{
m_st[i] = nums[lo];
return;
}
int mid = lo + (hi - lo) / 2;
createTree(lo, mid, 2 * i + 1, nums);
createTree(mid + 1, hi, 2 * i + 2, nums);
m_st[i] = m_st[2 * i + 1] + m_st[2 * i + 2];
}
void updateTree(int lo, int hi, int idx, int i, int add)
{
if(lo == hi)
{
m_st[i] += add;
return;
}
int mid = lo + (hi - lo) / 2;
if(idx <= mid)
updateTree(lo, mid, idx, 2 * i + 1, add);
else
updateTree(mid + 1, hi, idx, 2 * i + 2, add);
m_st[i] = m_st[2 * i + 1] + m_st[2 * i + 2];
}
int query(int qlo, int qhi, int lo, int hi, int i)
{
int mid = lo + (hi - lo) / 2;
if(qlo > hi || qhi < lo)
return 0;
else if(qlo <= lo && qhi >= hi)
return m_st[i];
else
return query(qlo, qhi, lo, mid, 2 * i + 1) + query(qlo, qhi, mid + 1, hi, 2 * i + 2);
}
};
class NumMatrix {
public:
NumMatrix(vector<vector<int>> matrix) {
nums = matrix;
m = matrix.size();
int n = m? matrix[0].size(): 0;
if(!m || !n)return;
m_st = vector<SegmentTree*>(3 * m, nullptr);
createTree(0, m - 1, 0, matrix);
}
~NumMatrix()
{
for(auto ptr : m_st)
{
if(ptr)
delete ptr;
}
}
void update(int row, int col, int val) {
int diff = val - nums[row][col];
update(0, m - 1, row, col, 0, diff);
nums[row][col] = val;
}
int sumRegion(int row1, int col1, int row2, int col2) {
return query(row1, row2, col1, col2, 0, m - 1, 0);
}
private:
vector<SegmentTree*> m_st;
vector<vector<int>> nums;
int m;
vector<int> createTree(int lo, int hi, int i, vector<vector<int>>& matrix)
{
if(lo == hi)
{
m_st[i] = new SegmentTree(matrix[lo]);
return matrix[lo];
}
int mid = lo + (hi - lo) / 2;
auto v1 = createTree(lo, mid, 2 * i + 1, matrix);
auto v2 = createTree(mid + 1, hi, 2 * i + 2, matrix);
for(int i = 0; i < v1.size(); ++i)
v1[i] += v2[i];
m_st[i] = new SegmentTree(v1);
return v1;
}
void update(int lo, int hi, int i, int j, int idx, int add)
{
if(lo == hi)
{
m_st[idx]->update(j, add);
return;
}
int mid = lo + (hi - lo) / 2;
if(i <= mid)
update(lo, mid, i, j, 2 * idx + 1, add);
else
update(mid + 1, hi, i, j, 2 * idx + 2, add);
m_st[idx]->update(j, add);
}
int query(int xlo, int xhi, int ylo, int yhi, int lo, int hi, int i)
{
int mid = lo + (hi - lo) / 2;
if(xhi < lo || xlo > hi)
return 0;
else if(xlo <= lo && xhi >= hi)
return m_st[i]->query(ylo, yhi);
else
return query(xlo, xhi, ylo, yhi, lo, mid, 2 * i + 1) + query(xlo, xhi, ylo, yhi, mid + 1, hi, 2 * i + 2);
}
};
/**
* Your NumMatrix object will be instantiated and called as such:
* NumMatrix obj = new NumMatrix(matrix);
* obj.update(row,col,val);
* int param_2 = obj.sumRegion(row1,col1,row2,col2);
*/
Quad Tree解法:

class NumMatrix
{
struct Node
{
int sum = 0;
Node* children[4] = {nullptr, nullptr, nullptr, nullptr};
Node(int val) : sum(val)
{
}
};
public:
NumMatrix(vector<vector<int>> matrix)
{
m_m = matrix.size(), m_n = m_m? matrix[0].size(): 0;
if(!m_m || !m_n)return;
root = createTree(matrix, 0, m_m - 1, 0, m_n - 1);
}
~NumMatrix()
{
del(root);
}
int sumRegion(int row1, int col1, int row2, int col2)
{
if(row1 < 0 || row2 >= m_m || col1 < 0 || col2 >= m_n)return 0;
return query(root, 0, m_m - 1, 0, m_n - 1, row1, row2, col1, col2);
}
void update(int row, int col, int val)
{
if(row < 0 || row >= m_m || col < 0 || col >= m_n)return;
update(root, 0, m_m - 1, 0, m_n - 1, row, col, val);
}
private:
int m_m, m_n;
Node* root = nullptr;
Node* createTree(vector<vector<int>>& matrix, int row1, int row2, int col1, int col2)
{
if(row1 > row2 || col1 > col2)return nullptr;
if(row1 == row2 && col1 == col2)return new Node(matrix[row1][col1]);
int rMid = row1 + (row2 - row1) / 2, cMid = col1 + (col2 - col1) / 2;
auto topL = createTree(matrix, row1, rMid, col1, cMid);
auto topR = createTree(matrix, row1, rMid, cMid + 1, col2);
auto botL = createTree(matrix, rMid + 1, row2, col1, cMid);
auto botR = createTree(matrix, rMid + 1, row2, cMid + 1, col2);
Node* node = new Node(val(topL) + val(topR) + val(botL) + val(botR));
node->children[0] = topL;
node->children[1] = topR;
node->children[2] = botL;
node->children[3] = botR;
return node;
}
int query(Node* curr, int row1, int row2, int col1, int col2, int qRow1, int qRow2, int qCol1, int qCol2)
{
if (row1 > row2 || col1 > col2)return 0;
if(qRow1 > row2 || qRow2 < row1 || qCol1 > col2 || qCol2 < col1)return 0;
if(qRow1 <= row1 && qCol1 <= col1 && qRow2 >= row2 && qCol2 >= col2)return curr->sum;
int rMid = row1 + (row2 - row1) / 2, cMid = col1 + (col2 - col1) / 2;
int res = query(curr->children[0], row1, rMid, col1, cMid, qRow1, qRow2, qCol1, qCol2);
res += query(curr->children[1], row1, rMid, cMid + 1, col2, qRow1, qRow2, qCol1, qCol2);
res += query(curr->children[2], rMid + 1, row2, col1, cMid, qRow1, qRow2, qCol1, qCol2);
res += query(curr->children[3], rMid + 1, row2, cMid + 1, col2, qRow1, qRow2, qCol1, qCol2);
return res;
}
int update(Node* curr, int row1, int row2, int col1, int col2, int row, int col, int val)
{
if(row1 == row2 && col1 == col2)
{
int add = val - curr->sum;
curr->sum += add;
return add;
}
int add = 0, rMid = row1 + (row2 - row1) / 2, cMid = col1 + (col2 - col1) / 2;
if(row <= rMid && col <= cMid)add = update(curr->children[0], row1, rMid, col1, cMid, row, col, val);
else if(row <= rMid && col > cMid)add = update(curr->children[1], row1, rMid, cMid + 1, col2, row, col, val);
else if(row > rMid && col <= cMid)add = update(curr->children[2], rMid + 1, row2, col1, cMid, row, col, val);
else add = update(curr->children[3], rMid + 1, row2, cMid + 1, col2, row, col, val);
curr->sum += add;
return add;
}
void del(Node* node)
{
if(!node)return;
for(int i = 0; i < 4; ++i)del(node->children[i]);
delete node;
}
int val(Node* n)
{
return n? n->sum: 0;
}
};
/**
* Your NumMatrix object will be instantiated and called as such:
* NumMatrix obj = new NumMatrix(matrix);
* obj.update(row,col,val);
* int param_2 = obj.sumRegion(row1,col1,row2,col2);
*/
view raw quadtree.cpp hosted with ❤ by GitHub
2D Binary Index Tree解法:


class NumMatrix {
public:
NumMatrix(vector<vector<int>> matrix) {
m = matrix.size();
n = m? matrix[0].size(): 0;
if(!m || !n)return;
m_BIT = vector<vector<int>>(m + 1, vector<int>(n + 1, 0));
m_matrix = matrix;
for(int i = 0; i < m; ++i)
{
for(int j = 0; j < n; ++j)
add(i, j, matrix[i][j]);
}
}
void update(int row, int col, int val) {
if(!m || !n)return;
int diff = val - m_matrix[row][col];
add(row, col, diff);
m_matrix[row][col] = val;
}
int sumRegion(int row1, int col1, int row2, int col2) {
if(!m || !n)return 0;
return sumTo(row2, col2) - sumTo(row2, col1 - 1) - sumTo(row1 -1, col2) + sumTo(row1 - 1, col1 - 1);
}
private:
vector<vector<int>> m_BIT;
vector<vector<int>> m_matrix;
int m, n;
void add(int row, int col, int diff)
{
if(!m || !n)return;
for(int i = row + 1; i <= m; i += i & -i)
for(int j = col + 1; j <= n; j += j & -j)
m_BIT[i][j] += diff;
}
int sumTo(int row, int col)
{
int sum = 0;
for(int i = row + 1; i > 0; i -= i & -i)
for(int j = col + 1; j > 0; j -= j & -j)
sum += m_BIT[i][j];
return sum;
}
};
/**
* Your NumMatrix object will be instantiated and called as such:
* NumMatrix obj = new NumMatrix(matrix);
* obj.update(row,col,val);
* int param_2 = obj.sumRegion(row1,col1,row2,col2);
*/

No comments:

Post a Comment