A* in Python
A* is a directed search algorithm. Given a good heuristic function, it can perform much better than a blind breadth-first search.
A* code
First, the main algorithm itself. I like using Python for lots of reasons. One that stands out, though, is that functions are first class. This allows us to write a generic A* function that works on many problems.
(Python is not unique in this. But it is still fun.)
Here is the main solving code.
import heapq import math import random def solve(start, finish, heuristic): """Find the shortest path from START to FINISH.""" heap = [] link = {} # parent node link h = {} # heuristic function cache g = {} # shortest path to a node g[start] = 0 h[start] = 0 link[start] = None heapq.heappush(heap, (0, 0, start)) # keep a count of the number of steps, and avoid an infinite loop. for kk in xrange(1000000): f, junk, current = heapq.heappop(heap) if current == finish: print "distance:", g[current], "steps:", kk return g[current], kk, build_path(start, finish, link) moves = current.get_moves() distance = g[current] for mv in moves: if mv not in g or g[mv] > distance + 1: g[mv] = distance + 1 if mv not in h: h[mv] = heuristic(mv) link[mv] = current heapq.heappush(heap, (g[mv] + h[mv], -kk, mv)) else: raise Exception("did not find a solution")
solve
takes two positions, a start and finish, and a heuristic
function. The heuristic function must always return a distance that
is less or equal to the actual distance between two positions. The
whole algorithm rests on that assumption.
g[mv]
holds the length of the shortest known path to mv
, and
h[mv]
holds the estimated distance from mv
to finish according to
the heuristic
function. Then, g[mv] + h[mv]
is the current best
estimate of the distance from start
to finish
through mv
.
We keep a the nodes to search next in a priority queue, ordered by the
the current best estimates of the distance to finish
.
solve
is a generic function, and it realized on duck-typing to work.
The position objects must provide a get_moves
method that returns
neighboring positions, a hash function __hash__
, and an
equality operator __eq__
.
Whenever the goal position is at the top of the heap we are done. This is because the heap is ordered by the best possible score from a position to the goal. If there is a shorter path to the goal, it would have already been explored, since we explore by the best possible path first. Since it hasn't been seen before, this path is the best.
Note that this isn't true about positions other than the goal. That is because the heuristic provides an optimistic guess for the shortest path from a position to the goal, not the shortest path between any two positions. This is the reason that in the inner loop, we check if we are in the shortest path seen so far to a node. A node may be visited multiple times by A*, with shorter paths to the node each time.
The function build_path
reconstructs the path to the solution for
us. This is useful for debugging, and because we often want to know
the path, not just the path length.
def build_path(start, finish, parent): """ Reconstruct the path from start to finish given a dict of parent links. """ x = finish xs = [x] while x != start: x = parent[x] xs.append(x) xs.reverse() return xs
Here is a heuristic that does nothing, just for testing. Using this heuristic will lead to Dijkstra's algorithm, which is a good, but uninformed search algorithm.
def no_heuristic(*args): """Dummy heuristic that doesn't do anything""" return 0
A grid example
Here is a simple test class where the program must find a path on a rectangular grid. We will assume the grid is infinite, but will also hope that our search doesn't go that far off course.
class GridPosition(object): """Represent a position on a grid.""" def __init__(self, x, y): self.x = x self.y = y def __hash__(self): return hash((self.x, self.y)) def __repr__(self): return "GridPosition(%d,%d)" % (self.x, self.y) def __eq__(self, other): return self.x == other.x and self.y == other.y def get_moves(self): # There are times when returning this in a shuffled order # would help avoid degenerate cases. For learning, though, # life is easier if the algorithm behaves predictably. yield GridPosition(self.x + 1, self.y) yield GridPosition(self.x, self.y + 1) yield GridPosition(self.x - 1, self.y) yield GridPosition(self.x, self.y - 1)
grid_start = GridPosition(0,0) grid_end = GridPosition(10,10)
def grid_test_no_heuristic(): solve(grid_start, grid_end, no_heuristic)
This gives us:
distance: 20 steps: 840
Now, we can add a heuristic. An obvious one is to Euclidean distance, since the shortest path between two points is a straight line.
def euclidean_h(goal): def f(pos): dx, dy = pos.x - goal.x, pos.y - goal.y return math.hypot(dx, dy) return f
def grid_test_euclidean_heuristic(): solve(grid_start, grid_end, euclidean_h(grid_end))
distance: 20 steps: 134
That result is significantly better.
We can do even better. Since our grid movements are restricted to left, right, up and down, we can use Manhattan distance as the heuristic. In this simple case, Manhattan distance is a perfect heuristic. Adding obstacles, or changing the cost of moving through grid points would keep Manhattan from being perfect.
def manhattan_h(goal): def f(pos): dx, dy = pos.x - goal.x, pos.y - goal.y return abs(dx) + abs(dy) return f
def grid_test_manhattan_heuristic(): solve(grid_start, grid_end, manhattan_h(grid_end))
distance: 20 steps: 20
We found the path without exploring any unnecessary nodes.
Block Puzzle
In the Stanford AI class offered online, they discussed A* in the context of the classic fifteen puzzle, but simplified to just eight pieces.
Here is a class for the block puzzle. Usually the 15 puzzle is used, but the 8 puzzle is a lot faster to solve.
class BlockPuzzle(object): def __init__(self, n, xs=None): """Create an nxn block puzzle Use XS to initialize to a specific state. """ self.n = n self.n2 = n * n if xs is None: self.xs = [(x + 1) % self.n2 for x in xrange(self.n2)] else: self.xs = list(xs) self.hsh = None self.last_move = [] def __hash__(self): if self.hsh is None: self.hsh = hash(tuple(self.xs)) return self.hsh def __repr__(self): return "BlockPuzzle(%d, %s)" % (self.n, self.xs) def show(self): ys = ["%2d" % x for x in self.xs] xs = [" ".join(ys[kk:kk+self.n]) for kk in xrange(0,self.n2, self.n)] return "\n".join(xs) def __eq__(self, other): return self.xs == other.xs def copy(self): return BlockPuzzle(self.n, self.xs) def get_moves(self): # Find the 0 tile, and then generate any moves we # can by sliding another block into its place. tile0 = self.xs.index(0) def swap(i): j = tile0 tmp = list(self.xs) last_move = tmp[i] tmp[i], tmp[j] = tmp[j], tmp[i] result = BlockPuzzle(self.n, tmp) result.last_move = last_move return result if tile0 - self.n >= 0: yield swap(tile0-self.n) if tile0 +self.n < self.n2: yield swap(tile0+self.n) if tile0 % self.n > 0: yield swap(tile0-1) if tile0 % self.n < self.n-1: yield swap(tile0+1)
We also need a way to create a shuffled puzzle. Here is a generic method for shuffling.
def shuffle(position, n): for kk in xrange(n): xs = list(position.get_moves()) position = random.choice(xs) return position
Now, we need a heuristic. The empty heuristic approach will take too long here.
The first, and simplest heuristic is to count how many tiles are out of place.
def misplaced_h(position): """Returns the number of tiles out of place.""" n2 = position.n2 c = 0 for kk in xrange(n2): if position.xs[kk] != kk+1: c += 1 return c
Here is a sample run
def test_block_8_misplaced(num_tests): for kk in xrange(num_tests): p = shuffle(BlockPuzzle(3), 200) print p.show() solve(p, BlockPuzzle(3), misplaced_h)
0 2 5 8 3 7 1 6 4 distance: 19 steps: 872 2 7 0 6 8 3 1 5 4 distance: 19 steps: 958 1 5 4 7 2 8 0 3 6 distance: 25 steps: 13027 6 8 7 2 5 4 0 1 3 distance: 27 steps: 26762 3 8 4 7 0 6 2 1 5 distance: 21 steps: 3008
Now, another heuristic is to measure the distance that each tile must move. This heuristic is ok, because we only move one tile at a time, and we know that each tile must move at least this many steps.
def distance_h(position): n = position.n def row(x): return x / n def col(x): return x % n score = 0 for idx, x in enumerate(position.xs): if x == 0: continue ir,ic = row(idx), col(idx) xr,xc = row(x-1), col(x-1) score += abs(ir-xr) + abs(ic-xc) return score
And another sample run.
def test_block_8_distance(num_tests): for kk in xrange(num_tests): p = shuffle(BlockPuzzle(3), 200) print p.show() solve(p, BlockPuzzle(3), distance_h)
4 7 2 1 0 5 6 8 3 distance: 22 steps: 941 6 5 1 4 0 3 2 7 8 distance: 16 steps: 59 1 3 4 2 0 5 7 8 6 distance: 16 steps: 235 4 5 0 3 8 2 6 1 7 distance: 24 steps: 1038 6 7 2 5 8 3 0 1 4 distance: 24 steps: 705
For similar sizes, the results from this heuristic are much better. This testing methodology was pretty poor, though. A more detailed analysis would be better, but this writeup is getting long as it is.
Here is a short sample of one to one comparisons.
def test_block(steps=100, count=5): for kk in xrange(count): p = shuffle(BlockPuzzle(3), steps) print p.show() print "misplaced", x = solve(p, BlockPuzzle(3), misplaced_h) print "distance", x = solve(p, BlockPuzzle(3), distance_h)
2 6 1 4 8 5 0 3 7 misplaced distance: 24 steps: 14962 distance distance: 24 steps: 862