vilarin commited on
Commit
8d12c02
·
verified ·
1 Parent(s): 8d02472

Update app/webui/patch.py

Browse files
Files changed (1) hide show
  1. app/webui/patch.py +130 -130
app/webui/patch.py CHANGED
@@ -1,131 +1,131 @@
1
- # a monkey patch to use llama-index completion
2
- from typing import Union, Callable
3
- from functools import wraps
4
- from src.translation_agent.utils import *
5
-
6
-
7
- from llama_index.llms.groq import Groq
8
- from llama_index.llms.cohere import Cohere
9
- from llama_index.llms.openai import OpenAI
10
- from llama_index.llms.together import TogetherLLM
11
- from llama_index.llms.ollama import Ollama
12
- from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
13
-
14
- from llama_index.core import Settings
15
- from llama_index.core.llms import ChatMessage
16
-
17
-
18
- # Add your LLMs here
19
-
20
- def model_load(
21
- endpoint: str,
22
- model: str,
23
- api_key: str = None,
24
- context_window: int = 4096,
25
- num_output: int = 512,
26
- ):
27
- if endpoint == "Groq":
28
- llm = Groq(
29
- model=model,
30
- api_key=api_key,
31
- )
32
- elif endpoint == "Cohere":
33
- llm = Cohere(
34
- model=model,
35
- api_key=api_key,
36
- )
37
- elif endpoint == "OpenAI":
38
- llm = OpenAI(
39
- model=model,
40
- api_key=api_key,
41
- )
42
- elif endpoint == "TogetherAI":
43
- llm = TogetherLLM(
44
- model=model,
45
- api_key=api_key,
46
- )
47
- elif endpoint == "ollama":
48
- llm = Ollama(
49
- model=model,
50
- request_timeout=120.0)
51
- elif endpoint == "Huggingface":
52
- llm = HuggingFaceInferenceAPI(
53
- model_name=model,
54
- token=api_key,
55
- task="text-generation",
56
- )
57
- Settings.llm = llm
58
- # maximum input size to the LLM
59
- Settings.context_window = context_window
60
-
61
- # number of tokens reserved for text generation.
62
- Settings.num_output = num_output
63
-
64
-
65
-
66
- def completion_wrapper(func: Callable) -> Callable:
67
- @wraps(func)
68
- def wrapper(
69
- prompt: str,
70
- system_message: str = "You are a helpful assistant.",
71
- temperature: float = 0.3,
72
- json_mode: bool = False,
73
- ) -> Union[str, dict]:
74
- """
75
- Generate a completion using the OpenAI API.
76
-
77
- Args:
78
- prompt (str): The user's prompt or query.
79
- system_message (str, optional): The system message to set the context for the assistant.
80
- Defaults to "You are a helpful assistant.".
81
- temperature (float, optional): The sampling temperature for controlling the randomness of the generated text.
82
- Defaults to 0.3.
83
- json_mode (bool, optional): Whether to return the response in JSON format.
84
- Defaults to False.
85
-
86
- Returns:
87
- Union[str, dict]: The generated completion.
88
- If json_mode is True, returns the complete API response as a dictionary.
89
- If json_mode is False, returns the generated text as a string.
90
- """
91
- llm = Settings.llm
92
- if llm.class_name() == "HuggingFaceInferenceAPI":
93
- llm.system_prompt = system_message
94
- messages = [
95
- ChatMessage(
96
- role="user", content=prompt),
97
- ]
98
- response = llm.chat(
99
- messages=messages,
100
- temperature=temperature,
101
- top_p=1,
102
- )
103
- return response.message.content
104
- else:
105
- messages = [
106
- ChatMessage(
107
- role="system", content=system_message),
108
- ChatMessage(
109
- role="user", content=prompt),
110
- ]
111
-
112
- if json_mode:
113
- response = llm.chat(
114
- temperature=temperature,
115
- top_p=1,
116
- response_format={"type": "json_object"},
117
- messages=messages,
118
- )
119
- return response.message.content
120
- else:
121
- response = llm.chat(
122
- temperature=temperature,
123
- top_p=1,
124
- messages=messages,
125
- )
126
- return response.message.content
127
-
128
- return wrapper
129
-
130
- openai_completion = get_completion
131
  get_completion = completion_wrapper(openai_completion)
 
1
+ # a monkey patch to use llama-index completion
2
+ from typing import Union, Callable
3
+ from functools import wraps
4
+ from src.translation_agent.utils import *
5
+
6
+
7
+ from llama_index.llms.groq import Groq
8
+ from llama_index.llms.cohere import Cohere
9
+ from llama_index.llms.openai import OpenAI
10
+ from llama_index.llms.together import TogetherLLM
11
+ from llama_index.llms.ollama import Ollama
12
+ from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
13
+
14
+ from llama_index.core import Settings
15
+ from llama_index.core.llms import ChatMessage
16
+
17
+
18
+ # Add your LLMs here
19
+
20
+ def model_load(
21
+ endpoint: str,
22
+ model: str,
23
+ api_key: str = None,
24
+ context_window: int = 4096,
25
+ num_output: int = 512,
26
+ ):
27
+ if endpoint == "Groq":
28
+ llm = Groq(
29
+ model=model,
30
+ api_key=api_key,
31
+ )
32
+ elif endpoint == "Cohere":
33
+ llm = Cohere(
34
+ model=model,
35
+ api_key=api_key,
36
+ )
37
+ elif endpoint == "OpenAI":
38
+ llm = OpenAI(
39
+ model=model,
40
+ api_key=api_key if api_key else os.getenv("OPENAI_API_KEY"),
41
+ )
42
+ elif endpoint == "TogetherAI":
43
+ llm = TogetherLLM(
44
+ model=model,
45
+ api_key=api_key,
46
+ )
47
+ elif endpoint == "ollama":
48
+ llm = Ollama(
49
+ model=model,
50
+ request_timeout=120.0)
51
+ elif endpoint == "Huggingface":
52
+ llm = HuggingFaceInferenceAPI(
53
+ model_name=model,
54
+ token=api_key,
55
+ task="text-generation",
56
+ )
57
+ Settings.llm = llm
58
+ # maximum input size to the LLM
59
+ Settings.context_window = context_window
60
+
61
+ # number of tokens reserved for text generation.
62
+ Settings.num_output = num_output
63
+
64
+
65
+
66
+ def completion_wrapper(func: Callable) -> Callable:
67
+ @wraps(func)
68
+ def wrapper(
69
+ prompt: str,
70
+ system_message: str = "You are a helpful assistant.",
71
+ temperature: float = 0.3,
72
+ json_mode: bool = False,
73
+ ) -> Union[str, dict]:
74
+ """
75
+ Generate a completion using the OpenAI API.
76
+
77
+ Args:
78
+ prompt (str): The user's prompt or query.
79
+ system_message (str, optional): The system message to set the context for the assistant.
80
+ Defaults to "You are a helpful assistant.".
81
+ temperature (float, optional): The sampling temperature for controlling the randomness of the generated text.
82
+ Defaults to 0.3.
83
+ json_mode (bool, optional): Whether to return the response in JSON format.
84
+ Defaults to False.
85
+
86
+ Returns:
87
+ Union[str, dict]: The generated completion.
88
+ If json_mode is True, returns the complete API response as a dictionary.
89
+ If json_mode is False, returns the generated text as a string.
90
+ """
91
+ llm = Settings.llm
92
+ if llm.class_name() == "HuggingFaceInferenceAPI":
93
+ llm.system_prompt = system_message
94
+ messages = [
95
+ ChatMessage(
96
+ role="user", content=prompt),
97
+ ]
98
+ response = llm.chat(
99
+ messages=messages,
100
+ temperature=temperature,
101
+ top_p=1,
102
+ )
103
+ return response.message.content
104
+ else:
105
+ messages = [
106
+ ChatMessage(
107
+ role="system", content=system_message),
108
+ ChatMessage(
109
+ role="user", content=prompt),
110
+ ]
111
+
112
+ if json_mode:
113
+ response = llm.chat(
114
+ temperature=temperature,
115
+ top_p=1,
116
+ response_format={"type": "json_object"},
117
+ messages=messages,
118
+ )
119
+ return response.message.content
120
+ else:
121
+ response = llm.chat(
122
+ temperature=temperature,
123
+ top_p=1,
124
+ messages=messages,
125
+ )
126
+ return response.message.content
127
+
128
+ return wrapper
129
+
130
+ openai_completion = get_completion
131
  get_completion = completion_wrapper(openai_completion)