Spaces:
Runtime error
Runtime error
""" | |
Use FastChat with Hugging Face generation APIs. | |
Usage: | |
python3 -m fastchat.serve.huggingface_api --model lmsys/vicuna-7b-v1.5 | |
python3 -m fastchat.serve.huggingface_api --model lmsys/fastchat-t5-3b-v1.0 | |
""" | |
import argparse | |
import torch | |
from fastchat.model import load_model, get_conversation_template, add_model_args | |
def main(args): | |
# Load model | |
model, tokenizer = load_model( | |
args.model_path, | |
device=args.device, | |
num_gpus=args.num_gpus, | |
max_gpu_memory=args.max_gpu_memory, | |
load_8bit=args.load_8bit, | |
cpu_offloading=args.cpu_offloading, | |
revision=args.revision, | |
debug=args.debug, | |
) | |
# Build the prompt with a conversation template | |
msg = args.message | |
conv = get_conversation_template(args.model_path) | |
conv.append_message(conv.roles[0], msg) | |
conv.append_message(conv.roles[1], None) | |
prompt = conv.get_prompt() | |
# Run inference | |
inputs = tokenizer([prompt], return_tensors="pt").to(args.device) | |
output_ids = model.generate( | |
**inputs, | |
do_sample=True if args.temperature > 1e-5 else False, | |
temperature=args.temperature, | |
repetition_penalty=args.repetition_penalty, | |
max_new_tokens=args.max_new_tokens, | |
) | |
if model.config.is_encoder_decoder: | |
output_ids = output_ids[0] | |
else: | |
output_ids = output_ids[0][len(inputs["input_ids"][0]) :] | |
outputs = tokenizer.decode( | |
output_ids, skip_special_tokens=True, spaces_between_special_tokens=False | |
) | |
# Print results | |
print(f"{conv.roles[0]}: {msg}") | |
print(f"{conv.roles[1]}: {outputs}") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
add_model_args(parser) | |
parser.add_argument("--temperature", type=float, default=0.7) | |
parser.add_argument("--repetition_penalty", type=float, default=1.0) | |
parser.add_argument("--max-new-tokens", type=int, default=1024) | |
parser.add_argument("--debug", action="store_true") | |
parser.add_argument("--message", type=str, default="Hello! Who are you?") | |
args = parser.parse_args() | |
# Reset default repetition penalty for T5 models. | |
if "t5" in args.model_path and args.repetition_penalty == 1.0: | |
args.repetition_penalty = 1.2 | |
main(args) | |