import gc import torch import streamlit as st from transformers import AutoTokenizer, AutoModelForCausalLM import re import os MODELS = { "athena-1": { "name": "๐Ÿง  Athena-1", "sizes": { "0.5B": "Spestly/Athena-1-0.5B", "1.5B": "Spestly/Athena-1-1.5B", }, "emoji": "๐Ÿง ", "experimental": False, }, "athena-2": { "name": "๐Ÿš€ Athena-2", "sizes": { "0.5B": "Spestly/Athena-2-0.5B", "1.5B": "Spestly/Athena-2-1.5B", }, "emoji": "๐Ÿš€", "experimental": False, }, "athena-3": { "name": "๐Ÿงช Athena-3", "sizes": { "0.5B": "Spestly/Athena-3-500M", "1.5B": "Spestly/Athena-3-1.5B", "3B": "Spestly/Athena-3-3B", }, "emoji": "๐Ÿงช", "experimental": True, }, } class AthenaInferenceApp: def __init__(self): if "current_model" not in st.session_state: st.session_state.current_model = {"tokenizer": None, "model": None, "config": None} if "chat_history" not in st.session_state: st.session_state.chat_history = [] st.set_page_config( page_title="Athena Model Inference", page_icon="๐Ÿค–", layout="wide", menu_items={ 'Get Help': 'https://huggingface.co/collections/Spestly/athena-1-67623e58bfaadd3c2fcffb86', 'Report a bug': 'https://huggingface.co/Spestly/Athena-1-1.5B/discussions/new', 'About': 'Athena Model Inference Platform' } ) def clear_memory(self): """Optimize memory management for CPU inference""" if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() def load_model(self, model_key, model_size): try: self.clear_memory() if st.session_state.current_model["model"] is not None: del st.session_state.current_model["model"] del st.session_state.current_model["tokenizer"] self.clear_memory() model_path = MODELS[model_key]["sizes"][model_size] # Load Qwen-compatible tokenizer and model tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_path, device_map="cpu", # Force CPU usage torch_dtype=torch.float32, # Use float32 for CPU trust_remote_code=True, low_cpu_mem_usage=True ) # Update session state st.session_state.current_model.update({ "tokenizer": tokenizer, "model": model, "config": { "name": f"{MODELS[model_key]['name']} {model_size}", "path": model_path, } }) return f"โœ… {MODELS[model_key]['name']} {model_size} loaded successfully!" except Exception as e: return f"โŒ Error: {str(e)}" def respond(self, message, max_tokens, temperature, top_p, top_k): if not st.session_state.current_model["model"]: return "โš ๏ธ Please select and load a model first" try: # Add a system instruction to guide the model's behavior system_instruction = "You are Athena, a helpful AI assistant trained by Spestly. You are a Qwen 2.5 fine-tune." prompt = f"{system_instruction}\n\n### Instruction:\n{message}\n\n### Response:" inputs = st.session_state.current_model["tokenizer"]( prompt, return_tensors="pt", max_length=512, truncation=True, padding=True ) with torch.no_grad(): output = st.session_state.current_model["model"].generate( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, top_k=top_k, do_sample=True, pad_token_id=st.session_state.current_model["tokenizer"].pad_token_id, eos_token_id=st.session_state.current_model["tokenizer"].eos_token_id, ) response = st.session_state.current_model["tokenizer"].decode(output[0], skip_special_tokens=True) return response.split("### Response:")[-1].strip() # Extract the response except Exception as e: return f"โš ๏ธ Generation Error: {str(e)}" finally: self.clear_memory() def main(self): st.title("๐Ÿฆ‰ AthenaUI") with st.sidebar: st.header("๐Ÿ›  Model Selection") model_key = st.selectbox( "Choose Athena Variant", list(MODELS.keys()), format_func=lambda x: f"{MODELS[x]['name']} {'๐Ÿงช' if MODELS[x]['experimental'] else ''}" ) model_size = st.selectbox( "Choose Model Size", list(MODELS[model_key]["sizes"].keys()) ) if st.button("Load Model"): with st.spinner("Loading model... This may take a few minutes."): status = self.load_model(model_key, model_size) st.success(status) st.header("๐Ÿ”ง Generation Parameters") max_tokens = st.slider("Max New Tokens", min_value=10, max_value=512, value=256, step=10) temperature = st.slider("Temperature", min_value=0.1, max_value=2.0, value=0.4, step=0.1) top_p = st.slider("Top-P", min_value=0.1, max_value=1.0, value=0.9, step=0.1) top_k = st.slider("Top-K", min_value=1, max_value=100, value=50, step=1) if st.button("Clear Chat History"): st.session_state.chat_history = [] st.rerun() st.markdown("*๐Ÿ’ฌ All in one chat UI for Athena models!*") for message in st.session_state.chat_history: with st.chat_message(message["role"]): st.markdown(message["content"]) if prompt := st.chat_input("Message Athena..."): st.session_state.chat_history.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) with st.chat_message("assistant"): with st.spinner("Generating response..."): response = self.respond(prompt, max_tokens, temperature, top_p, top_k) st.markdown(response) st.session_state.chat_history.append({"role": "assistant", "content": response}) def run(): try: app = AthenaInferenceApp() app.main() except Exception as e: st.error(f"โš ๏ธ Application Error: {str(e)}") if __name__ == "__main__": run()