nn-search-full / semantic_search.py
muryshev's picture
update
53ef13d
import gc
import traceback
from legal_info_search_utils.rules_utils import use_rules
from itertools import islice
import os
import torch
import numpy as np
from faiss import IndexFlatIP
from datasets import Dataset as dataset
from transformers import AutoTokenizer, AutoModel
from legal_info_search_utils.utils import query_tokenization, query_embed_extraction
import requests
import re
import json
import pymorphy3
from torch.cuda.amp import autocast
from elasticsearch_module import search_company
import torch.nn.functional as F
import pickle
from llm.prompts import LLM_PROMPT_QE, LLM_PROMPT_OLYMPIC, LLM_PROMPT_KEYS
from llm.vllm_api import LlmApi, LlmParams
from huggingface import dataset_utils
global_data_path = os.environ.get("GLOBAL_DATA_PATH", "./legal_info_search_data/")
# access token huggingface. Если задан, то используется модель с HF
hf_token = os.environ.get("HF_TOKEN", None)
hf_dataset = os.environ.get("HF_DATASET", None)
hf_model_name = os.environ.get("HF_MODEL_NAME", "")
if hf_token is not None and hf_dataset is not None:
global_data_path = dataset_utils.get_global_data_path()+global_data_path
print(f"Global data path: {global_data_path}")
global_model_path = os.environ.get("GLOBAL_MODEL_PATH", "./models/20240202_204910_ep8")
data_path_consult = global_data_path + "external_data"
internal_docs_data_path = global_data_path + "nmd_full"
spec_internal_docs_data_path = global_data_path + "nmd_short"
accounting_data_path = global_data_path + "bu"
companies_map_path = global_data_path + "companies_map/companies_map.json"
dict_path = global_data_path + "dict/dict_20241030.pkl"
general_nmd_path = global_data_path + "companies_map/general_nmd.json"
consultations_dataset_path = global_data_path + "consult_data"
explanations_dataset_path = global_data_path + "explanations"
explanations_for_llm_path = global_data_path + "explanations_for_llm/explanations_for_llm.json"
rules_list_path = global_data_path + "rules_list/terms.txt"
db_data_types = ['НКРФ', 'ГКРФ', 'ТКРФ', 'Федеральный закон', 'Письмо Минфина', 'Письмо ФНС',
'Приказ ФНС', 'Постановление Правительства', 'Судебный документ', 'ВНД', 'Бухгалтерский документ']
device = os.environ.get("MODEL_DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu')
llm_api_endpoint = os.environ.get("LLM_API_ENDPOINT", "")
headers = {'Content-Type': 'application/json'}
def_k = 15
class SemanticSearch:
def __init__(self, do_normalization: bool = True):
self.device = device
self.do_normalization = do_normalization
self.load_model()
# Основная база
self.full_base_search = True
self.index_consult = IndexFlatIP(self.embedding_dim)
self.index_explanations = IndexFlatIP(self.embedding_dim)
self.index_all_docs_with_accounting = IndexFlatIP(self.embedding_dim)
self.index_internal_docs = IndexFlatIP(self.embedding_dim)
self.spec_index_internal_docs = IndexFlatIP(self.embedding_dim)
self.index_teaser = IndexFlatIP(self.embedding_dim)
self.load_data()
# Обработка встраиваний
def process_embeddings(docs):
embeddings = torch.cat([torch.unsqueeze(torch.Tensor(x['doc_embedding']), 0) for x in docs], dim=0)
if self.do_normalization:
embeddings = F.normalize(embeddings, dim=-1).numpy()
return embeddings
# База ВНД
self.internal_docs_embeddings = process_embeddings(self.internal_docs)
self.index_internal_docs.add(self.internal_docs_embeddings)
self.spec_internal_docs_embeddings = process_embeddings(self.spec_internal_docs)
self.spec_index_internal_docs.add(self.spec_internal_docs_embeddings)
self.all_docs_with_accounting_embeddings = process_embeddings(self.all_docs_with_accounting)
self.index_all_docs_with_accounting.add(self.all_docs_with_accounting_embeddings)
# База консультаций
self.consult_embeddings = process_embeddings(self.all_consultations)
self.index_consult.add(self.consult_embeddings)
# База разъяснений
self.explanations_embeddings = process_embeddings(self.all_explanations)
self.index_explanations.add(self.explanations_embeddings)
@staticmethod
def get_main_info_with_llm(prompt: str):
response = requests.post(
url=llm_api_endpoint,
json={'prompt': ' [INST] ' + prompt + ' [/INST]',
'temperature': 0.0,
'n_predict': 2500.0,
'top_p': 0.95,
'min_p': 0.05,
'repeat_penalty': 1.2,
'stop': []})
answer = response.json()['content']
return answer
@staticmethod
def rerank_by_avg_score(refs, scores_to_take=3):
docs = {}
regex = r'_(\d{1,3})$'
refs = [(re.sub(regex, '', ref[0]), ref[1], float(ref[2])) for ref in refs]
for ref in refs:
if ref[0] not in docs.keys():
docs[ref[0]] = {'contents': [ref[1]], 'scores': [ref[2]]}
elif len(docs[ref[0]]['scores']) < scores_to_take:
docs[ref[0]]['contents'].append(ref[1])
docs[ref[0]]['scores'].append(ref[2])
for ref in docs:
docs[ref]['avg_score'] = np.mean(docs[ref]['scores'])
sorted_docs = sorted(docs.items(), key=lambda x: x[1]['avg_score'], reverse=True)
result_refs = [ref[0] for ref in sorted_docs]
return result_refs
async def olymp_think(self, query, sources, llm_params: LlmParams = None):
sources_text = ''
res = ''
for i, source in enumerate(sources):
sources_text += f'Источник [{i + 1}]: {sources[source]}\n'
# Если llm_params не переданы, значит используем микстраль по старому алгоритму
# TODO: Сделать api для микстрали (надо ли?)
if llm_params is None:
step = LLM_PROMPT_OLYMPIC.format(query=query, sources=sources_text)
res = self.get_main_info_with_llm(step)
else:
llm_api = LlmApi(llm_params)
query_for_trim = LLM_PROMPT_OLYMPIC.format(query=query, sources='')
trimmed_sources_result = await llm_api.trim_sources(sources_text, query_for_trim)
prompt = LLM_PROMPT_OLYMPIC.format(query=query, sources=trimmed_sources_result["result"])
res = await llm_api.predict(prompt)
return res
@staticmethod
def parse_step(text):
step4_start = text.find('(4)')
if step4_start != -1:
step4_start = 0
step5_start = text.find('(5)')
if step5_start == -1:
step5_start = 0
if step4_start + 3 < step5_start:
extracted_comment = text[step4_start + 3:step5_start]
else:
extracted_comment = ''
if '$$' in text:
extracted_comment = ''
extracted_content = re.findall(r'\[(.*?)\]', text[step5_start:])
extracted_numbers = []
for item in extracted_content:
if item.isdigit():
extracted_numbers.append(int(item))
return extracted_comment, extracted_numbers
@staticmethod
def lemmatize_query(text):
morph = pymorphy3.MorphAnalyzer()
signs = ',.<>?;\'\":}{!)(][-'
words = text.split()
lemmas = []
for word in words:
if not word.isupper():
word = morph.parse(word)[0].normal_form
lemmas.append(word)
for i, lemma in enumerate(lemmas):
while lemma[0] in signs and len(lemma) > 1:
lemma = lemma[1:]
lemmas[i] = lemma
while lemma[-1] in signs and len(lemma) > 1:
lemma = lemma[:-1]
lemmas[i] = lemma
return " ".join(lemmas)
@staticmethod
def mark_for_one_word_dict(lem_dict):
terms_first_word = set()
first_word_matching_names = {}
first_word_names_to_remove = {}
for name in lem_dict:
first_word = name.split()[0]
if first_word in terms_first_word:
lem_dict[name]['one_word_searchable'] = False
first_word_names_to_remove[first_word] = first_word_matching_names[first_word]
else:
terms_first_word.add(first_word)
first_word_matching_names[first_word] = name
for first_word in first_word_names_to_remove:
name = first_word_names_to_remove[first_word]
lem_dict[name]['one_word_searchable'] = False
return lem_dict
def lemmatize_dict(self, terms_dict):
lem_dict = {}
morph = pymorphy3.MorphAnalyzer()
for name in terms_dict:
if not name.isupper():
lem_name = morph.parse(name)[0].normal_form
else:
lem_name = name
lem_dict[lem_name] = {}
lem_dict[lem_name]['name'] = name
lem_dict[lem_name]['definitions'] = terms_dict[name]['definitions']
lem_dict[lem_name]['titles'] = terms_dict[name]['titles']
lem_dict[lem_name]['sources'] = terms_dict[name]['sources']
lem_dict[lem_name]['is_multi_def'] = terms_dict[name]['is_multi_def']
lem_dict[lem_name]['one_word_searchable'] = True
lem_dict = self.mark_for_one_word_dict(lem_dict)
return lem_dict
@staticmethod
def separate_one_word_searchable_dict(lem_dict):
lem_dict_fast = {}
lem_dict_slow = {}
for name in lem_dict:
if lem_dict[name]['one_word_searchable']:
lem_dict_fast[name] = {}
lem_dict_fast[name]['name'] = lem_dict[name]['name']
lem_dict_fast[name]['definitions'] = lem_dict[name]['definitions']
lem_dict_fast[name]['titles'] = lem_dict[name]['titles']
lem_dict_fast[name]['sources'] = lem_dict[name]['sources']
lem_dict_fast[name]['is_multi_def'] = lem_dict[name]['is_multi_def']
else:
lem_dict_slow[name] = {}
lem_dict_slow[name]['name'] = lem_dict[name]['name']
lem_dict_slow[name]['definitions'] = lem_dict[name]['definitions']
lem_dict_slow[name]['titles'] = lem_dict[name]['titles']
lem_dict_slow[name]['sources'] = lem_dict[name]['sources']
lem_dict_slow[name]['is_multi_def'] = lem_dict[name]['is_multi_def']
return lem_dict_fast, lem_dict_slow
@staticmethod
def extract_original_phrase(original_text, lemmatized_text, lemmatized_phrase):
words = original_text.split()
words_lem = lemmatized_text.split()
words_lem_phrase = lemmatized_phrase.split()
for i, word in enumerate(words_lem):
if word == words_lem_phrase[0]:
words_full = ' '.join(words_lem[i:i + len(words_lem_phrase)])
if words_full == lemmatized_phrase:
original_phrase = ' '.join(words[i:i + len(words_lem_phrase)])
return original_phrase
return False
def substitute_definitions(self, original_text, lem_dict, lem_dict_fast, lem_dict_slow, for_llm=False):
lemmatized_text = self.lemmatize_query(original_text)
found_phrases = set()
phrases_to_add1 = []
phrases_to_add2 = []
words = lemmatized_text.split()
sorted_lem_dict = sorted(lem_dict_slow.items(), key=lambda x: len(x[0]),
reverse=True) # можно сэкономить милисекунды и вынести сортировку по длине куда-то наружу
for lemmatized_phrase_tuple in sorted_lem_dict:
lemmatized_phrase = lemmatized_phrase_tuple[0]
is_new_phrase = True
is_one_word = True
lem_phrase_words = lemmatized_phrase.split()
if len(lem_phrase_words) > 1:
is_one_word = False
if lemmatized_phrase in lemmatized_text and not is_one_word:
if lemmatized_phrase in found_phrases:
is_new_phrase = False
else:
found_phrases.add(lemmatized_phrase)
original_phrase = self.extract_original_phrase(original_text, lemmatized_text, lemmatized_phrase)
phrases_to_add2.append((lemmatized_phrase, original_phrase))
if is_one_word and lemmatized_phrase in words:
for phrase in found_phrases:
if lemmatized_phrase in phrase:
is_new_phrase = False
if is_new_phrase:
found_phrases.add(lemmatized_phrase)
original_phrase = self.extract_original_phrase(original_text, lemmatized_text, lemmatized_phrase)
phrases_to_add2.append((lemmatized_phrase, original_phrase))
for word in words:
is_new_phrase = True
if word in lem_dict_fast:
for phrase in found_phrases:
if word in phrase:
is_new_phrase = False
break
if is_new_phrase:
found_phrases.add(word)
original_phrase = self.extract_original_phrase(original_text, lemmatized_text, word)
phrases_to_add1.append((word, original_phrase))
phrases_to_add = phrases_to_add1 + phrases_to_add2
definition_num = 0
definitions_info = []
substituted_text = original_text
try:
if for_llm:
for term, original_phrase in phrases_to_add:
if lem_dict[term]['is_multi_def']:
definition_num = 0 # Здесь может быть логика контекстно-зависимого выбора нужного определения
term_start = original_text.find(original_phrase)
if type(lem_dict[term]['definitions']) is list:
definitions_info.append(f"{term}-{lem_dict[term]['definitions'][definition_num]}")
else:
definitions_info.append(f"{term}-{lem_dict[term]['definitions']}")
if definitions_info:
definitions_str = ", ".join(definitions_info)
substituted_text = f"{original_text}. Дополнительная информация: {definitions_str}"
else:
substituted_text = original_text
else:
for term, original_phrase in phrases_to_add:
if lem_dict[term]['is_multi_def']:
# Здесь может быть логика контекстно-зависимого выбора нужного определения
definition_num = 0
term_start = substituted_text.find(original_phrase)
if type(lem_dict[term]['definitions']) is list:
substituted_text = substituted_text[:term_start + len(
original_phrase)] + f" ({lem_dict[term]['definitions'][definition_num]})" + substituted_text[
term_start + len(
original_phrase):]
else:
substituted_text = substituted_text[:term_start + len(
original_phrase)] + f" ({lem_dict[term]['definitions']})" + substituted_text[
term_start + len(
original_phrase):]
except Exception as e:
print(f'error processing\n {original_text}\n {term}: {e}')
return substituted_text, phrases_to_add
def filter_by_types(self,
pred: list[str] = None,
scores: list[float] = None,
indexes: list[int] = None,
docs_embeddings: list = None,
ctgs: dict = None):
ctgs = [ctg for ctg in ctgs.keys() if ctgs[ctg]]
filtred_pred, filtred_scores, filtred_indexes, filtred_docs_embeddings = [], [], [], []
for doc_name, score, index, doc_embedding in zip(pred, scores, indexes, docs_embeddings):
if ('ВНД' in doc_name and 'ВНД' in ctgs) or self.all_docs_with_accounting[index]['doc_type'] in ctgs:
filtred_pred.append(doc_name)
filtred_scores.append(score)
filtred_indexes.append(index)
filtred_docs_embeddings.append(doc_embedding)
return filtred_pred, filtred_scores, filtred_indexes, filtred_docs_embeddings
def get_types_of_docs(self, all_docs):
def type_determiner(doc_name):
names = ['НКРФ', 'ГКРФ', 'ТКРФ', 'Федеральный закон', 'Письмо Минфина', 'Письмо ФНС', 'Приказ ФНС',
'Постановление Правительства', 'Судебный документ', 'ВНД', 'Бухгалтерский документ']
for ctg in list(names):
if ctg in doc_name:
return ctg
for doc in all_docs:
doc_type = type_determiner(doc['doc_name'])
doc['doc_type'] = doc_type
return all_docs
def load_model(self):
if hf_token and hf_model_name:
self.tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_auth_token=hf_token)
self.model = AutoModel.from_pretrained(hf_model_name, use_auth_token=hf_token).to(self.device)
else:
self.tokenizer = AutoTokenizer.from_pretrained(global_model_path)
self.model = AutoModel.from_pretrained(global_model_path).to(self.device)
self.max_len = self.tokenizer.max_len_single_sentence
self.embedding_dim = self.model.config.hidden_size
def load_data(self):
with open(dict_path, "rb") as f:
self.terms_dict = pickle.load(f)
with open(companies_map_path, "r", encoding='utf-8') as f:
self.companies_map = json.load(f)
with open(general_nmd_path, "r", encoding='utf-8') as f:
self.general_nmd = json.load(f)
with open(explanations_for_llm_path, "r", encoding='utf-8') as f:
self.explanations_for_llm = json.load(f)
with open(rules_list_path, 'r', encoding='utf-8') as f:
self.rules_list = f.read().splitlines()
self.all_docs_info = dataset.load_from_disk(data_path_consult).to_list() # ONLY EXTERNAL DOCS
self.internal_docs = dataset.load_from_disk(internal_docs_data_path).to_list()
self.accounting_docs = dataset.load_from_disk(accounting_data_path).to_list()
self.spec_internal_docs = dataset.load_from_disk(spec_internal_docs_data_path).to_list()
self.all_docs_with_accounting = self.all_docs_info + self.accounting_docs
self.all_docs_with_accounting = self.get_types_of_docs(self.all_docs_with_accounting)
self.type_weights_nu = {'НКРФ': 1,
'ТКРФ': 1,
'ГКРФ': 1,
'Письмо Минфина': 0.9,
'Письмо ФНС': 0.6,
'Приказ ФНС': 1,
'Постановление Правительства': 1,
'Федеральный закон': 0.9,
'Судебный документ': 0.2,
'ВНД': 0.2,
'Бухгалтерский документ': 0.7,
'Закон Красноярского края': 1.2,
'Правила заполнения': 1.2,
'Правила ведения': 1.2}
self.all_consultations = dataset.load_from_disk(consultations_dataset_path).to_list()
self.all_explanations = dataset.load_from_disk(explanations_dataset_path).to_list()
@staticmethod
def remove_duplicate_paragraphs(paragraphs):
unique_paragraphs = []
seen = set()
for paragraph in paragraphs:
stripped_paragraph = paragraph.strip()
if stripped_paragraph and stripped_paragraph not in seen:
unique_paragraphs.append(paragraph)
seen.add(stripped_paragraph)
return '\n'.join(unique_paragraphs)
@staticmethod
def construct_base(idx_list, base):
concatenated_text = ""
seen_ids = set()
pattern = re.compile(r'_(\d{1,3})')
def find_overlap(a: str, b: str) -> int:
max_overlap = min(len(a), len(b))
for i in range(max_overlap, 0, -1):
if a[-i:] == b[:i]:
return i
return 0
def add_ellipsis(text: str) -> str:
if not text:
return text
segments = text.split('\n\n')
processed_segments = []
for segment in segments:
if segment and not (
segment[0].isupper() or segment[0].isdigit() or segment[0] in ['•', '-', '—', '.']):
segment = '...' + segment
if segment and not (segment.endswith('.') or segment.endswith(';')):
segment += '...'
processed_segments.append(segment)
return '\n\n'.join(processed_segments)
for current_index in idx_list:
if current_index in seen_ids:
continue
start_index = max(0, current_index - 2)
end_index = min(len(base), current_index + 3)
current_name_base = pattern.sub('', base[current_index]['doc_name'])
current_doc_text = base[current_index]['doc_text']
texts_to_concatenate = [current_doc_text]
for i in range(current_index - 1, start_index - 1, -1):
if i in seen_ids:
continue
surrounding_name_base = pattern.sub('', base[i]['doc_name'])
if current_name_base != surrounding_name_base:
break
surrounding_text = base[i]['doc_text']
overlap_length = find_overlap(surrounding_text, texts_to_concatenate[0])
if overlap_length == 0:
break
new_text = surrounding_text + texts_to_concatenate[0][overlap_length:]
texts_to_concatenate[0] = new_text
seen_ids.add(i)
for i in range(current_index + 1, end_index):
if i in seen_ids:
continue
surrounding_name_base = pattern.sub('', base[i]['doc_name'])
if current_name_base != surrounding_name_base:
break
surrounding_text = base[i]['doc_text']
overlap_length = find_overlap(texts_to_concatenate[-1], surrounding_text)
if overlap_length == 0:
break
new_text = texts_to_concatenate[-1] + surrounding_text[overlap_length:]
texts_to_concatenate[-1] = new_text
seen_ids.add(i)
combined_text = ' '.join(texts_to_concatenate)
concatenated_text += combined_text + '\n\n'
seen_ids.add(current_index)
concatenated_text = add_ellipsis(concatenated_text)
return concatenated_text.rstrip('\n')
def search_results_multiply_weights(self,
pred: list[str] = None,
scores: list[float] = None,
indexes: list[int] = None,
docs_embeddings: list = None) -> tuple[list[str], list[float], list[int], list]:
if pred is None or scores is None or indexes is None or docs_embeddings is None:
return [], [], [], []
weights = self.type_weights_nu
weighted_scores = [(weights.get(ctg, 0) * score, prediction, idx, emb)
for prediction, score, idx, emb in zip(pred, scores, indexes, docs_embeddings)
for ctg in weights if ctg in prediction]
weighted_scores.sort(reverse=True, key=lambda x: x[0])
if weighted_scores:
sorted_scores, sorted_preds, sorted_indexes, sorted_docs_embeddings = zip(*weighted_scores)
else:
sorted_scores, sorted_preds, sorted_indexes, sorted_docs_embeddings = [], [], [], []
return list(sorted_preds), list(sorted_scores), list(sorted_indexes), list(sorted_docs_embeddings)
def get_uniq_relevant_docs(self,
top_k: int,
query_refs_all: list[str],
scores: list[float],
indexes: list[int],
docs_embeddings: list[list[float]]
) -> tuple[dict[str, list[str]], dict[str, list[float]], dict[str, list[int]], dict[str, list[list[float]]]]:
regex = r'_\d{1,3}'
base_ref_dict = {}
for i, ref in enumerate(query_refs_all):
base_ref = re.sub(regex, '', ref)
base_ref = base_ref.strip()
if base_ref not in base_ref_dict:
if len(base_ref_dict) >= top_k:
continue
base_ref_dict[base_ref] = {
'refs': [],
'scores': [],
'indexes': [],
'embeddings': []
}
base_ref_dict[base_ref]['refs'].append(ref)
base_ref_dict[base_ref]['scores'].append(scores[i])
base_ref_dict[base_ref]['indexes'].append(indexes[i])
base_ref_dict[base_ref]['embeddings'].append(docs_embeddings[i])
def get_suffix_number(ref: str):
match = re.findall(regex, ref)
if match:
match = re.findall(regex, ref)[0].replace('_', '')
return int(match)
return None
for base_ref, data in base_ref_dict.items():
refs = data['refs']
scores_list = data['scores']
indexes_list = data['indexes']
embeddings_list = data['embeddings']
combined = list(zip(refs, scores_list, indexes_list, embeddings_list))
def sort_key(item):
ref = item[0]
suffix = get_suffix_number(ref)
return (0 if suffix is None else 1, suffix if suffix is not None else -1)
combined_sorted = sorted(combined, key=sort_key)
sorted_refs, sorted_scores, sorted_indexes, sorted_embeddings = zip(*combined_sorted)
base_ref_dict[base_ref]['refs'] = list(sorted_refs)[:20]
base_ref_dict[base_ref]['scores'] = list(sorted_scores)[:20]
base_ref_dict[base_ref]['indexes'] = list(sorted_indexes)[:20]
base_ref_dict[base_ref]['embeddings'] = list(sorted_embeddings)[:20]
unique_refs = {k: v['refs'] for k, v in base_ref_dict.items()}
filtered_scores = {k: v['scores'] for k, v in base_ref_dict.items()}
filtered_indexes = {k: v['indexes'] for k, v in base_ref_dict.items()}
filtered_docs_embeddings = {k: v['embeddings'] for k, v in base_ref_dict.items()}
return unique_refs, filtered_scores, filtered_indexes, filtered_docs_embeddings
def filter_results(self, pred_internal, scores_internal, indices_internal, docs_embeddings_internal, companies_files):
filt_pred_internal, filt_scores_internal, \
filt_indices_internal, filt_docs_embeddings_internal = list(), list(), list(), list()
def add_data(pred, ind, score, emb):
filt_pred_internal.append(pred)
filt_indices_internal.append(ind)
filt_scores_internal.append(score)
filt_docs_embeddings_internal.append(emb)
for pred, score, ind, emb in zip(pred_internal, scores_internal, indices_internal, docs_embeddings_internal):
if [doc for doc in self.general_nmd if doc in pred]:
add_data(pred, ind, score, emb)
continue
for company in companies_files:
if company in pred:
add_data(pred, ind, score, emb)
return filt_pred_internal, filt_scores_internal, filt_indices_internal, filt_docs_embeddings_internal
@staticmethod
def merge_dictionaries(dicts: list = None):
merged_dict = {}
max_length = max(len(d) for d in dicts)
for i in range(max_length):
for d in dicts:
keys = list(d.keys())
values = list(d.values())
if i < len(keys):
merged_dict[keys[i]] = values[i]
return merged_dict
@staticmethod
def check_specific_key(dictionary, key):
if key in dictionary and dictionary[key] is True:
for k, v in dictionary.items():
if k != key and v is True:
return False
return True
return False
@staticmethod
def remove_duplicates(input_list):
unique_dict = {}
for item in input_list:
unique_dict[item] = None
return list(unique_dict.keys())
async def search_engine(self,
query: str = None,
use_qe: bool = False,
categories: dict = None,
llm_params: LlmParams = None):
if True in list(categories.values()) and not all(categories.values()):
self.full_base_search = False
if self.check_specific_key(categories, 'ВНД'):
nmd_chunks = 120
nmd_refs = 45
extra_chunks = 1
extra_refs = 1
elif not categories['ВНД']:
extra_chunks = 120
extra_refs = 45
nmd_chunks = 1
nmd_refs = 1
else:
nmd_chunks = 60
nmd_refs = 23
extra_chunks = 60
extra_refs = 23
else:
self.full_base_search = True
nmd_chunks = 50
nmd_refs = 15
extra_chunks = 75
extra_refs = 30
# Ответы от ллм для отправки на фронт
llm_responses = []
# Токенизация и векторизация запроса
query_tokens = query_tokenization(query, self.tokenizer)
query_embeds = query_embed_extraction(query_tokens, self.model, self.do_normalization)
# Поиск по базе документов внешней
distances, indices = self.index_all_docs_with_accounting.search(query_embeds, len(self.all_docs_with_accounting))
pred = [self.all_docs_with_accounting[x]['doc_name'] for x in indices[0]]
docs_embeddings = [self.all_docs_with_accounting[x]['doc_embedding'] for x in indices[0]]
preds, scores, indexes, docs_embeddings = pred[:5000], list(distances[0])[:5000], \
list(indices[0])[:5000], docs_embeddings[:5000]
if not re.search('[Кк]расноярск', query):
self.type_weights_nu['Закон Красноярского края'] = 0
else:
self.type_weights_nu['Закон Красноярского края'] = 1.2
if not use_rules(query, self.rules_list):
self.type_weights_nu['Правила ведения'] = 0
self.type_weights_nu['Правила заполнения'] = 0
else:
self.type_weights_nu['Правила ведения'] = 1.2
self.type_weights_nu['Правила заполнения'] = 1.2
preds, scores, indexes, docs_embeddings = pred[:5000], list(distances[0])[:5000], \
list(indices[0])[:5000], docs_embeddings[:5000]
# Поиск по базе документов внутренних
if self.full_base_search or categories['ВНД']:
distances_internal, indices_internal = self.index_internal_docs.search(query_embeds, len(self.spec_internal_docs))
pred_internal = [self.spec_internal_docs[x]['doc_name'] for x in indices_internal[0]]
docs_embeddings_internal = [self.spec_internal_docs[x]['doc_embedding'] for x in indices_internal[0]]
indices_internal = indices_internal[0]
scores_internal = []
for title, score in zip(pred_internal, distances_internal[0]):
if 'КУП' in title:
scores_internal.append(score*1.2)
else:
scores_internal.append(score)
companies_files = search_company.find_nmd_docs(query, self.companies_map)
pred_internal, scores_internal, indices_internal, docs_embeddings_internal = self.filter_results(pred_internal,
scores_internal,
indices_internal,
docs_embeddings_internal,
companies_files)
combined = list(zip(pred_internal, scores_internal, indices_internal, docs_embeddings_internal))
sorted_combined = sorted(combined, key=lambda x: x[1], reverse=True)
top_nmd = sorted_combined[:nmd_chunks]
if 'ЕГДС' in query:
if not [x for x in top_nmd if 'п.5. Положение о КУП_262 (ВНД)' in x]:
ch262 = self.internal_docs[22976]
ch262 = (ch262['doc_name'], 1.0, 22976, ch262['chunks_embeddings'][0])
top_nmd.insert(0, ch262)
if not [x for x in top_nmd if 'п.5. Положение о КУП_130 (ВНД)' in x]:
ch130 = self.internal_docs[22844]
ch130 = (ch130['doc_name'], 1.0, 22844, ch130['chunks_embeddings'][0])
top_nmd.insert(1, ch130)
top_nmd = top_nmd[:nmd_chunks]
preds_internal, scores_internal, indexes_internal, internal_docs_embeddings = zip(*top_nmd)
preds_internal, scores_internal, indexes_internal, internal_docs_embeddings = list(preds_internal), \
list(scores_internal), \
list(indexes_internal), \
list(internal_docs_embeddings)
# Сбор уникальных внутренних документов
unique_preds_internal, unique_scores_internal, unique_indexes_internal, \
unique_docs_embeddings_internal = self.get_uniq_relevant_docs(
top_k=nmd_refs,
query_refs_all=preds_internal,
scores=scores_internal,
indexes=indexes_internal,
docs_embeddings=internal_docs_embeddings)
preds_internal, scores_internal, \
indexes_internal, internal_docs_embeddings = unique_preds_internal, unique_scores_internal,\
unique_indexes_internal, unique_docs_embeddings_internal
# Фильтрация или не фильтрация по категориям по наличию отметок в чек-боксах
if not self.full_base_search:
preds, scores, indexes, docs_embeddings = self.filter_by_types(preds, scores, indexes,
docs_embeddings, categories)
# Использование весов поверх скоров
sorted_preds, sorted_scores, sorted_indexes, sorted_docs_embeddings = self.search_results_multiply_weights(
pred=preds,
scores=scores,
indexes=indexes,
docs_embeddings=docs_embeddings)
sorted_preds, sorted_scores, sorted_indexes, sorted_docs_embeddings = sorted_preds[:extra_chunks], \
sorted_scores[:extra_chunks], \
sorted_indexes[:extra_chunks], \
sorted_docs_embeddings[:extra_chunks]
# Сбор уникальных документов внешних
unique_preds, unique_scores, unique_indexes, unique_docs_embeddings = self.get_uniq_relevant_docs(
top_k=extra_refs,
query_refs_all=sorted_preds,
scores=sorted_scores,
indexes=sorted_indexes,
docs_embeddings=sorted_docs_embeddings
)
preds, scores, indexes, docs_embeddings = unique_preds, unique_scores, unique_indexes, unique_docs_embeddings
if use_qe:
try:
prompt = LLM_PROMPT_KEYS.format(query=query)
if llm_params is None:
keyword_query = self.get_main_info_with_llm(prompt)
else:
llm_api = LlmApi(llm_params)
keyword_query = await llm_api.predict(prompt)
llm_responses.append(keyword_query)
keyword_query = re.sub(r'\[1\].*?(?=\[\d+\]|$)', '', keyword_query, flags=re.DOTALL).replace(' [2]', '').replace('[3]', '').strip()
keyword_query_tokens = query_tokenization(keyword_query, self.tokenizer)
keyword_query_embeds = query_embed_extraction(keyword_query_tokens,
self.model,
self.do_normalization)
keyword_distances, keyword_indices = self.index_all_docs_with_accounting.search(
keyword_query_embeds, len(self.all_docs_with_accounting))
keyword_pred = [self.all_docs_with_accounting[x]['doc_name'] for x in keyword_indices[0]]
keyword_docs_embeddings = [self.all_docs_with_accounting[x]['doc_embedding'] for x in
keyword_indices[0]]
if not self.full_base_search:
keyword_preds, keyword_scores, \
keyword_indexes, keyword_docs_embeddings = self.filter_by_types(keyword_pred,
keyword_distances[0],
keyword_indices[0],
keyword_docs_embeddings,
categories)
else:
keyword_preds, keyword_scores, \
keyword_indexes, keyword_docs_embeddings = keyword_pred, keyword_distances[0], \
keyword_indices[0],keyword_docs_embeddings
keyword_preds, keyword_scores, \
keyword_indexes, keyword_docs_embeddings = self.search_results_multiply_weights(
pred=keyword_preds, scores=keyword_scores,
indexes=keyword_indexes, docs_embeddings=keyword_docs_embeddings)
keyword_unique_preds, keyword_unique_scores, \
keyword_unique_indexes, keyword_unique_docs_embeddings = self.get_uniq_relevant_docs(
top_k=45,
query_refs_all=keyword_preds,
scores=keyword_scores,
indexes=keyword_indexes,
docs_embeddings=keyword_docs_embeddings)
preds = dict(list(self.merge_dictionaries([preds, keyword_unique_preds]).items())[:30])
scores = dict(list(self.merge_dictionaries([scores, keyword_unique_scores]).items())[:30])
indexes = dict(list(self.merge_dictionaries([indexes, keyword_unique_indexes]).items())[:30])
except:
traceback.print_exc()
print(f"Error applying keys (possibly the LLM is not available)")
if self.full_base_search or categories['ВНД']:
# Внесение внутренних топ-10 документов в выдачу
if self.full_base_search or categories['ВНД']:
preds = self.merge_dictionaries([preds, preds_internal])
scores = self.merge_dictionaries([scores, scores_internal])
indexes = self.merge_dictionaries([indexes, indexes_internal])
# Красивая сборка чанков для LLM
texts_for_llm, docs, teasers = [], [], []
for key, idx_list in indexes.items():
collected_text = []
if 'ВНД' in key:
base = self.internal_docs
else:
base = self.all_docs_with_accounting
if re.search('Минфин|Бухгалтерский документ|ФНС|Судебный документ|Постановление Правительства|Федеральный закон', key):
text = self.construct_base(idx_list, base)
collected_text.append(text)
else:
for idx in idx_list:
if idx < len(base):
for text in base[idx]['doc_text'].split('\n'):
collected_text.append(text)
collected_text = self.remove_duplicate_paragraphs(collected_text)
texts_for_llm.append(collected_text)
# Поиск релевантных консультаций
distances_consult, indices_consult = self.index_consult.search(query_embeds, len(self.all_consultations))
predicted_consultations = {self.all_consultations[x]['doc_name']: self.all_consultations[x]['doc_text']
for x in indices_consult[0]}
# Поиск релевантных разъяснений
distances_explanations, indices_explanations = self.index_explanations.search(query_embeds, len(self.all_explanations))
predicted_explanations = {self.all_explanations[x]['doc_name']: self.all_explanations[x]['doc_text']
for x in indices_explanations[0]}
results = list(zip(list(predicted_explanations.keys()),
list(predicted_explanations.values()),
distances_explanations[0]))
explanation_titles = self.rerank_by_avg_score(results)[:3]
try:
predicted_explanation = {explanation_title: self.explanations_for_llm[explanation_title] for explanation_title in explanation_titles}
except:
predicted_explanation = {}
print('The relevant document was not found in the system.')
return query, [x.replace('ФЕДЕРАЛЬНЫЙ СТАНДАРТ БУХГАЛТЕРСКОГО УЧЕТА', 'Федеральный стандарт бухгалтерского учета ФСБУ') for x in list(preds.keys())], texts_for_llm, dict(list(predicted_consultations.items())[:def_k]), \
predicted_explanation, llm_responses
async def olympic_branch(self,
query: str = None,
sources: dict = None,
categories: dict = None,
llm_params: LlmParams = None):
# Собираем все ответы ллм для отправки на фронт
llm_responses = []
text = await self.olymp_think(query, sources, llm_params)
llm_responses.append(text)
saved_sources = {}
saved_step_by_step = []
comment1, sources_choice = self.parse_step(text)
sources_choice = [source - 1 for source in sources_choice]
for idx, ref in enumerate(sources):
if idx in sources_choice and ref not in saved_sources.keys():
saved_sources.update({ref: sources[ref]})
should_continue = True
if comment1 == '':
count = 4
count = 0
while count < 4:
query, preds, \
texts_for_llm, predicted_consultations, \
predicted_explanation, skip_llm_responses = await self.search_engine(query, use_qe=False, categories=categories)
sources = dict(map(lambda i,j: (i,j), preds, texts_for_llm))
sources = dict(islice(sources.items(), 20))
text = await self.olymp_think(query, sources, llm_params)
llm_responses.append(text)
comment2, sources_choice = self.parse_step(text)
sources_choice = [source - 1 for source in sources_choice]
saved_step_by_step.append(sources_choice)
for idx, ref in enumerate(sources):
if idx in sources_choice and ref not in saved_sources.keys():
saved_sources.update({ref: sources[ref]})
if comment2 == '':
break
comment1 = comment2
count += 1
return saved_sources, saved_step_by_step, llm_responses
async def search(self,
query: str = None,
use_qe: bool = False,
use_olympic: bool = False,
categories: dict = None,
llm_params: LlmParams = None):
# Преобразование запроса
lem_dict = self.lemmatize_dict(self.terms_dict)
lem_dict_fast, lem_dict_slow = self.separate_one_word_searchable_dict(lem_dict)
query_for_llm, _ = self.substitute_definitions(query, lem_dict, lem_dict_fast, lem_dict_slow, for_llm=True)
query, _ = self.substitute_definitions(query, lem_dict, lem_dict_fast, lem_dict_slow, for_llm=False)
# Базовый поиск
query, base_preds, base_texts_for_llm, \
predicted_consultations, predicted_explanation, llm_responses = await self.search_engine(query, use_qe, categories, llm_params)
if use_olympic:
sources = dict(map(lambda i,j: (i,j), base_preds, base_texts_for_llm))
sources = dict(islice(sources.items(), 20))
olymp_results, olymp_step_by_step, llm_responses = await self.olympic_branch(query, sources, categories, llm_params)
olymp_preds, olymp_texts_for_llm = list(olymp_results.keys()), list(olymp_results.values())
if len(olymp_preds) <= 45:
preds = olymp_preds + base_preds
preds = self.remove_duplicates(preds)[:45]
texts_for_llm = olymp_texts_for_llm + base_texts_for_llm
texts_for_llm = self.remove_duplicates(texts_for_llm)[:45]
return query_for_llm, preds, texts_for_llm, predicted_consultations, predicted_explanation, llm_responses
else:
olymp_results = self.merge_dictionaries(olymp_step_by_step)[:45]
preds, texts_for_llm = list(olymp_results.keys()), list(olymp_results.values())
return query_for_llm, preds, texts_for_llm, predicted_consultations, predicted_explanation, llm_responses
else:
return query_for_llm, base_preds, base_texts_for_llm, predicted_consultations, predicted_explanation, llm_responses