恩... AI 的作業 - A* Search Algorithm。
老師是要我們寫八陣圖啦!不過我看到 A* 的時候想,其實 A* 需要的也就是這些東西:
- 當前的 state
- 接下來可能的 state
- 要能計算兩個 state 之間的距離
所以,我把 code 提煉成一個 template function,只要餵他一個 start 一個 target,他就會把中間要走的路找出來。而這個 template function 需要:
- NODE_T:表示 state
- distance(NODE_T, NODE_T):用來計算兩個 state 之間的距離
- solution(NODE_T, NODE_T):驗證是否能從一個 state 走到另一個 state
- NODE_T::childs():接下來可以走的 state
board.h:這是用來測試的 node class 與 distance() function 實做:
#ifndef _BOARD_H
#define _BOARD_H
#include <vector>
#include <algorithm>
#include <iostream>
#include <set>
template <int SIZE>
class basic_board
{
public:
typedef std::vector< char > storage_t;
private:
class INDEX_HELPER
{
public:
INDEX_HELPER(int const x, const storage_t & s):
idx_x_(x), storage_(s)
{ }
int const operator[] (int const idx_y) const
{
return storage_[idx_y * SIZE + idx_x_];
}
private:
int idx_x_;
storage_t const & storage_;
};
public:
basic_board(storage_t const & prototype):
storage_(prototype)
{ }
basic_board(bool random_init = false):
storage_(SIZE*SIZE)
{
for (int i=0;i<storage_.size();++i)
storage_[i] = i;
if (random_init)
std::random_shuffle(storage_.begin(), storage_.end());
}
template <int S>
friend class board_distance;
template <int S>
friend class board_solution;
bool operator != (basic_board const & rhs) const
{
return storage_ != rhs.storage_;
}
INDEX_HELPER const operator [] (int const idx_x) const
{
return INDEX_HELPER(idx_x, storage_);
}
std::vector< basic_board > childs()
{
std::vector< basic_board > ret;
// first find empty node and its (x, y)
storage_t::iterator e = std::find(storage_.begin(), storage_.end(), 0);
int e_x = ( e - storage_.begin() ) % SIZE,
e_y = ( e - storage_.begin() ) / SIZE;
basic_board tmp(*this);
storage_t::iterator ee = tmp.storage_.begin() + e_y * SIZE + e_x, ii;
if (e_y - 1 >= 0) // move up?
{
ii = tmp.storage_.begin() + (e_y - 1) * SIZE + e_x;
std::iter_swap(ee, ii);
ret.push_back(tmp);
std::iter_swap(ee, ii);
}
if (e_y + 1 < SIZE) // move down?
{
ii = tmp.storage_.begin() + (e_y + 1) * SIZE + e_x;
std::iter_swap(ee, ii);
ret.push_back(tmp);
std::iter_swap(ee, ii);
}
if (e_x - 1 >= 0) // move left?
{
ii = tmp.storage_.begin() + e_y * SIZE + e_x - 1;
std::iter_swap(ee, ii);
ret.push_back(tmp);
std::iter_swap(ee, ii);
}
if (e_x + 1 < SIZE) // move right?
{
ii = tmp.storage_.begin() + e_y * SIZE + e_x + 1;
std::iter_swap(ee, ii);
ret.push_back(tmp);
std::iter_swap(ee, ii);
}
return ret;
}
private:
storage_t storage_;
};
template <int SIZE>
std::ostream &
operator << (std::ostream & lhs, basic_board<SIZE> const & rhs)
{
for (int i = 0; i < SIZE; ++i)
{
for (int j = 0; j < SIZE; ++j)
lhs << rhs[j][i] << ", ";
lhs << std::endl;
}
}
// solution verifier
template <int SIZE>
class board_solution
{
public:
bool operator() (basic_board<SIZE> lhs, basic_board<SIZE> rhs)
{
// move both empty cell to upper-left corner
typename basic_board<SIZE>::storage_t::iterator e, i;
int e_x, e_y;
e = std::find(lhs.storage_.begin(), lhs.storage_.end(), 0);
e_x = ( e - lhs.storage_.begin() ) % SIZE;
e_y = ( e - lhs.storage_.begin() ) / SIZE;
while (e_y - 1 >= 0) // move up?
{
i = lhs.storage_.begin() + (e_y - 1) * SIZE + e_x;
std::iter_swap(e, i);
e = i;
e_x = ( e - lhs.storage_.begin() ) % SIZE;
e_y = ( e - lhs.storage_.begin() ) / SIZE;
}
while (e_x - 1 >= 0) // move left?
{
i = lhs.storage_.begin() + e_y * SIZE + e_x - 1;
std::iter_swap(e, i);
e = i;
e_x = ( e - lhs.storage_.begin() ) % SIZE;
e_y = ( e - lhs.storage_.begin() ) / SIZE;
}
e = std::find(rhs.storage_.begin(), rhs.storage_.end(), 0);
e_x = ( e - rhs.storage_.begin() ) % SIZE;
e_y = ( e - rhs.storage_.begin() ) / SIZE;
while (e_y - 1 >= 0) // move up?
{
i = rhs.storage_.begin() + (e_y - 1) * SIZE + e_x;
std::iter_swap(e, i);
e = i;
e_x = ( e - rhs.storage_.begin() ) % SIZE;
e_y = ( e - rhs.storage_.begin() ) / SIZE;
}
while (e_x - 1 >= 0) // move left?
{
i = rhs.storage_.begin() + e_y * SIZE + e_x - 1;
std::iter_swap(e, i);
e = i;
e_x = ( e - rhs.storage_.begin() ) % SIZE;
e_y = ( e - rhs.storage_.begin() ) / SIZE;
}
std::cout << lhs << std::endl;
std::cout << rhs << std::endl;
// checking parity
int p(0);
for (e = lhs.storage_.begin() + 1;
e != lhs.storage_.end();
++e)
{
i = std::find(rhs.storage_.begin()+1, rhs.storage_.end(), *e);
for (typename basic_board<SIZE>::storage_t::iterator j = e;
j != lhs.storage_.end();
++j)
if (std::find(rhs.storage_.begin()+1, i, *j) != i)
++p;
}
std::cout << "Parity: " << p << " (" << p%2 << ")" << std::endl;
return (p%2) != ((SIZE*SIZE)%2);
}
};
// manhatten distance
template <int SIZE>
class board_distance
{
public:
board_distance(basic_board<SIZE> const & target):
pos_cache_(SIZE*SIZE)
{
for (int i=0;i<SIZE*SIZE;++i)
{
typename basic_board<SIZE>::storage_t::const_iterator
t = std::find(target.storage_.begin(), target.storage_.end(), i);
pos_cache_[i] = std::make_pair(
( t - target.storage_.begin() ) % SIZE,
( t - target.storage_.begin() ) / SIZE
);
}
}
int operator() (basic_board<SIZE> const & lhs)
{
int ret(0);
for (int i=1;i<SIZE*SIZE;++i)
{
typename basic_board<SIZE>::storage_t::const_iterator
l = std::find(lhs.storage_.begin(), lhs.storage_.end(), i);
ret += std::abs(
pos_cache_[i].first -
( ( l - lhs.storage_.begin() ) % SIZE )
) + std::abs(
pos_cache_[i].second -
( ( l - lhs.storage_.begin() ) / SIZE )
);
}
return ret;
}
int operator() (basic_board<SIZE> const & lhs, basic_board<SIZE> const & rhs)
{
int ret(0);
for (int i=1;i<SIZE*SIZE;++i)
{
typename basic_board<SIZE>::storage_t::const_iterator
l = std::find(lhs.storage_.begin(), lhs.storage_.end(), i),
r = std::find(rhs.storage_.begin(), rhs.storage_.end(), i);
ret += std::abs(
( ( l - lhs.storage_.begin() ) % SIZE ) -
( ( r - rhs.storage_.begin() ) % SIZE )
) + std::abs(
( ( l - lhs.storage_.begin() ) / SIZE ) -
( ( r - rhs.storage_.begin() ) / SIZE )
);
}
return ret;
}
private:
std::vector< std::pair<int, int> > pos_cache_;
};
typedef basic_board<3> Board_3x3;
typedef basic_board<4> Board_4x4;
typedef basic_board<5> Board_5x5;
typedef board_distance<3> Distance_3x3;
typedef board_distance<4> Distance_4x4;
typedef board_distance<5> Distance_5x5;
typedef board_solution<3> Solution_3x3;
typedef board_solution<4> Solution_4x4;
typedef board_solution<5> Solution_5x5;
#endif
astar.h:這是 A* 演算法與一個 helper class。
#ifndef _ASTAR_H
#define _ASTAR_H
#include <map>
#include <vector>
template <typename COST_T, typename NODE_T>
class NODE_HELPER
{
public:
NODE_HELPER(COST_T c, NODE_T const & s, int p):
cost_(c), state_(s), parent_(p)
{ }
int parent() const { return parent_; }
int cost() const { return cost_; }
NODE_T & state() { return state_; };
std::vector< NODE_T > childs() { return state_.childs(); }
private:
int parent_; // parent step
COST_T cost_; // cost from start to this state
NODE_T state_; // current state
};
template <typename NODE_T, typename DISTANCE_T, typename VERIFIER_T>
std::vector< NODE_T >
astar_search(NODE_T const & start, NODE_T const & target, DISTANCE_T dist, VERIFIER_T sol)
{
if (!sol(start, target))
return std::vector< NODE_T >();
std::multimap<int, NODE_HELPER<int, NODE_T> > pending;
pending.insert(std::make_pair(dist(start, target), NODE_HELPER<int, NODE_T>(0, start, 0)));
std::vector< NODE_HELPER<int, NODE_T> > solution;
int n_iter(0);
while(dist(pending.begin()->second.state()))
{
solution.push_back(pending.begin()->second);
int cost_so_far = pending.begin()->second.cost();
pending.erase(pending.begin());
std::vector< NODE_T > tmp_cld = solution.rbegin()->childs();
for(typename std::vector< NODE_T >::iterator i = tmp_cld.begin();
i != tmp_cld.end();
++i)
{
int cost = cost_so_far + dist(solution.rbegin()->state(), *i);
pending.insert(
std::make_pair(
cost + dist(*i),
NODE_HELPER<int, NODE_T>(
cost,
*i,
solution.size() - 1)
)
);
}
}
std::vector< NODE_T > ret;
ret.push_back(pending.begin()->second.state());
for (int p_idx = pending.begin()->second.parent();
p_idx != 0;
p_idx = solution[p_idx].parent())
ret.push_back(solution[p_idx].state());
ret.push_back(start);
std::reverse(ret.begin(), ret.end());
return ret;
}
#endif
main.cpp:用來測試的...
#include <iostream>
#include <vector>
#include <sstream>
#include <unistd.h>
#include "board.h"
#include "astar.h"
typedef Board_4x4 BOARD_T;
typedef Distance_4x4 DIST_T;
typedef Solution_4x4 SOL_T;
int main(int argc, char* argv[])
{
int random_seed = 0;
if (argc > 1)
{
std::stringstream ss(argv[1]);
ss >> random_seed;
}
srand(random_seed);
std::cout << "Random seed: " << random_seed << std::endl;
BOARD_T b(true), t(false);
std::cout << "Start:" << std::endl << b << std::endl;
std::cout << "End:" << std::endl << t << std::endl;
std::cout << "Trying hard to solve (with A*)..." << std::endl;
std::vector< BOARD_T > result = astar_search(b, t, DIST_T(t), SOL_T() );
std::cout << "Solution steps: " << result.size() << std::endl;
for(std::vector< BOARD_T >::iterator i = result.begin();
i != result.end();
++i)
{
//sleep(1);
std::cout << *i << std::endl;
}
return 0;
}
如果編不起來,那是你的問題!BSD License!