sagar007 commited on
Commit
985eabb
·
verified ·
1 Parent(s): 1d330bc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -0
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, LlamaForCausalLM
4
+
5
+ # Initialize model and tokenizer
6
+ model_id = 'akjindal53244/Llama-3.1-Storm-8B'
7
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
8
+ model = LlamaForCausalLM.from_pretrained(
9
+ model_id,
10
+ torch_dtype=torch.bfloat16,
11
+ device_map="auto",
12
+ use_flash_attention_2=True
13
+ )
14
+
15
+ # Function to format the prompt
16
+ def format_prompt(messages):
17
+ prompt = "<|begin_of_text|>"
18
+ for message in messages:
19
+ prompt += f"<|start_header_id|>{message['role']}<|end_header_id|>\n\n{message['content']}<|eot_id|>"
20
+ prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n"
21
+ return prompt
22
+
23
+ # Function to generate response
24
+ def generate_response(message, history):
25
+ messages = [{"role": "system", "content": "You are a helpful assistant."}]
26
+ for human, assistant in history:
27
+ messages.append({"role": "user", "content": human})
28
+ messages.append({"role": "assistant", "content": assistant})
29
+ messages.append({"role": "user", "content": message})
30
+
31
+ prompt = format_prompt(messages)
32
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
33
+ generated_ids = model.generate(input_ids, max_new_tokens=256, temperature=0.7, do_sample=True, eos_token_id=tokenizer.eos_token_id)
34
+ response = tokenizer.decode(generated_ids[0][input_ids.shape[-1]:], skip_special_tokens=True)
35
+ return response.strip()
36
+
37
+ # Create Gradio interface
38
+ iface = gr.ChatInterface(
39
+ generate_response,
40
+ title="Llama-3.1-Storm-8B Chatbot",
41
+ description="Chat with the Llama-3.1-Storm-8B model. Type your message and press Enter to send.",
42
+ )
43
+
44
+ # Launch the app
45
+ iface.launch()