欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

强化学习之动态规划

程序员文章站 2024-03-17 17:14:34
...
策略迭代:策略评估,策略改进。
import pygame
from load import *
import math
import time
import random
import numpy as np

class YuanYangEnv:
    def __init__(self):
        self.states=[]
        for i in range(0,100):
            self.states.append(i)
        self.actions = ['e', 's', 'w', 'n']
        self.gamma = 0.8
        self.value = np.zeros((10, 10))
        self.viewer = None
        self.FPSCLOCK = pygame.time.Clock()
        #屏幕大小
        self.screen_size=(1200,900)
        self.bird_position=(0,0)
        self.limit_distance_x=120
        self.limit_distance_y=90
        self.obstacle_size=[120,90]
        self.obstacle1_x = []
        self.obstacle1_y = []
        self.obstacle2_x = []
        self.obstacle2_y = []
        self.path = []

        for i in range(8):
            #第一个障碍物
            self.obstacle1_x.append(360)
            if i <= 3:
                self.obstacle1_y.append(90 * i)
            else:
                self.obstacle1_y.append(90 * (i + 2))
            # 第二个障碍物
            self.obstacle2_x.append(720)
            if i <= 4:
                self.obstacle2_y.append(90 * i)
            else:
                self.obstacle2_y.append(90 * (i + 2))

        self.bird_male_init_position=[0,0]
        self.bird_male_position = [0, 0]
        self.bird_female_init_position=[1080,0]
    #def step(self):
    def collide(self,state_position):
        flag = 1
        flag1 = 1
        flag2 = 1
        # 判断第一个障碍物
        dx = []
        dy = []
        for i in range(8):
            dx1 = abs(self.obstacle1_x[i] - state_position[0])
            dx.append(dx1)
            dy1 = abs(self.obstacle1_y[i] - state_position[1])
            dy.append(dy1)
        mindx = min(dx)
        mindy = min(dy)
        if mindx >= self.limit_distance_x or mindy >= self.limit_distance_y:
            flag1 = 0
        # 判断第二个障碍物
        second_dx = []
        second_dy = []
        for i in range(8):
            dx2 = abs(self.obstacle2_x[i] - state_position[0])
            second_dx.append(dx2)
            dy2 = abs(self.obstacle2_y[i] - state_position[1])
            second_dy.append(dy2)
        mindx = min(second_dx)
        mindy = min(second_dy)
        if mindx >= self.limit_distance_x or mindy >= self.limit_distance_y:
            flag2 = 0
        if flag1 == 0 and flag2 == 0:
            flag = 0
        if state_position[0] > 1080 or state_position[0] < 0 or state_position[1] > 810 or state_position[1] < 0:
            flag = 1
        return flag
    def find(self,state_position):
        flag=0
        if abs(state_position[0]-self.bird_female_init_position[0])<self.limit_distance_x and abs(state_position[1]-self.bird_female_init_position[1])<self.limit_distance_y:
            flag=1
        return flag
    def state_to_position(self, state):
        i = int(state / 10)
        j = state % 10
        position = [0, 0]
        position[0] = 120 * j
        position[1] = 90 * i
        return position
    def position_to_state(self, position):
        i = position[0] / 120
        j = position[1] / 90
        return int(i + 10 * j)
    def reset(self):
        #随机产生初始状态
        flag1=1
        flag2=1
        while flag1 or flag2 ==1:
            #随机产生初始状态,0~99,randoom.random() 产生一个0~1的随机数
            state=self.states[int(random.random()*len(self.states))]
            state_position = self.state_to_position(state)
            flag1 = self.collide(state_position)
            flag2 = self.find(state_position)
        return state
    def transform(self,state, action):
        #将当前状态转化为坐标
        current_position=self.state_to_position(state)
        next_position = [0,0]
        flag_collide=0
        flag_find=0
        #判断当前坐标是否与障碍物碰撞
        flag_collide=self.collide(current_position)
        #判断状态是否是终点
        flag_find=self.find(current_position)
        if flag_collide==1 or flag_find==1:
            return state, 0, True
        #状态转移
        if action=='e':
            next_position[0]=current_position[0]+120
            next_position[1]=current_position[1]
        if action=='s':
            next_position[0]=current_position[0]
            next_position[1]=current_position[1]+90
        if action=='w':
            next_position[0] = current_position[0] - 120
            next_position[1] = current_position[1]
        if action=='n':
            next_position[0] = current_position[0]
            next_position[1] = current_position[1] - 90
        #判断next_state是否与障碍物碰撞
        flag_collide = self.collide(next_position)
        #如果碰撞,那么回报为-1,并结束
        if flag_collide==1:
            return self.position_to_state(current_position),-1,True
        #判断是否终点
        flag_find = self.find(next_position)
        if flag_find==1:
            return self.position_to_state(next_position),1,True
        return self.position_to_state(next_position), 0, False
    def gameover(self):
        for event in pygame.event.get():
            if event.type == QUIT:
                exit()
    def render(self):
        if self.viewer is None:
            pygame.init()
            #画一个窗口
            self.viewer=pygame.display.set_mode(self.screen_size,0,32)
            pygame.display.set_caption("yuanyang")
            #下载图片
            self.bird_male = load_bird_male()
            self.bird_female = load_bird_female()
            self.background = load_background()
            self.obstacle = load_obstacle()
            #self.viewer.blit(self.bird_male, self.bird_male_init_position)
            #在幕布上画图片
            self.viewer.blit(self.bird_female, self.bird_female_init_position)
            self.viewer.blit(self.background, (0, 0))
            self.font = pygame.font.SysFont("times", 35)
        self.viewer.blit(self.background,(0,0))
        #画直线
        for i in range(11):
            pygame.draw.lines(self.viewer, (255, 255, 255), True, ((120*i, 0), (120*i, 900)), 1)
            pygame.draw.lines(self.viewer, (255, 255, 255), True, ((0, 90* i), (1200, 90 * i)), 1)
        self.viewer.blit(self.bird_female, self.bird_female_init_position)
        #画障碍物
        for i in range(8):
            self.viewer.blit(self.obstacle, (self.obstacle1_x[i], self.obstacle1_y[i]))
            self.viewer.blit(self.obstacle, (self.obstacle2_x[i], self.obstacle2_y[i]))
        #画小鸟
        self.viewer.blit(self.bird_male,  self.bird_male_position)
        # 画值函数
        for i in range(10):
            for j in range(10):
                surface = self.font.render(str(round(float(self.value[i, j]), 3)), True, (0, 0, 0))
                self.viewer.blit(surface, (120 * i + 35, 90 * j + 35))
        # 画路径点
        for i in range(len(self.path)):
            rec_position = self.state_to_position(self.path[i])
            pygame.draw.rect(self.viewer, [255, 0, 0], [rec_position[0], rec_position[1], 120, 90], 3)
            surface = self.font.render(str(i), True, (255, 0, 0))
            self.viewer.blit(surface, (rec_position[0] + 5, rec_position[1] + 5))
        pygame.display.update()
        self.gameover()
        # time.sleep(0.1)
        self.FPSCLOCK.tick(30)


if __name__=="__main__":
    yy=YuanYangEnv()
    yy.render()
    while True:
        for event in pygame.event.get():
            if event.type == QUIT:
                exit()
    值迭代算法
from yuanyang import YuanYangEnv

import random
class DP_Value_Iter:
    def __init__(self, yuanyang):
        self.states = yuanyang.states
        self.actions = yuanyang.actions
        self.v = [0.0 for i in range(len(self.states) + 1)]
        self.pi = dict()
        self.yuanyang = yuanyang
        self.gamma = yuanyang.gamma
        for state in self.states:
            flag1 = 0
            flag2 = 0
            flag1 = yuanyang.collide(yuanyang.state_to_position(state))
            flag2 = yuanyang.find(yuanyang.state_to_position(state))

            if flag1 == 1 or flag2 == 1:
                continue
            self.pi[state] = self.actions[int(random.random() * len(self.actions))]

    def value_iteration(self):
        '''

        :return:
        '''
        for i in range(1000):
            delta = 0.0
            for state in self.states:
                flag1 = 0
                flag2 = 0
                flag1 = yuanyang.collide(yuanyang.state_to_position(state))
                flag2 = yuanyang.find(yuanyang.state_to_position(state))
                if flag1 == 1 or flag2 == 1:
                    continue

                a1 = self.actions[int(random.random() * 4)]
                s, r, t = yuanyang.transform(state, a1)

                # 策略评估
                v1 = r + self.gamma * self.v[s]
                # 策略改进
                for action in self.actions:
                    s, r, t = yuanyang.transform(state, action)
                    if v1 < r + self.gamma * self.v[s]:
                        a1 = action
                        v1 = r + self.gamma * self.v[s]
                delta += abs(v1 - self.v[state])
                self.pi[state] = a1
                self.v[state] = v1

            if delta < 1e-6:
                print("迭代次数为:", i)
                break


if __name__ == "__main__":
    yuanyang = YuanYangEnv()
    policy_value = DP_Value_Iter(yuanyang)
    policy_value.value_iteration()

    # 将v值打印出来
    s = 0
    path = []
    for state in range(100):
        i = int(state / 10)
        j = state % 10
        yuanyang.value[j, i] = policy_value.v[state]
    flag = 1
    step_num = 0
    # 将最优路径打印出来
    while flag:
        # 渲染路径
        path.append(s)
        yuanyang.path = path
        a = policy_value.pi[s]
        print("%d->%s\t" % (s, a))
        yuanyang.bird_male_position = yuanyang.state_to_position(s)
        yuanyang.render()
        import time
        time.sleep(0.2)
        step_num += 1
        s_, r, t = yuanyang.transform(s, a)
        if t == True or step_num > 20:
            flag = 0
        s = s_
    # 渲染最后的路径点
    yuanyang.bird_male_position = yuanyang.state_to_position(s)
    path.append(s)
    yuanyang.render()
    while True:
        yuanyang.render()
        
相关标签: 强化学习