from typing import Dict from business_transaction_map.common.constants import DEVICE, DO_NORMALIZATION, COLUMN_TYPE_DOC_MAP from business_transaction_map.components.faiss_vector_database import FaissVectorDatabase from business_transaction_map.components.embedding_extraction import EmbeddingExtractor import os import requests from prompts import BUSINESS_TRANSACTION_PROMPT db_files_path = os.environ.get("GLOBAL_TRANSACTION_MAPS_DATA_PATH", "transaction_maps_search_data/csv/карта_проводок_new.pkl") model_path = os.environ.get("GLOBAL_TRANSACTION_MAPS_MODEL_PATH", "") llm_api_endpoint = os.environ.get("LLM_API_ENDPOINT", "") class TransactionMapsSearch: def __init__(self, model_name_or_path: str = model_path, device: str = DEVICE): self.device = device self.model = self.load_model( model_name_or_path=model_name_or_path, device=device ) self.database = FaissVectorDatabase(str(db_files_path)) @staticmethod def extract_business_transaction_with_llm(question: str) -> str: question = BUSINESS_TRANSACTION_PROMPT.replace('{{ЗАПРОС}}', question) response = requests.post(url=llm_api_endpoint, json={"prompt": f"[INST] {question} [/INST]", #пробелы внутри [INST], как оказалось, важны. Без них можно словить бесконечную генерацию бреда от ллм "temperature": 0.0}) return response.json()['content'] @staticmethod def load_model(model_name_or_path: str = None, device: str = None): model = EmbeddingExtractor(model_name_or_path, device) return model @staticmethod def filter_answer(answer: Dict) -> Dict: """ Функция фильтрует ответы. Args: answer: Словарь с ответом и дополнительной информацией. Returns: Словарь уникальных ответов. """ list_ = [] del_key = [] for key in answer: if answer[key]["doc_name"] in list_: del_key.append(key) else: list_.append(answer[key]["doc_name"]) for i in del_key: answer.pop(i) return answer def search_transaction_map(self, query: str = None, find_transaction_maps_by_question: bool = False, k_neighbours: int = 15): if find_transaction_maps_by_question: query = self.extract_business_transaction_with_llm(query) cleaned_text = query.replace("\n", " ") # cleaned_text = 'query: ' + cleaned_text # only for e5 query_tokens = self.model.query_tokenization(cleaned_text) query_embeds = self.model.query_embed_extraction(query_tokens.to(self.device), DO_NORMALIZATION)[0] query_embeds = query_embeds[None, :] # Предсказывает расстояние и индекс. ИНДЕКС == номерам строк в df answer = self.database.search_transaction_map(query_embeds, k_neighbours) answer = self.filter_answer(answer) final_docs = {} answers = [] for value in list(answer.values()): final_docs[value["doc_name"] + '.xlsx'] = value[COLUMN_TYPE_DOC_MAP].upper() if 'sap' in value[ COLUMN_TYPE_DOC_MAP] else value[COLUMN_TYPE_DOC_MAP] answers.append(answer) return final_docs, answers