from threading import Thread import logging import time logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s - %(message)s", ) import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer model_id = "pszemraj/tFINE-850m-24x24-v0.5-instruct-L1" torch_device = "cuda" if torch.cuda.is_available() else "cpu" logging.info(f"Running on device:\t {torch_device}") logging.info(f"CPU threads:\t {torch.get_num_threads()}") if torch_device == "cuda": model = AutoModelForSeq2SeqLM.from_pretrained( model_id, load_in_8bit=True, device_map="auto" ) else: model = AutoModelForSeq2SeqLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) try: model = torch.compile(model) except Exception as e: logging.error(f"Unable to compile model:\t{e}") tokenizer = AutoTokenizer.from_pretrained(model_id) def run_generation( user_text, top_p, temperature, top_k, max_new_tokens, repetition_penalty=1.1, length_penalty=1.0, no_repeat_ngram_size=4, use_generation_config=False, ): st = time.perf_counter() # Get the model and tokenizer, and tokenize the user text. model_inputs = tokenizer([user_text], return_tensors="pt").to(torch_device) # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer # in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread. streamer = TextIteratorStreamer( tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True ) generate_kwargs = dict( model_inputs, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, num_beams=1, top_p=top_p, temperature=float(temperature), top_k=top_k, repetition_penalty=repetition_penalty, length_penalty=length_penalty, no_repeat_ngram_size=no_repeat_ngram_size, renormalize_logits=True, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() # Pull the generated text from the streamer, and update the model output. model_output = "" for new_text in streamer: model_output += new_text yield model_output logging.info("Total rt:\t{rt} sec".format(rt=round(time.perf_counter() - st, 3))) return model_output def reset_textbox(): return gr.update(value="") with gr.Blocks() as demo: duplicate_link = ( "https://huggingface.co/spaces/joaogante/transformers_streaming?duplicate=true" ) gr.Markdown( "# 🤗 Transformers 🔥Streaming🔥 on Gradio\n" "This demo showcases the use of the " "[streaming feature](https://huggingface.co/docs/transformers/main/en/generation_strategies#streaming) " "of 🤗 Transformers with Gradio to generate text in real-time. It uses " f"[{model_id}](https://huggingface.co/{model_id}) and the Spaces free compute tier.\n\n" f"Feel free to [duplicate this Space]({duplicate_link}) to try your own models or use this space as a " "template! 💛" ) gr.Markdown("---") with gr.Row(): with gr.Column(scale=4): user_text = gr.Textbox( value="How to become a polar bear tamer?", label="User input", ) model_output = gr.Textbox(label="Model output", lines=10, interactive=False) button_submit = gr.Button(value="Submit", variant="primary") with gr.Column(scale=1): max_new_tokens = gr.Slider( minimum=32, maximum=1024, value=256, step=32, interactive=True, label="Max New Tokens", ) top_p = gr.Slider( minimum=0.05, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top-p (nucleus sampling)", ) top_k = gr.Slider( minimum=1, maximum=50, value=50, step=1, interactive=True, label="Top-k", ) temperature = gr.Slider( minimum=0.1, maximum=1.4, value=0.3, step=0.05, interactive=True, label="Temperature", ) repetition_penalty = gr.Slider( minimum=0.9, maximum=2.5, value=1.1, step=0.1, interactive=True, label="Repetition Penalty", ) length_penalty = gr.Slider( minimum=0.8, maximum=1.5, value=1.0, step=0.1, interactive=True, label="Length Penalty", ) user_text.submit( run_generation, [user_text, top_p, temperature, top_k, max_new_tokens, repetition_penalty, length_penalty], model_output, ) button_submit.click( run_generation, [user_text, top_p, temperature, top_k, max_new_tokens, repetition_penalty, length_penalty], model_output, ) demo.queue(max_size=10).launch()