pvanand commited on
Commit
bdd570c
·
verified ·
1 Parent(s): e7a043e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +23 -27
main.py CHANGED
@@ -54,57 +54,53 @@ sys_prompts = {
54
  },
55
  }
56
 
57
- @app.post("/generate_report")
58
- @cache(expire=604800) # Set cache expiration to 7 days (7 * 24 * 60 * 60 seconds)
59
- async def generate_report(
60
- topic: str = Query(default="market research", description="input query to generate Report"),
61
- description: str = Query(default="", description="additional context for report"),
62
- user_id: str = Query(default="", description="unique user id"),
63
- user_name: str = Query(default="", description="user name"),
64
- internet: bool = Query(default=True, description="Enable Internet search"),
65
- output_format: str = Query(default="Tabular Report", description="Output format for the report", enum=["Chat", "Full Text Report", "Tabular Report", "Tables only"]),
66
  data_format: str = Query(default="Structured data", description="Type of data to extract from the internet", enum=["No presets", "Structured data", "Quantitative data"])
67
- ):
68
- query_str = topic
69
- internet_status = "online" if internet else "offline"
70
- sys_prompt_output_format = sys_prompts[internet_status][output_format]
 
 
 
71
  optimized_search_query = ""
72
  all_text_with_urls = [("","")]
73
 
74
  # Combine query with user keywords
75
- if internet:
76
- search_query = description
77
- # Search for relevant URLs
78
  try:
79
  urls, optimized_search_query = search_brave(search_query, num_results=4)
80
- # Fetch and extract content from the URLs
81
- all_text_with_urls = fetch_and_extract_content(data_format, urls, query_str)
82
- # Prepare the prompt for generating the report
83
  additional_context = limit_tokens(str(all_text_with_urls))
84
- prompt = f"#### COMPLETE THE TASK: {description} #### IN THE CONTEXT OF ### CONTEXT: {query_str} USING THE #### SCRAPED DATA:{additional_context}"
85
  except Exception as e:
86
- internet = False
87
  print("failed to search/scrape results, falling back to LLM response")
88
 
89
- if not internet:
90
- prompt = f"#### COMPLETE THE TASK: {description} #### IN THE CONTEXT OF ### CONTEXT: {query_str}"
91
 
92
  md_report = together_response(prompt, model=llm_default_medium, SysPrompt=sys_prompt_output_format)
93
 
94
- if user_id != "test":
95
- insert_data(user_id, query_str, description, str(all_text_with_urls), md_report)
96
 
97
  references_html = dict()
98
  for text, url in all_text_with_urls:
99
  references_html[url] = str(md_to_html(text))
100
 
101
- # Return the generated report
102
  return {
103
  "report": md_to_html(md_report),
104
  "references": references_html,
105
  "search_query": optimized_search_query
106
  }
107
-
108
  app.add_middleware(
109
  CORSMiddleware,
110
  allow_origins=["*"],
 
54
  },
55
  }
56
 
57
+ class ReportParams(BaseModel):
58
+ topic: str = Query(default="market research", description="input query to generate report")
59
+ description: str = Query(default="", description="additional context for report")
60
+ user_id: str = Query(default="", description="unique user id")
61
+ user_name: str = Query(default="", description="user name")
62
+ internet: bool = Query(default=True, description="Enable Internet search")
63
+ output_format: str = Query(default="Tabular Report", description="Output format for the report", enum=["Chat", "Full Text Report", "Tabular Report", "Tables only"])
 
 
64
  data_format: str = Query(default="Structured data", description="Type of data to extract from the internet", enum=["No presets", "Structured data", "Quantitative data"])
65
+
66
+ @app.post("/generate_report")
67
+ @cache(expire=604800) # Cache expiration set to 7 days
68
+ async def generate_report(params: ReportParams):
69
+ query_str = params.topic
70
+ internet_status = "online" if params.internet else "offline"
71
+ sys_prompt_output_format = sys_prompts[internet_status][params.output_format]
72
  optimized_search_query = ""
73
  all_text_with_urls = [("","")]
74
 
75
  # Combine query with user keywords
76
+ if params.internet:
77
+ search_query = params.description
 
78
  try:
79
  urls, optimized_search_query = search_brave(search_query, num_results=4)
80
+ all_text_with_urls = fetch_and_extract_content(params.data_format, urls, query_str)
 
 
81
  additional_context = limit_tokens(str(all_text_with_urls))
82
+ prompt = f"#### COMPLETE THE TASK: {params.description} #### IN THE CONTEXT OF ### CONTEXT: {query_str} USING THE #### SCRAPED DATA:{additional_context}"
83
  except Exception as e:
84
+ params.internet = False
85
  print("failed to search/scrape results, falling back to LLM response")
86
 
87
+ if not params.internet:
88
+ prompt = f"#### COMPLETE THE TASK: {params.description} #### IN THE CONTEXT OF ### CONTEXT: {query_str}"
89
 
90
  md_report = together_response(prompt, model=llm_default_medium, SysPrompt=sys_prompt_output_format)
91
 
92
+ if params.user_id != "test":
93
+ insert_data(params.user_id, query_str, params.description, str(all_text_with_urls), md_report)
94
 
95
  references_html = dict()
96
  for text, url in all_text_with_urls:
97
  references_html[url] = str(md_to_html(text))
98
 
 
99
  return {
100
  "report": md_to_html(md_report),
101
  "references": references_html,
102
  "search_query": optimized_search_query
103
  }
 
104
  app.add_middleware(
105
  CORSMiddleware,
106
  allow_origins=["*"],