nisten commited on
Commit
e203e91
·
verified ·
1 Parent(s): 159c2ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -23
app.py CHANGED
@@ -4,8 +4,9 @@ import torch
4
  import subprocess
5
  import sys
6
 
7
- # Force install the specific transformers version from the GitHub PR
8
  subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "--force-reinstall", "--no-deps", "accelerate", "git+https://github.com/Muennighoff/transformers.git@olmoe"])
 
9
 
10
  from transformers import OlmoeForCausalLM, AutoTokenizer
11
 
@@ -19,7 +20,8 @@ try:
19
  trust_remote_code=True,
20
  torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32,
21
  low_cpu_mem_usage=True,
22
- device_map="auto"
 
23
  )
24
  tokenizer = AutoTokenizer.from_pretrained(model_name)
25
  except Exception as e:
@@ -32,29 +34,24 @@ system_prompt = ("Adopt the persona of hilariously pissed off Andrej Karpathy "
32
  "while always answering questions in full first principles analysis type of thinking "
33
  "without using any analogies and always showing full working code or output in his answers.")
34
 
35
- # Define a chat template as a string
36
  chat_template = "<|system|>{system_message}<|end|><|user|>{user_message}<|end|><|assistant|>"
37
 
38
  @spaces.GPU
39
  def generate_response(message, history, temperature, max_new_tokens):
40
  if model is None or tokenizer is None:
41
- return "Model or tokenizer not loaded properly. Please check the logs."
42
-
43
- # Construct the full prompt
44
  full_prompt = chat_template.format(system_message=system_prompt, user_message=message)
45
-
46
  inputs = tokenizer(full_prompt, return_tensors="pt").to(DEVICE)
47
 
48
- with torch.no_grad():
49
- generate_ids = model.generate(
50
- inputs.input_ids,
51
- max_new_tokens=max_new_tokens,
52
- do_sample=True,
53
- temperature=temperature,
54
- eos_token_id=tokenizer.eos_token_id,
55
- )
56
- response = tokenizer.decode(generate_ids[0, inputs.input_ids.shape[1]:], skip_special_tokens=True)
57
- return response.strip()
58
 
59
  css = """
60
  #output {
@@ -65,7 +62,7 @@ css = """
65
  """
66
 
67
  with gr.Blocks(css=css) as demo:
68
- gr.Markdown("# Nisten's Karpathy Chatbot with OSS OLMoE")
69
  chatbot = gr.Chatbot(elem_id="output")
70
  msg = gr.Textbox(label="Meow")
71
  with gr.Row():
@@ -78,15 +75,17 @@ with gr.Blocks(css=css) as demo:
78
 
79
  def bot(history, temp, max_tokens):
80
  user_message = history[-1][0]
81
- bot_message = generate_response(user_message, history, temp, max_tokens)
82
- history[-1][1] = bot_message
83
- return history
 
 
84
 
85
  msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
86
  bot, [chatbot, temperature, max_new_tokens], chatbot
87
  )
88
- clear.click(lambda: None, None, chatbot, queue=False)
89
 
90
  if __name__ == "__main__":
91
  demo.queue(api_open=True)
92
- demo.launch(debug=True, show_api=True, share=True)
 
4
  import subprocess
5
  import sys
6
 
7
+ # Install required packages
8
  subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "--force-reinstall", "--no-deps", "accelerate", "git+https://github.com/Muennighoff/transformers.git@olmoe"])
9
+ subprocess.run('pip install flash-attn --no-build-isolation --no-deps', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
10
 
11
  from transformers import OlmoeForCausalLM, AutoTokenizer
12
 
 
20
  trust_remote_code=True,
21
  torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32,
22
  low_cpu_mem_usage=True,
23
+ device_map="auto",
24
+ _attn_implementation="flash_attention_2" # Enable Flash Attention 2
25
  )
26
  tokenizer = AutoTokenizer.from_pretrained(model_name)
27
  except Exception as e:
 
34
  "while always answering questions in full first principles analysis type of thinking "
35
  "without using any analogies and always showing full working code or output in his answers.")
36
 
 
37
  chat_template = "<|system|>{system_message}<|end|><|user|>{user_message}<|end|><|assistant|>"
38
 
39
  @spaces.GPU
40
  def generate_response(message, history, temperature, max_new_tokens):
41
  if model is None or tokenizer is None:
42
+ yield "Model or tokenizer not loaded properly. Please check the logs."
43
+ return
44
+
45
  full_prompt = chat_template.format(system_message=system_prompt, user_message=message)
 
46
  inputs = tokenizer(full_prompt, return_tensors="pt").to(DEVICE)
47
 
48
+ streamer = tokenizer.stream(inputs.input_ids, model, temperature=temperature, max_new_tokens=max_new_tokens)
49
+
50
+ collected_tokens = []
51
+ for token in streamer:
52
+ collected_tokens.append(token)
53
+ partial_text = tokenizer.decode(collected_tokens, skip_special_tokens=True)
54
+ yield partial_text.strip()
 
 
 
55
 
56
  css = """
57
  #output {
 
62
  """
63
 
64
  with gr.Blocks(css=css) as demo:
65
+ gr.Markdown("# Nisten's Karpathy Chatbot with OSS OLMoE (Now with Flash Attention 2!)")
66
  chatbot = gr.Chatbot(elem_id="output")
67
  msg = gr.Textbox(label="Meow")
68
  with gr.Row():
 
75
 
76
  def bot(history, temp, max_tokens):
77
  user_message = history[-1][0]
78
+ bot_message = ""
79
+ for token in generate_response(user_message, history, temp, max_tokens):
80
+ bot_message = token
81
+ history[-1][1] = bot_message
82
+ yield history
83
 
84
  msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
85
  bot, [chatbot, temperature, max_new_tokens], chatbot
86
  )
87
+ clear.click(lambda: None, None, chatbot, queue=True)
88
 
89
  if __name__ == "__main__":
90
  demo.queue(api_open=True)
91
+ demo.launch(debug=True, show_api=True)