NCTCMumbai commited on
Commit
5131264
·
1 Parent(s): b11281d

Update backend/query_llm.py

Browse files
Files changed (1) hide show
  1. backend/query_llm.py +51 -9
backend/query_llm.py CHANGED
@@ -70,15 +70,57 @@ def generate_hf(prompt: str, history: str, temperature: float = 0.9, max_new_tok
70
 
71
  temperature = max(float(temperature), 1e-2) # Ensure temperature isn't too low
72
  top_p = float(top_p)
73
-
74
- generate_kwargs = {
75
- 'temperature': temperature,
76
- 'max_new_tokens': max_new_tokens,
77
- 'top_p': top_p,
78
- 'repetition_penalty': repetition_penalty,
79
- 'do_sample': True,
80
- 'seed': 42,
81
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  formatted_prompt = format_prompt(prompt, "hf")
84
 
 
70
 
71
  temperature = max(float(temperature), 1e-2) # Ensure temperature isn't too low
72
  top_p = float(top_p)
73
+ generate_kwargs=[
74
+ gr.Textbox(
75
+ label="System Prompt",
76
+ max_lines=1,
77
+ interactive=True,
78
+ ),
79
+ gr.Slider(
80
+ label="Temperature",
81
+ value=0.9,
82
+ minimum=0.0,
83
+ maximum=1.0,
84
+ step=0.05,
85
+ interactive=True,
86
+ info="Higher values produce more diverse outputs",
87
+ ),
88
+ gr.Slider(
89
+ label="Max new tokens",
90
+ value=256,
91
+ minimum=0,
92
+ maximum=4048,
93
+ step=64,
94
+ interactive=True,
95
+ info="The maximum numbers of new tokens",
96
+ ),
97
+ gr.Slider(
98
+ label="Top-p (nucleus sampling)",
99
+ value=0.90,
100
+ minimum=0.0,
101
+ maximum=1,
102
+ step=0.05,
103
+ interactive=True,
104
+ info="Higher values sample more low-probability tokens",
105
+ ),
106
+ gr.Slider(
107
+ label="Repetition penalty",
108
+ value=1.2,
109
+ minimum=1.0,
110
+ maximum=2.0,
111
+ step=0.05,
112
+ interactive=True,
113
+ info="Penalize repeated tokens",
114
+ )
115
+ ]
116
+ # generate_kwargs = {
117
+ # 'temperature': temperature,
118
+ # 'max_new_tokens': max_new_tokens,
119
+ # 'top_p': top_p,
120
+ # 'repetition_penalty': repetition_penalty,
121
+ # 'do_sample': True,
122
+ # 'seed': 42,
123
+ # }
124
 
125
  formatted_prompt = format_prompt(prompt, "hf")
126