星期四, 4月 05, 2007

啟發式搜尋演算法 - A* Algorithm

恩... AI 的作業 - A* Search Algorithm

老師是要我們寫八陣圖啦!不過我看到 A* 的時候想,其實 A* 需要的也就是這些東西:

  • 當前的 state
  • 接下來可能的 state
  • 要能計算兩個 state 之間的距離

所以,我把 code 提煉成一個 template function,只要餵他一個 start 一個 target,他就會把中間要走的路找出來。而這個 template function 需要:

  • NODE_T:表示 state
  • distance(NODE_T, NODE_T):用來計算兩個 state 之間的距離
  • solution(NODE_T, NODE_T):驗證是否能從一個 state 走到另一個 state
  • NODE_T::childs():接下來可以走的 state

board.h:這是用來測試的 node class 與 distance() function 實做: #ifndef _BOARD_H
#define _BOARD_H

#include <vector>
#include <algorithm>
#include <iostream>
#include <set>

template <int SIZE>
class basic_board
{
public:
        typedef std::vector< char > storage_t;

private:
        class INDEX_HELPER
        {
        public:
                INDEX_HELPER(int const x, const storage_t & s):
                        idx_x_(x), storage_(s)
                { }

                int const operator[] (int const idx_y) const
                {
                        return storage_[idx_y * SIZE + idx_x_];
                }

        private:
                int idx_x_;
                storage_t const & storage_;
        };
public:
        basic_board(storage_t const & prototype):
                storage_(prototype)
        { }

        basic_board(bool random_init = false):
                storage_(SIZE*SIZE)
        {
                for (int i=0;i<storage_.size();++i)
                        storage_[i] = i;
                if (random_init)
                        std::random_shuffle(storage_.begin(), storage_.end());
        }

        template <int S>
        friend class board_distance;

        template <int S>
        friend class board_solution;

        bool operator != (basic_board const & rhs) const
        {
                return storage_ != rhs.storage_;
        }

        INDEX_HELPER const operator [] (int const idx_x) const
        {
                return INDEX_HELPER(idx_x, storage_);
        }

        std::vector< basic_board > childs()
        {
                std::vector< basic_board > ret;

                // first find empty node and its (x, y)
                storage_t::iterator e = std::find(storage_.begin(), storage_.end(), 0);
                int e_x = ( e - storage_.begin() ) % SIZE,
                    e_y = ( e - storage_.begin() ) / SIZE;

                basic_board tmp(*this);
                storage_t::iterator ee = tmp.storage_.begin() + e_y * SIZE + e_x, ii;
                if (e_y - 1 >= 0)        // move up?
                {
                        ii = tmp.storage_.begin() + (e_y - 1) * SIZE + e_x;
                        std::iter_swap(ee, ii);
                        ret.push_back(tmp);
                        std::iter_swap(ee, ii);
                }
                if (e_y + 1 < SIZE)        // move down?
                {
                        ii = tmp.storage_.begin() + (e_y + 1) * SIZE + e_x;
                        std::iter_swap(ee, ii);
                        ret.push_back(tmp);
                        std::iter_swap(ee, ii);
                }
                if (e_x - 1 >= 0)        // move left?
                {
                        ii = tmp.storage_.begin() + e_y * SIZE + e_x - 1;
                        std::iter_swap(ee, ii);
                        ret.push_back(tmp);
                        std::iter_swap(ee, ii);
                }
                if (e_x + 1 < SIZE)        // move right?
                {
                        ii = tmp.storage_.begin() + e_y * SIZE + e_x + 1;
                        std::iter_swap(ee, ii);
                        ret.push_back(tmp);
                        std::iter_swap(ee, ii);
                }

                return ret;
        }

private:
        storage_t storage_;
};

template <int SIZE>
std::ostream &
operator << (std::ostream & lhs, basic_board<SIZE> const & rhs)
{
        for (int i = 0; i < SIZE; ++i)
        {
                for (int j = 0; j < SIZE; ++j)
                        lhs << rhs[j][i] << ", ";
                lhs << std::endl;
        }
}

// solution verifier
template <int SIZE>
class board_solution
{
public:
        bool operator() (basic_board<SIZE> lhs, basic_board<SIZE> rhs)
        {
                // move both empty cell to upper-left corner
                typename basic_board<SIZE>::storage_t::iterator e, i;
                int e_x, e_y;
                
                e = std::find(lhs.storage_.begin(), lhs.storage_.end(), 0);
                e_x = ( e - lhs.storage_.begin() ) % SIZE;
                e_y = ( e - lhs.storage_.begin() ) / SIZE;
                while (e_y - 1 >= 0)        // move up?
                {
                        i = lhs.storage_.begin() + (e_y - 1) * SIZE + e_x;
                        std::iter_swap(e, i);
                        e = i;
                        e_x = ( e - lhs.storage_.begin() ) % SIZE;
                        e_y = ( e - lhs.storage_.begin() ) / SIZE;
                }
                while (e_x - 1 >= 0)        // move left?
                {
                        i = lhs.storage_.begin() + e_y * SIZE + e_x - 1;
                        std::iter_swap(e, i);
                        e = i;
                        e_x = ( e - lhs.storage_.begin() ) % SIZE;
                        e_y = ( e - lhs.storage_.begin() ) / SIZE;
                }
                
                e = std::find(rhs.storage_.begin(), rhs.storage_.end(), 0);
                e_x = ( e - rhs.storage_.begin() ) % SIZE;
                e_y = ( e - rhs.storage_.begin() ) / SIZE;
                while (e_y - 1 >= 0)        // move up?
                {
                        i = rhs.storage_.begin() + (e_y - 1) * SIZE + e_x;
                        std::iter_swap(e, i);
                        e = i;
                        e_x = ( e - rhs.storage_.begin() ) % SIZE;
                        e_y = ( e - rhs.storage_.begin() ) / SIZE;
                }
                while (e_x - 1 >= 0)        // move left?
                {
                        i = rhs.storage_.begin() + e_y * SIZE + e_x - 1;
                        std::iter_swap(e, i);
                        e = i;
                        e_x = ( e - rhs.storage_.begin() ) % SIZE;
                        e_y = ( e - rhs.storage_.begin() ) / SIZE;
                }

                std::cout << lhs << std::endl;
                std::cout << rhs << std::endl;

                // checking parity
                int p(0);
                for (e = lhs.storage_.begin() + 1;
                     e != lhs.storage_.end();
                     ++e)
                {
                        i = std::find(rhs.storage_.begin()+1, rhs.storage_.end(), *e);

                        for (typename basic_board<SIZE>::storage_t::iterator j = e;
                             j != lhs.storage_.end();
                             ++j)
                                if (std::find(rhs.storage_.begin()+1, i, *j) != i)
                                        ++p;
                }
                std::cout << "Parity: " << p << " (" << p%2 << ")" << std::endl;

                return (p%2) != ((SIZE*SIZE)%2);
        }
};

// manhatten distance
template <int SIZE>
class board_distance
{
public:
        board_distance(basic_board<SIZE> const & target):
                pos_cache_(SIZE*SIZE)
        {
                for (int i=0;i<SIZE*SIZE;++i)
                {
                        typename basic_board<SIZE>::storage_t::const_iterator
                                t = std::find(target.storage_.begin(), target.storage_.end(), i);
                        pos_cache_[i] = std::make_pair(
                                ( t - target.storage_.begin() ) % SIZE,
                                ( t - target.storage_.begin() ) / SIZE
                        );
                }
        }

        int operator() (basic_board<SIZE> const & lhs)
        {
                int ret(0);
                for (int i=1;i<SIZE*SIZE;++i)
                {
                        typename basic_board<SIZE>::storage_t::const_iterator
                                l = std::find(lhs.storage_.begin(), lhs.storage_.end(), i);
                        ret += std::abs(
                                        pos_cache_[i].first - 
                                        ( ( l - lhs.storage_.begin() ) % SIZE )
                                ) + std::abs(
                                        pos_cache_[i].second - 
                                        ( ( l - lhs.storage_.begin() ) / SIZE )
                                );
                }
                return ret;
        }

        int operator() (basic_board<SIZE> const & lhs, basic_board<SIZE> const & rhs)
        {
                int ret(0);
                for (int i=1;i<SIZE*SIZE;++i)
                {
                        typename basic_board<SIZE>::storage_t::const_iterator
                                l = std::find(lhs.storage_.begin(), lhs.storage_.end(), i),
                                r = std::find(rhs.storage_.begin(), rhs.storage_.end(), i);
                        ret += std::abs(
                                        ( ( l - lhs.storage_.begin() ) % SIZE ) -
                                        ( ( r - rhs.storage_.begin() ) % SIZE )
                                ) + std::abs(
                                        ( ( l - lhs.storage_.begin() ) / SIZE ) -
                                        ( ( r - rhs.storage_.begin() ) / SIZE )
                                );
                }
                return ret;
        }
private:
        std::vector< std::pair<int, int> > pos_cache_;
};

typedef basic_board<3> Board_3x3;
typedef basic_board<4> Board_4x4;
typedef basic_board<5> Board_5x5;

typedef board_distance<3> Distance_3x3;
typedef board_distance<4> Distance_4x4;
typedef board_distance<5> Distance_5x5;

typedef board_solution<3> Solution_3x3;
typedef board_solution<4> Solution_4x4;
typedef board_solution<5> Solution_5x5;

#endif

astar.h:這是 A* 演算法與一個 helper class。

#ifndef _ASTAR_H
#define _ASTAR_H

#include <map>
#include <vector>

template <typename COST_T, typename NODE_T>
class NODE_HELPER
{
public:
        NODE_HELPER(COST_T c, NODE_T const & s, int p):
                cost_(c), state_(s), parent_(p)
        { }

        int parent() const { return parent_; }
        int cost() const { return cost_; }
        NODE_T & state() { return state_; };

        std::vector< NODE_T > childs() { return state_.childs(); }

private:
        int parent_;        // parent step
        COST_T cost_;        // cost from start to this state
        NODE_T state_;        // current state
};

template <typename NODE_T, typename DISTANCE_T, typename VERIFIER_T>
std::vector< NODE_T >
astar_search(NODE_T const & start, NODE_T const & target, DISTANCE_T dist, VERIFIER_T sol)
{
        if (!sol(start, target))
                return std::vector< NODE_T >();

        std::multimap<int, NODE_HELPER<int, NODE_T> > pending;
        pending.insert(std::make_pair(dist(start, target), NODE_HELPER<int, NODE_T>(0, start, 0)));

        std::vector< NODE_HELPER<int, NODE_T> > solution;

        int n_iter(0);
        while(dist(pending.begin()->second.state()))
        {
                solution.push_back(pending.begin()->second);
                int cost_so_far = pending.begin()->second.cost();
                pending.erase(pending.begin());

                std::vector< NODE_T > tmp_cld = solution.rbegin()->childs();
                for(typename std::vector< NODE_T >::iterator i = tmp_cld.begin();
                    i != tmp_cld.end();
                    ++i)
                {
                        int cost = cost_so_far + dist(solution.rbegin()->state(), *i);
                        pending.insert(
                                std::make_pair(
                                        cost + dist(*i),
                                        NODE_HELPER<int, NODE_T>(
                                                cost,
                                                *i,
                                                solution.size() - 1)
                                )
                        );
                }
        }

        std::vector< NODE_T > ret;
        ret.push_back(pending.begin()->second.state());
        for (int p_idx = pending.begin()->second.parent();
             p_idx != 0;
             p_idx = solution[p_idx].parent())
                ret.push_back(solution[p_idx].state());
        ret.push_back(start);
        std::reverse(ret.begin(), ret.end());
        return ret;
}

#endif

main.cpp:用來測試的...

#include <iostream>
#include <vector>
#include <sstream>

#include <unistd.h>

#include "board.h"
#include "astar.h"

typedef Board_4x4 BOARD_T;
typedef Distance_4x4 DIST_T;
typedef Solution_4x4 SOL_T;

int main(int argc, char* argv[])
{
        int random_seed = 0;
        if (argc > 1)
        {
                std::stringstream ss(argv[1]);
                ss >> random_seed;
        }
        srand(random_seed);

        std::cout << "Random seed: " << random_seed << std::endl;

        BOARD_T b(true), t(false);

        std::cout << "Start:" << std::endl << b << std::endl;
        std::cout << "End:" << std::endl << t << std::endl;
        std::cout << "Trying hard to solve (with A*)..." << std::endl;

        std::vector< BOARD_T > result = astar_search(b, t, DIST_T(t), SOL_T() );

        std::cout << "Solution steps: " << result.size() << std::endl;
        for(std::vector< BOARD_T >::iterator i = result.begin();
            i != result.end();
            ++i)
        {
                //sleep(1);
                std::cout << *i << std::endl;
        }

        return 0;
}

如果編不起來,那是的問題!BSD License