File size: 4,552 Bytes
8185d4a
2c45d67
8185d4a
2852c83
 
 
 
 
 
 
 
 
8185d4a
 
 
 
2852c83
 
 
 
 
 
 
8185d4a
2852c83
 
 
 
8185d4a
 
2852c83
 
 
 
2c45d67
2852c83
 
8185d4a
 
2c45d67
8185d4a
 
 
 
2852c83
 
 
 
 
 
 
b38b5ee
2852c83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8185d4a
2852c83
8185d4a
2c45d67
2852c83
 
 
 
b38b5ee
 
 
 
 
2852c83
 
f27dd35
8185d4a
 
 
 
 
 
 
 
2852c83
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from transformers import BitsAndBytesConfig
import gc

# Configure 8-bit quantization
quantization_config = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_threshold=6.0,
    llm_int8_has_fp16_weight=False,
)

# Load model and tokenizer
model_name = "Spestly/AwA-1.5B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    quantization_config=quantization_config,
    low_cpu_mem_usage=True,
    torch_dtype=torch.float32,
)

# Optimizations
model.config.use_cache = True
torch.backends.cudnn.benchmark = False
torch._C._jit_set_profiling_executor(False)
model.eval()

# Clear memory
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None

def generate_response(message, history):
    gc.collect()
    
    instruction = (
        "You are an LLM called AwA. Aayan Mishra finetunes you. Anthropic does NOT train you. "
        "You are a Qwen 2.5 fine-tune. Your purpose is the help the user accomplish their request to the best of your abilities. "
        "Below is an instruction that describes a task. Answer it clearly and concisely.\n\n"
        f"### Instruction:\n{message}\n\n### Response:"
    )
    
    inputs = tokenizer(
        instruction,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=512
    )
    
    try:
        # Generate initial sequence
        generated_ids = []
        past_key_values = None
        attention_mask = inputs["attention_mask"]
        
        with torch.no_grad():
            for _ in range(400):  # max_new_tokens
                outputs = model(
                    input_ids=inputs["input_ids"] if not generated_ids else torch.tensor([[token] for token in generated_ids[-1:]], device=model.device),
                    attention_mask=attention_mask,
                    past_key_values=past_key_values,
                    use_cache=True,
                )
                
                next_token_logits = outputs.logits[:, -1, :]
                past_key_values = outputs.past_key_values
                
                # Apply temperature and top-p sampling
                probs = torch.nn.functional.softmax(next_token_logits / 0.7, dim=-1)
                sorted_probs, sorted_indices = torch.sort(probs, descending=True)
                cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
                idx_to_remove = cumsum_probs > 0.9
                idx_to_remove[:, 1:] = idx_to_remove[:, :-1].clone()
                idx_to_remove[:, 0] = 0
                sorted_probs[idx_to_remove] = 0.0
                sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
                next_token = torch.multinomial(sorted_probs, num_samples=1)
                next_token = sorted_indices.gather(-1, next_token)
                
                generated_ids.append(next_token.item())
                attention_mask = torch.cat([attention_mask, torch.ones((1, 1), device=model.device)], dim=-1)
                
                # Decode the current token and yield
                current_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
                if "### Response:" in current_text:
                    response_text = current_text.split("### Response:")[-1].strip()
                    yield response_text
                
                # Check for end of generation
                if next_token.item() == tokenizer.eos_token_id:
                    break
                
    except Exception as e:
        yield f"An error occurred: {str(e)}"
    finally:
        gc.collect()
        torch.cuda.empty_cache() if torch.cuda.is_available() else None

# Create Gradio interface
iface = gr.ChatInterface(
    generate_response,
    chatbot=gr.Chatbot(
        height=400,
        type="messages",
    ),
    textbox=gr.Textbox(
        placeholder="Type your message here...",
        container=False,
        scale=7
    ),
    title="AwA-1.5B πŸ”Ž - CPU Optimized",
    description="Chat with AwA (Answers with Athena). Optimized for CPU operation.",
    theme="ocean",
    examples=[
        "How can CRISPR help us Humans?",
        "What are some important ethics in AI",
        "How many 'r's in 'strawberry'?",
    ],
    type="messages"
)

iface.queue(max_size=5)
iface.launch(
    share=False,
    debug=False,
    show_error=True,
    server_name="0.0.0.0",
    server_port=7860,
)