Ilyas KHIAT commited on
Commit
b66e2f4
·
1 Parent(s): 6c65306

changing base prompt

Browse files
Files changed (4) hide show
  1. main.py +14 -4
  2. models.py +1 -0
  3. rag.py +9 -1
  4. requirements.txt +1 -1
main.py CHANGED
@@ -10,6 +10,8 @@ from rag import *
10
  from fastapi.responses import StreamingResponse
11
  import json
12
  from prompts import *
 
 
13
 
14
  load_dotenv()
15
 
@@ -42,12 +44,16 @@ class StyleWriter(BaseModel):
42
  style: str
43
  tonality: str
44
 
 
 
45
  class UserInput(BaseModel):
46
  prompt: str
47
  enterprise_id: str
48
  stream: Optional[bool] = False
49
  messages: Optional[list[dict]] = []
50
  style_tonality: Optional[StyleWriter] = None
 
 
51
 
52
 
53
  class EnterpriseData(BaseModel):
@@ -175,11 +181,11 @@ def generate_answer(user_input: UserInput):
175
  context = ""
176
 
177
  if user_input.style_tonality is None:
178
- prompt_formated = prompt_reformatting(template_prompt,context,prompt)
179
- answer = generate_response_via_langchain(prompt, model="gpt-4o",stream=user_input.stream,context = context , messages=user_input.messages,template=template_prompt)
180
  else:
181
- prompt_formated = prompt_reformatting(template_prompt,context,prompt,style=user_input.style_tonality.style,tonality=user_input.style_tonality.tonality)
182
- answer = generate_response_via_langchain(prompt, model="gpt-4o",stream=user_input.stream,context = context , messages=user_input.messages,style=user_input.style_tonality.style,tonality=user_input.style_tonality.tonality,template=template_prompt)
183
 
184
  if user_input.stream:
185
  return StreamingResponse(stream_generator(answer,prompt_formated), media_type="application/json")
@@ -192,6 +198,10 @@ def generate_answer(user_input: UserInput):
192
 
193
  except Exception as e:
194
  raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
 
 
 
 
195
 
196
 
197
 
 
10
  from fastapi.responses import StreamingResponse
11
  import json
12
  from prompts import *
13
+ from typing import Literal
14
+ from models import *
15
 
16
  load_dotenv()
17
 
 
44
  style: str
45
  tonality: str
46
 
47
+ models = ["gpt-4o","gpt-4o-mini","mistral-large-latest"]
48
+
49
  class UserInput(BaseModel):
50
  prompt: str
51
  enterprise_id: str
52
  stream: Optional[bool] = False
53
  messages: Optional[list[dict]] = []
54
  style_tonality: Optional[StyleWriter] = None
55
+ marque: Optional[str] = None
56
+ model: Literal["gpt-4o","gpt-4o-mini","mistral-large-latest"] = "gpt-4o"
57
 
58
 
59
  class EnterpriseData(BaseModel):
 
181
  context = ""
182
 
183
  if user_input.style_tonality is None:
184
+ prompt_formated = prompt_reformatting(template_prompt,context,prompt,enterprise_name=getattr(user_input,"marque",""))
185
+ answer = generate_response_via_langchain(prompt, model=getattr(user_input,"model","gpt-4o"),stream=user_input.stream,context = context , messages=user_input.messages,template=template_prompt,enterprise_name=getattr(user_input,"marque",""))
186
  else:
187
+ prompt_formated = prompt_reformatting(template_prompt,context,prompt,style=user_input.style_tonality.style,tonality=user_input.style_tonality.tonality,enterprise_name=getattr(user_input,"marque",""))
188
+ answer = generate_response_via_langchain(prompt,model=getattr(user_input,"model","gpt-4o"),stream=user_input.stream,context = context , messages=user_input.messages,style=user_input.style_tonality.style,tonality=user_input.style_tonality.tonality,template=template_prompt,enterprise_name=getattr(user_input,"marque",""))
189
 
190
  if user_input.stream:
191
  return StreamingResponse(stream_generator(answer,prompt_formated), media_type="application/json")
 
198
 
199
  except Exception as e:
200
  raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
201
+
202
+ @app.get("/models")
203
+ def get_models():
204
+ return {"models": models}
205
 
206
 
207
 
models.py ADDED
@@ -0,0 +1 @@
 
 
1
+ models = ["gpt-4o","gpt-4o-mini","mistral-large-latest"]
rag.py CHANGED
@@ -7,6 +7,7 @@ from langchain_core.documents import Document
7
  from langchain_openai import ChatOpenAI
8
  from langchain_core.output_parsers import StrOutputParser
9
  from langchain_core.prompts import PromptTemplate
 
10
  from uuid import uuid4
11
 
12
  import unicodedata
@@ -105,9 +106,16 @@ def generate_response_via_langchain(query: str, stream: bool = False, model: str
105
 
106
 
107
  prompt = PromptTemplate.from_template(template)
 
 
 
108
 
109
  # Initialize the OpenAI LLM with the specified model
110
- llm = ChatOpenAI(model=model,temperature=0)
 
 
 
 
111
 
112
  # Create an LLM chain with the prompt and the LLM
113
  llm_chain = prompt | llm | StrOutputParser()
 
7
  from langchain_openai import ChatOpenAI
8
  from langchain_core.output_parsers import StrOutputParser
9
  from langchain_core.prompts import PromptTemplate
10
+ from langchain_mistralai import ChatMistralAI
11
  from uuid import uuid4
12
 
13
  import unicodedata
 
106
 
107
 
108
  prompt = PromptTemplate.from_template(template)
109
+
110
+ print(f"model: {model}")
111
+ print(f"marque: {enterprise_name}")
112
 
113
  # Initialize the OpenAI LLM with the specified model
114
+ if model.startswith("gpt"):
115
+ llm = ChatOpenAI(model=model,temperature=0)
116
+ if model.startswith("mistral"):
117
+ llm = ChatMistralAI(model=model,temperature=0)
118
+
119
 
120
  # Create an LLM chain with the prompt and the LLM
121
  llm_chain = prompt | llm | StrOutputParser()
requirements.txt CHANGED
@@ -13,4 +13,4 @@ langchain
13
  langchain-openai
14
  langchain-community
15
  langchain-pinecone
16
-
 
13
  langchain-openai
14
  langchain-community
15
  langchain-pinecone
16
+ langchain_mistralai