from fastapi import FastAPI, HTTPException from semantic_search import SemanticSearch from transaction_maps_search import TransactionMapsSearch from pydantic import BaseModel import os import datetime import json import traceback from llm.vllm_api import LlmParams # Set the path for log files LOGS_BASE_PATH = os.getenv("LOGS_BASE_PATH", "logs") # Create logs directory if it doesn't exist # if not os.path.exists(LOGS_BASE_PATH): # os.makedirs(LOGS_BASE_PATH) # Check if logs are enabled ENABLE_LOGS = os.getenv("ENABLE_LOGS", "0") == "1" class Query(BaseModel): query: str = '' top: int = 10 use_qe: bool = False use_olympic: bool = False find_transaction_maps_by_question: bool = False find_transaction_maps_by_operation: bool = False request_id: str = '' categories: dict = {'НКРФ': False, 'ГКРФ': False, 'ТКРФ': False, 'Федеральный закон': False, 'Письмо Минфина': False, 'Письмо ФНС': False, 'Приказ ФНС': False, 'Постановление Правительства': False, 'Судебный документ': False, 'ВНД': False, 'Бухгалтерский документ': False} llm_params: LlmParams = None # search = SemanticSearch() transaction_maps_search = TransactionMapsSearch() app = FastAPI( title="multistep-semantic-search-app", description="multistep-semantic-search-app", version="0.1.0", ) def log_query_result(query, top, request_id, result): if not ENABLE_LOGS: return timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") log_file_path = os.path.join(LOGS_BASE_PATH, f"{timestamp}.json") log_data = { "timestamp": timestamp, "query": query, "top": top, "request_id": request_id, "result": result } with open(log_file_path, 'w', encoding='utf-8') as log_file: json.dump(log_data, log_file, indent=2, ensure_ascii=False) @app.post('/search') async def search_route(query: Query) -> dict: try: question = getattr(query, "query", None) if not question: raise ValueError("Query parameter 'query' is required and cannot be empty.") top = getattr(query, "top", 15) use_qe = getattr(query, "use_qe", False) request_id = getattr(query, "request_id", None) categories = getattr(query, "categories", None) use_olympic = getattr(query, "use_olympic", False) find_transaction_maps_by_question = getattr(query, "find_transaction_maps_by_question", False) find_transaction_maps_by_operation = getattr(query, "find_transaction_maps_by_operation", False) llm_params = getattr(query, "llm_params", None) if find_transaction_maps_by_question or find_transaction_maps_by_operation: transaction_maps_results, answer = transaction_maps_search.search_transaction_map( query=question, find_transaction_maps_by_question=find_transaction_maps_by_question, k_neighbours=top) response = {'transaction_maps_results': transaction_maps_results} else: modified_query, titles, concat_docs, \ relevant_consultations, predicted_explanation, \ llm_responses = await search.search(question, use_qe, use_olympic, categories, query.llm_params) results = [{'title': str(item1), 'text_for_llm': str(item2)} for item1, item2 in zip(titles, concat_docs)] consultations = [{'title': key, 'text': value} for key, value in relevant_consultations.items()] explanations = [{'title': key, 'text': value} for key, value in predicted_explanation.items()] response = {'query': modified_query, 'results': results, 'consultations': consultations, 'explanations': explanations, 'llm_responses': llm_responses} log_query_result(question, top, request_id, response) return response except ValueError as ve: traceback.print_exception(type(ve), ve, ve.__traceback__) raise HTTPException(status_code=400, detail=str(ve)) except Exception as e: traceback.print_exception(type(e), e, e.__traceback__) raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}") @app.get('/health') def health(): return {"status": "ok"} @app.get('/read_logs') def read_logs(): logs = [] for log_file in os.listdir(LOGS_BASE_PATH): if log_file.endswith(".json"): with open(os.path.join(LOGS_BASE_PATH, log_file), 'r', encoding='utf-8') as file: log_data = json.load(file) logs.append(log_data) return logs @app.get('/analyze_logs') def analyze_logs(): logs_by_query_top = {} for log_file in os.listdir(LOGS_BASE_PATH): if log_file.endswith(".json"): with open(os.path.join(LOGS_BASE_PATH, log_file), 'r', encoding='utf-8') as file: log_data = json.load(file) query = log_data.get("query", "") top = log_data.get("top", "") request_id = log_data.get("request_id", "") # Group logs by query and top key = f"{query}_{top}" if key not in logs_by_query_top: logs_by_query_top[key] = [] logs_by_query_top[key].append(log_data) # Analyze logs and filter out logs with different results for the same query and top invalid_logs = [] for key, logs in logs_by_query_top.items(): if len(set(json.dumps(log['result']) for log in logs)) > 1: invalid_logs.extend(logs) return invalid_logs