import gc import traceback from legal_info_search_utils.rules_utils import use_rules from itertools import islice import os import torch import numpy as np from faiss import IndexFlatIP from datasets import Dataset as dataset from transformers import AutoTokenizer, AutoModel from legal_info_search_utils.utils import query_tokenization, query_embed_extraction import requests import re import json import pymorphy3 from torch.cuda.amp import autocast from elasticsearch_module import search_company import torch.nn.functional as F import pickle from llm.prompts import LLM_PROMPT_QE, LLM_PROMPT_OLYMPIC, LLM_PROMPT_KEYS from llm.vllm_api import LlmApi, LlmParams from huggingface import dataset_utils global_data_path = os.environ.get("GLOBAL_DATA_PATH", "./legal_info_search_data/") # access token huggingface. Если задан, то используется модель с HF hf_token = os.environ.get("HF_TOKEN", None) hf_dataset = os.environ.get("HF_DATASET", None) hf_model_name = os.environ.get("HF_MODEL_NAME", "") if hf_token is not None and hf_dataset is not None: global_data_path = dataset_utils.get_global_data_path()+global_data_path print(f"Global data path: {global_data_path}") global_model_path = os.environ.get("GLOBAL_MODEL_PATH", "./models/20240202_204910_ep8") data_path_consult = global_data_path + "external_data" internal_docs_data_path = global_data_path + "nmd_full" spec_internal_docs_data_path = global_data_path + "nmd_short" accounting_data_path = global_data_path + "bu" companies_map_path = global_data_path + "companies_map/companies_map.json" dict_path = global_data_path + "dict/dict_20241030.pkl" general_nmd_path = global_data_path + "companies_map/general_nmd.json" consultations_dataset_path = global_data_path + "consult_data" explanations_dataset_path = global_data_path + "explanations" explanations_for_llm_path = global_data_path + "explanations_for_llm/explanations_for_llm.json" rules_list_path = global_data_path + "rules_list/terms.txt" db_data_types = ['НКРФ', 'ГКРФ', 'ТКРФ', 'Федеральный закон', 'Письмо Минфина', 'Письмо ФНС', 'Приказ ФНС', 'Постановление Правительства', 'Судебный документ', 'ВНД', 'Бухгалтерский документ'] device = os.environ.get("MODEL_DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu') llm_api_endpoint = os.environ.get("LLM_API_ENDPOINT", "") headers = {'Content-Type': 'application/json'} def_k = 15 class SemanticSearch: def __init__(self, do_normalization: bool = True): self.device = device self.do_normalization = do_normalization self.load_model() # Основная база self.full_base_search = True self.index_consult = IndexFlatIP(self.embedding_dim) self.index_explanations = IndexFlatIP(self.embedding_dim) self.index_all_docs_with_accounting = IndexFlatIP(self.embedding_dim) self.index_internal_docs = IndexFlatIP(self.embedding_dim) self.spec_index_internal_docs = IndexFlatIP(self.embedding_dim) self.index_teaser = IndexFlatIP(self.embedding_dim) self.load_data() # Обработка встраиваний def process_embeddings(docs): embeddings = torch.cat([torch.unsqueeze(torch.Tensor(x['doc_embedding']), 0) for x in docs], dim=0) if self.do_normalization: embeddings = F.normalize(embeddings, dim=-1).numpy() return embeddings # База ВНД self.internal_docs_embeddings = process_embeddings(self.internal_docs) self.index_internal_docs.add(self.internal_docs_embeddings) self.spec_internal_docs_embeddings = process_embeddings(self.spec_internal_docs) self.spec_index_internal_docs.add(self.spec_internal_docs_embeddings) self.all_docs_with_accounting_embeddings = process_embeddings(self.all_docs_with_accounting) self.index_all_docs_with_accounting.add(self.all_docs_with_accounting_embeddings) # База консультаций self.consult_embeddings = process_embeddings(self.all_consultations) self.index_consult.add(self.consult_embeddings) # База разъяснений self.explanations_embeddings = process_embeddings(self.all_explanations) self.index_explanations.add(self.explanations_embeddings) @staticmethod def get_main_info_with_llm(prompt: str): response = requests.post( url=llm_api_endpoint, json={'prompt': ' [INST] ' + prompt + ' [/INST]', 'temperature': 0.0, 'n_predict': 2500.0, 'top_p': 0.95, 'min_p': 0.05, 'repeat_penalty': 1.2, 'stop': []}) answer = response.json()['content'] return answer @staticmethod def rerank_by_avg_score(refs, scores_to_take=3): docs = {} regex = r'_(\d{1,3})$' refs = [(re.sub(regex, '', ref[0]), ref[1], float(ref[2])) for ref in refs] for ref in refs: if ref[0] not in docs.keys(): docs[ref[0]] = {'contents': [ref[1]], 'scores': [ref[2]]} elif len(docs[ref[0]]['scores']) < scores_to_take: docs[ref[0]]['contents'].append(ref[1]) docs[ref[0]]['scores'].append(ref[2]) for ref in docs: docs[ref]['avg_score'] = np.mean(docs[ref]['scores']) sorted_docs = sorted(docs.items(), key=lambda x: x[1]['avg_score'], reverse=True) result_refs = [ref[0] for ref in sorted_docs] return result_refs async def olymp_think(self, query, sources, llm_params: LlmParams = None): sources_text = '' res = '' for i, source in enumerate(sources): sources_text += f'Источник [{i + 1}]: {sources[source]}\n' # Если llm_params не переданы, значит используем микстраль по старому алгоритму # TODO: Сделать api для микстрали (надо ли?) if llm_params is None: step = LLM_PROMPT_OLYMPIC.format(query=query, sources=sources_text) res = self.get_main_info_with_llm(step) else: llm_api = LlmApi(llm_params) query_for_trim = LLM_PROMPT_OLYMPIC.format(query=query, sources='') trimmed_sources_result = await llm_api.trim_sources(sources_text, query_for_trim) prompt = LLM_PROMPT_OLYMPIC.format(query=query, sources=trimmed_sources_result["result"]) res = await llm_api.predict(prompt) return res @staticmethod def parse_step(text): step4_start = text.find('(4)') if step4_start != -1: step4_start = 0 step5_start = text.find('(5)') if step5_start == -1: step5_start = 0 if step4_start + 3 < step5_start: extracted_comment = text[step4_start + 3:step5_start] else: extracted_comment = '' if '$$' in text: extracted_comment = '' extracted_content = re.findall(r'\[(.*?)\]', text[step5_start:]) extracted_numbers = [] for item in extracted_content: if item.isdigit(): extracted_numbers.append(int(item)) return extracted_comment, extracted_numbers @staticmethod def lemmatize_query(text): morph = pymorphy3.MorphAnalyzer() signs = ',.<>?;\'\":}{!)(][-' words = text.split() lemmas = [] for word in words: if not word.isupper(): word = morph.parse(word)[0].normal_form lemmas.append(word) for i, lemma in enumerate(lemmas): while lemma[0] in signs and len(lemma) > 1: lemma = lemma[1:] lemmas[i] = lemma while lemma[-1] in signs and len(lemma) > 1: lemma = lemma[:-1] lemmas[i] = lemma return " ".join(lemmas) @staticmethod def mark_for_one_word_dict(lem_dict): terms_first_word = set() first_word_matching_names = {} first_word_names_to_remove = {} for name in lem_dict: first_word = name.split()[0] if first_word in terms_first_word: lem_dict[name]['one_word_searchable'] = False first_word_names_to_remove[first_word] = first_word_matching_names[first_word] else: terms_first_word.add(first_word) first_word_matching_names[first_word] = name for first_word in first_word_names_to_remove: name = first_word_names_to_remove[first_word] lem_dict[name]['one_word_searchable'] = False return lem_dict def lemmatize_dict(self, terms_dict): lem_dict = {} morph = pymorphy3.MorphAnalyzer() for name in terms_dict: if not name.isupper(): lem_name = morph.parse(name)[0].normal_form else: lem_name = name lem_dict[lem_name] = {} lem_dict[lem_name]['name'] = name lem_dict[lem_name]['definitions'] = terms_dict[name]['definitions'] lem_dict[lem_name]['titles'] = terms_dict[name]['titles'] lem_dict[lem_name]['sources'] = terms_dict[name]['sources'] lem_dict[lem_name]['is_multi_def'] = terms_dict[name]['is_multi_def'] lem_dict[lem_name]['one_word_searchable'] = True lem_dict = self.mark_for_one_word_dict(lem_dict) return lem_dict @staticmethod def separate_one_word_searchable_dict(lem_dict): lem_dict_fast = {} lem_dict_slow = {} for name in lem_dict: if lem_dict[name]['one_word_searchable']: lem_dict_fast[name] = {} lem_dict_fast[name]['name'] = lem_dict[name]['name'] lem_dict_fast[name]['definitions'] = lem_dict[name]['definitions'] lem_dict_fast[name]['titles'] = lem_dict[name]['titles'] lem_dict_fast[name]['sources'] = lem_dict[name]['sources'] lem_dict_fast[name]['is_multi_def'] = lem_dict[name]['is_multi_def'] else: lem_dict_slow[name] = {} lem_dict_slow[name]['name'] = lem_dict[name]['name'] lem_dict_slow[name]['definitions'] = lem_dict[name]['definitions'] lem_dict_slow[name]['titles'] = lem_dict[name]['titles'] lem_dict_slow[name]['sources'] = lem_dict[name]['sources'] lem_dict_slow[name]['is_multi_def'] = lem_dict[name]['is_multi_def'] return lem_dict_fast, lem_dict_slow @staticmethod def extract_original_phrase(original_text, lemmatized_text, lemmatized_phrase): words = original_text.split() words_lem = lemmatized_text.split() words_lem_phrase = lemmatized_phrase.split() for i, word in enumerate(words_lem): if word == words_lem_phrase[0]: words_full = ' '.join(words_lem[i:i + len(words_lem_phrase)]) if words_full == lemmatized_phrase: original_phrase = ' '.join(words[i:i + len(words_lem_phrase)]) return original_phrase return False def substitute_definitions(self, original_text, lem_dict, lem_dict_fast, lem_dict_slow, for_llm=False): lemmatized_text = self.lemmatize_query(original_text) found_phrases = set() phrases_to_add1 = [] phrases_to_add2 = [] words = lemmatized_text.split() sorted_lem_dict = sorted(lem_dict_slow.items(), key=lambda x: len(x[0]), reverse=True) # можно сэкономить милисекунды и вынести сортировку по длине куда-то наружу for lemmatized_phrase_tuple in sorted_lem_dict: lemmatized_phrase = lemmatized_phrase_tuple[0] is_new_phrase = True is_one_word = True lem_phrase_words = lemmatized_phrase.split() if len(lem_phrase_words) > 1: is_one_word = False if lemmatized_phrase in lemmatized_text and not is_one_word: if lemmatized_phrase in found_phrases: is_new_phrase = False else: found_phrases.add(lemmatized_phrase) original_phrase = self.extract_original_phrase(original_text, lemmatized_text, lemmatized_phrase) phrases_to_add2.append((lemmatized_phrase, original_phrase)) if is_one_word and lemmatized_phrase in words: for phrase in found_phrases: if lemmatized_phrase in phrase: is_new_phrase = False if is_new_phrase: found_phrases.add(lemmatized_phrase) original_phrase = self.extract_original_phrase(original_text, lemmatized_text, lemmatized_phrase) phrases_to_add2.append((lemmatized_phrase, original_phrase)) for word in words: is_new_phrase = True if word in lem_dict_fast: for phrase in found_phrases: if word in phrase: is_new_phrase = False break if is_new_phrase: found_phrases.add(word) original_phrase = self.extract_original_phrase(original_text, lemmatized_text, word) phrases_to_add1.append((word, original_phrase)) phrases_to_add = phrases_to_add1 + phrases_to_add2 definition_num = 0 definitions_info = [] substituted_text = original_text try: if for_llm: for term, original_phrase in phrases_to_add: if lem_dict[term]['is_multi_def']: definition_num = 0 # Здесь может быть логика контекстно-зависимого выбора нужного определения term_start = original_text.find(original_phrase) if type(lem_dict[term]['definitions']) is list: definitions_info.append(f"{term}-{lem_dict[term]['definitions'][definition_num]}") else: definitions_info.append(f"{term}-{lem_dict[term]['definitions']}") if definitions_info: definitions_str = ", ".join(definitions_info) substituted_text = f"{original_text}. Дополнительная информация: {definitions_str}" else: substituted_text = original_text else: for term, original_phrase in phrases_to_add: if lem_dict[term]['is_multi_def']: # Здесь может быть логика контекстно-зависимого выбора нужного определения definition_num = 0 term_start = substituted_text.find(original_phrase) if type(lem_dict[term]['definitions']) is list: substituted_text = substituted_text[:term_start + len( original_phrase)] + f" ({lem_dict[term]['definitions'][definition_num]})" + substituted_text[ term_start + len( original_phrase):] else: substituted_text = substituted_text[:term_start + len( original_phrase)] + f" ({lem_dict[term]['definitions']})" + substituted_text[ term_start + len( original_phrase):] except Exception as e: print(f'error processing\n {original_text}\n {term}: {e}') return substituted_text, phrases_to_add def filter_by_types(self, pred: list[str] = None, scores: list[float] = None, indexes: list[int] = None, docs_embeddings: list = None, ctgs: dict = None): ctgs = [ctg for ctg in ctgs.keys() if ctgs[ctg]] filtred_pred, filtred_scores, filtred_indexes, filtred_docs_embeddings = [], [], [], [] for doc_name, score, index, doc_embedding in zip(pred, scores, indexes, docs_embeddings): if ('ВНД' in doc_name and 'ВНД' in ctgs) or self.all_docs_with_accounting[index]['doc_type'] in ctgs: filtred_pred.append(doc_name) filtred_scores.append(score) filtred_indexes.append(index) filtred_docs_embeddings.append(doc_embedding) return filtred_pred, filtred_scores, filtred_indexes, filtred_docs_embeddings def get_types_of_docs(self, all_docs): def type_determiner(doc_name): names = ['НКРФ', 'ГКРФ', 'ТКРФ', 'Федеральный закон', 'Письмо Минфина', 'Письмо ФНС', 'Приказ ФНС', 'Постановление Правительства', 'Судебный документ', 'ВНД', 'Бухгалтерский документ'] for ctg in list(names): if ctg in doc_name: return ctg for doc in all_docs: doc_type = type_determiner(doc['doc_name']) doc['doc_type'] = doc_type return all_docs def load_model(self): if hf_token and hf_model_name: self.tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_auth_token=hf_token) self.model = AutoModel.from_pretrained(hf_model_name, use_auth_token=hf_token).to(self.device) else: self.tokenizer = AutoTokenizer.from_pretrained(global_model_path) self.model = AutoModel.from_pretrained(global_model_path).to(self.device) self.max_len = self.tokenizer.max_len_single_sentence self.embedding_dim = self.model.config.hidden_size def load_data(self): with open(dict_path, "rb") as f: self.terms_dict = pickle.load(f) with open(companies_map_path, "r", encoding='utf-8') as f: self.companies_map = json.load(f) with open(general_nmd_path, "r", encoding='utf-8') as f: self.general_nmd = json.load(f) with open(explanations_for_llm_path, "r", encoding='utf-8') as f: self.explanations_for_llm = json.load(f) with open(rules_list_path, 'r', encoding='utf-8') as f: self.rules_list = f.read().splitlines() self.all_docs_info = dataset.load_from_disk(data_path_consult).to_list() # ONLY EXTERNAL DOCS self.internal_docs = dataset.load_from_disk(internal_docs_data_path).to_list() self.accounting_docs = dataset.load_from_disk(accounting_data_path).to_list() self.spec_internal_docs = dataset.load_from_disk(spec_internal_docs_data_path).to_list() self.all_docs_with_accounting = self.all_docs_info + self.accounting_docs self.all_docs_with_accounting = self.get_types_of_docs(self.all_docs_with_accounting) self.type_weights_nu = {'НКРФ': 1, 'ТКРФ': 1, 'ГКРФ': 1, 'Письмо Минфина': 0.9, 'Письмо ФНС': 0.6, 'Приказ ФНС': 1, 'Постановление Правительства': 1, 'Федеральный закон': 0.9, 'Судебный документ': 0.2, 'ВНД': 0.2, 'Бухгалтерский документ': 0.7, 'Закон Красноярского края': 1.2, 'Правила заполнения': 1.2, 'Правила ведения': 1.2} self.all_consultations = dataset.load_from_disk(consultations_dataset_path).to_list() self.all_explanations = dataset.load_from_disk(explanations_dataset_path).to_list() @staticmethod def remove_duplicate_paragraphs(paragraphs): unique_paragraphs = [] seen = set() for paragraph in paragraphs: stripped_paragraph = paragraph.strip() if stripped_paragraph and stripped_paragraph not in seen: unique_paragraphs.append(paragraph) seen.add(stripped_paragraph) return '\n'.join(unique_paragraphs) @staticmethod def construct_base(idx_list, base): concatenated_text = "" seen_ids = set() pattern = re.compile(r'_(\d{1,3})') def find_overlap(a: str, b: str) -> int: max_overlap = min(len(a), len(b)) for i in range(max_overlap, 0, -1): if a[-i:] == b[:i]: return i return 0 def add_ellipsis(text: str) -> str: if not text: return text segments = text.split('\n\n') processed_segments = [] for segment in segments: if segment and not ( segment[0].isupper() or segment[0].isdigit() or segment[0] in ['•', '-', '—', '.']): segment = '...' + segment if segment and not (segment.endswith('.') or segment.endswith(';')): segment += '...' processed_segments.append(segment) return '\n\n'.join(processed_segments) for current_index in idx_list: if current_index in seen_ids: continue start_index = max(0, current_index - 2) end_index = min(len(base), current_index + 3) current_name_base = pattern.sub('', base[current_index]['doc_name']) current_doc_text = base[current_index]['doc_text'] texts_to_concatenate = [current_doc_text] for i in range(current_index - 1, start_index - 1, -1): if i in seen_ids: continue surrounding_name_base = pattern.sub('', base[i]['doc_name']) if current_name_base != surrounding_name_base: break surrounding_text = base[i]['doc_text'] overlap_length = find_overlap(surrounding_text, texts_to_concatenate[0]) if overlap_length == 0: break new_text = surrounding_text + texts_to_concatenate[0][overlap_length:] texts_to_concatenate[0] = new_text seen_ids.add(i) for i in range(current_index + 1, end_index): if i in seen_ids: continue surrounding_name_base = pattern.sub('', base[i]['doc_name']) if current_name_base != surrounding_name_base: break surrounding_text = base[i]['doc_text'] overlap_length = find_overlap(texts_to_concatenate[-1], surrounding_text) if overlap_length == 0: break new_text = texts_to_concatenate[-1] + surrounding_text[overlap_length:] texts_to_concatenate[-1] = new_text seen_ids.add(i) combined_text = ' '.join(texts_to_concatenate) concatenated_text += combined_text + '\n\n' seen_ids.add(current_index) concatenated_text = add_ellipsis(concatenated_text) return concatenated_text.rstrip('\n') def search_results_multiply_weights(self, pred: list[str] = None, scores: list[float] = None, indexes: list[int] = None, docs_embeddings: list = None) -> tuple[list[str], list[float], list[int], list]: if pred is None or scores is None or indexes is None or docs_embeddings is None: return [], [], [], [] weights = self.type_weights_nu weighted_scores = [(weights.get(ctg, 0) * score, prediction, idx, emb) for prediction, score, idx, emb in zip(pred, scores, indexes, docs_embeddings) for ctg in weights if ctg in prediction] weighted_scores.sort(reverse=True, key=lambda x: x[0]) if weighted_scores: sorted_scores, sorted_preds, sorted_indexes, sorted_docs_embeddings = zip(*weighted_scores) else: sorted_scores, sorted_preds, sorted_indexes, sorted_docs_embeddings = [], [], [], [] return list(sorted_preds), list(sorted_scores), list(sorted_indexes), list(sorted_docs_embeddings) def get_uniq_relevant_docs(self, top_k: int, query_refs_all: list[str], scores: list[float], indexes: list[int], docs_embeddings: list[list[float]] ) -> tuple[dict[str, list[str]], dict[str, list[float]], dict[str, list[int]], dict[str, list[list[float]]]]: regex = r'_\d{1,3}' base_ref_dict = {} for i, ref in enumerate(query_refs_all): base_ref = re.sub(regex, '', ref) base_ref = base_ref.strip() if base_ref not in base_ref_dict: if len(base_ref_dict) >= top_k: continue base_ref_dict[base_ref] = { 'refs': [], 'scores': [], 'indexes': [], 'embeddings': [] } base_ref_dict[base_ref]['refs'].append(ref) base_ref_dict[base_ref]['scores'].append(scores[i]) base_ref_dict[base_ref]['indexes'].append(indexes[i]) base_ref_dict[base_ref]['embeddings'].append(docs_embeddings[i]) def get_suffix_number(ref: str): match = re.findall(regex, ref) if match: match = re.findall(regex, ref)[0].replace('_', '') return int(match) return None for base_ref, data in base_ref_dict.items(): refs = data['refs'] scores_list = data['scores'] indexes_list = data['indexes'] embeddings_list = data['embeddings'] combined = list(zip(refs, scores_list, indexes_list, embeddings_list)) def sort_key(item): ref = item[0] suffix = get_suffix_number(ref) return (0 if suffix is None else 1, suffix if suffix is not None else -1) combined_sorted = sorted(combined, key=sort_key) sorted_refs, sorted_scores, sorted_indexes, sorted_embeddings = zip(*combined_sorted) base_ref_dict[base_ref]['refs'] = list(sorted_refs)[:20] base_ref_dict[base_ref]['scores'] = list(sorted_scores)[:20] base_ref_dict[base_ref]['indexes'] = list(sorted_indexes)[:20] base_ref_dict[base_ref]['embeddings'] = list(sorted_embeddings)[:20] unique_refs = {k: v['refs'] for k, v in base_ref_dict.items()} filtered_scores = {k: v['scores'] for k, v in base_ref_dict.items()} filtered_indexes = {k: v['indexes'] for k, v in base_ref_dict.items()} filtered_docs_embeddings = {k: v['embeddings'] for k, v in base_ref_dict.items()} return unique_refs, filtered_scores, filtered_indexes, filtered_docs_embeddings def filter_results(self, pred_internal, scores_internal, indices_internal, docs_embeddings_internal, companies_files): filt_pred_internal, filt_scores_internal, \ filt_indices_internal, filt_docs_embeddings_internal = list(), list(), list(), list() def add_data(pred, ind, score, emb): filt_pred_internal.append(pred) filt_indices_internal.append(ind) filt_scores_internal.append(score) filt_docs_embeddings_internal.append(emb) for pred, score, ind, emb in zip(pred_internal, scores_internal, indices_internal, docs_embeddings_internal): if [doc for doc in self.general_nmd if doc in pred]: add_data(pred, ind, score, emb) continue for company in companies_files: if company in pred: add_data(pred, ind, score, emb) return filt_pred_internal, filt_scores_internal, filt_indices_internal, filt_docs_embeddings_internal @staticmethod def merge_dictionaries(dicts: list = None): merged_dict = {} max_length = max(len(d) for d in dicts) for i in range(max_length): for d in dicts: keys = list(d.keys()) values = list(d.values()) if i < len(keys): merged_dict[keys[i]] = values[i] return merged_dict @staticmethod def check_specific_key(dictionary, key): if key in dictionary and dictionary[key] is True: for k, v in dictionary.items(): if k != key and v is True: return False return True return False @staticmethod def remove_duplicates(input_list): unique_dict = {} for item in input_list: unique_dict[item] = None return list(unique_dict.keys()) async def search_engine(self, query: str = None, use_qe: bool = False, categories: dict = None, llm_params: LlmParams = None): if True in list(categories.values()) and not all(categories.values()): self.full_base_search = False if self.check_specific_key(categories, 'ВНД'): nmd_chunks = 120 nmd_refs = 45 extra_chunks = 1 extra_refs = 1 elif not categories['ВНД']: extra_chunks = 120 extra_refs = 45 nmd_chunks = 1 nmd_refs = 1 else: nmd_chunks = 60 nmd_refs = 23 extra_chunks = 60 extra_refs = 23 else: self.full_base_search = True nmd_chunks = 50 nmd_refs = 15 extra_chunks = 75 extra_refs = 30 # Ответы от ллм для отправки на фронт llm_responses = [] # Токенизация и векторизация запроса query_tokens = query_tokenization(query, self.tokenizer) query_embeds = query_embed_extraction(query_tokens, self.model, self.do_normalization) # Поиск по базе документов внешней distances, indices = self.index_all_docs_with_accounting.search(query_embeds, len(self.all_docs_with_accounting)) pred = [self.all_docs_with_accounting[x]['doc_name'] for x in indices[0]] docs_embeddings = [self.all_docs_with_accounting[x]['doc_embedding'] for x in indices[0]] preds, scores, indexes, docs_embeddings = pred[:5000], list(distances[0])[:5000], \ list(indices[0])[:5000], docs_embeddings[:5000] if not re.search('[Кк]расноярск', query): self.type_weights_nu['Закон Красноярского края'] = 0 else: self.type_weights_nu['Закон Красноярского края'] = 1.2 if not use_rules(query, self.rules_list): self.type_weights_nu['Правила ведения'] = 0 self.type_weights_nu['Правила заполнения'] = 0 else: self.type_weights_nu['Правила ведения'] = 1.2 self.type_weights_nu['Правила заполнения'] = 1.2 preds, scores, indexes, docs_embeddings = pred[:5000], list(distances[0])[:5000], \ list(indices[0])[:5000], docs_embeddings[:5000] # Поиск по базе документов внутренних if self.full_base_search or categories['ВНД']: distances_internal, indices_internal = self.index_internal_docs.search(query_embeds, len(self.spec_internal_docs)) pred_internal = [self.spec_internal_docs[x]['doc_name'] for x in indices_internal[0]] docs_embeddings_internal = [self.spec_internal_docs[x]['doc_embedding'] for x in indices_internal[0]] indices_internal = indices_internal[0] scores_internal = [] for title, score in zip(pred_internal, distances_internal[0]): if 'КУП' in title: scores_internal.append(score*1.2) else: scores_internal.append(score) companies_files = search_company.find_nmd_docs(query, self.companies_map) pred_internal, scores_internal, indices_internal, docs_embeddings_internal = self.filter_results(pred_internal, scores_internal, indices_internal, docs_embeddings_internal, companies_files) combined = list(zip(pred_internal, scores_internal, indices_internal, docs_embeddings_internal)) sorted_combined = sorted(combined, key=lambda x: x[1], reverse=True) top_nmd = sorted_combined[:nmd_chunks] if 'ЕГДС' in query: if not [x for x in top_nmd if 'п.5. Положение о КУП_262 (ВНД)' in x]: ch262 = self.internal_docs[22976] ch262 = (ch262['doc_name'], 1.0, 22976, ch262['chunks_embeddings'][0]) top_nmd.insert(0, ch262) if not [x for x in top_nmd if 'п.5. Положение о КУП_130 (ВНД)' in x]: ch130 = self.internal_docs[22844] ch130 = (ch130['doc_name'], 1.0, 22844, ch130['chunks_embeddings'][0]) top_nmd.insert(1, ch130) top_nmd = top_nmd[:nmd_chunks] preds_internal, scores_internal, indexes_internal, internal_docs_embeddings = zip(*top_nmd) preds_internal, scores_internal, indexes_internal, internal_docs_embeddings = list(preds_internal), \ list(scores_internal), \ list(indexes_internal), \ list(internal_docs_embeddings) # Сбор уникальных внутренних документов unique_preds_internal, unique_scores_internal, unique_indexes_internal, \ unique_docs_embeddings_internal = self.get_uniq_relevant_docs( top_k=nmd_refs, query_refs_all=preds_internal, scores=scores_internal, indexes=indexes_internal, docs_embeddings=internal_docs_embeddings) preds_internal, scores_internal, \ indexes_internal, internal_docs_embeddings = unique_preds_internal, unique_scores_internal,\ unique_indexes_internal, unique_docs_embeddings_internal # Фильтрация или не фильтрация по категориям по наличию отметок в чек-боксах if not self.full_base_search: preds, scores, indexes, docs_embeddings = self.filter_by_types(preds, scores, indexes, docs_embeddings, categories) # Использование весов поверх скоров sorted_preds, sorted_scores, sorted_indexes, sorted_docs_embeddings = self.search_results_multiply_weights( pred=preds, scores=scores, indexes=indexes, docs_embeddings=docs_embeddings) sorted_preds, sorted_scores, sorted_indexes, sorted_docs_embeddings = sorted_preds[:extra_chunks], \ sorted_scores[:extra_chunks], \ sorted_indexes[:extra_chunks], \ sorted_docs_embeddings[:extra_chunks] # Сбор уникальных документов внешних unique_preds, unique_scores, unique_indexes, unique_docs_embeddings = self.get_uniq_relevant_docs( top_k=extra_refs, query_refs_all=sorted_preds, scores=sorted_scores, indexes=sorted_indexes, docs_embeddings=sorted_docs_embeddings ) preds, scores, indexes, docs_embeddings = unique_preds, unique_scores, unique_indexes, unique_docs_embeddings if use_qe: try: prompt = LLM_PROMPT_KEYS.format(query=query) if llm_params is None: keyword_query = self.get_main_info_with_llm(prompt) else: llm_api = LlmApi(llm_params) keyword_query = await llm_api.predict(prompt) llm_responses.append(keyword_query) keyword_query = re.sub(r'\[1\].*?(?=\[\d+\]|$)', '', keyword_query, flags=re.DOTALL).replace(' [2]', '').replace('[3]', '').strip() keyword_query_tokens = query_tokenization(keyword_query, self.tokenizer) keyword_query_embeds = query_embed_extraction(keyword_query_tokens, self.model, self.do_normalization) keyword_distances, keyword_indices = self.index_all_docs_with_accounting.search( keyword_query_embeds, len(self.all_docs_with_accounting)) keyword_pred = [self.all_docs_with_accounting[x]['doc_name'] for x in keyword_indices[0]] keyword_docs_embeddings = [self.all_docs_with_accounting[x]['doc_embedding'] for x in keyword_indices[0]] if not self.full_base_search: keyword_preds, keyword_scores, \ keyword_indexes, keyword_docs_embeddings = self.filter_by_types(keyword_pred, keyword_distances[0], keyword_indices[0], keyword_docs_embeddings, categories) else: keyword_preds, keyword_scores, \ keyword_indexes, keyword_docs_embeddings = keyword_pred, keyword_distances[0], \ keyword_indices[0],keyword_docs_embeddings keyword_preds, keyword_scores, \ keyword_indexes, keyword_docs_embeddings = self.search_results_multiply_weights( pred=keyword_preds, scores=keyword_scores, indexes=keyword_indexes, docs_embeddings=keyword_docs_embeddings) keyword_unique_preds, keyword_unique_scores, \ keyword_unique_indexes, keyword_unique_docs_embeddings = self.get_uniq_relevant_docs( top_k=45, query_refs_all=keyword_preds, scores=keyword_scores, indexes=keyword_indexes, docs_embeddings=keyword_docs_embeddings) preds = dict(list(self.merge_dictionaries([preds, keyword_unique_preds]).items())[:30]) scores = dict(list(self.merge_dictionaries([scores, keyword_unique_scores]).items())[:30]) indexes = dict(list(self.merge_dictionaries([indexes, keyword_unique_indexes]).items())[:30]) except: traceback.print_exc() print(f"Error applying keys (possibly the LLM is not available)") if self.full_base_search or categories['ВНД']: # Внесение внутренних топ-10 документов в выдачу if self.full_base_search or categories['ВНД']: preds = self.merge_dictionaries([preds, preds_internal]) scores = self.merge_dictionaries([scores, scores_internal]) indexes = self.merge_dictionaries([indexes, indexes_internal]) # Красивая сборка чанков для LLM texts_for_llm, docs, teasers = [], [], [] for key, idx_list in indexes.items(): collected_text = [] if 'ВНД' in key: base = self.internal_docs else: base = self.all_docs_with_accounting if re.search('Минфин|Бухгалтерский документ|ФНС|Судебный документ|Постановление Правительства|Федеральный закон', key): text = self.construct_base(idx_list, base) collected_text.append(text) else: for idx in idx_list: if idx < len(base): for text in base[idx]['doc_text'].split('\n'): collected_text.append(text) collected_text = self.remove_duplicate_paragraphs(collected_text) texts_for_llm.append(collected_text) # Поиск релевантных консультаций distances_consult, indices_consult = self.index_consult.search(query_embeds, len(self.all_consultations)) predicted_consultations = {self.all_consultations[x]['doc_name']: self.all_consultations[x]['doc_text'] for x in indices_consult[0]} # Поиск релевантных разъяснений distances_explanations, indices_explanations = self.index_explanations.search(query_embeds, len(self.all_explanations)) predicted_explanations = {self.all_explanations[x]['doc_name']: self.all_explanations[x]['doc_text'] for x in indices_explanations[0]} results = list(zip(list(predicted_explanations.keys()), list(predicted_explanations.values()), distances_explanations[0])) explanation_titles = self.rerank_by_avg_score(results)[:3] try: predicted_explanation = {explanation_title: self.explanations_for_llm[explanation_title] for explanation_title in explanation_titles} except: predicted_explanation = {} print('The relevant document was not found in the system.') return query, [x.replace('ФЕДЕРАЛЬНЫЙ СТАНДАРТ БУХГАЛТЕРСКОГО УЧЕТА', 'Федеральный стандарт бухгалтерского учета ФСБУ') for x in list(preds.keys())], texts_for_llm, dict(list(predicted_consultations.items())[:def_k]), \ predicted_explanation, llm_responses async def olympic_branch(self, query: str = None, sources: dict = None, categories: dict = None, llm_params: LlmParams = None): # Собираем все ответы ллм для отправки на фронт llm_responses = [] text = await self.olymp_think(query, sources, llm_params) llm_responses.append(text) saved_sources = {} saved_step_by_step = [] comment1, sources_choice = self.parse_step(text) sources_choice = [source - 1 for source in sources_choice] for idx, ref in enumerate(sources): if idx in sources_choice and ref not in saved_sources.keys(): saved_sources.update({ref: sources[ref]}) should_continue = True if comment1 == '': count = 4 count = 0 while count < 4: query, preds, \ texts_for_llm, predicted_consultations, \ predicted_explanation, skip_llm_responses = await self.search_engine(query, use_qe=False, categories=categories) sources = dict(map(lambda i,j: (i,j), preds, texts_for_llm)) sources = dict(islice(sources.items(), 20)) text = await self.olymp_think(query, sources, llm_params) llm_responses.append(text) comment2, sources_choice = self.parse_step(text) sources_choice = [source - 1 for source in sources_choice] saved_step_by_step.append(sources_choice) for idx, ref in enumerate(sources): if idx in sources_choice and ref not in saved_sources.keys(): saved_sources.update({ref: sources[ref]}) if comment2 == '': break comment1 = comment2 count += 1 return saved_sources, saved_step_by_step, llm_responses async def search(self, query: str = None, use_qe: bool = False, use_olympic: bool = False, categories: dict = None, llm_params: LlmParams = None): # Преобразование запроса lem_dict = self.lemmatize_dict(self.terms_dict) lem_dict_fast, lem_dict_slow = self.separate_one_word_searchable_dict(lem_dict) query_for_llm, _ = self.substitute_definitions(query, lem_dict, lem_dict_fast, lem_dict_slow, for_llm=True) query, _ = self.substitute_definitions(query, lem_dict, lem_dict_fast, lem_dict_slow, for_llm=False) # Базовый поиск query, base_preds, base_texts_for_llm, \ predicted_consultations, predicted_explanation, llm_responses = await self.search_engine(query, use_qe, categories, llm_params) if use_olympic: sources = dict(map(lambda i,j: (i,j), base_preds, base_texts_for_llm)) sources = dict(islice(sources.items(), 20)) olymp_results, olymp_step_by_step, llm_responses = await self.olympic_branch(query, sources, categories, llm_params) olymp_preds, olymp_texts_for_llm = list(olymp_results.keys()), list(olymp_results.values()) if len(olymp_preds) <= 45: preds = olymp_preds + base_preds preds = self.remove_duplicates(preds)[:45] texts_for_llm = olymp_texts_for_llm + base_texts_for_llm texts_for_llm = self.remove_duplicates(texts_for_llm)[:45] return query_for_llm, preds, texts_for_llm, predicted_consultations, predicted_explanation, llm_responses else: olymp_results = self.merge_dictionaries(olymp_step_by_step)[:45] preds, texts_for_llm = list(olymp_results.keys()), list(olymp_results.values()) return query_for_llm, preds, texts_for_llm, predicted_consultations, predicted_explanation, llm_responses else: return query_for_llm, base_preds, base_texts_for_llm, predicted_consultations, predicted_explanation, llm_responses