nn-search-full / ranker.py
muryshev's picture
init
b24d496
import numpy as np
import xgboost as xgb
from typing import List
def remove_duplicates(input_list: List) -> List:
result = list(dict.fromkeys(input_list))
return result
class Ranker:
def __init__(self, model_path: str = None):
self.model = xgb.Booster()
if model_path is not None:
self.model.load_model(model_path)
def rank(self,
titles: List[str],
scores: List[float],
indexes: List[int],
embeddings: List[List[float]]) -> tuple:
"""
:param titles: названия документов, которым принадлежат чанки или сами названия документов, если документ целый
:param embeddings: эмбединги чанков или цельных документов
Индексы соответствуют друг другу, то есть эмбединг в embeddings[0] относится к названию документа в titles[0]
:return: список названий документов, отранжированный и без дубликатов от чанкинга
"""
dmatrix_of_features = xgb.DMatrix(np.array(embeddings))
dmatrix_of_features.set_group([len(embeddings)])
predictions = self.model.predict(dmatrix_of_features)
indexes_ranked_documents = sorted(range(len(predictions)), key=lambda k: predictions[k], reverse=True)
titles_reranked_documents = [titles[item] for item in indexes_ranked_documents]
scores_reranked_documents = sorted([scores[item] for item in indexes_ranked_documents], reverse=True)
indexes_reranked_documents = [indexes[item] for item in indexes_ranked_documents]
return titles_reranked_documents, scores_reranked_documents, indexes_reranked_documents