[email protected]
commited on
Commit
·
ea077e1
1
Parent(s):
89033ee
Implement model integration strategy and selector for multiple AI models
Browse files- model/ModelIntegrations.py +37 -0
- model/ModelStrategy.py +6 -0
- model/selector.py +43 -0
- pages/chatbot.py +3 -0
- rag.py +2 -5
- requirements.txt +3 -0
model/ModelIntegrations.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .ModelStrategy import ModelStrategy
|
2 |
+
|
3 |
+
from langchain_community.chat_models import ChatOpenAI
|
4 |
+
from langchain_mistralai.chat_models import ChatMistralAI
|
5 |
+
from langchain_anthropic import ChatAnthropic
|
6 |
+
from langchain_ollama import ChatOllama
|
7 |
+
|
8 |
+
class MistralModel(ModelStrategy):
|
9 |
+
def get_model(self, model_name):
|
10 |
+
return ChatMistralAI(model=model_name)
|
11 |
+
|
12 |
+
|
13 |
+
class OpenAIModel(ModelStrategy):
|
14 |
+
def get_model(self, model_name):
|
15 |
+
return ChatOpenAI(model=model_name)
|
16 |
+
|
17 |
+
|
18 |
+
class AnthropicModel(ModelStrategy):
|
19 |
+
def get_model(self, model_name):
|
20 |
+
return ChatAnthropic(model=model_name)
|
21 |
+
|
22 |
+
|
23 |
+
class OllamaModel(ModelStrategy):
|
24 |
+
def get_model(self, model_name):
|
25 |
+
return ChatOllama(model=model_name)
|
26 |
+
|
27 |
+
class ModelManager():
|
28 |
+
def __init__(self):
|
29 |
+
self.models = {
|
30 |
+
"mistral": MistralModel(),
|
31 |
+
"openai": OpenAIModel(),
|
32 |
+
"anthropic": AnthropicModel(),
|
33 |
+
"ollama": OllamaModel()
|
34 |
+
}
|
35 |
+
|
36 |
+
def get_model(self, provider, model_name):
|
37 |
+
return self.models[provider].get_model(model_name)
|
model/ModelStrategy.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
|
3 |
+
class ModelStrategy(ABC):
|
4 |
+
@abstractmethod
|
5 |
+
def get_model(self, model_name):
|
6 |
+
pass
|
model/selector.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from .ModelIntegrations import ModelManager
|
3 |
+
|
4 |
+
def ModelSelector():
|
5 |
+
# Dictionnaire des modèles par fournisseur
|
6 |
+
model_providers = {
|
7 |
+
"Mistral": {
|
8 |
+
"mistral-large-latest": "mistral.mistral-large-latest",
|
9 |
+
"open-mixtral-8x7b": "mistral.open-mixtral-8x7b",
|
10 |
+
},
|
11 |
+
"OpenAI": {
|
12 |
+
"gpt-4o": "openai.gpt-4o",
|
13 |
+
},
|
14 |
+
"Anthropic": {
|
15 |
+
"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620",
|
16 |
+
"claude-3-opus-20240229": "anthropic.claude-3-opus-20240229",
|
17 |
+
"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229",
|
18 |
+
},
|
19 |
+
# "Ollama": {
|
20 |
+
# "llama3": "ollama.llama3"
|
21 |
+
# }
|
22 |
+
}
|
23 |
+
|
24 |
+
# Créer une liste avec les noms de modèle, groupés par fournisseur (fournisseur - modèle)
|
25 |
+
model_options = []
|
26 |
+
model_mapping = {}
|
27 |
+
|
28 |
+
for provider, models in model_providers.items():
|
29 |
+
for model_name, model_instance in models.items():
|
30 |
+
option_name = f"{provider} - {model_name}"
|
31 |
+
model_options.append(option_name)
|
32 |
+
model_mapping[option_name] = model_instance
|
33 |
+
|
34 |
+
# Sélection d'un modèle via un seul sélecteur
|
35 |
+
selected_model_option = st.selectbox("Choisissez votre modèle", options=model_options)
|
36 |
+
|
37 |
+
# Afficher le modèle sélectionné
|
38 |
+
st.write(f"Current model: {model_mapping[selected_model_option]}")
|
39 |
+
|
40 |
+
if(st.session_state["assistant"]):
|
41 |
+
splitter = model_mapping[selected_model_option].split(".")
|
42 |
+
st.session_state["assistant"].setModel(ModelManager().get_model(splitter[0], splitter[1]))
|
43 |
+
|
pages/chatbot.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import streamlit as st
|
2 |
from streamlit_chat import message
|
|
|
3 |
|
4 |
def display_messages():
|
5 |
for i, (msg, is_user) in enumerate(st.session_state["messages"]):
|
@@ -22,6 +23,8 @@ def process_input():
|
|
22 |
def page():
|
23 |
st.subheader("Posez vos questions")
|
24 |
|
|
|
|
|
25 |
if "assistant" not in st.session_state:
|
26 |
st.text("Assistant non initialisé")
|
27 |
|
|
|
1 |
import streamlit as st
|
2 |
from streamlit_chat import message
|
3 |
+
from model import selector
|
4 |
|
5 |
def display_messages():
|
6 |
for i, (msg, is_user) in enumerate(st.session_state["messages"]):
|
|
|
23 |
def page():
|
24 |
st.subheader("Posez vos questions")
|
25 |
|
26 |
+
selector.ModelSelector()
|
27 |
+
|
28 |
if "assistant" not in st.session_state:
|
29 |
st.text("Assistant non initialisé")
|
30 |
|
rag.py
CHANGED
@@ -19,7 +19,6 @@ from prompt_template import base_template
|
|
19 |
# load .env in local dev
|
20 |
load_dotenv()
|
21 |
env_api_key = os.environ.get("MISTRAL_API_KEY")
|
22 |
-
llm_model = "open-mixtral-8x7b"
|
23 |
|
24 |
class Rag:
|
25 |
document_vector_store = None
|
@@ -28,7 +27,7 @@ class Rag:
|
|
28 |
|
29 |
def __init__(self, vectore_store=None):
|
30 |
|
31 |
-
self.model = ChatMistralAI(model=llm_model)
|
32 |
self.embedding = MistralAIEmbeddings(model="mistral-embed", mistral_api_key=env_api_key)
|
33 |
|
34 |
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100, length_function=len)
|
@@ -73,11 +72,9 @@ class Rag:
|
|
73 |
)
|
74 |
|
75 |
def ask(self, query: str, messages: list):
|
76 |
-
|
77 |
self.chain = self.prompt | self.model | StrOutputParser()
|
78 |
|
79 |
-
print("messages ", messages)
|
80 |
-
|
81 |
# Retrieve the context document
|
82 |
if self.retriever is None:
|
83 |
documentContext = ''
|
|
|
19 |
# load .env in local dev
|
20 |
load_dotenv()
|
21 |
env_api_key = os.environ.get("MISTRAL_API_KEY")
|
|
|
22 |
|
23 |
class Rag:
|
24 |
document_vector_store = None
|
|
|
27 |
|
28 |
def __init__(self, vectore_store=None):
|
29 |
|
30 |
+
# self.model = ChatMistralAI(model=llm_model)
|
31 |
self.embedding = MistralAIEmbeddings(model="mistral-embed", mistral_api_key=env_api_key)
|
32 |
|
33 |
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100, length_function=len)
|
|
|
72 |
)
|
73 |
|
74 |
def ask(self, query: str, messages: list):
|
75 |
+
print(self.model)
|
76 |
self.chain = self.prompt | self.model | StrOutputParser()
|
77 |
|
|
|
|
|
78 |
# Retrieve the context document
|
79 |
if self.retriever is None:
|
80 |
documentContext = ''
|
requirements.txt
CHANGED
@@ -17,3 +17,6 @@ langchain-openai
|
|
17 |
langchain-community
|
18 |
langchain-pinecone
|
19 |
langchain_mistralai
|
|
|
|
|
|
|
|
17 |
langchain-community
|
18 |
langchain-pinecone
|
19 |
langchain_mistralai
|
20 |
+
langchain_anthropic
|
21 |
+
langchain_ollama
|
22 |
+
pyyaml
|