Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from huggingface_hub import login | |
choices_base_models = { | |
'groloch/Llama-3.2-3B-Instruct-PromptEnhancing': 'meta-llama/Llama-3.2-3B-Instruct', | |
'groloch/gemma-2-2b-it-PromptEnhancing': 'google/gemma-2-2b-it', | |
'groloch/Qwen2.5-3B-Instruct-PromptEnhancing': 'Qwen/Qwen2.5-3B-Instruct', | |
'groloch/Ministral-3b-instruct-PromptEnhancing': 'ministral/Ministral-3b-instruct' | |
} | |
choices_gen_token = { | |
'groloch/Llama-3.2-3B-Instruct-PromptEnhancing': 'assistant', | |
'groloch/gemma-2-2b-it-PromptEnhancing': 'model', | |
'groloch/Qwen2.5-3B-Instruct-PromptEnhancing': 'assistant', | |
'groloch/Ministral-3b-instruct-PromptEnhancing': 'ministral/Ministral-3b-instruct' | |
} | |
gated_models = [ | |
'groloch/Llama-3.2-3B-Instruct-PromptEnhancing', | |
'groloch/gemma-2-2b-it-PromptEnhancing' | |
] | |
previous_choice = '' | |
model = None | |
tokenizer = None | |
logged_in = False | |
def load_model(adapter_repo_id: str): | |
global model, tokenizer | |
base_repo_id = choices_base_models[adapter_repo_id] | |
tokenizer = AutoTokenizer.from_pretrained(base_repo_id) | |
model = AutoModelForCausalLM.from_pretrained(base_repo_id, torch_dtype=torch.bfloat16) | |
model.load_adapter(adapter_repo_id) | |
def generate(prompt_to_enhance: str, | |
choice: str, | |
max_tokens: float, | |
temperature: float, | |
top_p: float, | |
repetition_penalty: float, | |
access_token: str | |
): | |
if prompt_to_enhance is None or prompt_to_enhance == '': | |
raise gr.Error('Please enter a prompt') | |
global previous_choice | |
if choice != previous_choice: | |
previous_choice = choice | |
load_model(choice) | |
if choice in gated_models and access_token == '': | |
raise gr.Error(f'Please enter your access token (in Additional inputs) if youre using one of the following \ | |
models: {', '.join(gated_models)}. Make sure you have access to those models.') | |
global logged_in | |
if not logged_in and choice in gated_models: | |
login(access_token) | |
logged_in = True | |
chat = [ | |
{'role' : 'user', 'content': prompt_to_enhance} | |
] | |
prompt = tokenizer.apply_chat_template(chat, | |
tokenize=False, | |
add_generation_prompt=True, | |
return_tensors='pt') | |
encoding = tokenizer(prompt, return_tensors="pt") | |
generation_config = model.generation_config | |
generation_config.do_sample = True | |
generation_config.max_new_tokens = int(max_tokens) | |
generation_config.temperature = float(temperature) | |
generation_config.top_p = float(top_p) | |
generation_config.num_return_sequences = 1 | |
generation_config.pad_token_id = tokenizer.eos_token_id | |
generation_config.eos_token_id = tokenizer.eos_token_id | |
generation_config.repetition_penalty = float(repetition_penalty) | |
with torch.inference_mode(): | |
outputs = model.generate( | |
input_ids=encoding.input_ids, | |
attention_mask=encoding.attention_mask, | |
generation_config=generation_config | |
) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True).split(choices_gen_token[choice])[-1] | |
# | |
# Inputs | |
# | |
model_choice = gr.Dropdown( | |
label='Model choice', | |
choices=['groloch/Llama-3.2-3B-Instruct-PromptEnhancing', | |
'groloch/gemma-2-2b-it-PromptEnhancing', | |
'groloch/Qwen2.5-3B-Instruct-PromptEnhancing', | |
'groloch/Ministral-3b-instruct-PromptEnhancing' | |
], | |
value='groloch/Llama-3.2-3B-Instruct-PromptEnhancing' | |
) | |
input_prompt = gr.Text( | |
label='Prompt to enhance' | |
) | |
# | |
# Additional inputs | |
# | |
input_max_tokens = gr.Number( | |
label='Max generated tokens', | |
value=64, | |
minimum=16, | |
maximum=128 | |
) | |
input_temperature = gr.Number( | |
label='Temperature', | |
value=0.3, | |
minimum=0.0, | |
maximum=1.5, | |
step=0.05 | |
) | |
input_top_p = gr.Number( | |
label='Top p', | |
value=0.9, | |
minimum=0.0, | |
maximum=1.0, | |
step=0.05 | |
) | |
input_repetition_penalty = gr.Number( | |
label='Repetition penalty', | |
value=2.0, | |
minimum=0.0, | |
maximum=5.0, | |
step=0.1 | |
) | |
input_access_token = gr.Text( | |
label='Access token for gated models', | |
value='' | |
) | |
demo = gr.Interface( | |
generate, | |
title='Prompt Enhancing Playground', | |
description='This space is a tool to compare the different prompt enhancing model I have finetuned. \ | |
Feel free to experiment as you want ! \n\ | |
If you want to use this locally, you can download the gpu version (see in files)', | |
inputs=[input_prompt, model_choice], | |
additional_inputs=[input_max_tokens, | |
input_temperature, | |
input_top_p, | |
input_repetition_penalty, | |
input_access_token | |
], | |
outputs=['text'] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |