A* in Python

A* is a directed search algorithm. Given a good heuristic function, it can perform much better than a blind breadth-first search.

A* code

First, the main algorithm itself. I like using Python for lots of reasons. One that stands out, though, is that functions are first class. This allows us to write a generic A* function that works on many problems.

(Python is not unique in this. But it is still fun.)

Here is the main solving code.

import heapq
import math
import random
def solve(start, finish, heuristic):
    """Find the shortest path from START to FINISH."""
    heap = []

    link = {} # parent node link
    h = {} # heuristic function cache
    g = {} # shortest path to a node

    g[start] = 0
    h[start] = 0
    link[start] = None

    heapq.heappush(heap, (0, 0, start))
    # keep a count of the  number of steps, and avoid an infinite loop.
    for kk in xrange(1000000):
        f, junk, current = heapq.heappop(heap)
        if current == finish:
            print "distance:", g[current], "steps:", kk
            return g[current], kk, build_path(start, finish, link)

        moves = current.get_moves()
        distance = g[current]
        for mv in moves:
            if mv not in g or g[mv] > distance + 1:
                g[mv] = distance + 1
                if mv not in h:
                    h[mv] = heuristic(mv)
                link[mv] = current
                heapq.heappush(heap, (g[mv] + h[mv], -kk, mv))
        raise Exception("did not find a solution")

solve takes two positions, a start and finish, and a heuristic function. The heuristic function must always return a distance that is less or equal to the actual distance between two positions. The whole algorithm rests on that assumption.

g[mv] holds the length of the shortest known path to mv, and h[mv] holds the estimated distance from mv to finish according to the heuristic function. Then, g[mv] + h[mv] is the current best estimate of the distance from start to finish through mv.

We keep a the nodes to search next in a priority queue, ordered by the the current best estimates of the distance to finish.

solve is a generic function, and it realized on duck-typing to work. The position objects must provide a get_moves method that returns neighboring positions, a hash function __hash__, and an equality operator __eq__.

Whenever the goal position is at the top of the heap we are done. This is because the heap is ordered by the best possible score from a position to the goal. If there is a shorter path to the goal, it would have already been explored, since we explore by the best possible path first. Since it hasn't been seen before, this path is the best.

Note that this isn't true about positions other than the goal. That is because the heuristic provides an optimistic guess for the shortest path from a position to the goal, not the shortest path between any two positions. This is the reason that in the inner loop, we check if we are in the shortest path seen so far to a node. A node may be visited multiple times by A*, with shorter paths to the node each time.

The function build_path reconstructs the path to the solution for us. This is useful for debugging, and because we often want to know the path, not just the path length.

def build_path(start, finish, parent):
    Reconstruct the path from start to finish given
    a dict of parent links.

    x = finish
    xs = [x]
    while x != start:
        x = parent[x]
    return xs

Here is a heuristic that does nothing, just for testing. Using this heuristic will lead to Dijkstra's algorithm, which is a good, but uninformed search algorithm.

def no_heuristic(*args):
   """Dummy heuristic that doesn't do anything"""
   return 0

A grid example

Here is a simple test class where the program must find a path on a rectangular grid. We will assume the grid is infinite, but will also hope that our search doesn't go that far off course.

class GridPosition(object):
    """Represent a position on a grid."""
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __hash__(self):
        return hash((self.x, self.y))

    def __repr__(self):
        return "GridPosition(%d,%d)" % (self.x, self.y)

    def __eq__(self, other):
        return self.x == other.x and self.y == other.y

    def get_moves(self):
        # There are times when returning this in a shuffled order
        # would help avoid degenerate cases.  For learning, though,
        # life is easier if the algorithm behaves predictably.
        yield GridPosition(self.x + 1, self.y)
        yield GridPosition(self.x, self.y + 1)
        yield GridPosition(self.x - 1, self.y)
        yield GridPosition(self.x, self.y - 1)
grid_start = GridPosition(0,0)
grid_end = GridPosition(10,10)
def grid_test_no_heuristic():
   solve(grid_start, grid_end, no_heuristic)

This gives us:

distance: 20 steps: 840

Now, we can add a heuristic. An obvious one is to Euclidean distance, since the shortest path between two points is a straight line.

def euclidean_h(goal):
    def f(pos):
        dx, dy = pos.x - goal.x, pos.y - goal.y
        return math.hypot(dx, dy)
    return f
def grid_test_euclidean_heuristic():
    solve(grid_start, grid_end, euclidean_h(grid_end))
distance: 20 steps: 134

That result is significantly better.

We can do even better. Since our grid movements are restricted to left, right, up and down, we can use Manhattan distance as the heuristic. In this simple case, Manhattan distance is a perfect heuristic. Adding obstacles, or changing the cost of moving through grid points would keep Manhattan from being perfect.

def manhattan_h(goal):
   def f(pos):
       dx, dy = pos.x - goal.x, pos.y - goal.y
       return abs(dx) + abs(dy)
   return f
def grid_test_manhattan_heuristic():
    solve(grid_start, grid_end, manhattan_h(grid_end))
distance: 20 steps: 20

We found the path without exploring any unnecessary nodes.

Block Puzzle

In the Stanford AI class offered online, they discussed A* in the context of the classic fifteen puzzle, but simplified to just eight pieces.

Here is a class for the block puzzle. Usually the 15 puzzle is used, but the 8 puzzle is a lot faster to solve.

class BlockPuzzle(object):
    def __init__(self, n, xs=None):
        """Create an nxn block puzzle

        Use XS to initialize to a specific state.
        self.n = n
        self.n2 = n * n
        if xs is None:
            self.xs = [(x + 1) % self.n2 for x in xrange(self.n2)]
            self.xs = list(xs)
        self.hsh = None
        self.last_move = []

    def __hash__(self):
        if self.hsh is None:
            self.hsh = hash(tuple(self.xs))
        return self.hsh

    def __repr__(self):
        return "BlockPuzzle(%d, %s)" % (self.n, self.xs)

    def show(self):
        ys = ["%2d" % x for x in self.xs]
        xs = [" ".join(ys[kk:kk+self.n]) for kk in xrange(0,self.n2, self.n)]
        return "\n".join(xs)

    def __eq__(self, other):
        return self.xs == other.xs

    def copy(self):
        return BlockPuzzle(self.n, self.xs)

    def get_moves(self):
        # Find the 0 tile, and then generate any moves we
        # can by sliding another block into its place.
        tile0 = self.xs.index(0)
        def swap(i):
            j = tile0
            tmp = list(self.xs)
            last_move = tmp[i]
            tmp[i], tmp[j] = tmp[j], tmp[i]
            result = BlockPuzzle(self.n, tmp)
            result.last_move = last_move
            return result

        if tile0 - self.n >= 0:
            yield swap(tile0-self.n)
        if tile0 +self.n < self.n2:
            yield swap(tile0+self.n)
        if tile0 % self.n > 0:
            yield swap(tile0-1)
        if tile0 % self.n < self.n-1:
            yield swap(tile0+1)

We also need a way to create a shuffled puzzle. Here is a generic method for shuffling.

def shuffle(position, n):
    for kk in xrange(n):
        xs = list(position.get_moves())
        position = random.choice(xs)
    return position

Now, we need a heuristic. The empty heuristic approach will take too long here.

The first, and simplest heuristic is to count how many tiles are out of place.

def misplaced_h(position):
    """Returns the number of tiles out of place."""
    n2 = position.n2
    c = 0
    for kk in xrange(n2):
        if position.xs[kk] != kk+1:
            c += 1
    return c

Here is a sample run

def test_block_8_misplaced(num_tests):
    for kk in xrange(num_tests):
        p = shuffle(BlockPuzzle(3), 200)
        print p.show()
        solve(p, BlockPuzzle(3), misplaced_h)
 0  2  5
 8  3  7
 1  6  4
 distance: 19 steps: 872
 2  7  0
 6  8  3
 1  5  4
 distance: 19 steps: 958
 1  5  4
 7  2  8
 0  3  6
 distance: 25 steps: 13027
 6  8  7
 2  5  4
 0  1  3
 distance: 27 steps: 26762
 3  8  4
 7  0  6
 2  1  5
distance: 21 steps: 3008

Now, another heuristic is to measure the distance that each tile must move. This heuristic is ok, because we only move one tile at a time, and we know that each tile must move at least this many steps.

def distance_h(position):
    n = position.n
    def row(x): return x / n
    def col(x): return x % n
    score = 0
    for idx, x in enumerate(position.xs):
        if x == 0: continue
        ir,ic = row(idx), col(idx)
        xr,xc = row(x-1), col(x-1)
        score += abs(ir-xr) + abs(ic-xc)
    return score

And another sample run.

def test_block_8_distance(num_tests):
    for kk in xrange(num_tests):
        p = shuffle(BlockPuzzle(3), 200)
        print p.show()
        solve(p, BlockPuzzle(3), distance_h)
 4  7  2
 1  0  5
 6  8  3
distance: 22 steps: 941
 6  5  1
 4  0  3
 2  7  8
distance: 16 steps: 59
 1  3  4
 2  0  5
 7  8  6
distance: 16 steps: 235
 4  5  0
 3  8  2
 6  1  7
distance: 24 steps: 1038
 6  7  2
 5  8  3
 0  1  4
distance: 24 steps: 705

For similar sizes, the results from this heuristic are much better. This testing methodology was pretty poor, though. A more detailed analysis would be better, but this writeup is getting long as it is.

Here is a short sample of one to one comparisons.

def test_block(steps=100, count=5):
    for kk in xrange(count):
        p = shuffle(BlockPuzzle(3), steps)
        print p.show()
        print "misplaced",
        x = solve(p, BlockPuzzle(3), misplaced_h)
        print "distance",
        x = solve(p, BlockPuzzle(3), distance_h)
 2  6  1
 4  8  5
 0  3  7
misplaced distance: 24 steps: 14962
distance distance: 24 steps: 862