Spaces:
Sleeping
Sleeping
Ilyas KHIAT
commited on
Commit
·
b66e2f4
1
Parent(s):
6c65306
changing base prompt
Browse files
main.py
CHANGED
@@ -10,6 +10,8 @@ from rag import *
|
|
10 |
from fastapi.responses import StreamingResponse
|
11 |
import json
|
12 |
from prompts import *
|
|
|
|
|
13 |
|
14 |
load_dotenv()
|
15 |
|
@@ -42,12 +44,16 @@ class StyleWriter(BaseModel):
|
|
42 |
style: str
|
43 |
tonality: str
|
44 |
|
|
|
|
|
45 |
class UserInput(BaseModel):
|
46 |
prompt: str
|
47 |
enterprise_id: str
|
48 |
stream: Optional[bool] = False
|
49 |
messages: Optional[list[dict]] = []
|
50 |
style_tonality: Optional[StyleWriter] = None
|
|
|
|
|
51 |
|
52 |
|
53 |
class EnterpriseData(BaseModel):
|
@@ -175,11 +181,11 @@ def generate_answer(user_input: UserInput):
|
|
175 |
context = ""
|
176 |
|
177 |
if user_input.style_tonality is None:
|
178 |
-
prompt_formated = prompt_reformatting(template_prompt,context,prompt)
|
179 |
-
answer = generate_response_via_langchain(prompt, model="gpt-4o",stream=user_input.stream,context = context , messages=user_input.messages,template=template_prompt)
|
180 |
else:
|
181 |
-
prompt_formated = prompt_reformatting(template_prompt,context,prompt,style=user_input.style_tonality.style,tonality=user_input.style_tonality.tonality)
|
182 |
-
answer = generate_response_via_langchain(prompt,
|
183 |
|
184 |
if user_input.stream:
|
185 |
return StreamingResponse(stream_generator(answer,prompt_formated), media_type="application/json")
|
@@ -192,6 +198,10 @@ def generate_answer(user_input: UserInput):
|
|
192 |
|
193 |
except Exception as e:
|
194 |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
|
|
|
|
|
|
|
|
195 |
|
196 |
|
197 |
|
|
|
10 |
from fastapi.responses import StreamingResponse
|
11 |
import json
|
12 |
from prompts import *
|
13 |
+
from typing import Literal
|
14 |
+
from models import *
|
15 |
|
16 |
load_dotenv()
|
17 |
|
|
|
44 |
style: str
|
45 |
tonality: str
|
46 |
|
47 |
+
models = ["gpt-4o","gpt-4o-mini","mistral-large-latest"]
|
48 |
+
|
49 |
class UserInput(BaseModel):
|
50 |
prompt: str
|
51 |
enterprise_id: str
|
52 |
stream: Optional[bool] = False
|
53 |
messages: Optional[list[dict]] = []
|
54 |
style_tonality: Optional[StyleWriter] = None
|
55 |
+
marque: Optional[str] = None
|
56 |
+
model: Literal["gpt-4o","gpt-4o-mini","mistral-large-latest"] = "gpt-4o"
|
57 |
|
58 |
|
59 |
class EnterpriseData(BaseModel):
|
|
|
181 |
context = ""
|
182 |
|
183 |
if user_input.style_tonality is None:
|
184 |
+
prompt_formated = prompt_reformatting(template_prompt,context,prompt,enterprise_name=getattr(user_input,"marque",""))
|
185 |
+
answer = generate_response_via_langchain(prompt, model=getattr(user_input,"model","gpt-4o"),stream=user_input.stream,context = context , messages=user_input.messages,template=template_prompt,enterprise_name=getattr(user_input,"marque",""))
|
186 |
else:
|
187 |
+
prompt_formated = prompt_reformatting(template_prompt,context,prompt,style=user_input.style_tonality.style,tonality=user_input.style_tonality.tonality,enterprise_name=getattr(user_input,"marque",""))
|
188 |
+
answer = generate_response_via_langchain(prompt,model=getattr(user_input,"model","gpt-4o"),stream=user_input.stream,context = context , messages=user_input.messages,style=user_input.style_tonality.style,tonality=user_input.style_tonality.tonality,template=template_prompt,enterprise_name=getattr(user_input,"marque",""))
|
189 |
|
190 |
if user_input.stream:
|
191 |
return StreamingResponse(stream_generator(answer,prompt_formated), media_type="application/json")
|
|
|
198 |
|
199 |
except Exception as e:
|
200 |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
201 |
+
|
202 |
+
@app.get("/models")
|
203 |
+
def get_models():
|
204 |
+
return {"models": models}
|
205 |
|
206 |
|
207 |
|
models.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
models = ["gpt-4o","gpt-4o-mini","mistral-large-latest"]
|
rag.py
CHANGED
@@ -7,6 +7,7 @@ from langchain_core.documents import Document
|
|
7 |
from langchain_openai import ChatOpenAI
|
8 |
from langchain_core.output_parsers import StrOutputParser
|
9 |
from langchain_core.prompts import PromptTemplate
|
|
|
10 |
from uuid import uuid4
|
11 |
|
12 |
import unicodedata
|
@@ -105,9 +106,16 @@ def generate_response_via_langchain(query: str, stream: bool = False, model: str
|
|
105 |
|
106 |
|
107 |
prompt = PromptTemplate.from_template(template)
|
|
|
|
|
|
|
108 |
|
109 |
# Initialize the OpenAI LLM with the specified model
|
110 |
-
|
|
|
|
|
|
|
|
|
111 |
|
112 |
# Create an LLM chain with the prompt and the LLM
|
113 |
llm_chain = prompt | llm | StrOutputParser()
|
|
|
7 |
from langchain_openai import ChatOpenAI
|
8 |
from langchain_core.output_parsers import StrOutputParser
|
9 |
from langchain_core.prompts import PromptTemplate
|
10 |
+
from langchain_mistralai import ChatMistralAI
|
11 |
from uuid import uuid4
|
12 |
|
13 |
import unicodedata
|
|
|
106 |
|
107 |
|
108 |
prompt = PromptTemplate.from_template(template)
|
109 |
+
|
110 |
+
print(f"model: {model}")
|
111 |
+
print(f"marque: {enterprise_name}")
|
112 |
|
113 |
# Initialize the OpenAI LLM with the specified model
|
114 |
+
if model.startswith("gpt"):
|
115 |
+
llm = ChatOpenAI(model=model,temperature=0)
|
116 |
+
if model.startswith("mistral"):
|
117 |
+
llm = ChatMistralAI(model=model,temperature=0)
|
118 |
+
|
119 |
|
120 |
# Create an LLM chain with the prompt and the LLM
|
121 |
llm_chain = prompt | llm | StrOutputParser()
|
requirements.txt
CHANGED
@@ -13,4 +13,4 @@ langchain
|
|
13 |
langchain-openai
|
14 |
langchain-community
|
15 |
langchain-pinecone
|
16 |
-
|
|
|
13 |
langchain-openai
|
14 |
langchain-community
|
15 |
langchain-pinecone
|
16 |
+
langchain_mistralai
|