[email protected] commited on
Commit
ea077e1
·
1 Parent(s): 89033ee

Implement model integration strategy and selector for multiple AI models

Browse files
model/ModelIntegrations.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .ModelStrategy import ModelStrategy
2
+
3
+ from langchain_community.chat_models import ChatOpenAI
4
+ from langchain_mistralai.chat_models import ChatMistralAI
5
+ from langchain_anthropic import ChatAnthropic
6
+ from langchain_ollama import ChatOllama
7
+
8
+ class MistralModel(ModelStrategy):
9
+ def get_model(self, model_name):
10
+ return ChatMistralAI(model=model_name)
11
+
12
+
13
+ class OpenAIModel(ModelStrategy):
14
+ def get_model(self, model_name):
15
+ return ChatOpenAI(model=model_name)
16
+
17
+
18
+ class AnthropicModel(ModelStrategy):
19
+ def get_model(self, model_name):
20
+ return ChatAnthropic(model=model_name)
21
+
22
+
23
+ class OllamaModel(ModelStrategy):
24
+ def get_model(self, model_name):
25
+ return ChatOllama(model=model_name)
26
+
27
+ class ModelManager():
28
+ def __init__(self):
29
+ self.models = {
30
+ "mistral": MistralModel(),
31
+ "openai": OpenAIModel(),
32
+ "anthropic": AnthropicModel(),
33
+ "ollama": OllamaModel()
34
+ }
35
+
36
+ def get_model(self, provider, model_name):
37
+ return self.models[provider].get_model(model_name)
model/ModelStrategy.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ class ModelStrategy(ABC):
4
+ @abstractmethod
5
+ def get_model(self, model_name):
6
+ pass
model/selector.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from .ModelIntegrations import ModelManager
3
+
4
+ def ModelSelector():
5
+ # Dictionnaire des modèles par fournisseur
6
+ model_providers = {
7
+ "Mistral": {
8
+ "mistral-large-latest": "mistral.mistral-large-latest",
9
+ "open-mixtral-8x7b": "mistral.open-mixtral-8x7b",
10
+ },
11
+ "OpenAI": {
12
+ "gpt-4o": "openai.gpt-4o",
13
+ },
14
+ "Anthropic": {
15
+ "claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620",
16
+ "claude-3-opus-20240229": "anthropic.claude-3-opus-20240229",
17
+ "claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229",
18
+ },
19
+ # "Ollama": {
20
+ # "llama3": "ollama.llama3"
21
+ # }
22
+ }
23
+
24
+ # Créer une liste avec les noms de modèle, groupés par fournisseur (fournisseur - modèle)
25
+ model_options = []
26
+ model_mapping = {}
27
+
28
+ for provider, models in model_providers.items():
29
+ for model_name, model_instance in models.items():
30
+ option_name = f"{provider} - {model_name}"
31
+ model_options.append(option_name)
32
+ model_mapping[option_name] = model_instance
33
+
34
+ # Sélection d'un modèle via un seul sélecteur
35
+ selected_model_option = st.selectbox("Choisissez votre modèle", options=model_options)
36
+
37
+ # Afficher le modèle sélectionné
38
+ st.write(f"Current model: {model_mapping[selected_model_option]}")
39
+
40
+ if(st.session_state["assistant"]):
41
+ splitter = model_mapping[selected_model_option].split(".")
42
+ st.session_state["assistant"].setModel(ModelManager().get_model(splitter[0], splitter[1]))
43
+
pages/chatbot.py CHANGED
@@ -1,5 +1,6 @@
1
  import streamlit as st
2
  from streamlit_chat import message
 
3
 
4
  def display_messages():
5
  for i, (msg, is_user) in enumerate(st.session_state["messages"]):
@@ -22,6 +23,8 @@ def process_input():
22
  def page():
23
  st.subheader("Posez vos questions")
24
 
 
 
25
  if "assistant" not in st.session_state:
26
  st.text("Assistant non initialisé")
27
 
 
1
  import streamlit as st
2
  from streamlit_chat import message
3
+ from model import selector
4
 
5
  def display_messages():
6
  for i, (msg, is_user) in enumerate(st.session_state["messages"]):
 
23
  def page():
24
  st.subheader("Posez vos questions")
25
 
26
+ selector.ModelSelector()
27
+
28
  if "assistant" not in st.session_state:
29
  st.text("Assistant non initialisé")
30
 
rag.py CHANGED
@@ -19,7 +19,6 @@ from prompt_template import base_template
19
  # load .env in local dev
20
  load_dotenv()
21
  env_api_key = os.environ.get("MISTRAL_API_KEY")
22
- llm_model = "open-mixtral-8x7b"
23
 
24
  class Rag:
25
  document_vector_store = None
@@ -28,7 +27,7 @@ class Rag:
28
 
29
  def __init__(self, vectore_store=None):
30
 
31
- self.model = ChatMistralAI(model=llm_model)
32
  self.embedding = MistralAIEmbeddings(model="mistral-embed", mistral_api_key=env_api_key)
33
 
34
  self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100, length_function=len)
@@ -73,11 +72,9 @@ class Rag:
73
  )
74
 
75
  def ask(self, query: str, messages: list):
76
-
77
  self.chain = self.prompt | self.model | StrOutputParser()
78
 
79
- print("messages ", messages)
80
-
81
  # Retrieve the context document
82
  if self.retriever is None:
83
  documentContext = ''
 
19
  # load .env in local dev
20
  load_dotenv()
21
  env_api_key = os.environ.get("MISTRAL_API_KEY")
 
22
 
23
  class Rag:
24
  document_vector_store = None
 
27
 
28
  def __init__(self, vectore_store=None):
29
 
30
+ # self.model = ChatMistralAI(model=llm_model)
31
  self.embedding = MistralAIEmbeddings(model="mistral-embed", mistral_api_key=env_api_key)
32
 
33
  self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100, length_function=len)
 
72
  )
73
 
74
  def ask(self, query: str, messages: list):
75
+ print(self.model)
76
  self.chain = self.prompt | self.model | StrOutputParser()
77
 
 
 
78
  # Retrieve the context document
79
  if self.retriever is None:
80
  documentContext = ''
requirements.txt CHANGED
@@ -17,3 +17,6 @@ langchain-openai
17
  langchain-community
18
  langchain-pinecone
19
  langchain_mistralai
 
 
 
 
17
  langchain-community
18
  langchain-pinecone
19
  langchain_mistralai
20
+ langchain_anthropic
21
+ langchain_ollama
22
+ pyyaml