Spaces:
Runtime error
Runtime error
File size: 4,989 Bytes
596ba9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import argparse
import time
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
import torch
from threading import Thread
MODEL_PATH = 'THUDM/glm-4-9b-chat'
def stress_test(token_len, n, num_gpu):
device = torch.device(f"cuda:{num_gpu - 1}" if torch.cuda.is_available() and num_gpu > 0 else "cpu")
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
padding_side="left"
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
torch_dtype=torch.bfloat16
).to(device).eval()
# Use INT4 weight infer
# model = AutoModelForCausalLM.from_pretrained(
# MODEL_PATH,
# trust_remote_code=True,
# quantization_config=BitsAndBytesConfig(load_in_4bit=True),
# low_cpu_mem_usage=True,
# ).eval()
times = []
decode_times = []
print("Warming up...")
vocab_size = tokenizer.vocab_size
warmup_token_len = 20
random_token_ids = torch.randint(3, vocab_size - 200, (warmup_token_len - 5,), dtype=torch.long)
start_tokens = [151331, 151333, 151336, 198]
end_tokens = [151337]
input_ids = torch.tensor(start_tokens + random_token_ids.tolist() + end_tokens, dtype=torch.long).unsqueeze(0).to(
device)
attention_mask = torch.ones_like(input_ids, dtype=torch.bfloat16).to(device)
position_ids = torch.arange(len(input_ids[0]), dtype=torch.bfloat16).unsqueeze(0).to(device)
warmup_inputs = {
'input_ids': input_ids,
'attention_mask': attention_mask,
'position_ids': position_ids
}
with torch.no_grad():
_ = model.generate(
input_ids=warmup_inputs['input_ids'],
attention_mask=warmup_inputs['attention_mask'],
max_new_tokens=2048,
do_sample=False,
repetition_penalty=1.0,
eos_token_id=[151329, 151336, 151338]
)
print("Warming up complete. Starting stress test...")
for i in range(n):
random_token_ids = torch.randint(3, vocab_size - 200, (token_len - 5,), dtype=torch.long)
input_ids = torch.tensor(start_tokens + random_token_ids.tolist() + end_tokens, dtype=torch.long).unsqueeze(
0).to(device)
attention_mask = torch.ones_like(input_ids, dtype=torch.bfloat16).to(device)
position_ids = torch.arange(len(input_ids[0]), dtype=torch.bfloat16).unsqueeze(0).to(device)
test_inputs = {
'input_ids': input_ids,
'attention_mask': attention_mask,
'position_ids': position_ids
}
streamer = TextIteratorStreamer(
tokenizer=tokenizer,
timeout=36000,
skip_prompt=True,
skip_special_tokens=True
)
generate_kwargs = {
"input_ids": test_inputs['input_ids'],
"attention_mask": test_inputs['attention_mask'],
"max_new_tokens": 512,
"do_sample": False,
"repetition_penalty": 1.0,
"eos_token_id": [151329, 151336, 151338],
"streamer": streamer
}
start_time = time.time()
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
first_token_time = None
all_token_times = []
for token in streamer:
current_time = time.time()
if first_token_time is None:
first_token_time = current_time
times.append(first_token_time - start_time)
all_token_times.append(current_time)
t.join()
end_time = time.time()
avg_decode_time_per_token = len(all_token_times) / (end_time - first_token_time) if all_token_times else 0
decode_times.append(avg_decode_time_per_token)
print(
f"Iteration {i + 1}/{n} - Prefilling Time: {times[-1]:.4f} seconds - Average Decode Time: {avg_decode_time_per_token:.4f} tokens/second")
torch.cuda.empty_cache()
avg_first_token_time = sum(times) / n
avg_decode_time = sum(decode_times) / n
print(f"\nAverage First Token Time over {n} iterations: {avg_first_token_time:.4f} seconds")
print(f"Average Decode Time per Token over {n} iterations: {avg_decode_time:.4f} tokens/second")
return times, avg_first_token_time, decode_times, avg_decode_time
def main():
parser = argparse.ArgumentParser(description="Stress test for model inference")
parser.add_argument('--token_len', type=int, default=1000, help='Number of tokens for each test')
parser.add_argument('--n', type=int, default=3, help='Number of iterations for the stress test')
parser.add_argument('--num_gpu', type=int, default=1, help='Number of GPUs to use for inference')
args = parser.parse_args()
token_len = args.token_len
n = args.n
num_gpu = args.num_gpu
stress_test(token_len, n, num_gpu)
if __name__ == "__main__":
main()
|