nn-search-full / transaction_maps_search.py
muryshev's picture
update
67beed8
raw
history blame
3.48 kB
from typing import Dict
from business_transaction_map.common.constants import DEVICE, DO_NORMALIZATION, COLUMN_TYPE_DOC_MAP
from business_transaction_map.components.faiss_vector_database import FaissVectorDatabase
from business_transaction_map.components.embedding_extraction import EmbeddingExtractor
import os
from prompts import BUSINESS_TRANSACTION_PROMPT
from llm.common import LlmApi
db_files_path = os.environ.get("GLOBAL_TRANSACTION_MAPS_DATA_PATH", "transaction_maps_search_data/csv/карта_проводок_new.pkl")
model_path = os.environ.get("GLOBAL_TRANSACTION_MAPS_MODEL_PATH", "")
class TransactionMapsSearch:
def __init__(self,
model_name_or_path: str = model_path,
device: str = DEVICE):
self.device = device
self.model = self.load_model(
model_name_or_path=model_name_or_path,
device=device
)
self.database = FaissVectorDatabase(str(db_files_path))
@staticmethod
async def extract_business_transaction_with_llm(question: str, llm_api: LlmApi) -> str:
prompt = BUSINESS_TRANSACTION_PROMPT.replace('{{ЗАПРОС}}', question)
res = await llm_api.predict(prompt)
return res
@staticmethod
def load_model(model_name_or_path: str = None,
device: str = None):
model = EmbeddingExtractor(model_name_or_path, device)
return model
@staticmethod
def filter_answer(answer: Dict) -> Dict:
"""
Функция фильтрует ответы.
Args:
answer: Словарь с ответом и дополнительной информацией.
Returns:
Словарь уникальных ответов.
"""
list_ = []
del_key = []
for key in answer:
if answer[key]["doc_name"] in list_:
del_key.append(key)
else:
list_.append(answer[key]["doc_name"])
for i in del_key:
answer.pop(i)
return answer
async def search_transaction_map(self,
query: str = None,
find_transaction_maps_by_question: bool = False,
k_neighbours: int = 15,
llm_api: LlmApi = None):
if find_transaction_maps_by_question:
query = await self.extract_business_transaction_with_llm(query, llm_api)
cleaned_text = query.replace("\n", " ")
# cleaned_text = 'query: ' + cleaned_text # only for e5
query_tokens = self.model.query_tokenization(cleaned_text)
query_embeds = self.model.query_embed_extraction(query_tokens.to(self.device), DO_NORMALIZATION)[0]
query_embeds = query_embeds[None, :]
# Предсказывает расстояние и индекс. ИНДЕКС == номерам строк в df
answer = self.database.search_transaction_map(query_embeds, k_neighbours)
answer = self.filter_answer(answer)
final_docs = {}
answers = []
for value in list(answer.values()):
final_docs[value["doc_name"] + '.xlsx'] = value[COLUMN_TYPE_DOC_MAP].upper() if 'sap' in value[
COLUMN_TYPE_DOC_MAP] else value[COLUMN_TYPE_DOC_MAP]
answers.append(answer)
return final_docs, answers