Monday, July 9, 2018

[LeetCode]Shortest Path to Get All Keys


这是一道最短路径的问题,但是特殊的地方在于只有我们到达某一些点之后才能解锁对应的其他的点。所以普通的最短路径的算法是不行的,需要进行一定的修改。在那之前,brute force的算法是很容易想到的,因为我们的目的是取得所有的锁,锁最多也就6个,我们可以枚举所有序列的permutation。比如我们只有三把锁,我们枚举的序列为:

  • abc
  • acb
  • bac
  • bca
  • cab
  • cba
对于每一个枚举的序列,比如abc,我们依次计算a到b和b到c的最短距离,用bfs计算即可。注意每一次拿到钥匙之后要解锁对应的部分即可。假设输入的矩阵为m x n,锁的数目为k,那么permutation的数量有k!个,序列的长度为k,所以我们一共要进行k * k次bfs,每一次的时间复杂度为O(m * n),所以总的时间复杂度为O(m * n * k * k!)。代码如下:


class Solution {
public:
int shortestPathAllKeys(vector<string>& grid) {
unordered_map<char, int> map;
int m = grid.size(), n = m? grid[0].size(): 0;
for(int i = 0; i < m; ++i)
{
for(int j = 0; j < n; ++j)
{
if(grid[i][j] !='.' && grid[i][j] != '#')
map[grid[i][j]] = i * n + j;
}
}
//get permutations
int k = (map.size() - 1) / 2;
string str, curr;
for(int i = 0; i < k; ++i)str += 'a' + i;
vector<string> perms;
permutations(perms, str, curr, 0);
//for each permutation, get the shortest path by bfs
int start = map['@'], minDist = INT_MAX;
for(const auto& perm : perms)
{
int curr = start, currDist = 0, keys = 0;
bool canReach = true;
for(const auto& node : perm)
{
int res = bfs(grid, curr, map[node], keys);
//unable to reach a node
if(res == -1)
{
canReach = false;
break;
}
currDist += res;
curr = map[node];
keys |= (1 << (node - 'a'));
//early cut off
if(currDist >= minDist)
{
canReach = false;
break;
}
}
if(canReach)minDist = min(minDist, currDist);
}
return minDist == INT_MAX? -1: minDist;
}
private:
void permutations(vector<string>& res, const string& str, string& curr, int used)
{
int n = str.size();
if(curr.size() == n)
{
res.push_back(curr);
return;
}
for(int i = 0; i < n; ++i)
{
if(used & (1 << i))continue;
curr += str[i];
permutations(res, str, curr, used | (1 << i));
curr.pop_back();
}
}
int bfs(vector<string>& grid, int start, int target, int keys)
{
int m = grid.size(), n = m? grid[0].size(): 0, dist = 0;
vector<pair<int, int>> dirs = {{1, 0}, {-1, 0}, {0, 1}, {0, -1}};
queue<int> q;
unordered_set<int> visited;
q.push(start);
visited.insert(start);
while(q.size())
{
int sz = q.size();
++dist;
for(int k = 0; k < sz; ++k)
{
int curr = q.front(), i = curr / n, j = curr % n;
q.pop();
for(const auto& dir : dirs)
{
int x = i + dir.first, y = j + dir.second;
if(x >= 0 && x < m && y >= 0 && y < n && grid[x][y] != '#' && visited.find(x * n + y) == visited.end())
{
if(isupper(grid[i][j]))
{
int key = tolower(grid[i][j]) - 'a';
if((keys & (1 << key)) == 0)continue;
}
if(x * n + y == target)return dist;
q.push(x * n + y);
visited.insert(x * n + y);
}
}
}
}
return -1;
}
};



另一种思路是建带权的图,因为我们所关心的节点只有起点和大小写字母。所以我们可以只把他们看做节点,剩下的.和#都忽略。节点和节点之间用带权边相连,建图的过程我们对每个在意的节点run bfs即可。但是这并不是最终的图,因为我们只是单纯的run最短路径的算法是没有意义的,我们搜索时候的节点里同时要存当前已获取钥匙的状态。最终当节点的状态为所有钥匙都找到的状态时,就代表我们找到最短的路径。所以我们构建的图每个节点包含两个部分,分别是:

  • 对应的字符,@, a, b c 等
  • 对应获取钥匙的状态
所以最终的图是一个多层的图,每一个节点有不同的状态。通过相连的边可以从一个状态到达另一个状态对应的节点。

并且这一定是找全钥匙的最短路径,因为只要我们run dijkstra算法,其一定是按照所有节点距离源节点最短路径的长短从小到大找的。我们碰到的第一个状态为获取所有钥匙的点,一定是最短距离距源点最短的点。同样用n,m和k表示复杂度的话,建图的时间复杂度O(m * n * (2 * k + 1)),因为总共有2 * k + 1个节点是我们需要关心的。dijkstra的时间复杂度的话,因为我们搜索的节点有(2 * k + 1) * 2^k个,而对于每一个节点,其向外连接的最多可能有(2* k + 1)条边,所以边总共最多有(2 * k + 1)^2 * 2^k个,所以dijkstra的时间复杂度为E * log V = O((2 * k + 1)^2 * 2^k * log((2 * k + 1) * 2^k))。总的时间复杂度为O(m * n * (2 * k + 1) + (2 * k + 1)^2 * 2^k * log((2 * k + 1) * 2^k))。空间复杂度O(E) = O(2 * k + 1)^2 * 2^k)。代码如下:


class Solution {
public:
int shortestPathAllKeys(vector<string>& grid) {
unordered_map<char, int> map;
int m = grid.size(), n = m? grid[0].size(): 0;
for(int i = 0; i < m; ++i)
{
for(int j = 0; j < n; ++j)
{
if(grid[i][j] !='.' && grid[i][j] != '#')
map[grid[i][j]] = i * n + j;
}
}
//get permutations
int k = (map.size() - 1) / 2;
string str, curr;
for(int i = 0; i < k; ++i)str += 'a' + i;
vector<string> perms;
permutations(perms, str, curr, 0);
//for each permutation, get the shortest path by bfs
int start = map['@'], minDist = INT_MAX;
for(const auto& perm : perms)
{
int curr = start, currDist = 0, keys = 0;
bool canReach = true;
for(const auto& node : perm)
{
int res = bfs(grid, curr, map[node], keys);
//unable to reach a node
if(res == -1)
{
canReach = false;
break;
}
currDist += res;
curr = map[node];
keys |= (1 << (node - 'a'));
//early cut off
if(currDist >= minDist)
{
canReach = false;
break;
}
}
if(canReach)minDist = min(minDist, currDist);
}
return minDist == INT_MAX? -1: minDist;
}
private:
void permutations(vector<string>& res, const string& str, string& curr, int used)
{
int n = str.size();
if(curr.size() == n)
{
res.push_back(curr);
return;
}
for(int i = 0; i < n; ++i)
{
if(used & (1 << i))continue;
curr += str[i];
permutations(res, str, curr, used | (1 << i));
curr.pop_back();
}
}
int bfs(vector<string>& grid, int start, int target, int keys)
{
int m = grid.size(), n = m? grid[0].size(): 0, dist = 0;
vector<pair<int, int>> dirs = {{1, 0}, {-1, 0}, {0, 1}, {0, -1}};
queue<int> q;
unordered_set<int> visited;
q.push(start);
visited.insert(start);
while(q.size())
{
int sz = q.size();
++dist;
for(int k = 0; k < sz; ++k)
{
int curr = q.front(), i = curr / n, j = curr % n;
q.pop();
for(const auto& dir : dirs)
{
int x = i + dir.first, y = j + dir.second;
if(x >= 0 && x < m && y >= 0 && y < n && grid[i][j] != '#' && visited.find(x * n + y) == visited.end())
{
if(isupper(grid[i][j]))
{
int key = tolower(grid[i][j]) - 'a';
if((keys & (1 << key)) == 0)continue;
}
if(x * n + y == target)return dist;
q.push(x * n + y);
visited.insert(x * n + y);
}
}
}
}
return -1;
}
};

No comments:

Post a Comment