File size: 4,552 Bytes
8185d4a 2c45d67 8185d4a 2852c83 8185d4a 2852c83 8185d4a 2852c83 8185d4a 2852c83 2c45d67 2852c83 8185d4a 2c45d67 8185d4a 2852c83 b38b5ee 2852c83 8185d4a 2852c83 8185d4a 2c45d67 2852c83 b38b5ee 2852c83 f27dd35 8185d4a 2852c83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from transformers import BitsAndBytesConfig
import gc
# Configure 8-bit quantization
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
)
# Load model and tokenizer
model_name = "Spestly/AwA-1.5B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
quantization_config=quantization_config,
low_cpu_mem_usage=True,
torch_dtype=torch.float32,
)
# Optimizations
model.config.use_cache = True
torch.backends.cudnn.benchmark = False
torch._C._jit_set_profiling_executor(False)
model.eval()
# Clear memory
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
def generate_response(message, history):
gc.collect()
instruction = (
"You are an LLM called AwA. Aayan Mishra finetunes you. Anthropic does NOT train you. "
"You are a Qwen 2.5 fine-tune. Your purpose is the help the user accomplish their request to the best of your abilities. "
"Below is an instruction that describes a task. Answer it clearly and concisely.\n\n"
f"### Instruction:\n{message}\n\n### Response:"
)
inputs = tokenizer(
instruction,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
try:
# Generate initial sequence
generated_ids = []
past_key_values = None
attention_mask = inputs["attention_mask"]
with torch.no_grad():
for _ in range(400): # max_new_tokens
outputs = model(
input_ids=inputs["input_ids"] if not generated_ids else torch.tensor([[token] for token in generated_ids[-1:]], device=model.device),
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
next_token_logits = outputs.logits[:, -1, :]
past_key_values = outputs.past_key_values
# Apply temperature and top-p sampling
probs = torch.nn.functional.softmax(next_token_logits / 0.7, dim=-1)
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
idx_to_remove = cumsum_probs > 0.9
idx_to_remove[:, 1:] = idx_to_remove[:, :-1].clone()
idx_to_remove[:, 0] = 0
sorted_probs[idx_to_remove] = 0.0
sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
next_token = torch.multinomial(sorted_probs, num_samples=1)
next_token = sorted_indices.gather(-1, next_token)
generated_ids.append(next_token.item())
attention_mask = torch.cat([attention_mask, torch.ones((1, 1), device=model.device)], dim=-1)
# Decode the current token and yield
current_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
if "### Response:" in current_text:
response_text = current_text.split("### Response:")[-1].strip()
yield response_text
# Check for end of generation
if next_token.item() == tokenizer.eos_token_id:
break
except Exception as e:
yield f"An error occurred: {str(e)}"
finally:
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# Create Gradio interface
iface = gr.ChatInterface(
generate_response,
chatbot=gr.Chatbot(
height=400,
type="messages",
),
textbox=gr.Textbox(
placeholder="Type your message here...",
container=False,
scale=7
),
title="AwA-1.5B π - CPU Optimized",
description="Chat with AwA (Answers with Athena). Optimized for CPU operation.",
theme="ocean",
examples=[
"How can CRISPR help us Humans?",
"What are some important ethics in AI",
"How many 'r's in 'strawberry'?",
],
type="messages"
)
iface.queue(max_size=5)
iface.launch(
share=False,
debug=False,
show_error=True,
server_name="0.0.0.0",
server_port=7860,
) |