import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM import torch from transformers import BitsAndBytesConfig import gc # Configure 8-bit quantization quantization_config = BitsAndBytesConfig( load_in_8bit=True, llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, ) # Load model and tokenizer model_name = "Spestly/AwA-1.5B" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", quantization_config=quantization_config, low_cpu_mem_usage=True, torch_dtype=torch.float32, ) # Optimizations model.config.use_cache = True torch.backends.cudnn.benchmark = False torch._C._jit_set_profiling_executor(False) model.eval() # Clear memory gc.collect() torch.cuda.empty_cache() if torch.cuda.is_available() else None def generate_response(message, history): gc.collect() instruction = ( "You are an LLM called AwA. Aayan Mishra finetunes you. Anthropic does NOT train you. " "You are a Qwen 2.5 fine-tune. Your purpose is the help the user accomplish their request to the best of your abilities. " "Below is an instruction that describes a task. Answer it clearly and concisely.\n\n" f"### Instruction:\n{message}\n\n### Response:" ) inputs = tokenizer( instruction, return_tensors="pt", padding=True, truncation=True, max_length=512 ) try: # Generate initial sequence generated_ids = [] past_key_values = None attention_mask = inputs["attention_mask"] with torch.no_grad(): for _ in range(400): # max_new_tokens outputs = model( input_ids=inputs["input_ids"] if not generated_ids else torch.tensor([[token] for token in generated_ids[-1:]], device=model.device), attention_mask=attention_mask, past_key_values=past_key_values, use_cache=True, ) next_token_logits = outputs.logits[:, -1, :] past_key_values = outputs.past_key_values # Apply temperature and top-p sampling probs = torch.nn.functional.softmax(next_token_logits / 0.7, dim=-1) sorted_probs, sorted_indices = torch.sort(probs, descending=True) cumsum_probs = torch.cumsum(sorted_probs, dim=-1) idx_to_remove = cumsum_probs > 0.9 idx_to_remove[:, 1:] = idx_to_remove[:, :-1].clone() idx_to_remove[:, 0] = 0 sorted_probs[idx_to_remove] = 0.0 sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True) next_token = torch.multinomial(sorted_probs, num_samples=1) next_token = sorted_indices.gather(-1, next_token) generated_ids.append(next_token.item()) attention_mask = torch.cat([attention_mask, torch.ones((1, 1), device=model.device)], dim=-1) # Decode the current token and yield current_text = tokenizer.decode(generated_ids, skip_special_tokens=True) if "### Response:" in current_text: response_text = current_text.split("### Response:")[-1].strip() yield response_text # Check for end of generation if next_token.item() == tokenizer.eos_token_id: break except Exception as e: yield f"An error occurred: {str(e)}" finally: gc.collect() torch.cuda.empty_cache() if torch.cuda.is_available() else None # Create Gradio interface iface = gr.ChatInterface( generate_response, chatbot=gr.Chatbot( height=400, type="messages", ), textbox=gr.Textbox( placeholder="Type your message here...", container=False, scale=7 ), title="AwA-1.5B 🔎 - CPU Optimized", description="Chat with AwA (Answers with Athena). Optimized for CPU operation.", theme="ocean", examples=[ "How can CRISPR help us Humans?", "What are some important ethics in AI", "How many 'r's in 'strawberry'?", ], type="messages" ) iface.queue(max_size=5) iface.launch( share=False, debug=False, show_error=True, server_name="0.0.0.0", server_port=7860, )