欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Python版B树 博客分类: Python  

程序员文章站 2024-02-04 16:27:40
...
话说以前的树都用java写的,最近发现python有点生疏了,于是用python写了个B树实现,B树在索引领域用得还是蛮多了,如果没记错mysql的默认索引好像就是B树...

首先是数据实体对象,很简单,只存放key,value
class Entity(object):
    '''数据实体'''
    
    def __init__(self,key,value):
        self.key = key
        self.value = value



然后节点对象
class Node(object):
    '''B树的节点'''
    
    def __init__(self):
        self.parent = None
        self.entitys = []
        self.childs = []
        
    def find(self,key):
        '''通过key查找并返回一个数据实体'''
        
        for e in self.entitys:
            if key == e.key:
                return e
            
    
    def delete(self,key):
        '''通过key删除一个数据实体,并返回它和它的下标(下标,实体)'''
        for i,e in enumerate(self.entitys):
            if e.key == key:
                del self.entitys[i]
                return (i,e)
            
            
    def isLeaf(self):
        '''判断该节点是否是一个叶子节点'''
        
        return len(self.childs) == 0
    
    
    def addEntity(self,entity):
        '''添加一个数据实体'''
        
        self.entitys.append(entity)
        self.entitys.sort(key=lambda x:x.key)
        
        
    def addChild(self,node):
        '''添加一个子节点'''

        self.childs.append(node)
        node.parent = self
        self.childs.sort(key=lambda x:x.entitys[0].key)



最后是Tree类
class Tree(object):
    '''B树'''
    
    def __init__(self,size=6):
        self.size = size
        self.root = None
        self.length = 0
        
    
    def add(self,key,value=None):
        '''插入一条数据到B树'''
        
        self.length += 1
        
        if self.root:
            current = self.root
            
            while not current.isLeaf():
                for i,e in enumerate(current.entitys):
                    if e.key > key:
                        current = current.childs[i]
                        break
                    elif e.key == key:
                        e.value = value
                        self.length -= 1
                        return
                else:
                    current = current.childs[-1]
                    
            current.addEntity(Entity(key,value))
                
            if len(current.entitys) > self.size:
                self.__spilt(current)
        else:
            self.root = Node()
            self.root.addEntity(Entity(key,value))
    
    
    def get(self,key):
        '''通过key查询一个数据'''
        
        node = self.__findNode(key)
        
        if node:  
            return node.find(key).value
            
            
    def delete(self,key):
        '''通过key删除一个数据项并返回它'''
        
        node = self.__findNode(key)
        
        if node:
            i,e = node.delete(key)
            
            #在节点不是叶子节点时需要做修复(取对应下标的子节点的最大的一个数据项来补)
            if not node.isLeaf():
                child = node.childs[i]
                j,entity = child.delete(child.entitys[-1].key)
                node.addEntity(entity)
                
                while not child.isLeaf():
                    node = child
                    child = child.childs[j]
                    j,entity = child.delete(child.entitys[-1].key)
                    node.addEntity(entity)
            
            self.length -= 1
            return e.value
    
    
    def isEmpty(self):
        return self.length == 0
    
                
    def __findNode(self, key):
        '''通过key值查询一个数据在哪个节点,找到就返回该节点'''
        
        if self.root:
            current = self.root
            
            while not current.isLeaf():
                for i, e in enumerate(current.entitys):
                    if e.key > key:
                        current = current.childs[i]
                        break
                    elif e.key == key:
                        return current
                else:
                    current = current.childs[-1]
            
            if current.find(key):
                return current
            
            
    def __spilt(self,node):
        '''
        分裂一个节点,规则为:
        1、中间的数据项移到父节点
        2、新建一个右兄弟节点,将中间节点右边的数据项移到新节点
        '''
        
        middle = len(node.entitys) / 2
        
        top = node.entitys[middle]
        
        right = Node()
        
        for e in node.entitys[middle + 1:]:
            right.addEntity(e)
            
        for n in node.childs[middle + 1:]:
            right.addChild(n)
        
        node.entitys = node.entitys[:middle]
        node.childs = node.childs[:middle + 1]
        
        parent = node.parent
        
        if parent:
            parent.addEntity(top)
            parent.addChild(right)
            
            if len(parent.entitys) > self.size:
                self.__spilt(parent)
        else:
            self.root = Node()
            self.root.addEntity(top)
            self.root.addChild(node)
            self.root.addChild(right)



测试代码
if __name__ == '__main__':
    t = Tree(4)
    t.add(20)
    t.add(40)
    t.add(60)
    t.add(70,'c') 
    t.add(80)       
    t.add(10) 
    t.add(30)
    t.add(15,'python')
    t.add(75,'java')
    t.add(85)
    t.add(90)
    t.add(25)
    t.add(35,'c#')
    t.add(50)
    t.add(22,'c++')
    t.add(27)
    t.add(32)
    
    print t.get(15)
    print t.get(75)
    print t.delete(35)
    print t.delete(22)
    t.add(22,'lua')
    print t.get(22)
    print t.length