Spaces:
Sleeping
Sleeping
import gc | |
import traceback | |
from legal_info_search_utils.rules_utils import use_rules | |
from itertools import islice | |
import os | |
import torch | |
import numpy as np | |
from faiss import IndexFlatIP | |
from datasets import Dataset as dataset | |
from transformers import AutoTokenizer, AutoModel | |
from legal_info_search_utils.utils import query_tokenization, query_embed_extraction | |
import requests | |
import re | |
import json | |
import pymorphy3 | |
from torch.cuda.amp import autocast | |
from elasticsearch_module import search_company | |
import torch.nn.functional as F | |
import pickle | |
from llm.prompts import LLM_PROMPT_QE, LLM_PROMPT_OLYMPIC, LLM_PROMPT_KEYS | |
from llm.vllm_api import LlmApi, LlmParams | |
global_data_path = os.environ.get("GLOBAL_DATA_PATH", "./legal_info_search_data/") | |
global_model_path = os.environ.get("GLOBAL_MODEL_PATH", "./models/20240202_204910_ep8") | |
data_path_consult = global_data_path + "external_data" | |
internal_docs_data_path = global_data_path + "nmd_full" | |
spec_internal_docs_data_path = global_data_path + "nmd_short" | |
accounting_data_path = global_data_path + "bu" | |
companies_map_path = global_data_path + "companies_map/companies_map.json" | |
dict_path = global_data_path + "dict/dict_20241030.pkl" | |
general_nmd_path = global_data_path + "companies_map/general_nmd.json" | |
consultations_dataset_path = global_data_path + "consult_data" | |
explanations_dataset_path = global_data_path + "explanations" | |
explanations_for_llm_path = global_data_path + "explanations_for_llm/explanations_for_llm.json" | |
rules_list_path = global_data_path + "rules_list/terms.txt" | |
db_data_types = ['НКРФ', 'ГКРФ', 'ТКРФ', 'Федеральный закон', 'Письмо Минфина', 'Письмо ФНС', | |
'Приказ ФНС', 'Постановление Правительства', 'Судебный документ', 'ВНД', 'Бухгалтерский документ'] | |
device = os.environ.get("MODEL_DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu') | |
# access token huggingface. Если задан, то используется модель с HF | |
hf_token = os.environ.get("HF_TOKEN", "") | |
hf_model_name = os.environ.get("HF_MODEL_NAME", "") | |
llm_api_endpoint = os.environ.get("LLM_API_ENDPOINT", "") | |
headers = {'Content-Type': 'application/json'} | |
def_k = 15 | |
class SemanticSearch: | |
def __init__(self, do_normalization: bool = True): | |
self.device = device | |
self.do_normalization = do_normalization | |
self.load_model() | |
# Основная база | |
self.full_base_search = True | |
self.index_consult = IndexFlatIP(self.embedding_dim) | |
self.index_explanations = IndexFlatIP(self.embedding_dim) | |
self.index_all_docs_with_accounting = IndexFlatIP(self.embedding_dim) | |
self.index_internal_docs = IndexFlatIP(self.embedding_dim) | |
self.spec_index_internal_docs = IndexFlatIP(self.embedding_dim) | |
self.index_teaser = IndexFlatIP(self.embedding_dim) | |
self.load_data() | |
# Обработка встраиваний | |
def process_embeddings(docs): | |
embeddings = torch.cat([torch.unsqueeze(torch.Tensor(x['doc_embedding']), 0) for x in docs], dim=0) | |
if self.do_normalization: | |
embeddings = F.normalize(embeddings, dim=-1).numpy() | |
return embeddings | |
# База ВНД | |
self.internal_docs_embeddings = process_embeddings(self.internal_docs) | |
self.index_internal_docs.add(self.internal_docs_embeddings) | |
self.spec_internal_docs_embeddings = process_embeddings(self.spec_internal_docs) | |
self.spec_index_internal_docs.add(self.spec_internal_docs_embeddings) | |
self.all_docs_with_accounting_embeddings = process_embeddings(self.all_docs_with_accounting) | |
self.index_all_docs_with_accounting.add(self.all_docs_with_accounting_embeddings) | |
# База консультаций | |
self.consult_embeddings = process_embeddings(self.all_consultations) | |
self.index_consult.add(self.consult_embeddings) | |
# База разъяснений | |
self.explanations_embeddings = process_embeddings(self.all_explanations) | |
self.index_explanations.add(self.explanations_embeddings) | |
def get_main_info_with_llm(prompt: str): | |
response = requests.post( | |
url=llm_api_endpoint, | |
json={'prompt': ' [INST] ' + prompt + ' [/INST]', | |
'temperature': 0.0, | |
'n_predict': 2500.0, | |
'top_p': 0.95, | |
'min_p': 0.05, | |
'repeat_penalty': 1.2, | |
'stop': []}) | |
answer = response.json()['content'] | |
return answer | |
def rerank_by_avg_score(refs, scores_to_take=3): | |
docs = {} | |
regex = r'_(\d{1,3})$' | |
refs = [(re.sub(regex, '', ref[0]), ref[1], float(ref[2])) for ref in refs] | |
for ref in refs: | |
if ref[0] not in docs.keys(): | |
docs[ref[0]] = {'contents': [ref[1]], 'scores': [ref[2]]} | |
elif len(docs[ref[0]]['scores']) < scores_to_take: | |
docs[ref[0]]['contents'].append(ref[1]) | |
docs[ref[0]]['scores'].append(ref[2]) | |
for ref in docs: | |
docs[ref]['avg_score'] = np.mean(docs[ref]['scores']) | |
sorted_docs = sorted(docs.items(), key=lambda x: x[1]['avg_score'], reverse=True) | |
result_refs = [ref[0] for ref in sorted_docs] | |
return result_refs | |
async def olymp_think(self, query, sources, llm_params: LlmParams = None): | |
sources_text = '' | |
res = '' | |
for i, source in enumerate(sources): | |
sources_text += f'Источник [{i + 1}]: {sources[source]}\n' | |
# Если llm_params не переданы, значит используем микстраль по старому алгоритму | |
# TODO: Сделать api для микстрали (надо ли?) | |
if llm_params is None: | |
step = LLM_PROMPT_OLYMPIC.format(query=query, sources=sources_text) | |
res = self.get_main_info_with_llm(step) | |
else: | |
llm_api = LlmApi(llm_params) | |
query_for_trim = LLM_PROMPT_OLYMPIC.format(query=query, sources='') | |
trimmed_sources_result = await llm_api.trim_sources(sources_text, query_for_trim) | |
prompt = LLM_PROMPT_OLYMPIC.format(query=query, sources=trimmed_sources_result["result"]) | |
res = await llm_api.predict(prompt) | |
return res | |
def parse_step(text): | |
step4_start = text.find('(4)') | |
if step4_start != -1: | |
step4_start = 0 | |
step5_start = text.find('(5)') | |
if step5_start == -1: | |
step5_start = 0 | |
if step4_start + 3 < step5_start: | |
extracted_comment = text[step4_start + 3:step5_start] | |
else: | |
extracted_comment = '' | |
if '$$' in text: | |
extracted_comment = '' | |
extracted_content = re.findall(r'\[(.*?)\]', text[step5_start:]) | |
extracted_numbers = [] | |
for item in extracted_content: | |
if item.isdigit(): | |
extracted_numbers.append(int(item)) | |
return extracted_comment, extracted_numbers | |
def lemmatize_query(text): | |
morph = pymorphy3.MorphAnalyzer() | |
signs = ',.<>?;\'\":}{!)(][-' | |
words = text.split() | |
lemmas = [] | |
for word in words: | |
if not word.isupper(): | |
word = morph.parse(word)[0].normal_form | |
lemmas.append(word) | |
for i, lemma in enumerate(lemmas): | |
while lemma[0] in signs and len(lemma) > 1: | |
lemma = lemma[1:] | |
lemmas[i] = lemma | |
while lemma[-1] in signs and len(lemma) > 1: | |
lemma = lemma[:-1] | |
lemmas[i] = lemma | |
return " ".join(lemmas) | |
def mark_for_one_word_dict(lem_dict): | |
terms_first_word = set() | |
first_word_matching_names = {} | |
first_word_names_to_remove = {} | |
for name in lem_dict: | |
first_word = name.split()[0] | |
if first_word in terms_first_word: | |
lem_dict[name]['one_word_searchable'] = False | |
first_word_names_to_remove[first_word] = first_word_matching_names[first_word] | |
else: | |
terms_first_word.add(first_word) | |
first_word_matching_names[first_word] = name | |
for first_word in first_word_names_to_remove: | |
name = first_word_names_to_remove[first_word] | |
lem_dict[name]['one_word_searchable'] = False | |
return lem_dict | |
def lemmatize_dict(self, terms_dict): | |
lem_dict = {} | |
morph = pymorphy3.MorphAnalyzer() | |
for name in terms_dict: | |
if not name.isupper(): | |
lem_name = morph.parse(name)[0].normal_form | |
else: | |
lem_name = name | |
lem_dict[lem_name] = {} | |
lem_dict[lem_name]['name'] = name | |
lem_dict[lem_name]['definitions'] = terms_dict[name]['definitions'] | |
lem_dict[lem_name]['titles'] = terms_dict[name]['titles'] | |
lem_dict[lem_name]['sources'] = terms_dict[name]['sources'] | |
lem_dict[lem_name]['is_multi_def'] = terms_dict[name]['is_multi_def'] | |
lem_dict[lem_name]['one_word_searchable'] = True | |
lem_dict = self.mark_for_one_word_dict(lem_dict) | |
return lem_dict | |
def separate_one_word_searchable_dict(lem_dict): | |
lem_dict_fast = {} | |
lem_dict_slow = {} | |
for name in lem_dict: | |
if lem_dict[name]['one_word_searchable']: | |
lem_dict_fast[name] = {} | |
lem_dict_fast[name]['name'] = lem_dict[name]['name'] | |
lem_dict_fast[name]['definitions'] = lem_dict[name]['definitions'] | |
lem_dict_fast[name]['titles'] = lem_dict[name]['titles'] | |
lem_dict_fast[name]['sources'] = lem_dict[name]['sources'] | |
lem_dict_fast[name]['is_multi_def'] = lem_dict[name]['is_multi_def'] | |
else: | |
lem_dict_slow[name] = {} | |
lem_dict_slow[name]['name'] = lem_dict[name]['name'] | |
lem_dict_slow[name]['definitions'] = lem_dict[name]['definitions'] | |
lem_dict_slow[name]['titles'] = lem_dict[name]['titles'] | |
lem_dict_slow[name]['sources'] = lem_dict[name]['sources'] | |
lem_dict_slow[name]['is_multi_def'] = lem_dict[name]['is_multi_def'] | |
return lem_dict_fast, lem_dict_slow | |
def extract_original_phrase(original_text, lemmatized_text, lemmatized_phrase): | |
words = original_text.split() | |
words_lem = lemmatized_text.split() | |
words_lem_phrase = lemmatized_phrase.split() | |
for i, word in enumerate(words_lem): | |
if word == words_lem_phrase[0]: | |
words_full = ' '.join(words_lem[i:i + len(words_lem_phrase)]) | |
if words_full == lemmatized_phrase: | |
original_phrase = ' '.join(words[i:i + len(words_lem_phrase)]) | |
return original_phrase | |
return False | |
def substitute_definitions(self, original_text, lem_dict, lem_dict_fast, lem_dict_slow, for_llm=False): | |
lemmatized_text = self.lemmatize_query(original_text) | |
found_phrases = set() | |
phrases_to_add1 = [] | |
phrases_to_add2 = [] | |
words = lemmatized_text.split() | |
sorted_lem_dict = sorted(lem_dict_slow.items(), key=lambda x: len(x[0]), | |
reverse=True) # можно сэкономить милисекунды и вынести сортировку по длине куда-то наружу | |
for lemmatized_phrase_tuple in sorted_lem_dict: | |
lemmatized_phrase = lemmatized_phrase_tuple[0] | |
is_new_phrase = True | |
is_one_word = True | |
lem_phrase_words = lemmatized_phrase.split() | |
if len(lem_phrase_words) > 1: | |
is_one_word = False | |
if lemmatized_phrase in lemmatized_text and not is_one_word: | |
if lemmatized_phrase in found_phrases: | |
is_new_phrase = False | |
else: | |
found_phrases.add(lemmatized_phrase) | |
original_phrase = self.extract_original_phrase(original_text, lemmatized_text, lemmatized_phrase) | |
phrases_to_add2.append((lemmatized_phrase, original_phrase)) | |
if is_one_word and lemmatized_phrase in words: | |
for phrase in found_phrases: | |
if lemmatized_phrase in phrase: | |
is_new_phrase = False | |
if is_new_phrase: | |
found_phrases.add(lemmatized_phrase) | |
original_phrase = self.extract_original_phrase(original_text, lemmatized_text, lemmatized_phrase) | |
phrases_to_add2.append((lemmatized_phrase, original_phrase)) | |
for word in words: | |
is_new_phrase = True | |
if word in lem_dict_fast: | |
for phrase in found_phrases: | |
if word in phrase: | |
is_new_phrase = False | |
break | |
if is_new_phrase: | |
found_phrases.add(word) | |
original_phrase = self.extract_original_phrase(original_text, lemmatized_text, word) | |
phrases_to_add1.append((word, original_phrase)) | |
phrases_to_add = phrases_to_add1 + phrases_to_add2 | |
definition_num = 0 | |
definitions_info = [] | |
substituted_text = original_text | |
try: | |
if for_llm: | |
for term, original_phrase in phrases_to_add: | |
if lem_dict[term]['is_multi_def']: | |
definition_num = 0 # Здесь может быть логика контекстно-зависимого выбора нужного определения | |
term_start = original_text.find(original_phrase) | |
if type(lem_dict[term]['definitions']) is list: | |
definitions_info.append(f"{term}-{lem_dict[term]['definitions'][definition_num]}") | |
else: | |
definitions_info.append(f"{term}-{lem_dict[term]['definitions']}") | |
if definitions_info: | |
definitions_str = ", ".join(definitions_info) | |
substituted_text = f"{original_text}. Дополнительная информация: {definitions_str}" | |
else: | |
substituted_text = original_text | |
else: | |
for term, original_phrase in phrases_to_add: | |
if lem_dict[term]['is_multi_def']: | |
# Здесь может быть логика контекстно-зависимого выбора нужного определения | |
definition_num = 0 | |
term_start = substituted_text.find(original_phrase) | |
if type(lem_dict[term]['definitions']) is list: | |
substituted_text = substituted_text[:term_start + len( | |
original_phrase)] + f" ({lem_dict[term]['definitions'][definition_num]})" + substituted_text[ | |
term_start + len( | |
original_phrase):] | |
else: | |
substituted_text = substituted_text[:term_start + len( | |
original_phrase)] + f" ({lem_dict[term]['definitions']})" + substituted_text[ | |
term_start + len( | |
original_phrase):] | |
except Exception as e: | |
print(f'error processing\n {original_text}\n {term}: {e}') | |
return substituted_text, phrases_to_add | |
def filter_by_types(self, | |
pred: list[str] = None, | |
scores: list[float] = None, | |
indexes: list[int] = None, | |
docs_embeddings: list = None, | |
ctgs: dict = None): | |
ctgs = [ctg for ctg in ctgs.keys() if ctgs[ctg]] | |
filtred_pred, filtred_scores, filtred_indexes, filtred_docs_embeddings = [], [], [], [] | |
for doc_name, score, index, doc_embedding in zip(pred, scores, indexes, docs_embeddings): | |
if ('ВНД' in doc_name and 'ВНД' in ctgs) or self.all_docs_with_accounting[index]['doc_type'] in ctgs: | |
filtred_pred.append(doc_name) | |
filtred_scores.append(score) | |
filtred_indexes.append(index) | |
filtred_docs_embeddings.append(doc_embedding) | |
return filtred_pred, filtred_scores, filtred_indexes, filtred_docs_embeddings | |
def get_types_of_docs(self, all_docs): | |
def type_determiner(doc_name): | |
names = ['НКРФ', 'ГКРФ', 'ТКРФ', 'Федеральный закон', 'Письмо Минфина', 'Письмо ФНС', 'Приказ ФНС', | |
'Постановление Правительства', 'Судебный документ', 'ВНД', 'Бухгалтерский документ'] | |
for ctg in list(names): | |
if ctg in doc_name: | |
return ctg | |
for doc in all_docs: | |
doc_type = type_determiner(doc['doc_name']) | |
doc['doc_type'] = doc_type | |
return all_docs | |
def load_model(self): | |
if hf_token and hf_model_name: | |
self.tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_auth_token=True) | |
self.model = AutoModel.from_pretrained(hf_model_name, use_auth_token=True).to(self.device) | |
else: | |
self.tokenizer = AutoTokenizer.from_pretrained(global_model_path) | |
self.model = AutoModel.from_pretrained(global_model_path).to(self.device) | |
self.max_len = self.tokenizer.max_len_single_sentence | |
self.embedding_dim = self.model.config.hidden_size | |
def load_data(self): | |
with open(dict_path, "rb") as f: | |
self.terms_dict = pickle.load(f) | |
with open(companies_map_path, "r", encoding='utf-8') as f: | |
self.companies_map = json.load(f) | |
with open(general_nmd_path, "r", encoding='utf-8') as f: | |
self.general_nmd = json.load(f) | |
with open(explanations_for_llm_path, "r", encoding='utf-8') as f: | |
self.explanations_for_llm = json.load(f) | |
with open(rules_list_path, 'r', encoding='utf-8') as f: | |
self.rules_list = f.read().splitlines() | |
self.all_docs_info = dataset.load_from_disk(data_path_consult).to_list() # ONLY EXTERNAL DOCS | |
self.internal_docs = dataset.load_from_disk(internal_docs_data_path).to_list() | |
self.accounting_docs = dataset.load_from_disk(accounting_data_path).to_list() | |
self.spec_internal_docs = dataset.load_from_disk(spec_internal_docs_data_path).to_list() | |
self.all_docs_with_accounting = self.all_docs_info + self.accounting_docs | |
self.all_docs_with_accounting = self.get_types_of_docs(self.all_docs_with_accounting) | |
self.type_weights_nu = {'НКРФ': 1, | |
'ТКРФ': 1, | |
'ГКРФ': 1, | |
'Письмо Минфина': 0.9, | |
'Письмо ФНС': 0.6, | |
'Приказ ФНС': 1, | |
'Постановление Правительства': 1, | |
'Федеральный закон': 0.9, | |
'Судебный документ': 0.2, | |
'ВНД': 0.2, | |
'Бухгалтерский документ': 0.7, | |
'Закон Красноярского края': 1.2, | |
'Правила заполнения': 1.2, | |
'Правила ведения': 1.2} | |
self.all_consultations = dataset.load_from_disk(consultations_dataset_path).to_list() | |
self.all_explanations = dataset.load_from_disk(explanations_dataset_path).to_list() | |
def remove_duplicate_paragraphs(paragraphs): | |
unique_paragraphs = [] | |
seen = set() | |
for paragraph in paragraphs: | |
stripped_paragraph = paragraph.strip() | |
if stripped_paragraph and stripped_paragraph not in seen: | |
unique_paragraphs.append(paragraph) | |
seen.add(stripped_paragraph) | |
return '\n'.join(unique_paragraphs) | |
def construct_base(idx_list, base): | |
concatenated_text = "" | |
seen_ids = set() | |
pattern = re.compile(r'_(\d{1,3})') | |
def find_overlap(a: str, b: str) -> int: | |
max_overlap = min(len(a), len(b)) | |
for i in range(max_overlap, 0, -1): | |
if a[-i:] == b[:i]: | |
return i | |
return 0 | |
def add_ellipsis(text: str) -> str: | |
if not text: | |
return text | |
segments = text.split('\n\n') | |
processed_segments = [] | |
for segment in segments: | |
if segment and not ( | |
segment[0].isupper() or segment[0].isdigit() or segment[0] in ['•', '-', '—', '.']): | |
segment = '...' + segment | |
if segment and not (segment.endswith('.') or segment.endswith(';')): | |
segment += '...' | |
processed_segments.append(segment) | |
return '\n\n'.join(processed_segments) | |
for current_index in idx_list: | |
if current_index in seen_ids: | |
continue | |
start_index = max(0, current_index - 2) | |
end_index = min(len(base), current_index + 3) | |
current_name_base = pattern.sub('', base[current_index]['doc_name']) | |
current_doc_text = base[current_index]['doc_text'] | |
texts_to_concatenate = [current_doc_text] | |
for i in range(current_index - 1, start_index - 1, -1): | |
if i in seen_ids: | |
continue | |
surrounding_name_base = pattern.sub('', base[i]['doc_name']) | |
if current_name_base != surrounding_name_base: | |
break | |
surrounding_text = base[i]['doc_text'] | |
overlap_length = find_overlap(surrounding_text, texts_to_concatenate[0]) | |
if overlap_length == 0: | |
break | |
new_text = surrounding_text + texts_to_concatenate[0][overlap_length:] | |
texts_to_concatenate[0] = new_text | |
seen_ids.add(i) | |
for i in range(current_index + 1, end_index): | |
if i in seen_ids: | |
continue | |
surrounding_name_base = pattern.sub('', base[i]['doc_name']) | |
if current_name_base != surrounding_name_base: | |
break | |
surrounding_text = base[i]['doc_text'] | |
overlap_length = find_overlap(texts_to_concatenate[-1], surrounding_text) | |
if overlap_length == 0: | |
break | |
new_text = texts_to_concatenate[-1] + surrounding_text[overlap_length:] | |
texts_to_concatenate[-1] = new_text | |
seen_ids.add(i) | |
combined_text = ' '.join(texts_to_concatenate) | |
concatenated_text += combined_text + '\n\n' | |
seen_ids.add(current_index) | |
concatenated_text = add_ellipsis(concatenated_text) | |
return concatenated_text.rstrip('\n') | |
def search_results_multiply_weights(self, | |
pred: list[str] = None, | |
scores: list[float] = None, | |
indexes: list[int] = None, | |
docs_embeddings: list = None) -> tuple[list[str], list[float], list[int], list]: | |
if pred is None or scores is None or indexes is None or docs_embeddings is None: | |
return [], [], [], [] | |
weights = self.type_weights_nu | |
weighted_scores = [(weights.get(ctg, 0) * score, prediction, idx, emb) | |
for prediction, score, idx, emb in zip(pred, scores, indexes, docs_embeddings) | |
for ctg in weights if ctg in prediction] | |
weighted_scores.sort(reverse=True, key=lambda x: x[0]) | |
if weighted_scores: | |
sorted_scores, sorted_preds, sorted_indexes, sorted_docs_embeddings = zip(*weighted_scores) | |
else: | |
sorted_scores, sorted_preds, sorted_indexes, sorted_docs_embeddings = [], [], [], [] | |
return list(sorted_preds), list(sorted_scores), list(sorted_indexes), list(sorted_docs_embeddings) | |
def get_uniq_relevant_docs(self, | |
top_k: int, | |
query_refs_all: list[str], | |
scores: list[float], | |
indexes: list[int], | |
docs_embeddings: list[list[float]] | |
) -> tuple[dict[str, list[str]], dict[str, list[float]], dict[str, list[int]], dict[str, list[list[float]]]]: | |
regex = r'_\d{1,3}' | |
base_ref_dict = {} | |
for i, ref in enumerate(query_refs_all): | |
base_ref = re.sub(regex, '', ref) | |
base_ref = base_ref.strip() | |
if base_ref not in base_ref_dict: | |
if len(base_ref_dict) >= top_k: | |
continue | |
base_ref_dict[base_ref] = { | |
'refs': [], | |
'scores': [], | |
'indexes': [], | |
'embeddings': [] | |
} | |
base_ref_dict[base_ref]['refs'].append(ref) | |
base_ref_dict[base_ref]['scores'].append(scores[i]) | |
base_ref_dict[base_ref]['indexes'].append(indexes[i]) | |
base_ref_dict[base_ref]['embeddings'].append(docs_embeddings[i]) | |
def get_suffix_number(ref: str): | |
match = re.findall(regex, ref) | |
if match: | |
match = re.findall(regex, ref)[0].replace('_', '') | |
return int(match) | |
return None | |
for base_ref, data in base_ref_dict.items(): | |
refs = data['refs'] | |
scores_list = data['scores'] | |
indexes_list = data['indexes'] | |
embeddings_list = data['embeddings'] | |
combined = list(zip(refs, scores_list, indexes_list, embeddings_list)) | |
def sort_key(item): | |
ref = item[0] | |
suffix = get_suffix_number(ref) | |
return (0 if suffix is None else 1, suffix if suffix is not None else -1) | |
combined_sorted = sorted(combined, key=sort_key) | |
sorted_refs, sorted_scores, sorted_indexes, sorted_embeddings = zip(*combined_sorted) | |
base_ref_dict[base_ref]['refs'] = list(sorted_refs)[:20] | |
base_ref_dict[base_ref]['scores'] = list(sorted_scores)[:20] | |
base_ref_dict[base_ref]['indexes'] = list(sorted_indexes)[:20] | |
base_ref_dict[base_ref]['embeddings'] = list(sorted_embeddings)[:20] | |
unique_refs = {k: v['refs'] for k, v in base_ref_dict.items()} | |
filtered_scores = {k: v['scores'] for k, v in base_ref_dict.items()} | |
filtered_indexes = {k: v['indexes'] for k, v in base_ref_dict.items()} | |
filtered_docs_embeddings = {k: v['embeddings'] for k, v in base_ref_dict.items()} | |
return unique_refs, filtered_scores, filtered_indexes, filtered_docs_embeddings | |
def filter_results(self, pred_internal, scores_internal, indices_internal, docs_embeddings_internal, companies_files): | |
filt_pred_internal, filt_scores_internal, \ | |
filt_indices_internal, filt_docs_embeddings_internal = list(), list(), list(), list() | |
def add_data(pred, ind, score, emb): | |
filt_pred_internal.append(pred) | |
filt_indices_internal.append(ind) | |
filt_scores_internal.append(score) | |
filt_docs_embeddings_internal.append(emb) | |
for pred, score, ind, emb in zip(pred_internal, scores_internal, indices_internal, docs_embeddings_internal): | |
if [doc for doc in self.general_nmd if doc in pred]: | |
add_data(pred, ind, score, emb) | |
continue | |
for company in companies_files: | |
if company in pred: | |
add_data(pred, ind, score, emb) | |
return filt_pred_internal, filt_scores_internal, filt_indices_internal, filt_docs_embeddings_internal | |
def merge_dictionaries(dicts: list = None): | |
merged_dict = {} | |
max_length = max(len(d) for d in dicts) | |
for i in range(max_length): | |
for d in dicts: | |
keys = list(d.keys()) | |
values = list(d.values()) | |
if i < len(keys): | |
merged_dict[keys[i]] = values[i] | |
return merged_dict | |
def check_specific_key(dictionary, key): | |
if key in dictionary and dictionary[key] is True: | |
for k, v in dictionary.items(): | |
if k != key and v is True: | |
return False | |
return True | |
return False | |
def remove_duplicates(input_list): | |
unique_dict = {} | |
for item in input_list: | |
unique_dict[item] = None | |
return list(unique_dict.keys()) | |
async def search_engine(self, | |
query: str = None, | |
use_qe: bool = False, | |
categories: dict = None, | |
llm_params: LlmParams = None): | |
if True in list(categories.values()) and not all(categories.values()): | |
self.full_base_search = False | |
if self.check_specific_key(categories, 'ВНД'): | |
nmd_chunks = 120 | |
nmd_refs = 45 | |
extra_chunks = 1 | |
extra_refs = 1 | |
elif not categories['ВНД']: | |
extra_chunks = 120 | |
extra_refs = 45 | |
nmd_chunks = 1 | |
nmd_refs = 1 | |
else: | |
nmd_chunks = 60 | |
nmd_refs = 23 | |
extra_chunks = 60 | |
extra_refs = 23 | |
else: | |
self.full_base_search = True | |
nmd_chunks = 50 | |
nmd_refs = 15 | |
extra_chunks = 75 | |
extra_refs = 30 | |
# Ответы от ллм для отправки на фронт | |
llm_responses = [] | |
# Токенизация и векторизация запроса | |
query_tokens = query_tokenization(query, self.tokenizer) | |
query_embeds = query_embed_extraction(query_tokens, self.model, self.do_normalization) | |
# Поиск по базе документов внешней | |
distances, indices = self.index_all_docs_with_accounting.search(query_embeds, len(self.all_docs_with_accounting)) | |
pred = [self.all_docs_with_accounting[x]['doc_name'] for x in indices[0]] | |
docs_embeddings = [self.all_docs_with_accounting[x]['doc_embedding'] for x in indices[0]] | |
preds, scores, indexes, docs_embeddings = pred[:5000], list(distances[0])[:5000], \ | |
list(indices[0])[:5000], docs_embeddings[:5000] | |
if not re.search('[Кк]расноярск', query): | |
self.type_weights_nu['Закон Красноярского края'] = 0 | |
else: | |
self.type_weights_nu['Закон Красноярского края'] = 1.2 | |
if not use_rules(query, self.rules_list): | |
self.type_weights_nu['Правила ведения'] = 0 | |
self.type_weights_nu['Правила заполнения'] = 0 | |
else: | |
self.type_weights_nu['Правила ведения'] = 1.2 | |
self.type_weights_nu['Правила заполнения'] = 1.2 | |
preds, scores, indexes, docs_embeddings = pred[:5000], list(distances[0])[:5000], \ | |
list(indices[0])[:5000], docs_embeddings[:5000] | |
# Поиск по базе документов внутренних | |
if self.full_base_search or categories['ВНД']: | |
distances_internal, indices_internal = self.index_internal_docs.search(query_embeds, len(self.spec_internal_docs)) | |
pred_internal = [self.spec_internal_docs[x]['doc_name'] for x in indices_internal[0]] | |
docs_embeddings_internal = [self.spec_internal_docs[x]['doc_embedding'] for x in indices_internal[0]] | |
indices_internal = indices_internal[0] | |
scores_internal = [] | |
for title, score in zip(pred_internal, distances_internal[0]): | |
if 'КУП' in title: | |
scores_internal.append(score*1.2) | |
else: | |
scores_internal.append(score) | |
companies_files = search_company.find_nmd_docs(query, self.companies_map) | |
pred_internal, scores_internal, indices_internal, docs_embeddings_internal = self.filter_results(pred_internal, | |
scores_internal, | |
indices_internal, | |
docs_embeddings_internal, | |
companies_files) | |
combined = list(zip(pred_internal, scores_internal, indices_internal, docs_embeddings_internal)) | |
sorted_combined = sorted(combined, key=lambda x: x[1], reverse=True) | |
top_nmd = sorted_combined[:nmd_chunks] | |
if 'ЕГДС' in query: | |
if not [x for x in top_nmd if 'п.5. Положение о КУП_262 (ВНД)' in x]: | |
ch262 = self.internal_docs[22976] | |
ch262 = (ch262['doc_name'], 1.0, 22976, ch262['chunks_embeddings'][0]) | |
top_nmd.insert(0, ch262) | |
if not [x for x in top_nmd if 'п.5. Положение о КУП_130 (ВНД)' in x]: | |
ch130 = self.internal_docs[22844] | |
ch130 = (ch130['doc_name'], 1.0, 22844, ch130['chunks_embeddings'][0]) | |
top_nmd.insert(1, ch130) | |
top_nmd = top_nmd[:nmd_chunks] | |
preds_internal, scores_internal, indexes_internal, internal_docs_embeddings = zip(*top_nmd) | |
preds_internal, scores_internal, indexes_internal, internal_docs_embeddings = list(preds_internal), \ | |
list(scores_internal), \ | |
list(indexes_internal), \ | |
list(internal_docs_embeddings) | |
# Сбор уникальных внутренних документов | |
unique_preds_internal, unique_scores_internal, unique_indexes_internal, \ | |
unique_docs_embeddings_internal = self.get_uniq_relevant_docs( | |
top_k=nmd_refs, | |
query_refs_all=preds_internal, | |
scores=scores_internal, | |
indexes=indexes_internal, | |
docs_embeddings=internal_docs_embeddings) | |
preds_internal, scores_internal, \ | |
indexes_internal, internal_docs_embeddings = unique_preds_internal, unique_scores_internal,\ | |
unique_indexes_internal, unique_docs_embeddings_internal | |
# Фильтрация или не фильтрация по категориям по наличию отметок в чек-боксах | |
if not self.full_base_search: | |
preds, scores, indexes, docs_embeddings = self.filter_by_types(preds, scores, indexes, | |
docs_embeddings, categories) | |
# Использование весов поверх скоров | |
sorted_preds, sorted_scores, sorted_indexes, sorted_docs_embeddings = self.search_results_multiply_weights( | |
pred=preds, | |
scores=scores, | |
indexes=indexes, | |
docs_embeddings=docs_embeddings) | |
sorted_preds, sorted_scores, sorted_indexes, sorted_docs_embeddings = sorted_preds[:extra_chunks], \ | |
sorted_scores[:extra_chunks], \ | |
sorted_indexes[:extra_chunks], \ | |
sorted_docs_embeddings[:extra_chunks] | |
# Сбор уникальных документов внешних | |
unique_preds, unique_scores, unique_indexes, unique_docs_embeddings = self.get_uniq_relevant_docs( | |
top_k=extra_refs, | |
query_refs_all=sorted_preds, | |
scores=sorted_scores, | |
indexes=sorted_indexes, | |
docs_embeddings=sorted_docs_embeddings | |
) | |
preds, scores, indexes, docs_embeddings = unique_preds, unique_scores, unique_indexes, unique_docs_embeddings | |
if use_qe: | |
try: | |
prompt = LLM_PROMPT_KEYS.format(query=query) | |
if llm_params is None: | |
keyword_query = self.get_main_info_with_llm(prompt) | |
else: | |
llm_api = LlmApi(llm_params) | |
keyword_query = await llm_api.predict(prompt) | |
llm_responses.append(keyword_query) | |
keyword_query = re.sub(r'\[1\].*?(?=\[\d+\]|$)', '', keyword_query, flags=re.DOTALL).replace(' [2]', '').replace('[3]', '').strip() | |
keyword_query_tokens = query_tokenization(keyword_query, self.tokenizer) | |
keyword_query_embeds = query_embed_extraction(keyword_query_tokens, | |
self.model, | |
self.do_normalization) | |
keyword_distances, keyword_indices = self.index_all_docs_with_accounting.search( | |
keyword_query_embeds, len(self.all_docs_with_accounting)) | |
keyword_pred = [self.all_docs_with_accounting[x]['doc_name'] for x in keyword_indices[0]] | |
keyword_docs_embeddings = [self.all_docs_with_accounting[x]['doc_embedding'] for x in | |
keyword_indices[0]] | |
if not self.full_base_search: | |
keyword_preds, keyword_scores, \ | |
keyword_indexes, keyword_docs_embeddings = self.filter_by_types(keyword_pred, | |
keyword_distances[0], | |
keyword_indices[0], | |
keyword_docs_embeddings, | |
categories) | |
else: | |
keyword_preds, keyword_scores, \ | |
keyword_indexes, keyword_docs_embeddings = keyword_pred, keyword_distances[0], \ | |
keyword_indices[0],keyword_docs_embeddings | |
keyword_preds, keyword_scores, \ | |
keyword_indexes, keyword_docs_embeddings = self.search_results_multiply_weights( | |
pred=keyword_preds, scores=keyword_scores, | |
indexes=keyword_indexes, docs_embeddings=keyword_docs_embeddings) | |
keyword_unique_preds, keyword_unique_scores, \ | |
keyword_unique_indexes, keyword_unique_docs_embeddings = self.get_uniq_relevant_docs( | |
top_k=45, | |
query_refs_all=keyword_preds, | |
scores=keyword_scores, | |
indexes=keyword_indexes, | |
docs_embeddings=keyword_docs_embeddings) | |
preds = dict(list(self.merge_dictionaries([preds, keyword_unique_preds]).items())[:30]) | |
scores = dict(list(self.merge_dictionaries([scores, keyword_unique_scores]).items())[:30]) | |
indexes = dict(list(self.merge_dictionaries([indexes, keyword_unique_indexes]).items())[:30]) | |
except: | |
traceback.print_exc() | |
print(f"Error applying keys (possibly the LLM is not available)") | |
if self.full_base_search or categories['ВНД']: | |
# Внесение внутренних топ-10 документов в выдачу | |
if self.full_base_search or categories['ВНД']: | |
preds = self.merge_dictionaries([preds, preds_internal]) | |
scores = self.merge_dictionaries([scores, scores_internal]) | |
indexes = self.merge_dictionaries([indexes, indexes_internal]) | |
# Красивая сборка чанков для LLM | |
texts_for_llm, docs, teasers = [], [], [] | |
for key, idx_list in indexes.items(): | |
collected_text = [] | |
if 'ВНД' in key: | |
base = self.internal_docs | |
else: | |
base = self.all_docs_with_accounting | |
if re.search('Минфин|Бухгалтерский документ|ФНС|Судебный документ|Постановление Правительства|Федеральный закон', key): | |
text = self.construct_base(idx_list, base) | |
collected_text.append(text) | |
else: | |
for idx in idx_list: | |
if idx < len(base): | |
for text in base[idx]['doc_text'].split('\n'): | |
collected_text.append(text) | |
collected_text = self.remove_duplicate_paragraphs(collected_text) | |
texts_for_llm.append(collected_text) | |
# Поиск релевантных консультаций | |
distances_consult, indices_consult = self.index_consult.search(query_embeds, len(self.all_consultations)) | |
predicted_consultations = {self.all_consultations[x]['doc_name']: self.all_consultations[x]['doc_text'] | |
for x in indices_consult[0]} | |
# Поиск релевантных разъяснений | |
distances_explanations, indices_explanations = self.index_explanations.search(query_embeds, len(self.all_explanations)) | |
predicted_explanations = {self.all_explanations[x]['doc_name']: self.all_explanations[x]['doc_text'] | |
for x in indices_explanations[0]} | |
results = list(zip(list(predicted_explanations.keys()), | |
list(predicted_explanations.values()), | |
distances_explanations[0])) | |
explanation_titles = self.rerank_by_avg_score(results)[:3] | |
try: | |
predicted_explanation = {explanation_title: self.explanations_for_llm[explanation_title] for explanation_title in explanation_titles} | |
except: | |
predicted_explanation = {} | |
print('The relevant document was not found in the system.') | |
return query, [x.replace('ФЕДЕРАЛЬНЫЙ СТАНДАРТ БУХГАЛТЕРСКОГО УЧЕТА', 'Федеральный стандарт бухгалтерского учета ФСБУ') for x in list(preds.keys())], texts_for_llm, dict(list(predicted_consultations.items())[:def_k]), \ | |
predicted_explanation, llm_responses | |
async def olympic_branch(self, | |
query: str = None, | |
sources: dict = None, | |
categories: dict = None, | |
llm_params: LlmParams = None): | |
# Собираем все ответы ллм для отправки на фронт | |
llm_responses = [] | |
text = await self.olymp_think(query, sources, llm_params) | |
llm_responses.append(text) | |
saved_sources = {} | |
saved_step_by_step = [] | |
comment1, sources_choice = self.parse_step(text) | |
sources_choice = [source - 1 for source in sources_choice] | |
for idx, ref in enumerate(sources): | |
if idx in sources_choice and ref not in saved_sources.keys(): | |
saved_sources.update({ref: sources[ref]}) | |
should_continue = True | |
if comment1 == '': | |
count = 4 | |
count = 0 | |
while count < 4: | |
query, preds, \ | |
texts_for_llm, predicted_consultations, \ | |
predicted_explanation, skip_llm_responses = await self.search_engine(query, use_qe=False, categories=categories) | |
sources = dict(map(lambda i,j: (i,j), preds, texts_for_llm)) | |
sources = dict(islice(sources.items(), 20)) | |
text = await self.olymp_think(query, sources, llm_params) | |
llm_responses.append(text) | |
comment2, sources_choice = self.parse_step(text) | |
sources_choice = [source - 1 for source in sources_choice] | |
saved_step_by_step.append(sources_choice) | |
for idx, ref in enumerate(sources): | |
if idx in sources_choice and ref not in saved_sources.keys(): | |
saved_sources.update({ref: sources[ref]}) | |
if comment2 == '': | |
break | |
comment1 = comment2 | |
count += 1 | |
return saved_sources, saved_step_by_step, llm_responses | |
async def search(self, | |
query: str = None, | |
use_qe: bool = False, | |
use_olympic: bool = False, | |
categories: dict = None, | |
llm_params: LlmParams = None): | |
# Преобразование запроса | |
lem_dict = self.lemmatize_dict(self.terms_dict) | |
lem_dict_fast, lem_dict_slow = self.separate_one_word_searchable_dict(lem_dict) | |
query_for_llm, _ = self.substitute_definitions(query, lem_dict, lem_dict_fast, lem_dict_slow, for_llm=True) | |
query, _ = self.substitute_definitions(query, lem_dict, lem_dict_fast, lem_dict_slow, for_llm=False) | |
# Базовый поиск | |
query, base_preds, base_texts_for_llm, \ | |
predicted_consultations, predicted_explanation, llm_responses = await self.search_engine(query, use_qe, categories, llm_params) | |
if use_olympic: | |
sources = dict(map(lambda i,j: (i,j), base_preds, base_texts_for_llm)) | |
sources = dict(islice(sources.items(), 20)) | |
olymp_results, olymp_step_by_step, llm_responses = await self.olympic_branch(query, sources, categories, llm_params) | |
olymp_preds, olymp_texts_for_llm = list(olymp_results.keys()), list(olymp_results.values()) | |
if len(olymp_preds) <= 45: | |
preds = olymp_preds + base_preds | |
preds = self.remove_duplicates(preds)[:45] | |
texts_for_llm = olymp_texts_for_llm + base_texts_for_llm | |
texts_for_llm = self.remove_duplicates(texts_for_llm)[:45] | |
return query_for_llm, preds, texts_for_llm, predicted_consultations, predicted_explanation, llm_responses | |
else: | |
olymp_results = self.merge_dictionaries(olymp_step_by_step)[:45] | |
preds, texts_for_llm = list(olymp_results.keys()), list(olymp_results.values()) | |
return query_for_llm, preds, texts_for_llm, predicted_consultations, predicted_explanation, llm_responses | |
else: | |
return query_for_llm, base_preds, base_texts_for_llm, predicted_consultations, predicted_explanation, llm_responses | |