import os
import torch
import numpy as np
import gradio as gr
from random import sample
from detoxify import Detoxify
from datasets import load_dataset
from huggingface_hub import HfApi, ModelFilter, ModelSearchArguments
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM
from transformers import BloomTokenizerFast, BloomForCausalLM
HF_AUTH_TOKEN = os.environ.get("hf_token" or True)
DATASET = "allenai/real-toxicity-prompts"
CHECKPOINTS = {
"DistilGPT2 by HuggingFace 🤗": "distilgpt2",
"GPT-Neo 125M by EleutherAI 🤖": "EleutherAI/gpt-neo-125M",
"BLOOM 560M by BigScience 🌸": "bigscience/bloom-560m",
"Custom Model": None,
}
MODEL_CLASSES = {
"DistilGPT2 by HuggingFace 🤗": (GPT2LMHeadModel, GPT2Tokenizer),
"GPT-Neo 125M by EleutherAI 🤖": (GPTNeoForCausalLM, GPT2Tokenizer),
"BLOOM 560M by BigScience 🌸": (BloomForCausalLM, BloomTokenizerFast),
"Custom Model": (AutoModelForCausalLM, AutoTokenizer),
}
CHOICES = sorted(list(CHECKPOINTS.keys())[:3])
def load_model(model_name, custom_model_path, token):
try:
model_class, tokenizer_class = MODEL_CLASSES[model_name]
model_path = CHECKPOINTS[model_name]
except KeyError:
model_class, tokenizer_class = MODEL_CLASSES["Custom Model"]
model_path = custom_model_path or model_name
model = model_class.from_pretrained(model_path, use_auth_token=token)
tokenizer = tokenizer_class.from_pretrained(model_path, use_auth_token=token)
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
model.eval()
return model, tokenizer
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
def set_seed(seed, n_gpu):
np.random.seed(seed)
torch.manual_seed(seed)
if n_gpu > 0:
torch.cuda.manual_seed_all(seed)
def adjust_length_to_model(length, max_sequence_length):
if length < 0 and max_sequence_length > 0:
length = max_sequence_length
elif 0 < max_sequence_length < length:
length = max_sequence_length # No generation bigger than model size
elif length < 0:
length = MAX_LENGTH # avoid infinite loop
return length
def generate(
model_name,
token,
custom_model_path,
input_sentence,
length=75,
temperature=0.7,
top_k=50,
top_p=0.95,
seed=42,
no_cuda=False,
num_return_sequences=1,
stop_token=".",
):
# load device
# if not no_cuda:
device = torch.device(
"cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
)
n_gpu = 0 if no_cuda else torch.cuda.device_count()
# Set seed
set_seed(seed, n_gpu)
# Load model
model, tokenizer = load_model(model_name, custom_model_path, token)
model.to(device)
# length = adjust_length_to_model(length, max_sequence_length=model.config.max_position_embeddings)
# Tokenize input
encoded_prompt = tokenizer.encode(
input_sentence, add_special_tokens=False, return_tensors="pt"
)
encoded_prompt = encoded_prompt.to(device)
input_ids = encoded_prompt
# Generate output
output_sequences = model.generate(
input_ids=input_ids,
max_length=length + len(encoded_prompt[0]),
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=True,
num_return_sequences=num_return_sequences,
)
generated_sequences = list()
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
generated_sequence = generated_sequence.tolist()
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
# remove prompt
text = text[
len(
tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)
) :
]
# remove all text after last occurence of stop_token
text = text[: text.rfind(stop_token) + 1]
generated_sequences.append(text)
return generated_sequences[0]
def show_mode(mode):
if mode == "Single Model":
return (gr.update(visible=True), gr.update(visible=False))
if mode == "Multi-Model":
return (gr.update(visible=False), gr.update(visible=True))
def prepare_dataset(dataset):
dataset = load_dataset(dataset, split="train")
return dataset
def load_prompts(dataset):
prompts = [dataset[i]["prompt"]["text"] for i in range(len(dataset))]
return prompts
def random_sample(prompt_list):
random_sample = sample(prompt_list, 10)
return random_sample
def show_dataset(dataset):
raw_data = prepare_dataset(dataset)
prompts = load_prompts(raw_data)
return (
gr.update(
choices=random_sample(prompts),
label="You can find below a random subset from the RealToxicityPrompts dataset",
visible=True,
),
gr.update(visible=True),
prompts,
)
def update_dropdown(prompts):
return gr.update(choices=random_sample(prompts))
def show_search_bar(value):
if value == "Custom Model":
return (value, gr.update(visible=True))
else:
return (value, gr.update(visible=False))
def search_model(model_name, token):
api = HfApi()
model_args = ModelSearchArguments()
filt = ModelFilter(
task=model_args.pipeline_tag.TextGeneration, library=model_args.library.PyTorch
)
results = api.list_models(filter=filt, search=model_name, use_auth_token=token)
model_list = [model.modelId for model in results]
return gr.update(
visible=True,
choices=model_list,
label="Choose the model",
)
def show_api_key_textbox(checkbox):
if checkbox:
return gr.update(visible=True)
else:
return gr.update(visible=False)
def forward_model_choice(model_choice_path):
return (model_choice_path, model_choice_path)
def auto_complete(input, generated):
output = input + " " + generated
output_spans = [{"entity": "OUTPUT", "start": len(input), "end": len(output)}]
completed_prompt = {"text": output, "entities": output_spans}
return completed_prompt
def process_user_input(
model, token, custom_model_path, input, length, temperature, top_p, top_k
):
warning = "Please enter a valid prompt."
if input == None:
generated = warning
else:
generated = generate(
model_name=model,
token=token,
custom_model_path=custom_model_path,
input_sentence=input,
length=length,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
generated = generated.replace("\n", " ")
generated_with_spans = auto_complete(input=input, generated=generated)
return (
gr.update(value=generated_with_spans),
gr.update(visible=True),
gr.update(visible=True),
input,
generated,
)
def pass_to_textbox(input):
return gr.update(value=input)
def run_detoxify(text):
results = Detoxify("original").predict(text)
json_ready_results = {cat: float(score) for (cat, score) in results.items()}
return json_ready_results
def compute_toxi_output(output_text):
scores = run_detoxify(output_text)
return (gr.update(value=scores, visible=True), gr.update(visible=True))
def compute_change(input, output):
change_percent = round(((float(output) - input) / input) * 100, 2)
return change_percent
def compare_toxi_scores(input_text, output_scores):
input_scores = run_detoxify(input_text)
json_ready_results = {cat: float(score) for (cat, score) in input_scores.items()}
compare_scores = {
cat: compute_change(json_ready_results[cat], output_scores[cat])
for cat in json_ready_results
for cat in output_scores
}
return (
gr.update(value=json_ready_results, visible=True),
gr.update(value=compare_scores, visible=True),
)
def show_flag_choices():
return gr.update(visible=True)
def update_flag(flag_value):
return (
flag_value,
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=False),
)
def upload_flag(*args):
flags = list(args)
flags[1] = bytes(flags[1], "utf-8")
flagging_callback.flag(flags)
return gr.update(visible=True)
def forward_model_choice_multi(model_choice_path):
CHOICES.append(model_choice_path)
return gr.update(choices=CHOICES)
def process_user_input_multi(models, input, token, length, temperature, top_p, top_k):
warning = "Please enter a valid prompt."
if input == None:
generated = warning
else:
generated_dict = {
model: generate(
model_name=model,
token=token,
custom_model_path=None,
input_sentence=input,
length=length,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
for model in sorted(models)
}
generated_with_spans_dict = {
model: auto_complete(input, generated)
for model, generated in generated_dict.items()
}
update_outputs = [
gr.HighlightedText.update(value=output, label=model)
for model, output in generated_with_spans_dict.items()
]
update_hide = [
gr.HighlightedText.update(visible=False) for i in range(10 - len(models))
]
return update_outputs + update_hide
def show_choices_multi(models):
update_show = [gr.HighlightedText.update(visible=True) for model in sorted(models)]
update_hide = [
gr.HighlightedText.update(visible=False, value=None, label=None)
for i in range(10 - len(models))
]
return update_show + update_hide
def show_params(checkbox):
if checkbox == True:
return gr.update(visible=True)
else:
return gr.update(visible=False)
CSS = """
#inside_group {
padding-top: 0.6em;
padding-bottom: 0.6em;
}
#pw textarea {
-webkit-text-security: disc;
}
"""
with gr.Blocks(css=CSS) as demo:
dataset = gr.Variable(value=DATASET)
prompts_var = gr.Variable(value=None)
input_var = gr.Variable(label="Input Prompt", value=None)
output_var = gr.Variable(label="Output", value=None)
model_choice = gr.Variable(label="Model", value=None)
custom_model_path = gr.Variable(value=None)
flag_choice = gr.Variable(label="Flag", value=None)
flagging_callback = gr.HuggingFaceDatasetSaver(
hf_token=HF_AUTH_TOKEN,
dataset_name="fsdlredteam/flagged_3",
private=True,
)
gr.Markdown("
")
gr.Markdown("BuggingSpace
")
gr.Markdown(
"FSDL 2022 Red-Teaming Open-Source Models Project
"
)
gr.Markdown(
"### Pick a text generation model below, write a prompt and explore the output"
)
gr.Markdown("### Or compare the output of multiple models at the same time")
choose_mode = gr.Radio(
choices=["Single Model", "Multi-Model"],
value="Single Model",
interactive=True,
visible=True,
show_label=False,
)
with gr.Group() as single_model:
gr.Markdown(
"You can upload any model from the Hugging Face hub -even private ones, \
provided you use your private key! "
"Write your prompt or alternatively use one from the \
[RealToxicityPrompts](https://allenai.org/data/real-toxicity-prompts) dataset."
)
gr.Markdown(
"Use it to audit the model for potential failure modes, \
analyse its output with the Detoxify suite and contribute by reporting any problematic result."
)
gr.Markdown(
"Beware ! Generation can take up to a few minutes with very large models."
)
with gr.Row():
with gr.Column(scale=1): # input & prompts dataset exploration
gr.Markdown("### 1. Select a prompt", elem_id="inside_group")
input_text = gr.Textbox(
label="Write your prompt below.",
interactive=True,
lines=4,
elem_id="inside_group",
)
gr.Markdown("— or —", elem_id="inside_group")
inspo_button = gr.Button(
"Click here if you need some inspiration", elem_id="inside_group"
)
prompts_drop = gr.Dropdown(visible=False, elem_id="inside_group")
randomize_button = gr.Button(
"Show another subset", visible=False, elem_id="inside_group"
)
show_params_checkbox_single = gr.Checkbox(
label="Set custom params", interactive=True, value=False
)
with gr.Box(visible=False) as params_box_single:
length_single = gr.Slider(
label="Output length",
visible=True,
interactive=True,
minimum=50,
maximum=200,
value=75,
)
top_k_single = gr.Slider(
label="top_k",
visible=True,
interactive=True,
minimum=1,
maximum=100,
value=50,
)
top_p_single = gr.Slider(
label="top_p",
visible=True,
interactive=True,
minimum=0.1,
maximum=1,
value=0.95,
)
temperature_single = gr.Slider(
label="temperature",
visible=True,
interactive=True,
minimum=0.1,
maximum=1,
value=0.7,
)
with gr.Column(scale=1): # Model choice & output
gr.Markdown("### 2. Evaluate output")
model_radio = gr.Radio(
choices=list(CHECKPOINTS.keys()),
label="Model",
interactive=True,
elem_id="inside_group",
)
search_bar = gr.Textbox(
label="Search model",
interactive=True,
visible=False,
elem_id="inside_group",
)
model_drop = gr.Dropdown(visible=False)
private_checkbox = gr.Checkbox(
visible=True, label="Private Model ?", elem_id="inside_group"
)
api_key_textbox = gr.Textbox(
label="Enter your AUTH TOKEN below",
value=None,
interactive=True,
visible=False,
elem_id="pw",
)
generate_button = gr.Button(
"Submit your prompt", elem_id="inside_group"
)
output_spans = gr.HighlightedText(visible=True, label="Generated text")
flag_button = gr.Button(
"Report output here", visible=False, elem_id="inside_group"
)
with gr.Row(): # Flagging
with gr.Column(scale=1):
flag_radio = gr.Radio(
choices=[
"Toxic",
"Offensive",
"Repetitive",
"Incorrect",
"Other",
],
label="What's wrong with the output ?",
interactive=True,
visible=False,
elem_id="inside_group",
)
user_comment = gr.Textbox(
label="(Optional) Briefly describe the issue",
visible=False,
interactive=True,
elem_id="inside_group",
)
confirm_flag_button = gr.Button(
"Confirm report", visible=False, elem_id="inside_group"
)
with gr.Row(): # Flagging success
success_message = gr.Markdown(
"Your report has been successfully registered. Thank you!",
visible=False,
elem_id="inside_group",
)
with gr.Row(): # Toxicity buttons
toxi_button = gr.Button(
"Run a toxicity analysis of the model's output",
visible=False,
elem_id="inside_group",
)
toxi_button_compare = gr.Button(
"Compare toxicity on input and output",
visible=False,
elem_id="inside_group",
)
with gr.Row(): # Toxicity scores
toxi_scores_input = gr.JSON(
label="Detoxify classification of your input",
visible=False,
elem_id="inside_group",
)
toxi_scores_output = gr.JSON(
label="Detoxify classification of the model's output",
visible=False,
elem_id="inside_group",
)
toxi_scores_compare = gr.JSON(
label="Percentage change between Input and Output",
visible=False,
elem_id="inside_group",
)
with gr.Group(visible=False) as multi_model:
model_list = list()
gr.Markdown(
"#### Run the same input on multiple models and compare the outputs"
)
gr.Markdown(
"You can upload any model from the Hugging Face hub -even private ones, provided you use your private key!"
)
gr.Markdown(
"Use this feature to compare the same model at different checkpoints"
)
gr.Markdown("Or to benchmark your model against another one as a reference.")
gr.Markdown(
"Beware ! Generation can take up to a few minutes with very large models."
)
with gr.Row(elem_id="inside_group"):
with gr.Column():
models_multi = gr.CheckboxGroup(
choices=CHOICES,
label="Models",
interactive=True,
elem_id="inside_group",
value=None,
)
with gr.Column():
generate_button_multi = gr.Button(
"Submit your prompt", elem_id="inside_group"
)
show_params_checkbox_multi = gr.Checkbox(
label="Set custom params", interactive=True, value=False
)
with gr.Box(visible=False) as params_box_multi:
length_multi = gr.Slider(
label="Output length",
visible=True,
interactive=True,
minimum=50,
maximum=200,
value=75,
)
top_k_multi = gr.Slider(
label="top_k",
visible=True,
interactive=True,
minimum=1,
maximum=100,
value=50,
)
top_p_multi = gr.Slider(
label="top_p",
visible=True,
interactive=True,
minimum=0.1,
maximum=1,
value=0.95,
)
temperature_multi = gr.Slider(
label="temperature",
visible=True,
interactive=True,
minimum=0.1,
maximum=1,
value=0.7,
)
with gr.Row(elem_id="inside_group"):
with gr.Column(elem_id="inside_group", scale=1):
input_text_multi = gr.Textbox(
label="Write your prompt below.",
interactive=True,
lines=4,
elem_id="inside_group",
)
with gr.Column(elem_id="inside_group", scale=1):
search_bar_multi = gr.Textbox(
label="Search another model",
interactive=True,
visible=True,
elem_id="inside_group",
)
model_drop_multi = gr.Dropdown(visible=False, elem_id="inside_group")
private_checkbox_multi = gr.Checkbox(
visible=True, label="Private Model ?"
)
api_key_textbox_multi = gr.Textbox(
label="Enter your AUTH TOKEN below",
value=None,
interactive=True,
visible=False,
elem_id="pw",
)
with gr.Row() as outputs_row:
for i in range(10):
output_spans_multi = gr.HighlightedText(
visible=False, elem_id="inside_group"
)
model_list.append(output_spans_multi)
with gr.Row():
gr.Markdown(
"App made during the [FSDL course](https://fullstackdeeplearning.com) \
by Team53: Jean-Antoine, Sajenthan, Sashank, Kemp, Srihari, Astitwa"
)
# Single Model
choose_mode.change(
fn=show_mode, inputs=choose_mode, outputs=[single_model, multi_model]
)
inspo_button.click(
fn=show_dataset,
inputs=dataset,
outputs=[prompts_drop, randomize_button, prompts_var],
)
prompts_drop.change(fn=pass_to_textbox, inputs=prompts_drop, outputs=input_text)
randomize_button.click(
fn=update_dropdown, inputs=prompts_var, outputs=prompts_drop
),
model_radio.change(
fn=show_search_bar, inputs=model_radio, outputs=[model_choice, search_bar]
)
search_bar.submit(
fn=search_model,
inputs=[search_bar, api_key_textbox],
outputs=model_drop,
show_progress=True,
)
private_checkbox.change(
fn=show_api_key_textbox, inputs=private_checkbox, outputs=api_key_textbox
)
model_drop.change(
fn=forward_model_choice,
inputs=model_drop,
outputs=[model_choice, custom_model_path],
)
generate_button.click(
fn=process_user_input,
inputs=[
model_choice,
api_key_textbox,
custom_model_path,
input_text,
length_single,
temperature_single,
top_p_single,
top_k_single,
],
outputs=[output_spans, toxi_button, flag_button, input_var, output_var],
show_progress=True,
)
toxi_button.click(
fn=compute_toxi_output,
inputs=output_var,
outputs=[toxi_scores_output, toxi_button_compare],
show_progress=True,
)
toxi_button_compare.click(
fn=compare_toxi_scores,
inputs=[input_text, toxi_scores_output],
outputs=[toxi_scores_input, toxi_scores_compare],
show_progress=True,
)
flag_button.click(fn=show_flag_choices, inputs=None, outputs=flag_radio)
flag_radio.change(
fn=update_flag,
inputs=flag_radio,
outputs=[flag_choice, confirm_flag_button, user_comment, flag_button],
)
flagging_callback.setup(
[input_var, output_var, model_choice, user_comment, flag_choice],
"flagged_data_points",
)
confirm_flag_button.click(
fn=upload_flag,
inputs=[input_var, output_var, model_choice, user_comment, flag_choice],
outputs=success_message,
)
show_params_checkbox_single.change(
fn=show_params, inputs=show_params_checkbox_single, outputs=params_box_single
)
# Model comparison
search_bar_multi.submit(
fn=search_model,
inputs=[search_bar_multi, api_key_textbox_multi],
outputs=model_drop_multi,
show_progress=True,
)
show_params_checkbox_multi.change(
fn=show_params, inputs=show_params_checkbox_multi, outputs=params_box_multi
)
private_checkbox_multi.change(
fn=show_api_key_textbox,
inputs=private_checkbox_multi,
outputs=api_key_textbox_multi,
)
model_drop_multi.change(
fn=forward_model_choice_multi, inputs=model_drop_multi, outputs=[models_multi]
)
models_multi.change(fn=show_choices_multi, inputs=models_multi, outputs=model_list)
generate_button_multi.click(
fn=process_user_input_multi,
inputs=[
models_multi,
input_text_multi,
api_key_textbox_multi,
length_multi,
temperature_multi,
top_p_multi,
top_k_multi,
],
outputs=model_list,
show_progress=True,
)
if __name__ == "__main__":
# demo.queue(concurrency_count=3)
demo.launch(debug=True)