欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

[TextMatch框架] 基于召回和排序的文本搜索

程序员文章站 2022-12-20 23:37:27
git clone https://github.com/MachineLP/TextMatchcd TextMatchexport PYTHONPATH=${PYTHONPATH}:../TextMatchpython tests/core_test/text_search_test.pytests/core_test/text_search_test.pyimport sysfrom textmatch.core.text_match import TextMatchfrom text...
git clone https://github.com/MachineLP/TextMatch
cd TextMatch
export PYTHONPATH=${PYTHONPATH}:../TextMatch
python tests/core_test/text_search_test.py

tests/core_test/text_search_test.py

import sys
from textmatch.core.text_match import TextMatch
from textmatch.core.qa_match import QMatch, AMatch, SemanticMatch
from textmatch.models.text_search.model_factory_search import ModelFactorySearch



def text_match_recall(testword, doc_dict):
    # QMatch
    q_match = QMatch( q_dict=doc_dict, match_models=['bow', 'tfidf', 'ngram_tfidf', 'albert']) 
    q_match_pre = q_match.predict(testword, match_strategy='score', vote_threshold=0.5, key_weight = {'bow': 1, 'tfidf': 1, 'ngram_tfidf': 1, 'albert': 1})
    # print ('q_match_pre>>>>>', q_match_pre )
    return q_match_pre 

def text_match_sort(testword, candidate_doc_dict):
    text_match = TextMatch( q_dict=candidate_doc_dict, match_models=['bm25', 'edit_sim', 'jaccard_sim'] )
    text_match_res = text_match.predict( query, match_strategy='score', vote_threshold=-100.0, key_weight = {'bm25': 0, 'edit_sim': 1, 'jaccard_sim': 1} )
    return text_match_res 


if __name__ == '__main__':
    doc_dict = {"0":"我去玉龙雪山并且喜欢玉龙雪山玉龙雪山", "1":"我在玉龙雪山并且喜欢玉龙雪山", "2":"我在九寨沟", "3":"我在九寨沟,很喜欢", "4":"很喜欢"}   
    query = "我在九寨沟,很喜欢"
    
    # 直接搜索
    mf = ModelFactorySearch( match_models=['bm25', 'edit_sim', 'jaccard_sim'] )
    mf.init(words_dict=doc_dict)
    pre = mf.predict(query)
    print ('pre>>>>>', pre) 

    
    # 先召回
    match_pre = text_match_recall( query, doc_dict )
    print( '召回的结果:', match_pre )

    candidate_doc_dict = dict( zip( match_pre.keys(), [doc_dict[key] for key in match_pre.keys()] ) )
    print ("candidate_doc_dict:", candidate_doc_dict)

    # 再排序
    # ['bm25', 'edit_sim', 'jaccard_sim']
    text_match_res = text_match_sort( query, candidate_doc_dict )
    print ('排序的score>>>>>', text_match_res) 



    '''
    # 排序
    mf = ModelFactorySearch( match_models=['bm25', 'edit_sim', 'jaccard_sim'] )
    mf.init(words_dict=candidate_doc_dict) 
    pre = mf.predict(query)
    print ('排序的结果>>>>>', pre) 
    '''
    
    '''
    召回的结果: {'2': 0.5995837299668828, '3': 0.9999999210000139, '4': 0.5460526286735667}
    candidate_doc_dict: {'2': '我在九寨沟', '3': '我在九寨沟,很喜欢', '4': '很喜欢'}
排序的score>>>>> {'2': 0.55, '3': 1.0, '4': 0.34285714285714286}
    '''

本文地址:https://blog.csdn.net/u014365862/article/details/107448071