欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

A*算法解N数码问题

程序员文章站 2022-03-22 15:38:52
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档文章目录前言一、A*算法二、N数码问题三、代码实现1.码盘节点的类定义2.源代码总结前言最近上课遇到了八数码问题,正好为了练一练代码,就自己动手开始写,因为用的python,没有传统的树和链表结构,所以写起来遇到了一些麻烦,这里记录一下,大佬轻拍一、A*算法A*算法是一种启发式算法,具体内容可参考一下这位大佬的笔记,记录的很详细,我的算法也是基于这篇笔记复现的。这篇文章也解释了A和A*算法的重要区别,解答了我对于这两个算法的疑问...


前言

最近上课遇到了八数码问题,正好为了练一练代码,就自己动手开始写,因为用的python,没有传统的树和链表结构,所以写起来遇到了一些麻烦,这里记录一下,大佬轻拍


一、A*算法

A*算法是一种启发式算法,具体内容可参考一下这位大佬的笔记,记录的很详细,我的算法也是基于这篇笔记复现的。这篇文章也解释了A和A*算法的重要区别,解答了我对于这两个算法的疑问。
https://blog.csdn.net/qwezhaohaihong/article/details/103353885

二、N数码问题

八数码问题是N数码的特殊情况,对于python的N数码实现参考了这篇文章。
https://blog.csdn.net/qq_35976351/article/details/82767029

这篇文章创建了一种节点状态类,将算法搜索过程中的节点状态,父子节点、启发函数值等元素加入类中,并覆写了类的相等判断,使得节点类能够模拟出搜索树的结构,这一点在我的代码中有所借鉴。

原文章赋予OPEN表堆特性,并且覆写比较函数,用启发值进行比较,使得在OPEN表的维护在堆的压入和弹出过程自动完成。为堆排序覆写的比较函数也是启发我的一点。

但是这篇文章中代码有一些错误,比如在计算manhattan距离时目标码盘对应元素的位置解算有错误。另外在求解过程中引入了hash值去重,但还是会有重复判断的问题。对于OPEN表、CLOSE表、M表和G树的维护也有所欠缺。

三、代码实现

1.码盘节点的类定义

代码如下:

class Astar_node(object):
    def __init__(self, gn, hn, state=None, par=None):
        '''
        初始化
        :param state: 节点存储的状态
        :param par: 父节点
        '''
        self.state = state  # 节点状态
        self.par = par  # 父节点
        self.gn, self.hn = gn, hn  # 启发信息

    @property  # 定义fn属性
    def fn(self):
        return self.gn + self.hn

    def __eq__(self, obj):  # 相等的判断
        return self.state == obj.state

    def __ne__(self, obj):  # 不等的判断
        return not self.__eq__(obj)

    def print_state(self):
        for row in range(Lenth):
            for col in range(Lenth):
                print('%3d' % self.state[row][col], end='')
            print('')
        print('--------------------')

    def print_node(self):
        print("gn=%d,hn=%d,fn=%d" % (self.gn, self.hn, self.fn))
        self.print_state()

对于码盘节点的类定义,在实现过程中首先定义了基本启发函数g(n)和h(n)以及节点状态和父节点指向等基本属性。对于启发函数f(n)采用属性定义方法,避免在OPEN表排序时忘记赋值的错误。
由于在排序过程中没有使用堆特性,直接使用sort函数对实例的属性进行排序,所以删除了原来的大小比较函数的覆写。
个人认为在A*算法中已经充分考了节点重复状态下的问题。扩展节点加入M表时已经排除了当前节点的父节点,也就是不走回头路。M中的数据在加入OPEN、CLOSE表的过程中也进行了重复性验证,对重复节点遇到启发函数值更小的情况已经进行了考虑。所以我取消了用hash值对码盘的状态进行去重判断和相等判断,而是采用原始码盘状态进行重复性验证和相等判断。

另外在十五数码问题过程中我发现一个很有意思的现象,这也是为什么我想记录下实现过程的原因,就是关于g(n)的设定问题。一般在A*算法中,通常关注f(n)的选取,但很少讨论g(n)的选取,一般认为每进行一部扩展,将g(n)加1,也就是我们认为每走一步则增加一个步数的代价。但是在我测试的过程中发现,g(n)对算法的搜索效率也有非常大的影响,下面是我测试g(n)每扩展一步深度,所花费的时间,100秒为算法求解超时

①g(n)= 1
A*算法解N数码问题
②g(n)= -1
A*算法解N数码问题
③g(n)=0.2
A*算法解N数码问题
④g(n)= 0.5
A*算法解N数码问题

可以看到在g(n)=1的时候,搜索算法搜索了大量的节点,但是实际求出的解路径只有14个,最后超时未求解出来答案。
g(n)=0.5时得到了相似的结果,但是这次算法搜索了相近的节点数,虽然同样没有求出答案,但是走出了更多的解路径。
g(n)=-1时,算法很快就求出了解路径,但是解路径十分长。显然不是A*算法所要求解的最优路径。
g(n)= 0.2时,算法在一个相对合理的时间内求出了一条相对合理的解。
至此可以得出对于A*解十五数码问题的一个大致结论:
1.首先每一步的g(n)并不是取1最好,g(n)所代表的代价并不是实际所走的步数
2.g(n)有一个合适的取值区间能够同时取得解的优选时间和优选路径。
3.g(n)取值超过2中的区间后求解时间和路径会随取值增大成指数级增长。
4.g(n)取值小于2中的区间后,求解路径会随取值增加而减小,但时间会在一定范围内缩减。

那么2中所说的区间是否存在呢。我尝试验证了之后得到一下数据:
⑤g(n) = 0.22
A*算法解N数码问题

⑥g(n)= 0.25,0.27,0.29

A*算法解N数码问题
A*算法解N数码问题
A*算法解N数码问题

⑦g(n) = 0.3
A*算法解N数码问题

⑧g(n)=0.32
A*算法解N数码问题

⑨g(n)=0.35
A*算法解N数码问题

⑩g(n) = 0.46,0.44,0,42超时

另外有一个不符合规律的数据,在0.23~0.24之间,算法执行时间都较长,并且0.24为几秒,而0.238为十几秒。这组数据较为反常。
g(n)=0.23
A*算法解N数码问题

除去反常数据可以看到,g(n)对算法求解影响大致是符合三条规律的。
在我完成了代码之后,在初始使用g(n)=1求解算法时,无法得出原本能几秒求出的解,我在反复检查了代码之后发现依然不能求解,但是却能求解简单的码盘,因此我尝试不使用步数代价,发现能够很快求出解。后面才有了这些数据。在此提出我的猜想:
如果有玩过华容道或拼图游戏的朋友可能会知道,最短的求解路径有时候反而需要多走一些步数,而在十五数码这个问题中,一个空格转一圈回到原点时,实际上码盘是会发生改变的,而这个过程中可能会增加h(n)的值,但是在旋转完成 后,整个f(n)的值却可能是最小的。因此猜测这就是为什么g(n)的最优区间在0.25附近。当小于这个区间时,我们实际上是减小解路径的行走代价,鼓励多在解路径上进行尝试,因此会出现长路径和少扩展以及合适的求解时间。当大于这个区间时,我们在增加算法在解路径的行走代价,造成算法每走一步都需要谨慎小心,不断扩展节点,期望求得较小的行走代价从而得到最短的路径,最终造成扩展图十分庞大,超时后最终能得到的解路径很少。
因此g(n)的选取对算法其实是有影响的。
PS:如果有兴趣可以尝试一下h(n)函数对算法求解的影响。

2.源代码

代码如下(示例):

import copy
import re
import time
import os


# Lenth = 0  # 码盘边长


# 状态节点
class Astar_node(object):
    def __init__(self, gn, hn, state=None, par=None):
        '''
        初始化
        :param state: 节点存储的状态
        :param par: 父节点
        '''
        self.state = state  # 节点状态
        self.par = par  # 父节点
        self.gn, self.hn = gn, hn  # 启发信息

    @property  # 定义fn属性
    def fn(self):
        return self.gn + self.hn

    def __eq__(self, obj):  # 相等的判断
        return self.state == obj.state

    def __ne__(self, obj):  # 不等的判断
        return not self.__eq__(obj)

    def print_state(self):
        for row in range(Lenth):
            for col in range(Lenth):
                print('%3d' % self.state[row][col], end='')
            print('')
        print('--------------------')

    def print_node(self):
        print("gn=%d,hn=%d,fn=%d" % (self.gn, self.hn, self.fn))
        self.print_state()


def manhattan_dis(cur):
    '''
    计算和目标码盘的曼哈顿距离
    :param cur: 当前节点
    :return: 到目的状态的曼哈顿距离
    '''
    cur_state = cur.state
    end_state = end_node.state
    dist = 0
    for row in range(Lenth):
        for col in range(Lenth):
            if cur_state[row][col] == end_state[row][col]:
                continue
            num = cur_state[row][col]
            # 求目标码盘对应元素的横纵坐标
            num_row = num // Lenth if num % 4 != 0 else (num - 1) // Lenth
            num_row = num_row if num_row != -1 else 3
            num_col = num % Lenth - 1
            num_col = num_col if num_col != -1 else 3
            dist += (abs(row - num_row) + abs(col - num_col))
    return dist


class A_start:
    '''
    A*算法初始化
    :param start: 起始节点
    :param end: 终止节点
    :param heuristic_fn: 启发函数
    return: G search_cnt
    '''

    def __init__(self, start, end, heuristic_fn,time_limite):
        self.OPEN = []  # OPEN表
        self.CLOSE = []  # CLOSE表
        self.G = []  # 搜索树
        self.start = start
        self.end = end
        self.cur_node = None
        self.heuristic_fn = heuristic_fn
        self.start_t = 0  # 计时变量
        self.end_t = 0

        self.G.append(self.start)  # 初始化搜索图
        self.OPEN.append(self.start)  # 初始化OPEN表

    def begin_search(self):  # 算法开始
        self.start_t = time.time()
        # 找空位坐标
        blank_pos = None
        while 1:
            # OPEN表为空表示无解 直接退出
            if self.OPEN == [] or (time.time()-self.start_t>time_limite):
                print("There is no anser!")
                self.print_result()
                break
            else:
                self.cur_node = self.OPEN.pop(0)  # 弹出OPEN表中第一个元素
                self.CLOSE.append(self.cur_node)  # 当前节点放入CLOSE表表示扩展完成
                # self.cur_node.print_node()
                # 搜索到目标节点
                if self.cur_node == self.end:
                    # self.end.par=self.cur_node
                    print("Success!")
                    self.print_result()
                    break
                # 找节点空位
                for row in range(Lenth):
                    for col in range(Lenth):
                        if self.cur_node.state[row][col] == 0:
                            blank_pos = [row, col]
                            break
                # 扩展节点
                M = []  # 扩展出的新节点集合(不包括当前节点的父节点)
                for dict in dicts:
                    b_x, b_y = blank_pos[0], blank_pos[1]
                    n_x, n_y = b_x + dict[0], b_y + dict[1]
                    if n_x in range(Lenth) and n_y in range(Lenth):  # 越界判定
                        new_node = Astar_node(0, 0, copy.deepcopy(self.cur_node.state))
                        new_node.state[b_x][b_y], new_node.state[n_x][n_y] = \
                            new_node.state[n_x][n_y], new_node.state[b_x][b_y]  # 移动空位
                        if new_node != self.cur_node.par:  # 扩展结点不是当前节点的父节点
                            new_node.gn = self.cur_node.gn + 0.25
                            new_node.hn = self.heuristic_fn(new_node)  # 计算节点hn
                            M.append(new_node)  # 新节点加入集合
                # 处理新扩展的节点
                for node in M:
                    # 去重扩展搜索树
                    if node not in self.G:
                        self.G.append(node)
                    # 未出现在OPEN和CLOSE表中 将扩展节点父节点设为当前节点并加入OPEN表
                    if node not in self.OPEN and node not in self.CLOSE:
                        node.par = self.cur_node
                        self.OPEN.append(node)
                    # 出现在OPEN表中 比较OPEN表和M表中的fn值 若M<OPEN 在OPEN表中将该节点父节点设为当前节点
                    elif node in self.OPEN:
                        for node_open in self.OPEN:
                            if node == node_open and node.fn < node_open.fn:
                                # node_open.par = self.cur_node
                                node.par=self.cur_node
                                self.OPEN.remove(node_open)
                                self.OPEN.append(node)
                    # 出现在CLOSE表中 比较CLOSE表和M表中的fn值 若M<CLOSE 在CLOSE表中将扩展节点子节点指向当前节点(将当前节点父节点设为CLOSE表中扩展节点)弹出CLOSE表中的该节点加入OPEN表
                    elif node in self.CLOSE:
                        for node_close in self.CLOSE:
                            if node == node_close and node.fn < node_close.fn:
                                self.cur_node.par = node_close
                                self.CLOSE.remove(node_close)
                                self.OPEN.append(node_close)

                # 依照启发信息重排OPEN表
                self.OPEN.sort(key=lambda x: x.fn)

    def print_result(self):
        self.end_t = time.time()
        #打印路径
        path_cnt = 1
        self.cur_node.print_state()
        while True:
            self.cur_node = self.cur_node.par
            self.cur_node.print_state()
            path_cnt += 1
            if self.cur_node.par == root_node:
                break
        print("number of searched node:%d" % len(self.G))
        print("number of CLOSE:%d" % len(self.CLOSE))
        print("Lenth of path:%d" % (path_cnt - 1))
        print("Time:%f" % (self.end_t - self.start_t))


if __name__ == '__main__':
    dicts = [[0, 1], [0, -1], [-1, 0], [1, 0]]  # 空格移动方向
    with open("./infile.txt", "r") as f:  # 读取码盘边长和初始码盘
        Lenth = int(f.readline().strip().split()[-1])
        List = list(map(int, f.readline().strip().split()))
    # 创建初始码盘和目标码盘
    Start = [List[i:i + Lenth] for i in range(0, len(List), Lenth)]
    GOAL_list = [i for i in range(1, Lenth * Lenth)]
    GOAL_list.append(0)
    GOAL = [GOAL_list[i:i + Lenth] for i in range(0, len(GOAL_list), Lenth)]

    root_node = Astar_node(0, 0, [[0] * 4] * 4, None)  # 创建树根
    end_node = Astar_node(0, 0, GOAL, None)  # 创建目标节点
    start_node = Astar_node(0, 0, Start, root_node)  # 创建初始节点
    start_node.hn = manhattan_dis(start_node)

    time_limite=100
    Astar = A_start(start_node, end_node, manhattan_dis,time_limite)
    Astar.begin_search()

    os.system("pause")

总结

A*算法确实是一个优秀的算法,在亲自编写的过程中才体会到这个算法的严谨和精妙之处。算法有很强的扩展性和灵活性。对于很多问题的求解都适用。比如:路径搜索,图搜索等等。如果有机会希望能尝试它的改进算法的研究和实现。

本文地址:https://blog.csdn.net/qq_40185348/article/details/111072230