Spestly commited on
Commit
425e273
·
verified ·
1 Parent(s): 7409be6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -120
app.py CHANGED
@@ -1,130 +1,53 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
- import gc
5
 
6
- # Global model and tokenizer
7
- model = None
8
- tokenizer = None
 
9
 
10
- def load_model():
11
- """Load the model and tokenizer into memory."""
12
- global model, tokenizer
13
- model_name = "Spestly/Athena-1-0.5B" # Replace with a smaller or quantized model for better performance on CPU
14
-
15
- model = AutoModelForCausalLM.from_pretrained(
16
- model_name,
17
- low_cpu_mem_usage=True,
18
- torch_dtype=torch.float32, # Keep float32 for CPU usage
19
- device_map="cpu"
20
- )
21
- tokenizer = AutoTokenizer.from_pretrained(model_name)
22
- model.eval()
23
 
24
- def generate_response(input_text, chat_history):
25
- """Generate a response using the model and update the chat history."""
26
- global model, tokenizer
 
 
 
 
27
 
28
- # Load model if not loaded
29
- if model is None or tokenizer is None:
30
- load_model()
31
-
32
- try:
33
- # Alpaca Chat Template
34
- instruction = (
35
- "You are an LLM called Athena. You are finetuned by Aayan Mishra (Spestly). Below is an instruction that describes a task. "
36
- "Write a response that appropriately completes the request.\n\n"
37
- f"### Instruction:\n{input_text}\n\n### Response:"
38
- )
39
- chat_history.append({"role": "user", "content": input_text}) # Add user input to chat history
40
-
41
- # Tokenization
42
- inputs = tokenizer(
43
- instruction,
44
- return_tensors="pt",
45
- truncation=True,
46
- max_length=256 # Limit input length for CPU performance
47
- )
48
-
49
- # Generate response
50
- with torch.no_grad():
51
- outputs = model.generate(
52
- input_ids=inputs["input_ids"],
53
- attention_mask=inputs["attention_mask"],
54
- max_new_tokens=100, # Limit the number of tokens generated
55
- do_sample=True,
56
- top_k=40, # Adjust top_k for faster response
57
- top_p=0.85, # Adjust top_p for faster sampling
58
- temperature=0.7,
59
- pad_token_id=tokenizer.pad_token_id,
60
- eos_token_id=tokenizer.eos_token_id,
61
- repetition_penalty=1.2,
62
- num_beams=1 # Use single beam for faster processing
63
- )
64
-
65
- # Decode response
66
- response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
67
- response = response.split("### Response:")[-1].strip() # Extract model's response
68
-
69
- # Update chat history
70
- chat_history.append({"role": "assistant", "content": response})
71
-
72
- # Manual garbage collection for CPU usage
73
- gc.collect()
74
- torch.cuda.empty_cache()
75
-
76
- return chat_history
77
-
78
- except Exception as e:
79
- return chat_history + [{"role": "error", "content": f"Error: {str(e)}"}]
80
-
81
- # Gradio UI
82
- def render_chat(chat_history):
83
- """
84
- Render the chat history into a format that mimics the ChatGPT UI.
85
- """
86
- chat_ui = ""
87
- for entry in chat_history:
88
- if entry["role"] == "user":
89
- chat_ui += f'<div class="user-message"><b>User:</b> {entry["content"]}</div>'
90
- elif entry["role"] == "assistant":
91
- chat_ui += f'<div class="assistant-message"><b>Athena:</b> {entry["content"]}</div>'
92
- elif entry["role"] == "error":
93
- chat_ui += f'<div class="error-message"><b>Error:</b> {entry["content"]}</div>'
94
- return chat_ui
95
-
96
- with gr.Blocks(css="""
97
- body { background-color: #202123; color: white; font-family: 'Arial', sans-serif; margin: 0; padding: 0; }
98
- .chat-container { background-color: #333; border-radius: 10px; padding: 15px; max-height: 500px; overflow-y: auto; }
99
- .user-message { text-align: left; margin: 10px 0; padding: 10px; background-color: #444; border-radius: 10px; }
100
- .assistant-message { text-align: left; margin: 10px 0; padding: 10px; background-color: #555; border-radius: 10px; }
101
- .error-message { text-align: center; color: red; margin: 10px 0; padding: 10px; border: 1px solid red; border-radius: 10px; }
102
- .input-container { position: fixed; bottom: 0; width: 100%; background-color: #202123; padding: 15px; border-top: 1px solid #444; }
103
- .input-box { width: calc(100% - 30px); padding: 10px; border-radius: 10px; border: 1px solid #444; background-color: #333; color: white; }
104
- .submit-button { background-color: #10a37f; color: white; border: none; padding: 10px; border-radius: 10px; cursor: pointer; }
105
- .submit-button:hover { background-color: #0e8d69; }
106
- """) as demo:
107
- gr.Markdown("<h1 style='text-align: center;'>Athena-1 1.5B</h1>")
108
 
109
- # Chat history
110
- chat_history = gr.State([])
111
- chat_display = gr.HTML(label="Chat History")
 
 
 
 
 
 
112
 
113
- # Input box
114
- with gr.Row(elem_id="input-container"):
115
- user_input = gr.Textbox(placeholder="Type your message here...", elem_id="input-box")
116
- submit_button = gr.Button("Submit", elem_id="submit-button")
117
 
118
- # Submit button functionality
119
- submit_button.click(
120
- generate_response,
121
- inputs=[user_input, chat_history],
122
- outputs=[chat_history]
123
- ).then(
124
- render_chat,
125
- inputs=[chat_history],
126
- outputs=[chat_display]
127
- )
 
 
 
 
 
 
128
 
129
- # Launch the app
130
- demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
 
4
 
5
+ # Load model and tokenizer
6
+ model_name = "Spestly/Athena-1-0.5B"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, low_cpu_mem_usage=True)
9
 
10
+ # Set to evaluation mode
11
+ model.eval()
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ def generate_response(message, history):
14
+ instruction = (
15
+ "You are an LLM called Athena. You are finetuned by Aayan Mishra. You are NOT trained by Anthropic. "
16
+ "You are a Qwen 2.5 fine-tune. Your purpose is the help the user accomplish their request to the best of your abilities. "
17
+ "Below is an instruction that describes a task. Answer it clearly and concisely.\n\n"
18
+ f"### Instruction:\n{message}\n\n### Response:"
19
+ )
20
 
21
+ inputs = tokenizer(instruction, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ with torch.no_grad():
24
+ outputs = model.generate(
25
+ **inputs,
26
+ max_new_tokens=100,
27
+ num_return_sequences=1,
28
+ temperature=0.7,
29
+ top_p=0.9,
30
+ do_sample=True
31
+ )
32
 
33
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
34
+ response = response.split("### Response:")[-1].strip()
 
 
35
 
36
+ return response
37
+
38
+ iface = gr.ChatInterface(
39
+ generate_response,
40
+ chatbot=gr.Chatbot(height=600, type="messages"),
41
+ textbox=gr.Textbox(placeholder="Type your message here...", container=False, scale=7),
42
+ title="Athena-1",
43
+ description="Chat with Athena-1 0.5B",
44
+ theme="soft",
45
+ examples=[
46
+ "Can you give me a good salsa recipe?",
47
+ "What are Neural Networks?",
48
+ "What is the capital of Australia?",
49
+ ],
50
+ type="messages"
51
+ )
52
 
53
+ iface.launch()