Spaces:
Running
Running
Kang Suhyun
commited on
[#104] Display error message for the context window exceeded error (#105)
Browse filesChanges:
- Added exception handling for `ContextWindowExceededError` in the relevant module.
- Logged the error to capture detailed error information.
- Display an error message to inform users about the error and how to resolve it.
- model.py +16 -8
- response.py +14 -4
model.py
CHANGED
@@ -29,6 +29,10 @@ DEFAULT_SUMMARIZE_INSTRUCTION = "Summarize the following text, maintaining the l
|
|
29 |
DEFAULT_TRANSLATE_INSTRUCTION = "Translate the following text from {source_lang} to {target_lang}." # pylint: disable=line-too-long
|
30 |
|
31 |
|
|
|
|
|
|
|
|
|
32 |
class Model:
|
33 |
|
34 |
def __init__(
|
@@ -49,14 +53,18 @@ class Model:
|
|
49 |
self.translate_instruction = translateInstruction or DEFAULT_TRANSLATE_INSTRUCTION # pylint: disable=line-too-long
|
50 |
|
51 |
def completion(self, messages: List, max_tokens: float = None) -> str:
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
60 |
|
61 |
|
62 |
supported_models: List[Model] = [
|
|
|
29 |
DEFAULT_TRANSLATE_INSTRUCTION = "Translate the following text from {source_lang} to {target_lang}." # pylint: disable=line-too-long
|
30 |
|
31 |
|
32 |
+
class ContextWindowExceededError(Exception):
|
33 |
+
pass
|
34 |
+
|
35 |
+
|
36 |
class Model:
|
37 |
|
38 |
def __init__(
|
|
|
53 |
self.translate_instruction = translateInstruction or DEFAULT_TRANSLATE_INSTRUCTION # pylint: disable=line-too-long
|
54 |
|
55 |
def completion(self, messages: List, max_tokens: float = None) -> str:
|
56 |
+
try:
|
57 |
+
response = litellm.completion(model=self.provider + "/" +
|
58 |
+
self.name if self.provider else self.name,
|
59 |
+
api_key=self.api_key,
|
60 |
+
api_base=self.api_base,
|
61 |
+
messages=messages,
|
62 |
+
max_tokens=max_tokens)
|
63 |
+
|
64 |
+
return response.choices[0].message.content
|
65 |
+
|
66 |
+
except litellm.ContextWindowExceededError as e:
|
67 |
+
raise ContextWindowExceededError() from e
|
68 |
|
69 |
|
70 |
supported_models: List[Model] = [
|
response.py
CHANGED
@@ -3,6 +3,7 @@ This module contains functions for generating responses using LLMs.
|
|
3 |
"""
|
4 |
|
5 |
import enum
|
|
|
6 |
from random import sample
|
7 |
from typing import List
|
8 |
from uuid import uuid4
|
@@ -11,9 +12,14 @@ from firebase_admin import firestore
|
|
11 |
import gradio as gr
|
12 |
|
13 |
from leaderboard import db
|
|
|
14 |
from model import Model
|
15 |
from model import supported_models
|
16 |
|
|
|
|
|
|
|
|
|
17 |
|
18 |
def get_history_collection(category: str):
|
19 |
if category == Category.SUMMARIZE.value:
|
@@ -81,10 +87,14 @@ def get_responses(prompt: str, category: str, source_lang: str,
|
|
81 |
create_history(category, model.name, instruction, prompt, response)
|
82 |
responses.append(response)
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
88 |
|
89 |
model_names = [model.name for model in models]
|
90 |
|
|
|
3 |
"""
|
4 |
|
5 |
import enum
|
6 |
+
import logging
|
7 |
from random import sample
|
8 |
from typing import List
|
9 |
from uuid import uuid4
|
|
|
12 |
import gradio as gr
|
13 |
|
14 |
from leaderboard import db
|
15 |
+
from model import ContextWindowExceededError
|
16 |
from model import Model
|
17 |
from model import supported_models
|
18 |
|
19 |
+
logging.basicConfig()
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
logger.setLevel(logging.INFO)
|
22 |
+
|
23 |
|
24 |
def get_history_collection(category: str):
|
25 |
if category == Category.SUMMARIZE.value:
|
|
|
87 |
create_history(category, model.name, instruction, prompt, response)
|
88 |
responses.append(response)
|
89 |
|
90 |
+
except ContextWindowExceededError as e:
|
91 |
+
logger.exception("Context window exceeded for model %s.", model.name)
|
92 |
+
raise gr.Error(
|
93 |
+
"The prompt is too long. Please try again with a shorter prompt."
|
94 |
+
) from e
|
95 |
+
except Exception as e:
|
96 |
+
logger.exception("Failed to get response from model %s.", model.name)
|
97 |
+
raise gr.Error("Failed to get response. Please try again.") from e
|
98 |
|
99 |
model_names = [model.name for model in models]
|
100 |
|