nisten commited on
Commit
e9acdad
·
verified ·
1 Parent(s): e203e91

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -14
app.py CHANGED
@@ -5,8 +5,8 @@ import subprocess
5
  import sys
6
 
7
  # Install required packages
8
- subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "--force-reinstall", "--no-deps", "accelerate", "git+https://github.com/Muennighoff/transformers.git@olmoe"])
9
- subprocess.run('pip install flash-attn --no-build-isolation --no-deps', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
10
 
11
  from transformers import OlmoeForCausalLM, AutoTokenizer
12
 
@@ -18,11 +18,12 @@ try:
18
  model = OlmoeForCausalLM.from_pretrained(
19
  model_name,
20
  trust_remote_code=True,
21
- torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32,
22
  low_cpu_mem_usage=True,
23
  device_map="auto",
24
  _attn_implementation="flash_attention_2" # Enable Flash Attention 2
25
  )
 
26
  tokenizer = AutoTokenizer.from_pretrained(model_name)
27
  except Exception as e:
28
  print(f"Error loading model: {e}")
@@ -45,13 +46,22 @@ def generate_response(message, history, temperature, max_new_tokens):
45
  full_prompt = chat_template.format(system_message=system_prompt, user_message=message)
46
  inputs = tokenizer(full_prompt, return_tensors="pt").to(DEVICE)
47
 
48
- streamer = tokenizer.stream(inputs.input_ids, model, temperature=temperature, max_new_tokens=max_new_tokens)
49
-
50
- collected_tokens = []
51
- for token in streamer:
52
- collected_tokens.append(token)
53
- partial_text = tokenizer.decode(collected_tokens, skip_special_tokens=True)
54
- yield partial_text.strip()
 
 
 
 
 
 
 
 
 
55
 
56
  css = """
57
  #output {
@@ -76,7 +86,7 @@ with gr.Blocks(css=css) as demo:
76
  def bot(history, temp, max_tokens):
77
  user_message = history[-1][0]
78
  bot_message = ""
79
- for token in generate_response(user_message, history, temp, max_tokens):
80
  bot_message = token
81
  history[-1][1] = bot_message
82
  yield history
@@ -84,8 +94,8 @@ with gr.Blocks(css=css) as demo:
84
  msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
85
  bot, [chatbot, temperature, max_new_tokens], chatbot
86
  )
87
- clear.click(lambda: None, None, chatbot, queue=True)
88
 
89
  if __name__ == "__main__":
90
- demo.queue(api_open=True)
91
- demo.launch(debug=True, show_api=True)
 
5
  import sys
6
 
7
  # Install required packages
8
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "--force-reinstall", "--no-deps", "einops" "accelerate", "git+https://github.com/Muennighoff/transformers.git@olmoe"])
9
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
10
 
11
  from transformers import OlmoeForCausalLM, AutoTokenizer
12
 
 
18
  model = OlmoeForCausalLM.from_pretrained(
19
  model_name,
20
  trust_remote_code=True,
21
+ torch_dtype=torch.float16, # Using float16 for lower precision
22
  low_cpu_mem_usage=True,
23
  device_map="auto",
24
  _attn_implementation="flash_attention_2" # Enable Flash Attention 2
25
  )
26
+ model.gradient_checkpointing_enable() # Enable gradient checkpointing
27
  tokenizer = AutoTokenizer.from_pretrained(model_name)
28
  except Exception as e:
29
  print(f"Error loading model: {e}")
 
46
  full_prompt = chat_template.format(system_message=system_prompt, user_message=message)
47
  inputs = tokenizer(full_prompt, return_tensors="pt").to(DEVICE)
48
 
49
+ try:
50
+ with torch.no_grad():
51
+ streamer = tokenizer.stream(inputs.input_ids, model, temperature=temperature, max_new_tokens=max_new_tokens)
52
+
53
+ collected_tokens = []
54
+ for token in streamer:
55
+ collected_tokens.append(token)
56
+ partial_text = tokenizer.decode(collected_tokens, skip_special_tokens=True)
57
+ yield partial_text.strip()
58
+ except RuntimeError as e:
59
+ if "CUDA out of memory" in str(e):
60
+ yield "GPU memory exceeded. Try reducing the max tokens or using a smaller model."
61
+ else:
62
+ yield f"An error occurred: {str(e)}"
63
+ except Exception as e:
64
+ yield f"An unexpected error occurred: {str(e)}"
65
 
66
  css = """
67
  #output {
 
86
  def bot(history, temp, max_tokens):
87
  user_message = history[-1][0]
88
  bot_message = ""
89
+ for token in generate_response(user_message, history[:-1], temp, max_tokens):
90
  bot_message = token
91
  history[-1][1] = bot_message
92
  yield history
 
94
  msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
95
  bot, [chatbot, temperature, max_new_tokens], chatbot
96
  )
97
+ clear.click(lambda: None, None, chatbot, queue=False)
98
 
99
  if __name__ == "__main__":
100
+ demo.queue(api_open=False, max_size=10) # Limiting queue size
101
+ demo.launch(debug=True, show_api=True, share=False) # Disabled sharing for security