nisten commited on
Commit
aaeb784
·
verified ·
1 Parent(s): ff1da0a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -18
app.py CHANGED
@@ -2,16 +2,15 @@ import gradio as gr
2
  import torch
3
  import subprocess
4
  import sys
 
5
 
6
  # Force install the specific transformers version from the GitHub PR
7
  subprocess.check_call([sys.executable, "-m", "pip", "install", "--force-reinstall", "--no-deps", "git+https://github.com/Muennighoff/transformers.git@olmoe"])
8
 
9
  from transformers import AutoModelForCausalLM, AutoTokenizer
10
 
11
- # Load model and tokenizer
12
  model_name = "allenai/OLMoE-1B-7B-0924-Instruct"
13
- model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype="auto").cuda().eval()
14
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
15
 
16
  # Define prompts
17
  system_prompt = ("Adopt the persona of hilariously pissed off Andrej Karpathy "
@@ -23,42 +22,90 @@ user_prompt = '<|user|>\n'
23
  assistant_prompt = '<|assistant|>\n'
24
  prompt_suffix = "<|end|>\n"
25
 
26
- def generate_response(message, history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  full_prompt = f"{system_prompt}\n{user_prompt}{message}{prompt_suffix}{assistant_prompt}"
28
 
29
- inputs = tokenizer(full_prompt, return_tensors="pt").to("cuda:0")
30
- generate_ids = model.generate(
31
- **inputs,
32
- max_new_tokens=4000,
33
- do_sample=True,
34
- temperature=0.7,
35
- eos_token_id=tokenizer.eos_token_id,
36
- )
 
37
  response = tokenizer.batch_decode(generate_ids[:, inputs['input_ids'].shape[1]:],
38
  skip_special_tokens=True,
39
  clean_up_tokenization_spaces=False)[0]
40
  return response.strip()
41
 
 
 
 
 
 
42
  # Set up Gradio interface
43
  with gr.Blocks() as demo:
44
- gr.Markdown("# Karpathy Chatbot")
45
  chatbot = gr.Chatbot()
46
  msg = gr.Textbox()
47
  clear = gr.Button("Clear")
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  def user(user_message, history):
50
  return "", history + [[user_message, None]]
51
 
52
- def bot(history):
53
  user_message = history[-1][0]
54
- bot_message = generate_response(user_message, history)
55
  history[-1][1] = bot_message
56
  return history
57
 
58
  msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
59
- bot, chatbot, chatbot
 
 
60
  )
61
  clear.click(lambda: None, None, chatbot, queue=False)
62
 
63
- demo.queue()
64
- demo.launch(debug=True, share=True)
 
 
 
 
 
 
 
2
  import torch
3
  import subprocess
4
  import sys
5
+ import os
6
 
7
  # Force install the specific transformers version from the GitHub PR
8
  subprocess.check_call([sys.executable, "-m", "pip", "install", "--force-reinstall", "--no-deps", "git+https://github.com/Muennighoff/transformers.git@olmoe"])
9
 
10
  from transformers import AutoModelForCausalLM, AutoTokenizer
11
 
12
+ # Define model name
13
  model_name = "allenai/OLMoE-1B-7B-0924-Instruct"
 
 
14
 
15
  # Define prompts
16
  system_prompt = ("Adopt the persona of hilariously pissed off Andrej Karpathy "
 
22
  assistant_prompt = '<|assistant|>\n'
23
  prompt_suffix = "<|end|>\n"
24
 
25
+ # Function to load model and tokenizer
26
+ def load_model_and_tokenizer(model_name):
27
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
28
+
29
+ # Check for CUDA availability
30
+ if torch.cuda.is_available():
31
+ print("CUDA is available. Using GPU.")
32
+ device = "cuda"
33
+ else:
34
+ print("CUDA is not available. Using CPU.")
35
+ device = "cpu"
36
+
37
+ # Load model
38
+ model = AutoModelForCausalLM.from_pretrained(
39
+ model_name,
40
+ trust_remote_code=True,
41
+ torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32
42
+ ).to(device).eval()
43
+
44
+ return model, tokenizer, device
45
+
46
+ # Function to generate response
47
+ def generate_response(message, history, model, tokenizer, device):
48
  full_prompt = f"{system_prompt}\n{user_prompt}{message}{prompt_suffix}{assistant_prompt}"
49
 
50
+ inputs = tokenizer(full_prompt, return_tensors="pt").to(device)
51
+ with torch.no_grad():
52
+ generate_ids = model.generate(
53
+ **inputs,
54
+ max_new_tokens=1000,
55
+ do_sample=True,
56
+ temperature=0.7,
57
+ eos_token_id=tokenizer.eos_token_id,
58
+ )
59
  response = tokenizer.batch_decode(generate_ids[:, inputs['input_ids'].shape[1]:],
60
  skip_special_tokens=True,
61
  clean_up_tokenization_spaces=False)[0]
62
  return response.strip()
63
 
64
+ # Function to set client for session
65
+ def set_client_for_session(request: gr.Request):
66
+ x_ip_token = request.headers.get('x-ip-token', '')
67
+ return {"X-IP-Token": x_ip_token}
68
+
69
  # Set up Gradio interface
70
  with gr.Blocks() as demo:
71
+ gr.Markdown("#Karpathy Chatbot")
72
  chatbot = gr.Chatbot()
73
  msg = gr.Textbox()
74
  clear = gr.Button("Clear")
75
+
76
+ # States
77
+ model_state = gr.State()
78
+ tokenizer_state = gr.State()
79
+ device_state = gr.State()
80
+ headers_state = gr.State()
81
+
82
+ def initialize_model(headers):
83
+ if not model_state.value:
84
+ model, tokenizer, device = load_model_and_tokenizer(model_name)
85
+ return model, tokenizer, device
86
+ return model_state.value, tokenizer_state.value, device_state.value
87
 
88
  def user(user_message, history):
89
  return "", history + [[user_message, None]]
90
 
91
+ def bot(history, model, tokenizer, device):
92
  user_message = history[-1][0]
93
+ bot_message = generate_response(user_message, history, model, tokenizer, device)
94
  history[-1][1] = bot_message
95
  return history
96
 
97
  msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
98
+ initialize_model, headers_state, [model_state, tokenizer_state, device_state]
99
+ ).then(
100
+ bot, [chatbot, model_state, tokenizer_state, device_state], chatbot
101
  )
102
  clear.click(lambda: None, None, chatbot, queue=False)
103
 
104
+ demo.load(set_client_for_session, None, headers_state)
105
+
106
+ if __name__ == "__main__":
107
+ if os.environ.get("SPACE_ID"):
108
+ demo.queue(api_open=False)
109
+ demo.launch(debug=True)
110
+ else:
111
+ demo.launch(debug=True, share=True)