import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM from huggingface_hub import login choices_base_models = { 'groloch/Llama-3.2-3B-Instruct-PromptEnhancing': 'meta-llama/Llama-3.2-3B-Instruct', 'groloch/gemma-2-2b-it-PromptEnhancing': 'google/gemma-2-2b-it', 'groloch/Qwen2.5-3B-Instruct-PromptEnhancing': 'Qwen/Qwen2.5-3B-Instruct', 'groloch/Ministral-3b-instruct-PromptEnhancing': 'ministral/Ministral-3b-instruct' } choices_gen_token = { 'groloch/Llama-3.2-3B-Instruct-PromptEnhancing': 'assistant', 'groloch/gemma-2-2b-it-PromptEnhancing': 'model', 'groloch/Qwen2.5-3B-Instruct-PromptEnhancing': 'assistant', 'groloch/Ministral-3b-instruct-PromptEnhancing': 'ministral/Ministral-3b-instruct' } gated_models = [ 'groloch/Llama-3.2-3B-Instruct-PromptEnhancing', 'groloch/gemma-2-2b-it-PromptEnhancing' ] previous_choice = '' model = None tokenizer = None logged_in = False def load_model(adapter_repo_id: str): global model, tokenizer base_repo_id = choices_base_models[adapter_repo_id] tokenizer = AutoTokenizer.from_pretrained(base_repo_id) model = AutoModelForCausalLM.from_pretrained(base_repo_id, torch_dtype=torch.bfloat16) model.load_adapter(adapter_repo_id) def generate(prompt_to_enhance: str, choice: str, max_tokens: float, temperature: float, top_p: float, repetition_penalty: float, access_token: str ): if prompt_to_enhance is None or prompt_to_enhance == '': raise gr.Error('Please enter a prompt') global previous_choice if choice != previous_choice: previous_choice = choice load_model(choice) if choice in gated_models and access_token == '': raise gr.Error(f'Please enter your access token (in Additional inputs) if youre using one of the following \ models: {', '.join(gated_models)}. Make sure you have access to those models.') global logged_in if not logged_in and choice in gated_models: login(access_token) logged_in = True chat = [ {'role' : 'user', 'content': prompt_to_enhance} ] prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True, return_tensors='pt') encoding = tokenizer(prompt, return_tensors="pt") generation_config = model.generation_config generation_config.do_sample = True generation_config.max_new_tokens = int(max_tokens) generation_config.temperature = float(temperature) generation_config.top_p = float(top_p) generation_config.num_return_sequences = 1 generation_config.pad_token_id = tokenizer.eos_token_id generation_config.eos_token_id = tokenizer.eos_token_id generation_config.repetition_penalty = float(repetition_penalty) with torch.inference_mode(): outputs = model.generate( input_ids=encoding.input_ids, attention_mask=encoding.attention_mask, generation_config=generation_config ) return tokenizer.decode(outputs[0], skip_special_tokens=True).split(choices_gen_token[choice])[-1] # # Inputs # model_choice = gr.Dropdown( label='Model choice', choices=['groloch/Llama-3.2-3B-Instruct-PromptEnhancing', 'groloch/gemma-2-2b-it-PromptEnhancing', 'groloch/Qwen2.5-3B-Instruct-PromptEnhancing', 'groloch/Ministral-3b-instruct-PromptEnhancing' ], value='groloch/Llama-3.2-3B-Instruct-PromptEnhancing' ) input_prompt = gr.Text( label='Prompt to enhance' ) # # Additional inputs # input_max_tokens = gr.Number( label='Max generated tokens', value=64, minimum=16, maximum=128 ) input_temperature = gr.Number( label='Temperature', value=0.3, minimum=0.0, maximum=1.5, step=0.05 ) input_top_p = gr.Number( label='Top p', value=0.9, minimum=0.0, maximum=1.0, step=0.05 ) input_repetition_penalty = gr.Number( label='Repetition penalty', value=2.0, minimum=0.0, maximum=5.0, step=0.1 ) input_access_token = gr.Text( label='Access token for gated models', value='' ) demo = gr.Interface( generate, title='Prompt Enhancing Playground', description='This space is a tool to compare the different prompt enhancing model I have finetuned. \ Feel free to experiment as you want ! \n\ If you want to use this locally, you can download the gpu version (see in files)', inputs=[input_prompt, model_choice], additional_inputs=[input_max_tokens, input_temperature, input_top_p, input_repetition_penalty, input_access_token ], outputs=['text'] ) if __name__ == "__main__": demo.launch()