AwA-1.5B / app.py
Spestly's picture
Update app.py
2852c83 verified
raw
history blame
4.55 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from transformers import BitsAndBytesConfig
import gc
# Configure 8-bit quantization
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
)
# Load model and tokenizer
model_name = "Spestly/AwA-1.5B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
quantization_config=quantization_config,
low_cpu_mem_usage=True,
torch_dtype=torch.float32,
)
# Optimizations
model.config.use_cache = True
torch.backends.cudnn.benchmark = False
torch._C._jit_set_profiling_executor(False)
model.eval()
# Clear memory
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
def generate_response(message, history):
gc.collect()
instruction = (
"You are an LLM called AwA. Aayan Mishra finetunes you. Anthropic does NOT train you. "
"You are a Qwen 2.5 fine-tune. Your purpose is the help the user accomplish their request to the best of your abilities. "
"Below is an instruction that describes a task. Answer it clearly and concisely.\n\n"
f"### Instruction:\n{message}\n\n### Response:"
)
inputs = tokenizer(
instruction,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
try:
# Generate initial sequence
generated_ids = []
past_key_values = None
attention_mask = inputs["attention_mask"]
with torch.no_grad():
for _ in range(400): # max_new_tokens
outputs = model(
input_ids=inputs["input_ids"] if not generated_ids else torch.tensor([[token] for token in generated_ids[-1:]], device=model.device),
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
next_token_logits = outputs.logits[:, -1, :]
past_key_values = outputs.past_key_values
# Apply temperature and top-p sampling
probs = torch.nn.functional.softmax(next_token_logits / 0.7, dim=-1)
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
idx_to_remove = cumsum_probs > 0.9
idx_to_remove[:, 1:] = idx_to_remove[:, :-1].clone()
idx_to_remove[:, 0] = 0
sorted_probs[idx_to_remove] = 0.0
sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
next_token = torch.multinomial(sorted_probs, num_samples=1)
next_token = sorted_indices.gather(-1, next_token)
generated_ids.append(next_token.item())
attention_mask = torch.cat([attention_mask, torch.ones((1, 1), device=model.device)], dim=-1)
# Decode the current token and yield
current_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
if "### Response:" in current_text:
response_text = current_text.split("### Response:")[-1].strip()
yield response_text
# Check for end of generation
if next_token.item() == tokenizer.eos_token_id:
break
except Exception as e:
yield f"An error occurred: {str(e)}"
finally:
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# Create Gradio interface
iface = gr.ChatInterface(
generate_response,
chatbot=gr.Chatbot(
height=400,
type="messages",
),
textbox=gr.Textbox(
placeholder="Type your message here...",
container=False,
scale=7
),
title="AwA-1.5B πŸ”Ž - CPU Optimized",
description="Chat with AwA (Answers with Athena). Optimized for CPU operation.",
theme="ocean",
examples=[
"How can CRISPR help us Humans?",
"What are some important ethics in AI",
"How many 'r's in 'strawberry'?",
],
type="messages"
)
iface.queue(max_size=5)
iface.launch(
share=False,
debug=False,
show_error=True,
server_name="0.0.0.0",
server_port=7860,
)