Gregor Betz
commited on
Only HF
Browse files
app.py
CHANGED
@@ -26,20 +26,15 @@ CLIENT_MODEL_KWARGS = {
|
|
26 |
}
|
27 |
|
28 |
GUIDE_KWARGS = {
|
29 |
-
"expert_model": "
|
30 |
# "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
31 |
-
|
32 |
-
# "accounts/fireworks/models/llama-v3-8b-instruct-hf",
|
33 |
-
# "accounts/fireworks/models/nous-hermes-2-mixtral-8x7b-dpo-fp8",
|
34 |
-
"inference_server_url": "https://api.fireworks.ai/inference/v1",
|
35 |
# "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3.1-70B-Instruct",
|
36 |
-
|
37 |
-
"llm_backend": "Fireworks",
|
38 |
"classifier_kwargs": {
|
39 |
"model_id": "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
|
40 |
-
"inference_server_url": "https://
|
41 |
-
|
42 |
-
"batch_size": 128,
|
43 |
},
|
44 |
}
|
45 |
|
@@ -190,7 +185,7 @@ async def bot(
|
|
190 |
if len(history_langchain_format) <= 1:
|
191 |
|
192 |
guide_kwargs = copy.deepcopy(GUIDE_KWARGS)
|
193 |
-
guide_kwargs["api_key"] = os.getenv("
|
194 |
guide_kwargs["classifier_kwargs"]["api_key"] = os.getenv("HF_TOKEN") # classifier api key
|
195 |
|
196 |
guide_config = RecursiveBalancingGuideConfig(**guide_kwargs)
|
|
|
26 |
}
|
27 |
|
28 |
GUIDE_KWARGS = {
|
29 |
+
"expert_model": "HuggingFaceH4/zephyr-7b-beta",
|
30 |
# "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
31 |
+
"inference_server_url": "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta",
|
|
|
|
|
|
|
32 |
# "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3.1-70B-Instruct",
|
33 |
+
"llm_backend": "HFChat",
|
|
|
34 |
"classifier_kwargs": {
|
35 |
"model_id": "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
|
36 |
+
"inference_server_url": "https://api-inference.huggingface.co/models/MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
|
37 |
+
"batch_size": 8,
|
|
|
38 |
},
|
39 |
}
|
40 |
|
|
|
185 |
if len(history_langchain_format) <= 1:
|
186 |
|
187 |
guide_kwargs = copy.deepcopy(GUIDE_KWARGS)
|
188 |
+
guide_kwargs["api_key"] = os.getenv("HF_TOKEN") # expert model api key
|
189 |
guide_kwargs["classifier_kwargs"]["api_key"] = os.getenv("HF_TOKEN") # classifier api key
|
190 |
|
191 |
guide_config = RecursiveBalancingGuideConfig(**guide_kwargs)
|