NCTCMumbai commited on
Commit
aad5b37
·
1 Parent(s): 5131264

Update backend/query_llm.py

Browse files
Files changed (1) hide show
  1. backend/query_llm.py +10 -52
backend/query_llm.py CHANGED
@@ -50,7 +50,7 @@ def format_prompt(message: str, api_kind: str):
50
  raise ValueError("API is not supported")
51
 
52
 
53
- def generate_hf(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 3000,
54
  top_p: float = 0.95, repetition_penalty: float = 1.0) -> Generator[str, None, str]:
55
  """
56
  Generate a sequence of tokens based on a given prompt and history using Mistral client.
@@ -70,57 +70,15 @@ def generate_hf(prompt: str, history: str, temperature: float = 0.9, max_new_tok
70
 
71
  temperature = max(float(temperature), 1e-2) # Ensure temperature isn't too low
72
  top_p = float(top_p)
73
- generate_kwargs=[
74
- gr.Textbox(
75
- label="System Prompt",
76
- max_lines=1,
77
- interactive=True,
78
- ),
79
- gr.Slider(
80
- label="Temperature",
81
- value=0.9,
82
- minimum=0.0,
83
- maximum=1.0,
84
- step=0.05,
85
- interactive=True,
86
- info="Higher values produce more diverse outputs",
87
- ),
88
- gr.Slider(
89
- label="Max new tokens",
90
- value=256,
91
- minimum=0,
92
- maximum=4048,
93
- step=64,
94
- interactive=True,
95
- info="The maximum numbers of new tokens",
96
- ),
97
- gr.Slider(
98
- label="Top-p (nucleus sampling)",
99
- value=0.90,
100
- minimum=0.0,
101
- maximum=1,
102
- step=0.05,
103
- interactive=True,
104
- info="Higher values sample more low-probability tokens",
105
- ),
106
- gr.Slider(
107
- label="Repetition penalty",
108
- value=1.2,
109
- minimum=1.0,
110
- maximum=2.0,
111
- step=0.05,
112
- interactive=True,
113
- info="Penalize repeated tokens",
114
- )
115
- ]
116
- # generate_kwargs = {
117
- # 'temperature': temperature,
118
- # 'max_new_tokens': max_new_tokens,
119
- # 'top_p': top_p,
120
- # 'repetition_penalty': repetition_penalty,
121
- # 'do_sample': True,
122
- # 'seed': 42,
123
- # }
124
 
125
  formatted_prompt = format_prompt(prompt, "hf")
126
 
 
50
  raise ValueError("API is not supported")
51
 
52
 
53
+ def generate_hf(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 4000,
54
  top_p: float = 0.95, repetition_penalty: float = 1.0) -> Generator[str, None, str]:
55
  """
56
  Generate a sequence of tokens based on a given prompt and history using Mistral client.
 
70
 
71
  temperature = max(float(temperature), 1e-2) # Ensure temperature isn't too low
72
  top_p = float(top_p)
73
+
74
+ generate_kwargs = {
75
+ 'temperature': temperature,
76
+ 'max_new_tokens': max_new_tokens,
77
+ 'top_p': top_p,
78
+ 'repetition_penalty': repetition_penalty,
79
+ 'do_sample': True,
80
+ 'seed': 42,
81
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  formatted_prompt = format_prompt(prompt, "hf")
84