Spaces:
Sleeping
Sleeping
File size: 5,416 Bytes
09d15e8 27da979 09d15e8 35700da 09d15e8 af12f2c 09d15e8 27da979 09d15e8 58acd65 af12f2c 09d15e8 27da979 5209ab6 27da979 09d15e8 27da979 09d15e8 0240ed4 09d15e8 27da979 5209ab6 27da979 ea9c426 09d15e8 27da979 09d15e8 27da979 09d15e8 27da979 09d15e8 2f48801 09d15e8 7c1503a 27da979 09d15e8 2f48801 09d15e8 6958233 2f48801 6958233 27da979 09d15e8 27da979 09d15e8 27da979 09d15e8 af12f2c 7c1503a 983046b 7c1503a af12f2c 5209ab6 27da979 5209ab6 09d15e8 5209ab6 ea9c426 27da979 5209ab6 27da979 5209ab6 27da979 09d15e8 58acd65 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
from threading import Thread
import logging
import time
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
)
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer
model_id = "pszemraj/tFINE-850m-24x24-v0.5-instruct-L1"
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
logging.info(f"Running on device:\t {torch_device}")
logging.info(f"CPU threads:\t {torch.get_num_threads()}")
if torch_device == "cuda":
model = AutoModelForSeq2SeqLM.from_pretrained(
model_id, load_in_8bit=True, device_map="auto"
)
else:
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
try:
model = torch.compile(model)
except Exception as e:
logging.error(f"Unable to compile model:\t{e}")
tokenizer = AutoTokenizer.from_pretrained(model_id)
def run_generation(
user_text,
top_p,
temperature,
top_k,
max_new_tokens,
repetition_penalty=1.1,
length_penalty=1.0,
no_repeat_ngram_size=4,
use_generation_config=False,
):
st = time.perf_counter()
# Get the model and tokenizer, and tokenize the user text.
model_inputs = tokenizer([user_text], return_tensors="pt").to(torch_device)
# Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
# in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
num_beams=1,
top_p=top_p,
temperature=float(temperature),
top_k=top_k,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
renormalize_logits=True,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# Pull the generated text from the streamer, and update the model output.
model_output = ""
for new_text in streamer:
model_output += new_text
yield model_output
logging.info("Total rt:\t{rt} sec".format(rt=round(time.perf_counter() - st, 3)))
return model_output
def reset_textbox():
return gr.update(value="")
with gr.Blocks() as demo:
duplicate_link = (
"https://huggingface.co/spaces/joaogante/transformers_streaming?duplicate=true"
)
gr.Markdown(
"# 🤗 Transformers 🔥Streaming🔥 on Gradio\n"
"This demo showcases the use of the "
"[streaming feature](https://huggingface.co/docs/transformers/main/en/generation_strategies#streaming) "
"of 🤗 Transformers with Gradio to generate text in real-time. It uses "
f"[{model_id}](https://huggingface.co/{model_id}) and the Spaces free compute tier.\n\n"
f"Feel free to [duplicate this Space]({duplicate_link}) to try your own models or use this space as a "
"template! 💛"
)
gr.Markdown("---")
with gr.Row():
with gr.Column(scale=4):
user_text = gr.Textbox(
value="How to become a polar bear tamer?",
label="User input",
)
model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
button_submit = gr.Button(value="Submit", variant="primary")
with gr.Column(scale=1):
max_new_tokens = gr.Slider(
minimum=32,
maximum=1024,
value=256,
step=32,
interactive=True,
label="Max New Tokens",
)
top_p = gr.Slider(
minimum=0.05,
maximum=1.0,
value=0.95,
step=0.05,
interactive=True,
label="Top-p (nucleus sampling)",
)
top_k = gr.Slider(
minimum=1,
maximum=50,
value=50,
step=1,
interactive=True,
label="Top-k",
)
temperature = gr.Slider(
minimum=0.1,
maximum=1.4,
value=0.3,
step=0.05,
interactive=True,
label="Temperature",
)
repetition_penalty = gr.Slider(
minimum=0.9,
maximum=2.5,
value=1.1,
step=0.1,
interactive=True,
label="Repetition Penalty",
)
length_penalty = gr.Slider(
minimum=0.8,
maximum=1.5,
value=1.0,
step=0.1,
interactive=True,
label="Length Penalty",
)
user_text.submit(
run_generation,
[user_text, top_p, temperature, top_k, max_new_tokens, repetition_penalty, length_penalty],
model_output,
)
button_submit.click(
run_generation,
[user_text, top_p, temperature, top_k, max_new_tokens, repetition_penalty, length_penalty],
model_output,
)
demo.queue(max_size=10).launch() |