nikravan commited on
Commit
80d5294
·
verified ·
1 Parent(s): 4f09fd3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -23
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import os
2
  import json
3
  import subprocess
4
  from threading import Thread
@@ -10,7 +10,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
10
 
11
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
12
 
13
- MODEL_ID = "nikravan/Marco_o1_q4"
14
  CHAT_TEMPLATE = "ChatML"
15
  MODEL_NAME = MODEL_ID.split("/")[-1]
16
  CONTEXT_LENGTH = 16000
@@ -120,26 +120,22 @@ model = AutoModelForCausalLM.from_pretrained(
120
  )
121
 
122
  # Create Gradio interface
123
- with gr.Blocks(theme=gr.themes.Soft(primary_hue=COLOR)) as demo:
124
- chatbot = gr.Chatbot(label=EMOJI + " " + MODEL_NAME, latex_delimiters=latex_delimiters_set)
125
- system_prompt = gr.Textbox("You are a code assistant.", label="System prompt")
126
- temperature = gr.Slider(0, 1, 0.3, label="Temperature")
127
- max_new_tokens = gr.Slider(128, 4096, 1024, label="Max new tokens")
128
- top_k = gr.Slider(1, 80, 40, label="Top K sampling")
129
- repetition_penalty = gr.Slider(0, 2, 1.1, label="Repetition penalty")
130
- top_p = gr.Slider(0, 1, 0.95, label="Top P sampling")
131
- message = gr.Textbox(label="User Input")
132
- submit = gr.Button("Submit")
133
-
134
- def respond(message, chatbot_history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
135
- response = predict(message, chatbot_history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p)
136
- chatbot_history.append((message, "".join(response)))
137
- return chatbot_history
138
 
139
- submit.click(
140
- respond,
141
- inputs=[message, chatbot, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p],
142
- outputs=chatbot
143
- )
144
 
145
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
  import json
3
  import subprocess
4
  from threading import Thread
 
10
 
11
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
12
 
13
+ MODEL_ID = "nikravan/Marco-O1-q4"
14
  CHAT_TEMPLATE = "ChatML"
15
  MODEL_NAME = MODEL_ID.split("/")[-1]
16
  CONTEXT_LENGTH = 16000
 
120
  )
121
 
122
  # Create Gradio interface
123
+ gr.ChatInterface(
124
+ predict,
125
+ title=EMOJI + " " + MODEL_NAME,
126
+ description=DESCRIPTION,
 
 
 
 
 
 
 
 
 
 
 
127
 
 
 
 
 
 
128
 
129
+
130
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False),
131
+ additional_inputs=[
132
+ gr.Textbox("You are a code assistant.", label="System prompt"),
133
+ gr.Slider(0, 1, 0.3, label="Temperature"),
134
+ gr.Slider(128, 4096, 1024, label="Max new tokens"),
135
+ gr.Slider(1, 80, 40, label="Top K sampling"),
136
+ gr.Slider(0, 2, 1.1, label="Repetition penalty"),
137
+ gr.Slider(0, 1, 0.95, label="Top P sampling"),
138
+ ],
139
+ theme=gr.themes.Soft(primary_hue=COLOR),
140
+ ).queue().launch()
141
+