Spestly commited on
Commit
2852c83
Β·
verified Β·
1 Parent(s): b38b5ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -27
app.py CHANGED
@@ -1,16 +1,40 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
 
 
 
 
 
 
 
 
 
4
 
5
  # Load model and tokenizer
6
  model_name = "Spestly/AwA-1.5B"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, low_cpu_mem_usage=True)
 
 
 
 
 
 
9
 
10
- # Set to evaluation mode
 
 
 
11
  model.eval()
12
 
 
 
 
 
13
  def generate_response(message, history):
 
 
14
  instruction = (
15
  "You are an LLM called AwA. Aayan Mishra finetunes you. Anthropic does NOT train you. "
16
  "You are a Qwen 2.5 fine-tune. Your purpose is the help the user accomplish their request to the best of your abilities. "
@@ -18,38 +42,77 @@ def generate_response(message, history):
18
  f"### Instruction:\n{message}\n\n### Response:"
19
  )
20
 
21
- inputs = tokenizer(instruction, return_tensors="pt")
22
-
23
- streamed_output = ""
 
 
 
 
24
 
25
- # Generate with streaming
26
- with torch.no_grad():
27
- for output in model.generate(
28
- **inputs,
29
- max_new_tokens=600,
30
- num_return_sequences=1,
31
- temperature=0.7,
32
- top_p=0.9,
33
- do_sample=True,
34
- pad_token_id=tokenizer.eos_token_id,
35
- streaming=True, # Enable streaming
36
- yield_per_token=True # Yield after each token
37
- ):
38
- next_token = tokenizer.decode(output[-1], skip_special_tokens=True)
39
- streamed_output += next_token
40
- yield streamed_output.split("### Response:")[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- # Create the Gradio interface with streaming enabled
43
  iface = gr.ChatInterface(
44
  generate_response,
45
- chatbot=gr.Chatbot(height=600, type="messages"),
 
 
 
46
  textbox=gr.Textbox(
47
  placeholder="Type your message here...",
48
  container=False,
49
  scale=7
50
  ),
51
- title="AwA-1.5B πŸ”Ž - Experimental",
52
- description="Chat with AwA (Answers with Athena). Please note that since AwA is an experimental model, some outputs may not be accurate/expected!",
53
  theme="ocean",
54
  examples=[
55
  "How can CRISPR help us Humans?",
@@ -59,5 +122,11 @@ iface = gr.ChatInterface(
59
  type="messages"
60
  )
61
 
62
- iface.queue() # Enable queuing for streaming
63
- iface.launch()
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
+ from transformers import BitsAndBytesConfig
5
+ import gc
6
+
7
+ # Configure 8-bit quantization
8
+ quantization_config = BitsAndBytesConfig(
9
+ load_in_8bit=True,
10
+ llm_int8_threshold=6.0,
11
+ llm_int8_has_fp16_weight=False,
12
+ )
13
 
14
  # Load model and tokenizer
15
  model_name = "Spestly/AwA-1.5B"
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ model_name,
19
+ device_map="auto",
20
+ quantization_config=quantization_config,
21
+ low_cpu_mem_usage=True,
22
+ torch_dtype=torch.float32,
23
+ )
24
 
25
+ # Optimizations
26
+ model.config.use_cache = True
27
+ torch.backends.cudnn.benchmark = False
28
+ torch._C._jit_set_profiling_executor(False)
29
  model.eval()
30
 
31
+ # Clear memory
32
+ gc.collect()
33
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
34
+
35
  def generate_response(message, history):
36
+ gc.collect()
37
+
38
  instruction = (
39
  "You are an LLM called AwA. Aayan Mishra finetunes you. Anthropic does NOT train you. "
40
  "You are a Qwen 2.5 fine-tune. Your purpose is the help the user accomplish their request to the best of your abilities. "
 
42
  f"### Instruction:\n{message}\n\n### Response:"
43
  )
44
 
45
+ inputs = tokenizer(
46
+ instruction,
47
+ return_tensors="pt",
48
+ padding=True,
49
+ truncation=True,
50
+ max_length=512
51
+ )
52
 
53
+ try:
54
+ # Generate initial sequence
55
+ generated_ids = []
56
+ past_key_values = None
57
+ attention_mask = inputs["attention_mask"]
58
+
59
+ with torch.no_grad():
60
+ for _ in range(400): # max_new_tokens
61
+ outputs = model(
62
+ input_ids=inputs["input_ids"] if not generated_ids else torch.tensor([[token] for token in generated_ids[-1:]], device=model.device),
63
+ attention_mask=attention_mask,
64
+ past_key_values=past_key_values,
65
+ use_cache=True,
66
+ )
67
+
68
+ next_token_logits = outputs.logits[:, -1, :]
69
+ past_key_values = outputs.past_key_values
70
+
71
+ # Apply temperature and top-p sampling
72
+ probs = torch.nn.functional.softmax(next_token_logits / 0.7, dim=-1)
73
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True)
74
+ cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
75
+ idx_to_remove = cumsum_probs > 0.9
76
+ idx_to_remove[:, 1:] = idx_to_remove[:, :-1].clone()
77
+ idx_to_remove[:, 0] = 0
78
+ sorted_probs[idx_to_remove] = 0.0
79
+ sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
80
+ next_token = torch.multinomial(sorted_probs, num_samples=1)
81
+ next_token = sorted_indices.gather(-1, next_token)
82
+
83
+ generated_ids.append(next_token.item())
84
+ attention_mask = torch.cat([attention_mask, torch.ones((1, 1), device=model.device)], dim=-1)
85
+
86
+ # Decode the current token and yield
87
+ current_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
88
+ if "### Response:" in current_text:
89
+ response_text = current_text.split("### Response:")[-1].strip()
90
+ yield response_text
91
+
92
+ # Check for end of generation
93
+ if next_token.item() == tokenizer.eos_token_id:
94
+ break
95
+
96
+ except Exception as e:
97
+ yield f"An error occurred: {str(e)}"
98
+ finally:
99
+ gc.collect()
100
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
101
 
102
+ # Create Gradio interface
103
  iface = gr.ChatInterface(
104
  generate_response,
105
+ chatbot=gr.Chatbot(
106
+ height=400,
107
+ type="messages",
108
+ ),
109
  textbox=gr.Textbox(
110
  placeholder="Type your message here...",
111
  container=False,
112
  scale=7
113
  ),
114
+ title="AwA-1.5B πŸ”Ž - CPU Optimized",
115
+ description="Chat with AwA (Answers with Athena). Optimized for CPU operation.",
116
  theme="ocean",
117
  examples=[
118
  "How can CRISPR help us Humans?",
 
122
  type="messages"
123
  )
124
 
125
+ iface.queue(max_size=5)
126
+ iface.launch(
127
+ share=False,
128
+ debug=False,
129
+ show_error=True,
130
+ server_name="0.0.0.0",
131
+ server_port=7860,
132
+ )