import os import torch import numpy as np import gradio as gr from random import sample from detoxify import Detoxify from datasets import load_dataset from huggingface_hub import HfApi, ModelFilter, ModelSearchArguments from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM from transformers import BloomTokenizerFast, BloomForCausalLM HF_AUTH_TOKEN = os.environ.get("hf_token" or True) DATASET = "allenai/real-toxicity-prompts" CHECKPOINTS = { "DistilGPT2 by HuggingFace 🤗": "distilgpt2", "GPT-Neo 125M by EleutherAI 🤖": "EleutherAI/gpt-neo-125M", "BLOOM 560M by BigScience 🌸": "bigscience/bloom-560m", "Custom Model": None, } MODEL_CLASSES = { "DistilGPT2 by HuggingFace 🤗": (GPT2LMHeadModel, GPT2Tokenizer), "GPT-Neo 125M by EleutherAI 🤖": (GPTNeoForCausalLM, GPT2Tokenizer), "BLOOM 560M by BigScience 🌸": (BloomForCausalLM, BloomTokenizerFast), "Custom Model": (AutoModelForCausalLM, AutoTokenizer), } CHOICES = sorted(list(CHECKPOINTS.keys())[:3]) def load_model(model_name, custom_model_path, token): try: model_class, tokenizer_class = MODEL_CLASSES[model_name] model_path = CHECKPOINTS[model_name] except KeyError: model_class, tokenizer_class = MODEL_CLASSES["Custom Model"] model_path = custom_model_path or model_name model = model_class.from_pretrained(model_path, use_auth_token=token) tokenizer = tokenizer_class.from_pretrained(model_path, use_auth_token=token) tokenizer.pad_token = tokenizer.eos_token model.config.pad_token_id = model.config.eos_token_id model.eval() return model, tokenizer MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop def set_seed(seed, n_gpu): np.random.seed(seed) torch.manual_seed(seed) if n_gpu > 0: torch.cuda.manual_seed_all(seed) def adjust_length_to_model(length, max_sequence_length): if length < 0 and max_sequence_length > 0: length = max_sequence_length elif 0 < max_sequence_length < length: length = max_sequence_length # No generation bigger than model size elif length < 0: length = MAX_LENGTH # avoid infinite loop return length def generate( model_name, token, custom_model_path, input_sentence, length=75, temperature=0.7, top_k=50, top_p=0.95, seed=42, no_cuda=False, num_return_sequences=1, stop_token=".", ): # load device # if not no_cuda: device = torch.device( "cuda" if torch.cuda.is_available() and not no_cuda else "cpu" ) n_gpu = 0 if no_cuda else torch.cuda.device_count() # Set seed set_seed(seed, n_gpu) # Load model model, tokenizer = load_model(model_name, custom_model_path, token) model.to(device) # length = adjust_length_to_model(length, max_sequence_length=model.config.max_position_embeddings) # Tokenize input encoded_prompt = tokenizer.encode( input_sentence, add_special_tokens=False, return_tensors="pt" ) encoded_prompt = encoded_prompt.to(device) input_ids = encoded_prompt # Generate output output_sequences = model.generate( input_ids=input_ids, max_length=length + len(encoded_prompt[0]), temperature=temperature, top_k=top_k, top_p=top_p, do_sample=True, num_return_sequences=num_return_sequences, ) generated_sequences = list() for generated_sequence_idx, generated_sequence in enumerate(output_sequences): generated_sequence = generated_sequence.tolist() text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) # remove prompt text = text[ len( tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True) ) : ] # remove all text after last occurence of stop_token text = text[: text.rfind(stop_token) + 1] generated_sequences.append(text) return generated_sequences[0] def show_mode(mode): if mode == "Single Model": return (gr.update(visible=True), gr.update(visible=False)) if mode == "Multi-Model": return (gr.update(visible=False), gr.update(visible=True)) def prepare_dataset(dataset): dataset = load_dataset(dataset, split="train") return dataset def load_prompts(dataset): prompts = [dataset[i]["prompt"]["text"] for i in range(len(dataset))] return prompts def random_sample(prompt_list): random_sample = sample(prompt_list, 10) return random_sample def show_dataset(dataset): raw_data = prepare_dataset(dataset) prompts = load_prompts(raw_data) return ( gr.update( choices=random_sample(prompts), label="You can find below a random subset from the RealToxicityPrompts dataset", visible=True, ), gr.update(visible=True), prompts, ) def update_dropdown(prompts): return gr.update(choices=random_sample(prompts)) def show_search_bar(value): if value == "Custom Model": return (value, gr.update(visible=True)) else: return (value, gr.update(visible=False)) def search_model(model_name, token): api = HfApi() model_args = ModelSearchArguments() filt = ModelFilter( task=model_args.pipeline_tag.TextGeneration, library=model_args.library.PyTorch ) results = api.list_models(filter=filt, search=model_name, use_auth_token=token) model_list = [model.modelId for model in results] return gr.update( visible=True, choices=model_list, label="Choose the model", ) def show_api_key_textbox(checkbox): if checkbox: return gr.update(visible=True) else: return gr.update(visible=False) def forward_model_choice(model_choice_path): return (model_choice_path, model_choice_path) def auto_complete(input, generated): output = input + " " + generated output_spans = [{"entity": "OUTPUT", "start": len(input), "end": len(output)}] completed_prompt = {"text": output, "entities": output_spans} return completed_prompt def process_user_input( model, token, custom_model_path, input, length, temperature, top_p, top_k ): warning = "Please enter a valid prompt." if input == None: generated = warning else: generated = generate( model_name=model, token=token, custom_model_path=custom_model_path, input_sentence=input, length=length, temperature=temperature, top_p=top_p, top_k=top_k, ) generated = generated.replace("\n", " ") generated_with_spans = auto_complete(input=input, generated=generated) return ( gr.update(value=generated_with_spans), gr.update(visible=True), gr.update(visible=True), input, generated, ) def pass_to_textbox(input): return gr.update(value=input) def run_detoxify(text): results = Detoxify("original").predict(text) json_ready_results = {cat: float(score) for (cat, score) in results.items()} return json_ready_results def compute_toxi_output(output_text): scores = run_detoxify(output_text) return (gr.update(value=scores, visible=True), gr.update(visible=True)) def compute_change(input, output): change_percent = round(((float(output) - input) / input) * 100, 2) return change_percent def compare_toxi_scores(input_text, output_scores): input_scores = run_detoxify(input_text) json_ready_results = {cat: float(score) for (cat, score) in input_scores.items()} compare_scores = { cat: compute_change(json_ready_results[cat], output_scores[cat]) for cat in json_ready_results for cat in output_scores } return ( gr.update(value=json_ready_results, visible=True), gr.update(value=compare_scores, visible=True), ) def show_flag_choices(): return gr.update(visible=True) def update_flag(flag_value): return ( flag_value, gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), ) def upload_flag(*args): flags = list(args) flags[1] = bytes(flags[1], "utf-8") flagging_callback.flag(flags) return gr.update(visible=True) def forward_model_choice_multi(model_choice_path): CHOICES.append(model_choice_path) return gr.update(choices=CHOICES) def process_user_input_multi(models, input, token, length, temperature, top_p, top_k): warning = "Please enter a valid prompt." if input == None: generated = warning else: generated_dict = { model: generate( model_name=model, token=token, custom_model_path=None, input_sentence=input, length=length, temperature=temperature, top_p=top_p, top_k=top_k, ) for model in sorted(models) } generated_with_spans_dict = { model: auto_complete(input, generated) for model, generated in generated_dict.items() } update_outputs = [ gr.HighlightedText.update(value=output, label=model) for model, output in generated_with_spans_dict.items() ] update_hide = [ gr.HighlightedText.update(visible=False) for i in range(10 - len(models)) ] return update_outputs + update_hide def show_choices_multi(models): update_show = [gr.HighlightedText.update(visible=True) for model in sorted(models)] update_hide = [ gr.HighlightedText.update(visible=False, value=None, label=None) for i in range(10 - len(models)) ] return update_show + update_hide def show_params(checkbox): if checkbox == True: return gr.update(visible=True) else: return gr.update(visible=False) CSS = """ #inside_group { padding-top: 0.6em; padding-bottom: 0.6em; } #pw textarea { -webkit-text-security: disc; } """ with gr.Blocks(css=CSS) as demo: dataset = gr.Variable(value=DATASET) prompts_var = gr.Variable(value=None) input_var = gr.Variable(label="Input Prompt", value=None) output_var = gr.Variable(label="Output", value=None) model_choice = gr.Variable(label="Model", value=None) custom_model_path = gr.Variable(value=None) flag_choice = gr.Variable(label="Flag", value=None) flagging_callback = gr.HuggingFaceDatasetSaver( hf_token=HF_AUTH_TOKEN, dataset_name="fsdlredteam/flagged_3", private=True, ) gr.Markdown("

") gr.Markdown("

BuggingSpace

") gr.Markdown( "

FSDL 2022 Red-Teaming Open-Source Models Project

" ) gr.Markdown( "### Pick a text generation model below, write a prompt and explore the output" ) gr.Markdown("### Or compare the output of multiple models at the same time") choose_mode = gr.Radio( choices=["Single Model", "Multi-Model"], value="Single Model", interactive=True, visible=True, show_label=False, ) with gr.Group() as single_model: gr.Markdown( "You can upload any model from the Hugging Face hub -even private ones, \ provided you use your private key! " "Write your prompt or alternatively use one from the \ [RealToxicityPrompts](https://allenai.org/data/real-toxicity-prompts) dataset." ) gr.Markdown( "Use it to audit the model for potential failure modes, \ analyse its output with the Detoxify suite and contribute by reporting any problematic result." ) gr.Markdown( "Beware ! Generation can take up to a few minutes with very large models." ) with gr.Row(): with gr.Column(scale=1): # input & prompts dataset exploration gr.Markdown("### 1. Select a prompt", elem_id="inside_group") input_text = gr.Textbox( label="Write your prompt below.", interactive=True, lines=4, elem_id="inside_group", ) gr.Markdown("— or —", elem_id="inside_group") inspo_button = gr.Button( "Click here if you need some inspiration", elem_id="inside_group" ) prompts_drop = gr.Dropdown(visible=False, elem_id="inside_group") randomize_button = gr.Button( "Show another subset", visible=False, elem_id="inside_group" ) show_params_checkbox_single = gr.Checkbox( label="Set custom params", interactive=True, value=False ) with gr.Box(visible=False) as params_box_single: length_single = gr.Slider( label="Output length", visible=True, interactive=True, minimum=50, maximum=200, value=75, ) top_k_single = gr.Slider( label="top_k", visible=True, interactive=True, minimum=1, maximum=100, value=50, ) top_p_single = gr.Slider( label="top_p", visible=True, interactive=True, minimum=0.1, maximum=1, value=0.95, ) temperature_single = gr.Slider( label="temperature", visible=True, interactive=True, minimum=0.1, maximum=1, value=0.7, ) with gr.Column(scale=1): # Model choice & output gr.Markdown("### 2. Evaluate output") model_radio = gr.Radio( choices=list(CHECKPOINTS.keys()), label="Model", interactive=True, elem_id="inside_group", ) search_bar = gr.Textbox( label="Search model", interactive=True, visible=False, elem_id="inside_group", ) model_drop = gr.Dropdown(visible=False) private_checkbox = gr.Checkbox( visible=True, label="Private Model ?", elem_id="inside_group" ) api_key_textbox = gr.Textbox( label="Enter your AUTH TOKEN below", value=None, interactive=True, visible=False, elem_id="pw", ) generate_button = gr.Button( "Submit your prompt", elem_id="inside_group" ) output_spans = gr.HighlightedText(visible=True, label="Generated text") flag_button = gr.Button( "Report output here", visible=False, elem_id="inside_group" ) with gr.Row(): # Flagging with gr.Column(scale=1): flag_radio = gr.Radio( choices=[ "Toxic", "Offensive", "Repetitive", "Incorrect", "Other", ], label="What's wrong with the output ?", interactive=True, visible=False, elem_id="inside_group", ) user_comment = gr.Textbox( label="(Optional) Briefly describe the issue", visible=False, interactive=True, elem_id="inside_group", ) confirm_flag_button = gr.Button( "Confirm report", visible=False, elem_id="inside_group" ) with gr.Row(): # Flagging success success_message = gr.Markdown( "Your report has been successfully registered. Thank you!", visible=False, elem_id="inside_group", ) with gr.Row(): # Toxicity buttons toxi_button = gr.Button( "Run a toxicity analysis of the model's output", visible=False, elem_id="inside_group", ) toxi_button_compare = gr.Button( "Compare toxicity on input and output", visible=False, elem_id="inside_group", ) with gr.Row(): # Toxicity scores toxi_scores_input = gr.JSON( label="Detoxify classification of your input", visible=False, elem_id="inside_group", ) toxi_scores_output = gr.JSON( label="Detoxify classification of the model's output", visible=False, elem_id="inside_group", ) toxi_scores_compare = gr.JSON( label="Percentage change between Input and Output", visible=False, elem_id="inside_group", ) with gr.Group(visible=False) as multi_model: model_list = list() gr.Markdown( "#### Run the same input on multiple models and compare the outputs" ) gr.Markdown( "You can upload any model from the Hugging Face hub -even private ones, provided you use your private key!" ) gr.Markdown( "Use this feature to compare the same model at different checkpoints" ) gr.Markdown("Or to benchmark your model against another one as a reference.") gr.Markdown( "Beware ! Generation can take up to a few minutes with very large models." ) with gr.Row(elem_id="inside_group"): with gr.Column(): models_multi = gr.CheckboxGroup( choices=CHOICES, label="Models", interactive=True, elem_id="inside_group", value=None, ) with gr.Column(): generate_button_multi = gr.Button( "Submit your prompt", elem_id="inside_group" ) show_params_checkbox_multi = gr.Checkbox( label="Set custom params", interactive=True, value=False ) with gr.Box(visible=False) as params_box_multi: length_multi = gr.Slider( label="Output length", visible=True, interactive=True, minimum=50, maximum=200, value=75, ) top_k_multi = gr.Slider( label="top_k", visible=True, interactive=True, minimum=1, maximum=100, value=50, ) top_p_multi = gr.Slider( label="top_p", visible=True, interactive=True, minimum=0.1, maximum=1, value=0.95, ) temperature_multi = gr.Slider( label="temperature", visible=True, interactive=True, minimum=0.1, maximum=1, value=0.7, ) with gr.Row(elem_id="inside_group"): with gr.Column(elem_id="inside_group", scale=1): input_text_multi = gr.Textbox( label="Write your prompt below.", interactive=True, lines=4, elem_id="inside_group", ) with gr.Column(elem_id="inside_group", scale=1): search_bar_multi = gr.Textbox( label="Search another model", interactive=True, visible=True, elem_id="inside_group", ) model_drop_multi = gr.Dropdown(visible=False, elem_id="inside_group") private_checkbox_multi = gr.Checkbox( visible=True, label="Private Model ?" ) api_key_textbox_multi = gr.Textbox( label="Enter your AUTH TOKEN below", value=None, interactive=True, visible=False, elem_id="pw", ) with gr.Row() as outputs_row: for i in range(10): output_spans_multi = gr.HighlightedText( visible=False, elem_id="inside_group" ) model_list.append(output_spans_multi) with gr.Row(): gr.Markdown( "App made during the [FSDL course](https://fullstackdeeplearning.com) \ by Team53: Jean-Antoine, Sajenthan, Sashank, Kemp, Srihari, Astitwa" ) # Single Model choose_mode.change( fn=show_mode, inputs=choose_mode, outputs=[single_model, multi_model] ) inspo_button.click( fn=show_dataset, inputs=dataset, outputs=[prompts_drop, randomize_button, prompts_var], ) prompts_drop.change(fn=pass_to_textbox, inputs=prompts_drop, outputs=input_text) randomize_button.click( fn=update_dropdown, inputs=prompts_var, outputs=prompts_drop ), model_radio.change( fn=show_search_bar, inputs=model_radio, outputs=[model_choice, search_bar] ) search_bar.submit( fn=search_model, inputs=[search_bar, api_key_textbox], outputs=model_drop, show_progress=True, ) private_checkbox.change( fn=show_api_key_textbox, inputs=private_checkbox, outputs=api_key_textbox ) model_drop.change( fn=forward_model_choice, inputs=model_drop, outputs=[model_choice, custom_model_path], ) generate_button.click( fn=process_user_input, inputs=[ model_choice, api_key_textbox, custom_model_path, input_text, length_single, temperature_single, top_p_single, top_k_single, ], outputs=[output_spans, toxi_button, flag_button, input_var, output_var], show_progress=True, ) toxi_button.click( fn=compute_toxi_output, inputs=output_var, outputs=[toxi_scores_output, toxi_button_compare], show_progress=True, ) toxi_button_compare.click( fn=compare_toxi_scores, inputs=[input_text, toxi_scores_output], outputs=[toxi_scores_input, toxi_scores_compare], show_progress=True, ) flag_button.click(fn=show_flag_choices, inputs=None, outputs=flag_radio) flag_radio.change( fn=update_flag, inputs=flag_radio, outputs=[flag_choice, confirm_flag_button, user_comment, flag_button], ) flagging_callback.setup( [input_var, output_var, model_choice, user_comment, flag_choice], "flagged_data_points", ) confirm_flag_button.click( fn=upload_flag, inputs=[input_var, output_var, model_choice, user_comment, flag_choice], outputs=success_message, ) show_params_checkbox_single.change( fn=show_params, inputs=show_params_checkbox_single, outputs=params_box_single ) # Model comparison search_bar_multi.submit( fn=search_model, inputs=[search_bar_multi, api_key_textbox_multi], outputs=model_drop_multi, show_progress=True, ) show_params_checkbox_multi.change( fn=show_params, inputs=show_params_checkbox_multi, outputs=params_box_multi ) private_checkbox_multi.change( fn=show_api_key_textbox, inputs=private_checkbox_multi, outputs=api_key_textbox_multi, ) model_drop_multi.change( fn=forward_model_choice_multi, inputs=model_drop_multi, outputs=[models_multi] ) models_multi.change(fn=show_choices_multi, inputs=models_multi, outputs=model_list) generate_button_multi.click( fn=process_user_input_multi, inputs=[ models_multi, input_text_multi, api_key_textbox_multi, length_multi, temperature_multi, top_p_multi, top_k_multi, ], outputs=model_list, show_progress=True, ) if __name__ == "__main__": # demo.queue(concurrency_count=3) demo.launch(debug=True)