Spaces:
Sleeping
Sleeping
from typing import Dict | |
from business_transaction_map.common.constants import DEVICE, DO_NORMALIZATION, COLUMN_TYPE_DOC_MAP | |
from business_transaction_map.components.faiss_vector_database import FaissVectorDatabase | |
from business_transaction_map.components.embedding_extraction import EmbeddingExtractor | |
import os | |
from prompts import BUSINESS_TRANSACTION_PROMPT | |
from llm.common import LlmApi | |
db_files_path = os.environ.get("GLOBAL_TRANSACTION_MAPS_DATA_PATH", "transaction_maps_search_data/csv/карта_проводок_new.pkl") | |
model_path = os.environ.get("GLOBAL_TRANSACTION_MAPS_MODEL_PATH", "") | |
class TransactionMapsSearch: | |
def __init__(self, | |
model_name_or_path: str = model_path, | |
device: str = DEVICE): | |
self.device = device | |
self.model = self.load_model( | |
model_name_or_path=model_name_or_path, | |
device=device | |
) | |
self.database = FaissVectorDatabase(str(db_files_path)) | |
async def extract_business_transaction_with_llm(question: str, llm_api: LlmApi) -> str: | |
prompt = BUSINESS_TRANSACTION_PROMPT.replace('{{ЗАПРОС}}', question) | |
res = await llm_api.predict(prompt) | |
return res | |
def load_model(model_name_or_path: str = None, | |
device: str = None): | |
model = EmbeddingExtractor(model_name_or_path, device) | |
return model | |
def filter_answer(answer: Dict) -> Dict: | |
""" | |
Функция фильтрует ответы. | |
Args: | |
answer: Словарь с ответом и дополнительной информацией. | |
Returns: | |
Словарь уникальных ответов. | |
""" | |
list_ = [] | |
del_key = [] | |
for key in answer: | |
if answer[key]["doc_name"] in list_: | |
del_key.append(key) | |
else: | |
list_.append(answer[key]["doc_name"]) | |
for i in del_key: | |
answer.pop(i) | |
return answer | |
async def search_transaction_map(self, | |
query: str = None, | |
find_transaction_maps_by_question: bool = False, | |
k_neighbours: int = 15, | |
llm_api: LlmApi = None): | |
if find_transaction_maps_by_question: | |
query = await self.extract_business_transaction_with_llm(query, llm_api) | |
cleaned_text = query.replace("\n", " ") | |
# cleaned_text = 'query: ' + cleaned_text # only for e5 | |
query_tokens = self.model.query_tokenization(cleaned_text) | |
query_embeds = self.model.query_embed_extraction(query_tokens.to(self.device), DO_NORMALIZATION)[0] | |
query_embeds = query_embeds[None, :] | |
# Предсказывает расстояние и индекс. ИНДЕКС == номерам строк в df | |
answer = self.database.search_transaction_map(query_embeds, k_neighbours) | |
answer = self.filter_answer(answer) | |
final_docs = {} | |
answers = [] | |
for value in list(answer.values()): | |
final_docs[value["doc_name"] + '.xlsx'] = value[COLUMN_TYPE_DOC_MAP].upper() if 'sap' in value[ | |
COLUMN_TYPE_DOC_MAP] else value[COLUMN_TYPE_DOC_MAP] | |
answers.append(answer) | |
return final_docs, answers | |