Spaces:
Running
on
T4
Running
on
T4
from fastapi import FastAPI, HTTPException | |
from semantic_search import SemanticSearch | |
from transaction_maps_search import TransactionMapsSearch | |
from pydantic import BaseModel | |
import os | |
import datetime | |
import json | |
import traceback | |
from llm.common import LlmParams, LlmPredictParams | |
from llm.deepinfra_api import DeepInfraApi | |
# Check if logs are enabled | |
ENABLE_LOGS = os.getenv("ENABLE_LOGS", "0") == "1" | |
# Set the path for log files | |
LOGS_BASE_PATH = os.getenv("LOGS_BASE_PATH", "logs") | |
# Create logs directory if it doesn't exist | |
if ENABLE_LOGS and not os.path.exists(LOGS_BASE_PATH): | |
os.makedirs(LOGS_BASE_PATH) | |
LLM_API_URL = os.getenv("LLM_API_URL", "") | |
LLM_API_KEY = os.getenv("LLM_API_KEY", "") | |
LLM_USE_DEEPINFRA = os.getenv("LLM_USE_DEEPINFRA", "") == "1" | |
print(LLM_USE_DEEPINFRA) | |
class Query(BaseModel): | |
query: str = '' | |
top: int = 10 | |
use_qe: bool = False | |
use_olympic: bool = False | |
find_transaction_maps_by_question: bool = False | |
find_transaction_maps_by_operation: bool = False | |
request_id: str = '' | |
categories: dict = {'НКРФ': False, | |
'ГКРФ': False, | |
'ТКРФ': False, | |
'Федеральный закон': False, | |
'Письмо Минфина': False, | |
'Письмо ФНС': False, | |
'Приказ ФНС': False, | |
'Постановление Правительства': False, | |
'Судебный документ': False, | |
'ВНД': False, | |
'Бухгалтерский документ': False} | |
llm_params: LlmParams = None | |
# search = SemanticSearch() | |
transaction_maps_search = TransactionMapsSearch() | |
app = FastAPI( | |
title="multistep-semantic-search-app", | |
description="multistep-semantic-search-app", | |
version="0.1.0", | |
) | |
def log_query_result(query, top, request_id, result): | |
if not ENABLE_LOGS: | |
return | |
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") | |
log_file_path = os.path.join(LOGS_BASE_PATH, f"{timestamp}.json") | |
log_data = { | |
"timestamp": timestamp, | |
"query": query, | |
"top": top, | |
"request_id": request_id, | |
"result": result | |
} | |
with open(log_file_path, 'w', encoding='utf-8') as log_file: | |
json.dump(log_data, log_file, indent=2, ensure_ascii=False) | |
async def search_route(query: Query) -> dict: | |
default_llm_params = LlmParams(url=LLM_API_URL,api_key=LLM_API_KEY, model="mistralai/Mixtral-8x7B-Instruct-v0.1", predict_params=LlmPredictParams(temperature=0.15, top_p=0.95, min_p=0.05, seed=42, repetition_penalty=1.2, presence_penalty=1.1, max_tokens=6000)) | |
try: | |
question = getattr(query, "query", None) | |
if not question: | |
raise ValueError("Query parameter 'query' is required and cannot be empty.") | |
top = getattr(query, "top", 15) | |
use_qe = getattr(query, "use_qe", False) | |
request_id = getattr(query, "request_id", None) | |
categories = getattr(query, "categories", None) | |
use_olympic = getattr(query, "use_olympic", False) | |
find_transaction_maps_by_question = getattr(query, "find_transaction_maps_by_question", False) | |
find_transaction_maps_by_operation = getattr(query, "find_transaction_maps_by_operation", False) | |
request_llm_params = getattr(query, "llm_params", None) | |
print(request_llm_params) | |
llm_params = default_llm_params#getattr(query, "llm_params", default_llm_params) | |
if LLM_USE_DEEPINFRA: | |
print(llm_params.model) | |
llm_api = DeepInfraApi(llm_params) | |
if find_transaction_maps_by_question or find_transaction_maps_by_operation: | |
transaction_maps_results, answer = await transaction_maps_search.search_transaction_map( | |
query=question, | |
find_transaction_maps_by_question=find_transaction_maps_by_question, | |
k_neighbours=top, | |
llm_api=llm_api) | |
response = {'transaction_maps_results': transaction_maps_results} | |
else: | |
modified_query, titles, concat_docs, \ | |
relevant_consultations, predicted_explanation, \ | |
llm_responses = await search.search(question, use_qe, use_olympic, categories, llm_params) | |
results = [{'title': str(item1), 'text_for_llm': str(item2)} for item1, item2 in | |
zip(titles, concat_docs)] | |
consultations = [{'title': key, 'text': value} for key, value in relevant_consultations.items()] | |
explanations = [{'title': key, 'text': value} for key, value in predicted_explanation.items()] | |
response = {'query': modified_query, 'results': results, | |
'consultations': consultations, 'explanations': explanations, 'llm_responses': llm_responses} | |
log_query_result(question, top, request_id, response) | |
return response | |
except ValueError as ve: | |
traceback.print_exception(type(ve), ve, ve.__traceback__) | |
raise HTTPException(status_code=400, detail=str(ve)) | |
except Exception as e: | |
traceback.print_exception(type(e), e, e.__traceback__) | |
raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}") | |
def health(): | |
return {"status": "ok"} | |
def read_logs(): | |
logs = [] | |
for log_file in os.listdir(LOGS_BASE_PATH): | |
if log_file.endswith(".json"): | |
with open(os.path.join(LOGS_BASE_PATH, log_file), 'r', encoding='utf-8') as file: | |
log_data = json.load(file) | |
logs.append(log_data) | |
return logs | |
def analyze_logs(): | |
logs_by_query_top = {} | |
for log_file in os.listdir(LOGS_BASE_PATH): | |
if log_file.endswith(".json"): | |
with open(os.path.join(LOGS_BASE_PATH, log_file), 'r', encoding='utf-8') as file: | |
log_data = json.load(file) | |
query = log_data.get("query", "") | |
top = log_data.get("top", "") | |
request_id = log_data.get("request_id", "") | |
# Group logs by query and top | |
key = f"{query}_{top}" | |
if key not in logs_by_query_top: | |
logs_by_query_top[key] = [] | |
logs_by_query_top[key].append(log_data) | |
# Analyze logs and filter out logs with different results for the same query and top | |
invalid_logs = [] | |
for key, logs in logs_by_query_top.items(): | |
if len(set(json.dumps(log['result']) for log in logs)) > 1: | |
invalid_logs.extend(logs) | |
return invalid_logs | |