欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

python小试牛刀: K-best算法

程序员文章站 2024-02-23 14:10:58
...
"""
k-best解决的问题:
n个珠宝价值vi和重量wi;求保留k个珠宝的单位价值最大化
n,vi,wi 值域。。。
"""

def k_best(v_ls,w_ls,k):
    l,r = 0.0, 1e5 #maximun 10,0000

    while abs(r-l) > 1e-6:
        mid = (l + r)/2.0
        vw_ls = [v - mid * w for v,w in zip(v_ls, w_ls)]
        choose = list(range(len(v_ls)))
        for jj in range(k): #一个bug,序号会乱,因为排序后不;
            for i in range(jj+1, len(v_ls)):
                if vw_ls[jj] < vw_ls[i]:
                    vw_ls[jj], vw_ls[i] = vw_ls[i], vw_ls[jj]
                    choose[jj], choose[i] = choose[i], choose[jj]
        sk = sum(vw_ls[0:k])
        if sk < 0:
            r = mid
        else:
            l = mid
    return choose, mid

"""
        choose = []
        for jj in range(k): #一个bug,不应该设定为前几个的;
            c = 0
            for i in range(jj+1,len(v_ls)):
                if vw_ls[jj] < vw_ls[i]:
                    vw_ls[jj], vw_ls[i] = vw_ls[i], vw_ls[jj]
                    c = i
            choose.append(c)
            """

if __name__ == "__main__":
    # n,k = input("输入珠宝数目n,待选数目k:")
    # v_ls = []
    # w_ls = []
    # for i in range(int(n)):
    #     v,w = input("输入珠宝价值,重量" ).split(',')#, input()
    #     v_ls.append(float(v))
    #     w_ls.append(float(w))

    v_ls = [3,40,1.5,10,8]
    w_ls = [1,20,1,5,4]
    k = 3
    print("list : ", v_ls, w_ls)
    chose, mid = k_best(v_ls,w_ls, k)
    print(chose, " value: ", mid)
    print("\nchoose k: ")
    for i in range(k):
        print(v_ls[chose[i]], w_ls[chose[i]])

result:

list : [3, 40, 1.5, 10, 8] [1, 20, 1, 5, 4]
[0, 4, 3, 1, 2] value: 2.1000007109250873

choose k:
3 1
8 4
10 5

comment

  • k-best目的是最优选取大集中的k个子集,是排序问题的进一步衍生。不单纯能用排序解决,但是算法也离不开排序。
  • 比如上述最大性价比的k个珠宝问题:如果按照vi/wi的单个性价比选取可能不是最优的。因为如上述(3,1),(40,20),(1.5,1),按照单个性价比依次是3,2,1.5,选前两个;但是实际上最优的是1、3,因为第三个质量小,不如第二个占重大,使得和比值过于偏向40,20了。
  • 值得注意的是,其中只涉及k个元素的排序,不必要排序所有,尤其集合n很大,选取k很小时候。复杂度O(kn),)(nn)会差异巨大。。。另外,可以用堆排序维持最小的k堆,或者随机选择来替代这里的冒泡,复杂度降低到O(nlg(k)), O(k*lg(k)) ?; 因为k个元素的堆高度为lg(k)