A*算法解N数码问题
前言
最近上课遇到了八数码问题,正好为了练一练代码,就自己动手开始写,因为用的python,没有传统的树和链表结构,所以写起来遇到了一些麻烦,这里记录一下,大佬轻拍
一、A*算法
A*算法是一种启发式算法,具体内容可参考一下这位大佬的笔记,记录的很详细,我的算法也是基于这篇笔记复现的。这篇文章也解释了A和A*算法的重要区别,解答了我对于这两个算法的疑问。
https://blog.csdn.net/qwezhaohaihong/article/details/103353885
二、N数码问题
八数码问题是N数码的特殊情况,对于python的N数码实现参考了这篇文章。
https://blog.csdn.net/qq_35976351/article/details/82767029
这篇文章创建了一种节点状态类,将算法搜索过程中的节点状态,父子节点、启发函数值等元素加入类中,并覆写了类的相等判断,使得节点类能够模拟出搜索树的结构,这一点在我的代码中有所借鉴。
原文章赋予OPEN表堆特性,并且覆写比较函数,用启发值进行比较,使得在OPEN表的维护在堆的压入和弹出过程自动完成。为堆排序覆写的比较函数也是启发我的一点。
但是这篇文章中代码有一些错误,比如在计算manhattan距离时目标码盘对应元素的位置解算有错误。另外在求解过程中引入了hash值去重,但还是会有重复判断的问题。对于OPEN表、CLOSE表、M表和G树的维护也有所欠缺。
三、代码实现
1.码盘节点的类定义
代码如下:
class Astar_node(object):
def __init__(self, gn, hn, state=None, par=None):
'''
初始化
:param state: 节点存储的状态
:param par: 父节点
'''
self.state = state # 节点状态
self.par = par # 父节点
self.gn, self.hn = gn, hn # 启发信息
@property # 定义fn属性
def fn(self):
return self.gn + self.hn
def __eq__(self, obj): # 相等的判断
return self.state == obj.state
def __ne__(self, obj): # 不等的判断
return not self.__eq__(obj)
def print_state(self):
for row in range(Lenth):
for col in range(Lenth):
print('%3d' % self.state[row][col], end='')
print('')
print('--------------------')
def print_node(self):
print("gn=%d,hn=%d,fn=%d" % (self.gn, self.hn, self.fn))
self.print_state()
对于码盘节点的类定义,在实现过程中首先定义了基本启发函数g(n)和h(n)以及节点状态和父节点指向等基本属性。对于启发函数f(n)采用属性定义方法,避免在OPEN表排序时忘记赋值的错误。
由于在排序过程中没有使用堆特性,直接使用sort函数对实例的属性进行排序,所以删除了原来的大小比较函数的覆写。
个人认为在A*算法中已经充分考了节点重复状态下的问题。扩展节点加入M表时已经排除了当前节点的父节点,也就是不走回头路。M中的数据在加入OPEN、CLOSE表的过程中也进行了重复性验证,对重复节点遇到启发函数值更小的情况已经进行了考虑。所以我取消了用hash值对码盘的状态进行去重判断和相等判断,而是采用原始码盘状态进行重复性验证和相等判断。
另外在十五数码问题过程中我发现一个很有意思的现象,这也是为什么我想记录下实现过程的原因,就是关于g(n)的设定问题。一般在A*算法中,通常关注f(n)的选取,但很少讨论g(n)的选取,一般认为每进行一部扩展,将g(n)加1,也就是我们认为每走一步则增加一个步数的代价。但是在我测试的过程中发现,g(n)对算法的搜索效率也有非常大的影响,下面是我测试g(n)每扩展一步深度,所花费的时间,100秒为算法求解超时
①g(n)= 1
②g(n)= -1
③g(n)=0.2
④g(n)= 0.5
可以看到在g(n)=1的时候,搜索算法搜索了大量的节点,但是实际求出的解路径只有14个,最后超时未求解出来答案。
g(n)=0.5时得到了相似的结果,但是这次算法搜索了相近的节点数,虽然同样没有求出答案,但是走出了更多的解路径。
g(n)=-1时,算法很快就求出了解路径,但是解路径十分长。显然不是A*算法所要求解的最优路径。
g(n)= 0.2时,算法在一个相对合理的时间内求出了一条相对合理的解。
至此可以得出对于A*解十五数码问题的一个大致结论:
1.首先每一步的g(n)并不是取1最好,g(n)所代表的代价并不是实际所走的步数
2.g(n)有一个合适的取值区间能够同时取得解的优选时间和优选路径。
3.g(n)取值超过2中的区间后求解时间和路径会随取值增大成指数级增长。
4.g(n)取值小于2中的区间后,求解路径会随取值增加而减小,但时间会在一定范围内缩减。
那么2中所说的区间是否存在呢。我尝试验证了之后得到一下数据:
⑤g(n) = 0.22
⑥g(n)= 0.25,0.27,0.29
⑦g(n) = 0.3
⑧g(n)=0.32
⑨g(n)=0.35
⑩g(n) = 0.46,0.44,0,42超时
另外有一个不符合规律的数据,在0.23~0.24之间,算法执行时间都较长,并且0.24为几秒,而0.238为十几秒。这组数据较为反常。
g(n)=0.23
除去反常数据可以看到,g(n)对算法求解影响大致是符合三条规律的。
在我完成了代码之后,在初始使用g(n)=1求解算法时,无法得出原本能几秒求出的解,我在反复检查了代码之后发现依然不能求解,但是却能求解简单的码盘,因此我尝试不使用步数代价,发现能够很快求出解。后面才有了这些数据。在此提出我的猜想:
如果有玩过华容道或拼图游戏的朋友可能会知道,最短的求解路径有时候反而需要多走一些步数,而在十五数码这个问题中,一个空格转一圈回到原点时,实际上码盘是会发生改变的,而这个过程中可能会增加h(n)的值,但是在旋转完成 后,整个f(n)的值却可能是最小的。因此猜测这就是为什么g(n)的最优区间在0.25附近。当小于这个区间时,我们实际上是减小解路径的行走代价,鼓励多在解路径上进行尝试,因此会出现长路径和少扩展以及合适的求解时间。当大于这个区间时,我们在增加算法在解路径的行走代价,造成算法每走一步都需要谨慎小心,不断扩展节点,期望求得较小的行走代价从而得到最短的路径,最终造成扩展图十分庞大,超时后最终能得到的解路径很少。
因此g(n)的选取对算法其实是有影响的。
PS:如果有兴趣可以尝试一下h(n)函数对算法求解的影响。
2.源代码
代码如下(示例):
import copy
import re
import time
import os
# Lenth = 0 # 码盘边长
# 状态节点
class Astar_node(object):
def __init__(self, gn, hn, state=None, par=None):
'''
初始化
:param state: 节点存储的状态
:param par: 父节点
'''
self.state = state # 节点状态
self.par = par # 父节点
self.gn, self.hn = gn, hn # 启发信息
@property # 定义fn属性
def fn(self):
return self.gn + self.hn
def __eq__(self, obj): # 相等的判断
return self.state == obj.state
def __ne__(self, obj): # 不等的判断
return not self.__eq__(obj)
def print_state(self):
for row in range(Lenth):
for col in range(Lenth):
print('%3d' % self.state[row][col], end='')
print('')
print('--------------------')
def print_node(self):
print("gn=%d,hn=%d,fn=%d" % (self.gn, self.hn, self.fn))
self.print_state()
def manhattan_dis(cur):
'''
计算和目标码盘的曼哈顿距离
:param cur: 当前节点
:return: 到目的状态的曼哈顿距离
'''
cur_state = cur.state
end_state = end_node.state
dist = 0
for row in range(Lenth):
for col in range(Lenth):
if cur_state[row][col] == end_state[row][col]:
continue
num = cur_state[row][col]
# 求目标码盘对应元素的横纵坐标
num_row = num // Lenth if num % 4 != 0 else (num - 1) // Lenth
num_row = num_row if num_row != -1 else 3
num_col = num % Lenth - 1
num_col = num_col if num_col != -1 else 3
dist += (abs(row - num_row) + abs(col - num_col))
return dist
class A_start:
'''
A*算法初始化
:param start: 起始节点
:param end: 终止节点
:param heuristic_fn: 启发函数
return: G search_cnt
'''
def __init__(self, start, end, heuristic_fn,time_limite):
self.OPEN = [] # OPEN表
self.CLOSE = [] # CLOSE表
self.G = [] # 搜索树
self.start = start
self.end = end
self.cur_node = None
self.heuristic_fn = heuristic_fn
self.start_t = 0 # 计时变量
self.end_t = 0
self.G.append(self.start) # 初始化搜索图
self.OPEN.append(self.start) # 初始化OPEN表
def begin_search(self): # 算法开始
self.start_t = time.time()
# 找空位坐标
blank_pos = None
while 1:
# OPEN表为空表示无解 直接退出
if self.OPEN == [] or (time.time()-self.start_t>time_limite):
print("There is no anser!")
self.print_result()
break
else:
self.cur_node = self.OPEN.pop(0) # 弹出OPEN表中第一个元素
self.CLOSE.append(self.cur_node) # 当前节点放入CLOSE表表示扩展完成
# self.cur_node.print_node()
# 搜索到目标节点
if self.cur_node == self.end:
# self.end.par=self.cur_node
print("Success!")
self.print_result()
break
# 找节点空位
for row in range(Lenth):
for col in range(Lenth):
if self.cur_node.state[row][col] == 0:
blank_pos = [row, col]
break
# 扩展节点
M = [] # 扩展出的新节点集合(不包括当前节点的父节点)
for dict in dicts:
b_x, b_y = blank_pos[0], blank_pos[1]
n_x, n_y = b_x + dict[0], b_y + dict[1]
if n_x in range(Lenth) and n_y in range(Lenth): # 越界判定
new_node = Astar_node(0, 0, copy.deepcopy(self.cur_node.state))
new_node.state[b_x][b_y], new_node.state[n_x][n_y] = \
new_node.state[n_x][n_y], new_node.state[b_x][b_y] # 移动空位
if new_node != self.cur_node.par: # 扩展结点不是当前节点的父节点
new_node.gn = self.cur_node.gn + 0.25
new_node.hn = self.heuristic_fn(new_node) # 计算节点hn
M.append(new_node) # 新节点加入集合
# 处理新扩展的节点
for node in M:
# 去重扩展搜索树
if node not in self.G:
self.G.append(node)
# 未出现在OPEN和CLOSE表中 将扩展节点父节点设为当前节点并加入OPEN表
if node not in self.OPEN and node not in self.CLOSE:
node.par = self.cur_node
self.OPEN.append(node)
# 出现在OPEN表中 比较OPEN表和M表中的fn值 若M<OPEN 在OPEN表中将该节点父节点设为当前节点
elif node in self.OPEN:
for node_open in self.OPEN:
if node == node_open and node.fn < node_open.fn:
# node_open.par = self.cur_node
node.par=self.cur_node
self.OPEN.remove(node_open)
self.OPEN.append(node)
# 出现在CLOSE表中 比较CLOSE表和M表中的fn值 若M<CLOSE 在CLOSE表中将扩展节点子节点指向当前节点(将当前节点父节点设为CLOSE表中扩展节点)弹出CLOSE表中的该节点加入OPEN表
elif node in self.CLOSE:
for node_close in self.CLOSE:
if node == node_close and node.fn < node_close.fn:
self.cur_node.par = node_close
self.CLOSE.remove(node_close)
self.OPEN.append(node_close)
# 依照启发信息重排OPEN表
self.OPEN.sort(key=lambda x: x.fn)
def print_result(self):
self.end_t = time.time()
#打印路径
path_cnt = 1
self.cur_node.print_state()
while True:
self.cur_node = self.cur_node.par
self.cur_node.print_state()
path_cnt += 1
if self.cur_node.par == root_node:
break
print("number of searched node:%d" % len(self.G))
print("number of CLOSE:%d" % len(self.CLOSE))
print("Lenth of path:%d" % (path_cnt - 1))
print("Time:%f" % (self.end_t - self.start_t))
if __name__ == '__main__':
dicts = [[0, 1], [0, -1], [-1, 0], [1, 0]] # 空格移动方向
with open("./infile.txt", "r") as f: # 读取码盘边长和初始码盘
Lenth = int(f.readline().strip().split()[-1])
List = list(map(int, f.readline().strip().split()))
# 创建初始码盘和目标码盘
Start = [List[i:i + Lenth] for i in range(0, len(List), Lenth)]
GOAL_list = [i for i in range(1, Lenth * Lenth)]
GOAL_list.append(0)
GOAL = [GOAL_list[i:i + Lenth] for i in range(0, len(GOAL_list), Lenth)]
root_node = Astar_node(0, 0, [[0] * 4] * 4, None) # 创建树根
end_node = Astar_node(0, 0, GOAL, None) # 创建目标节点
start_node = Astar_node(0, 0, Start, root_node) # 创建初始节点
start_node.hn = manhattan_dis(start_node)
time_limite=100
Astar = A_start(start_node, end_node, manhattan_dis,time_limite)
Astar.begin_search()
os.system("pause")
总结
A*算法确实是一个优秀的算法,在亲自编写的过程中才体会到这个算法的严谨和精妙之处。算法有很强的扩展性和灵活性。对于很多问题的求解都适用。比如:路径搜索,图搜索等等。如果有机会希望能尝试它的改进算法的研究和实现。
本文地址:https://blog.csdn.net/qq_40185348/article/details/111072230