nisten commited on
Commit
159c2ce
·
verified ·
1 Parent(s): b04ca7b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -16
app.py CHANGED
@@ -32,49 +32,45 @@ system_prompt = ("Adopt the persona of hilariously pissed off Andrej Karpathy "
32
  "while always answering questions in full first principles analysis type of thinking "
33
  "without using any analogies and always showing full working code or output in his answers.")
34
 
35
- # Define a chat template
36
- chat_template = {
37
- "system": "<|system|>{content}<|end|>",
38
- "user": "<|user|>{content}<|end|>",
39
- "assistant": "<|assistant|>{content}<|end|>",
40
- }
41
 
42
  @spaces.GPU
43
  def generate_response(message, history, temperature, max_new_tokens):
44
  if model is None or tokenizer is None:
45
  return "Model or tokenizer not loaded properly. Please check the logs."
46
 
47
- messages = [{"role": "system", "content": system_prompt},
48
- {"role": "user", "content": message}]
49
 
50
- inputs = tokenizer.apply_chat_template(messages, chat_template=chat_template, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(DEVICE)
51
 
52
  with torch.no_grad():
53
  generate_ids = model.generate(
54
- inputs,
55
  max_new_tokens=max_new_tokens,
56
  do_sample=True,
57
  temperature=temperature,
58
  eos_token_id=tokenizer.eos_token_id,
59
  )
60
- response = tokenizer.decode(generate_ids[0, inputs.shape[1]:], skip_special_tokens=True)
61
  return response.strip()
62
 
63
  css = """
64
  #output {
65
- height: 900px;
66
  overflow: auto;
67
- border: 1px solid #ccc;
68
  }
69
  """
70
 
71
  with gr.Blocks(css=css) as demo:
72
  gr.Markdown("# Nisten's Karpathy Chatbot with OSS OLMoE")
73
  chatbot = gr.Chatbot(elem_id="output")
74
- msg = gr.Textbox(label="Your message")
75
  with gr.Row():
76
  temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
77
- max_new_tokens = gr.Slider(minimum=50, maximum=4000, value=1000, step=50, label="Max New Tokens")
78
  clear = gr.Button("Clear")
79
 
80
  def user(user_message, history):
@@ -93,4 +89,4 @@ with gr.Blocks(css=css) as demo:
93
 
94
  if __name__ == "__main__":
95
  demo.queue(api_open=True)
96
- demo.launch(debug=True, show_api=True, share=True)
 
32
  "while always answering questions in full first principles analysis type of thinking "
33
  "without using any analogies and always showing full working code or output in his answers.")
34
 
35
+ # Define a chat template as a string
36
+ chat_template = "<|system|>{system_message}<|end|><|user|>{user_message}<|end|><|assistant|>"
 
 
 
 
37
 
38
  @spaces.GPU
39
  def generate_response(message, history, temperature, max_new_tokens):
40
  if model is None or tokenizer is None:
41
  return "Model or tokenizer not loaded properly. Please check the logs."
42
 
43
+ # Construct the full prompt
44
+ full_prompt = chat_template.format(system_message=system_prompt, user_message=message)
45
 
46
+ inputs = tokenizer(full_prompt, return_tensors="pt").to(DEVICE)
47
 
48
  with torch.no_grad():
49
  generate_ids = model.generate(
50
+ inputs.input_ids,
51
  max_new_tokens=max_new_tokens,
52
  do_sample=True,
53
  temperature=temperature,
54
  eos_token_id=tokenizer.eos_token_id,
55
  )
56
+ response = tokenizer.decode(generate_ids[0, inputs.input_ids.shape[1]:], skip_special_tokens=True)
57
  return response.strip()
58
 
59
  css = """
60
  #output {
61
+ height: 1000px;
62
  overflow: auto;
63
+ border: 2px solid #ccc;
64
  }
65
  """
66
 
67
  with gr.Blocks(css=css) as demo:
68
  gr.Markdown("# Nisten's Karpathy Chatbot with OSS OLMoE")
69
  chatbot = gr.Chatbot(elem_id="output")
70
+ msg = gr.Textbox(label="Meow")
71
  with gr.Row():
72
  temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
73
+ max_new_tokens = gr.Slider(minimum=50, maximum=4000, value=2000, step=50, label="Max New Tokens")
74
  clear = gr.Button("Clear")
75
 
76
  def user(user_message, history):
 
89
 
90
  if __name__ == "__main__":
91
  demo.queue(api_open=True)
92
+ demo.launch(debug=True, show_api=True, share=True)