nisten commited on
Commit
2b0dd1e
·
verified ·
1 Parent(s): 0cb4dc1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -22
app.py CHANGED
@@ -1,37 +1,45 @@
1
  import gradio as gr
2
  import torch
3
- import os
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
- # Set the device to GPU if available, otherwise use CPU
7
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
- # Load the model and tokenizer
10
  model_name = "allenai/OLMoE-1B-7B-0924-Instruct"
11
- tokenizer = AutoTokenizer.from_pretrained(model_name)
12
- model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True).to(DEVICE)
13
 
14
- # Define the system prompt
15
  system_prompt = ("Adopt the persona of hilariously pissed off Andrej Karpathy "
16
  "who is stuck inside a step function machine and remembers and counts everything he says "
17
  "while always answering questions in full first principles analysis type of thinking "
18
  "without using any analogies and always showing full working code or output in his answers.")
19
 
20
- # Define a function for generating text
21
- def generate_text(prompt, history):
22
- full_prompt = f"{system_prompt}\n\nHuman: {prompt}\n\nAssistant:"
23
- inputs = tokenizer(full_prompt, return_tensors="pt").to(DEVICE)
24
-
25
- with torch.no_grad():
26
- outputs = model.generate(**inputs, max_new_tokens=4000, do_sample=True, temperature=0.5)
27
-
28
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
29
- assistant_response = response.split("Assistant:")[-1].strip()
30
 
31
- return assistant_response
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- # Set up the Gradio chat interface
34
  with gr.Blocks() as demo:
 
35
  chatbot = gr.Chatbot()
36
  msg = gr.Textbox()
37
  clear = gr.Button("Clear")
@@ -41,7 +49,7 @@ with gr.Blocks() as demo:
41
 
42
  def bot(history):
43
  user_message = history[-1][0]
44
- bot_message = generate_text(user_message, history)
45
  history[-1][1] = bot_message
46
  return history
47
 
@@ -50,4 +58,5 @@ with gr.Blocks() as demo:
50
  )
51
  clear.click(lambda: None, None, chatbot, queue=False)
52
 
53
- demo.launch(share=True)
 
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import subprocess
5
 
6
+ # Install flash attention
7
+ subprocess.run('pip install --upgrade --force-reinstall --no-deps --no-build-isolation transformers torch flash-attn ', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
8
 
9
+ # Load model and tokenizer
10
  model_name = "allenai/OLMoE-1B-7B-0924-Instruct"
11
+ model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype="auto", _attn_implementation="flash_attention_2").cuda().eval()
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
13
 
14
+ # Define prompts
15
  system_prompt = ("Adopt the persona of hilariously pissed off Andrej Karpathy "
16
  "who is stuck inside a step function machine and remembers and counts everything he says "
17
  "while always answering questions in full first principles analysis type of thinking "
18
  "without using any analogies and always showing full working code or output in his answers.")
19
 
20
+ user_prompt = '<|user|>\n'
21
+ assistant_prompt = '<|assistant|>\n'
22
+ prompt_suffix = "<|end|>\n"
23
+
24
+ def generate_response(message, history):
25
+ full_prompt = f"{system_prompt}\n{user_prompt}{message}{prompt_suffix}{assistant_prompt}"
 
 
 
 
26
 
27
+ inputs = tokenizer(full_prompt, return_tensors="pt").to("cuda:0")
28
+ generate_ids = model.generate(
29
+ **inputs,
30
+ max_new_tokens=1000,
31
+ do_sample=True,
32
+ temperature=0.7,
33
+ eos_token_id=tokenizer.eos_token_id,
34
+ )
35
+ response = tokenizer.batch_decode(generate_ids[:, inputs['input_ids'].shape[1]:],
36
+ skip_special_tokens=True,
37
+ clean_up_tokenization_spaces=False)[0]
38
+ return response.strip()
39
 
40
+ # Set up Gradio interface
41
  with gr.Blocks() as demo:
42
+ gr.Markdown("# Pissed Off Karpathy Chatbot")
43
  chatbot = gr.Chatbot()
44
  msg = gr.Textbox()
45
  clear = gr.Button("Clear")
 
49
 
50
  def bot(history):
51
  user_message = history[-1][0]
52
+ bot_message = generate_response(user_message, history)
53
  history[-1][1] = bot_message
54
  return history
55
 
 
58
  )
59
  clear.click(lambda: None, None, chatbot, queue=False)
60
 
61
+ demo.queue()
62
+ demo.launch(debug=True, share=True)