欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Prioritized Experience Replay中的sumTree的实现,用sumTree来存储正样本,以便随机抽取

程序员文章站 2022-07-13 15:19:40
...

sumTree的定义可以参考这篇论文《Prioritized Experience Replay》,下面的详细代码注释链接,在注释中有叶子节点和树的总结点的关系:https://download.csdn.net/download/song91425/10568762

Prioritized Experience Replay中的sumTree的实现,用sumTree来存储正样本,以便随机抽取

import numpy as np


class Tree(object):
    write = 0

    def __init__(self, capacity):
        self.capacity = capacity  # capacity是叶子节点个数,
        self.tree = np.zeros(2 * capacity)  # 从1开始编号[1,capacity]
        self.data = np.zeros(capacity+1, dtype=object)  # 存叶子节点对应的数据data[叶子节点编号id] = data

    def add(self, p, data):
        idx = self.write + self.capacity  
        self.data[self.write+1] = data
        self._updatetree(idx, p)
        self.write += 1
        if self.write > self.capacity:  
            self.write = 0

    def _updatetree(self, idx, p):
        change = p - self.tree[idx]  
        self._propagate(idx, change)  
        self.tree[idx] = p  

    def _propagate(self, idx, change):
        parent = idx // 2  
        self.tree[parent] += change  # 更新父节点的值,是向上传播的体现
        if parent != 1:
            self._propagate(parent, change)  

    def _total(self):
        return self.tree[1]  

    def get(self, s):
        idx = self._retrieve(1, s)  
        index_data = idx - self.capacity + 1  
        return (idx, self.tree[idx], self.data[index_data])

    def _retrieve(self, idx, s):
        left = 2 * idx  
        right = left + 1
        if left >= (len(self.tree)-1):  
            return idx
        if s <= self.tree[left]:
            return self._retrieve(left, s)  # 往左孩子处查找
        else:
            return self._retrieve(right, s - self.tree[left])  # 往右孩子处查找

tree = Tree(5)
tree.add(1,3)
tree.add(2,4)
tree.add(3,5)
tree.add(4,6)
tree.add(6,11)

print(tree.get(4))  # (8, 4.0, 6)