nn-search-full / fastapi_app.py
muryshev's picture
removed logs
85dfc4f
raw
history blame
5.98 kB
from fastapi import FastAPI, HTTPException
from semantic_search import SemanticSearch
from transaction_maps_search import TransactionMapsSearch
from pydantic import BaseModel
import os
import datetime
import json
import traceback
from llm.vllm_api import LlmParams
# Set the path for log files
LOGS_BASE_PATH = os.getenv("LOGS_BASE_PATH", "logs")
# Create logs directory if it doesn't exist
# if not os.path.exists(LOGS_BASE_PATH):
# os.makedirs(LOGS_BASE_PATH)
# Check if logs are enabled
ENABLE_LOGS = os.getenv("ENABLE_LOGS", "0") == "1"
class Query(BaseModel):
query: str = ''
top: int = 10
use_qe: bool = False
use_olympic: bool = False
find_transaction_maps_by_question: bool = False
find_transaction_maps_by_operation: bool = False
request_id: str = ''
categories: dict = {'НКРФ': False,
'ГКРФ': False,
'ТКРФ': False,
'Федеральный закон': False,
'Письмо Минфина': False,
'Письмо ФНС': False,
'Приказ ФНС': False,
'Постановление Правительства': False,
'Судебный документ': False,
'ВНД': False,
'Бухгалтерский документ': False}
llm_params: LlmParams = None
# search = SemanticSearch()
transaction_maps_search = TransactionMapsSearch()
app = FastAPI(
title="multistep-semantic-search-app",
description="multistep-semantic-search-app",
version="0.1.0",
)
def log_query_result(query, top, request_id, result):
if not ENABLE_LOGS:
return
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
log_file_path = os.path.join(LOGS_BASE_PATH, f"{timestamp}.json")
log_data = {
"timestamp": timestamp,
"query": query,
"top": top,
"request_id": request_id,
"result": result
}
with open(log_file_path, 'w', encoding='utf-8') as log_file:
json.dump(log_data, log_file, indent=2, ensure_ascii=False)
@app.post('/search')
async def search_route(query: Query) -> dict:
try:
question = getattr(query, "query", None)
if not question:
raise ValueError("Query parameter 'query' is required and cannot be empty.")
top = getattr(query, "top", 15)
use_qe = getattr(query, "use_qe", False)
request_id = getattr(query, "request_id", None)
categories = getattr(query, "categories", None)
use_olympic = getattr(query, "use_olympic", False)
find_transaction_maps_by_question = getattr(query, "find_transaction_maps_by_question", False)
find_transaction_maps_by_operation = getattr(query, "find_transaction_maps_by_operation", False)
llm_params = getattr(query, "llm_params", None)
if find_transaction_maps_by_question or find_transaction_maps_by_operation:
transaction_maps_results, answer = transaction_maps_search.search_transaction_map(
query=question,
find_transaction_maps_by_question=find_transaction_maps_by_question,
k_neighbours=top)
response = {'transaction_maps_results': transaction_maps_results}
else:
modified_query, titles, concat_docs, \
relevant_consultations, predicted_explanation, \
llm_responses = await search.search(question, use_qe, use_olympic, categories, query.llm_params)
results = [{'title': str(item1), 'text_for_llm': str(item2)} for item1, item2 in
zip(titles, concat_docs)]
consultations = [{'title': key, 'text': value} for key, value in relevant_consultations.items()]
explanations = [{'title': key, 'text': value} for key, value in predicted_explanation.items()]
response = {'query': modified_query, 'results': results,
'consultations': consultations, 'explanations': explanations, 'llm_responses': llm_responses}
log_query_result(question, top, request_id, response)
return response
except ValueError as ve:
traceback.print_exception(type(ve), ve, ve.__traceback__)
raise HTTPException(status_code=400, detail=str(ve))
except Exception as e:
traceback.print_exception(type(e), e, e.__traceback__)
raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}")
@app.get('/health')
def health():
return {"status": "ok"}
@app.get('/read_logs')
def read_logs():
logs = []
for log_file in os.listdir(LOGS_BASE_PATH):
if log_file.endswith(".json"):
with open(os.path.join(LOGS_BASE_PATH, log_file), 'r', encoding='utf-8') as file:
log_data = json.load(file)
logs.append(log_data)
return logs
@app.get('/analyze_logs')
def analyze_logs():
logs_by_query_top = {}
for log_file in os.listdir(LOGS_BASE_PATH):
if log_file.endswith(".json"):
with open(os.path.join(LOGS_BASE_PATH, log_file), 'r', encoding='utf-8') as file:
log_data = json.load(file)
query = log_data.get("query", "")
top = log_data.get("top", "")
request_id = log_data.get("request_id", "")
# Group logs by query and top
key = f"{query}_{top}"
if key not in logs_by_query_top:
logs_by_query_top[key] = []
logs_by_query_top[key].append(log_data)
# Analyze logs and filter out logs with different results for the same query and top
invalid_logs = []
for key, logs in logs_by_query_top.items():
if len(set(json.dumps(log['result']) for log in logs)) > 1:
invalid_logs.extend(logs)
return invalid_logs