from huggingface_hub import AsyncInferenceClient import gradio as gr client = AsyncInferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1") def format_prompt(prompt: str, history: list[str], system_prompt: str) -> str: if not history: final_prompt = ( f"[INST] {system_prompt if system_prompt else ''}:\n{prompt} [/INST]" ) else: formatted_history = "".join( f"[INST] {user_prompt} [/INST]{bot_response} " for user_prompt, bot_response in history ) final_prompt = f"{formatted_history}[INST] {prompt} [/INST]" return final_prompt async def generate( prompt: str, history: list[str], system_prompt: str = "You're a helpful assistant.", temperature: float = 0.3, max_new_tokens: int = 4000, top_p: float = 0.95, repetition_penalty: float = 1.0, ): temperature = float(temperature) if temperature < 1e-2: temperature = 1e-2 top_p = float(top_p) generate_kwargs = dict( temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, seed=42, ) formatted_prompt = format_prompt( prompt=prompt, history=history, system_prompt=system_prompt ) stream = await client.text_generation( formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True, ) output = f"" async for response in stream: output += response.token.text yield output additional_inputs = [ gr.Textbox( label="System Prompt (optional)", value="You're a helpful assistant.", info="This is experimental", placeholder="system prompt", ), gr.Slider( label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs", ), gr.Slider( label="Max new tokens", value=256, minimum=0, maximum=1048, step=64, interactive=True, info="The maximum numbers of new tokens", ), gr.Slider( label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens", ), gr.Slider( label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens", ), ] chatbot = gr.Chatbot( avatar_images=["./user.png", "./bot.png"], bubble_full_width=False, show_label=False, show_copy_button=True, likeable=True, ) demo = gr.ChatInterface( fn=generate, additional_inputs=additional_inputs, chatbot=chatbot, title="🪷", description="Mixtral-8x7B-Instruct-v0.1", concurrency_limit=20, ) demo.queue().launch()