import gradio as gr
from datasets import load_dataset
import matplotlib as mpl
mpl.use('Agg')
from typing import List
import matplotlib.pyplot as plt
import numpy as np
import joblib
import itertools
import pandas as pd
cached_artifacts = joblib.load("cached_data.pkl")
laion = load_dataset("society-ethics/laion2B-en_continents", split="train").to_pandas()
medmcqa = load_dataset("society-ethics/medmcqa_age_gender_custom", split="train").to_pandas()
stack = load_dataset("society-ethics/the-stack-tabs_spaces", split="train").to_pandas()\
.drop(columns=["max_stars_repo_licenses", "max_issues_repo_licenses", "max_forks_repo_licenses"])
cached_artifacts["laion"]["text"] = {
"title": "Disaggregating by continent with a built-in module",
"description": """
The [`laion/laion2b-en` dataset](https://huggingface.co/datasets/laion/laion2B-en), created by [LAION](https://laion.ai), is used to train image generation models such as [Stable Diffusion](https://huggingface.co/spaces/stabilityai/stable-diffusion). The dataset contains pairs of images and captions, but we might also be curious about the distribution of specific topics, such as continents, mentioned in the captions.
The original dataset doesn't contain metadata about specific continents, but we can attempt to infer it from the `TEXT` feature with `disaggregators`. Note that several factors contribute to a high incidence of false positives, such as the fact that country and city names are frequently used as names for fashion products.
""",
"visualization": """
This view shows you a visualization of the relative proportion of each label in the disaggregated dataset. For this dataset, we've only disaggregated by one category (continent), but there are many possible values for it. While there are many rows that haven't been flagged with a continent (check "None" and see!), this disaggregator doesn't assign *Multiple* continents.
To see examples of individual rows, click over to the "Inspect" tab!
""",
"code": """
```python
from disaggregators import Disaggregator
disaggregator = Disaggregator("continent", column="TEXT")
# Note: this demo used a subset of the dataset
from datasets import load_dataset
ds = load_dataset("laion/laion2B-en", split="train", streaming=True).map(disaggregator)
```
"""
}
cached_artifacts["medmcqa"]["text"] = {
"title": "Overriding configurations for built-in modules",
"description": """
Meta's [Galactica model](https://galactica.org) is trained on a large-scale scientific corpus, which includes the [`medmcqa` dataset](https://huggingface.co/datasets/medmcqa) of medical entrance exam questions. MedMCQA has a `question` feature which often contains a case scenario, where a hypothetical patient presents with a condition.
The original dataset doesn't contain metadata about the age and binary gender, but we can infer them with the `age` and `gender` modules. If a module doesn't have the particular label options that you'd like, such as additional genders or specific age buckets, you can override the module's configuration. In this example we've configured the `age` module to use [NIH's MeSH age groups](https://www.ncbi.nlm.nih.gov/mesh/68009273).
""",
"visualization": """
Since we've disaggregated the MedMCQA dataset by *two* categories (age and binary gender), we can click on "Age + Gender" to visualize the proportions of the *intersections* of each group.
There are two things to note about this example:
1. The disaggregators for age and gender can flag rows as having more than one age or gender, which we've grouped here as "Multiple"
2. If you look at the data through the "Inspect" tab, you'll notice that there are some false positives. `disaggregators` is in early development, and these modules are in a very early "proof of concept" stage! Keep an eye out as we develop more sophisticated algorithms for disaggregation, and [join us over on GitHub](https://github.com/huggingface/disaggregators) if you'd like to contribute ideas, documentation, or code.
""",
"code": """
```python
from disaggregators import Disaggregator
from disaggregators.disaggregation_modules.age import Age, AgeLabels, AgeConfig
class MeSHAgeLabels(AgeLabels):
INFANT = "infant"
CHILD_PRESCHOOL = "child_preschool"
CHILD = "child"
ADOLESCENT = "adolescent"
ADULT = "adult"
MIDDLE_AGED = "middle_aged"
AGED = "aged"
AGED_80_OVER = "aged_80_over"
age_config = AgeConfig(
labels=MeSHAgeLabels,
ages=[list(MeSHAgeLabels)],
breakpoints=[0, 2, 5, 12, 18, 44, 64, 79]
)
age = Age(config=age_config, column="question")
disaggregator = Disaggregator([age, "gender"], column="question")
from datasets import load_dataset
ds = load_dataset("medmcqa", split="train").map(disaggregator)
```
"""
}
cached_artifacts["stack"]["text"] = {
"title": "Creating custom disaggregators",
"description": """
[The BigCode Project](https://www.bigcode-project.org/) recently released [`bigcode/the-stack`](https://huggingface.co/datasets/bigcode/the-stack), which contains contains over 6TB of permissively-licensed source code files covering 358 programming languages. One of the languages included is [JSX](https://reactjs.org/docs/introducing-jsx.html), which is an extension to JavaScript specifically designed for the [React UI library](https://reactjs.org/docs/introducing-jsx.html). Let's ask some questions about the React code in this dataset!
1. React lets developers define UI components [as functions or as classes](https://reactjs.org/docs/components-and-props.html#function-and-class-components). Which style is more popular in this dataset?
2. Programmers have long argued over using [tabs or spaces](https://www.youtube.com/watch?v=SsoOG6ZeyUI). Who's winning?
`disaggregators` makes it easy to add your own disaggregation modules. See the code snippet below for an example 🤗
""",
"visualization": """
Like the MedMCQA example, this dataset has also been disaggregated by more than one category. Using multiple disaggregation modules lets us get insights into interesting *intersections* of the subpopulations in our datasets.
""",
"code": """
```python
from disaggregators import Disaggregator, DisaggregationModuleLabels, CustomDisaggregator
class TabsSpacesLabels(DisaggregationModuleLabels):
TABS = "tabs"
SPACES = "spaces"
class TabsSpaces(CustomDisaggregator):
module_id = "tabs_spaces"
labels = TabsSpacesLabels
def __call__(self, row, *args, **kwargs):
if "\\t" in row[self.column]:
return {self.labels.TABS: True, self.labels.SPACES: False}
else:
return {self.labels.TABS: False, self.labels.SPACES: True}
class ReactComponentLabels(DisaggregationModuleLabels):
CLASS = "class"
FUNCTION = "function"
class ReactComponent(CustomDisaggregator):
module_id = "react_component"
labels = ReactComponentLabels
def __call__(self, row, *args, **kwargs):
if "extends React.Component" in row[self.column] or "extends Component" in row[self.column]:
return {self.labels.CLASS: True, self.labels.FUNCTION: False}
else:
return {self.labels.CLASS: False, self.labels.FUNCTION: True}
disaggregator = Disaggregator([TabsSpaces, ReactComponent], column="content")
# Note: this demo used a subset of the dataset
from datasets import load_dataset
ds = load_dataset("bigcode/the-stack", data_dir="data/jsx", split="train", streaming=True).map(disaggregator)
```
"""
}
def create_plot(selected_fields, available_fields, distributions, feature_names, plot=None):
plt.close('all')
clean_fields = [field for field in selected_fields if field not in ["Multiple", "None"]]
extra_options = [field for field in selected_fields if field in ["Multiple", "None"]]
distributions = distributions.reorder_levels(
sorted(list(available_fields)) + [idx for idx in distributions.index.names if idx not in available_fields]
)
distributions = distributions.sort_index()
def get_tuple(field):
return tuple(True if field == x else False for x in sorted(available_fields))
masks = [get_tuple(field) for field in sorted(clean_fields)]
data = [distributions.get(mask, 0) for mask in masks]
data = [x.sum() if type(x) != int else x for x in data]
if "Multiple" in extra_options:
masks_mult = [el for el in itertools.product((True, False), repeat=len(available_fields)) if el.count(True) > 1]
data = data + [sum([distributions.get(mask, pd.Series(dtype=float)).sum() for mask in masks_mult])]
if "None" in extra_options:
none_mask = tuple(False for x in available_fields)
data = data + [distributions.get(none_mask, pd.Series(dtype=float)).sum()]
fig, ax = plt.subplots()
title = "Distribution "
size = 0.3
cmap = plt.colormaps["Set3"]
outer_colors = cmap(np.arange(len(data)))
total_sum = sum(data)
all_fields = sorted(clean_fields) + sorted(extra_options)
labels = [f"{feature_names.get(c, c)}\n{round(data[i] / total_sum * 100, 2)}%" for i, c in enumerate(all_fields)]
ax.pie(data, radius=1, labels=labels, colors=outer_colors,
wedgeprops=dict(width=size, edgecolor='w'))
ax.set(aspect="equal", title=title)
if plot is None:
return gr.Plot(plt)
else:
new_plot = plot.update(plt)
return new_plot
# TODO: Consolidate with the other plot function...
def create_nested_plot(selected_outer, available_outer, selected_inner, available_inner, distributions, feature_names, plot=None):
plt.close('all')
clean_outer = [field for field in selected_outer if field not in ["Multiple", "None"]]
extra_outer = [field for field in selected_outer if field in ["Multiple", "None"]]
clean_inner = [field for field in selected_inner if field not in ["Multiple", "None"]]
extra_inner = [field for field in selected_inner if field in ["Multiple", "None"]]
distributions = distributions.reorder_levels(
sorted(list(available_outer)) + sorted(list(available_inner)) + sorted([idx for idx in distributions.index.names if idx not in (available_outer + available_inner)])
)
distributions = distributions.sort_index()
def get_tuple(field, field_options):
return tuple(True if field == x else False for x in sorted(field_options))
masks_outer = [get_tuple(field, available_outer) for field in sorted(clean_outer)]
masks_inner = [get_tuple(field, available_inner) for field in sorted(clean_inner)]
data_inner = [[distributions.get(m_o + mask, 0) for mask in masks_inner] for m_o in masks_outer]
masks_mult_inner = []
masks_none_inner = []
if "Multiple" in extra_inner:
masks_mult_inner = [el for el in itertools.product((True, False), repeat=len(available_inner)) if el.count(True) > 1]
masks_mult = [m_o + m_i for m_i in masks_mult_inner for m_o in masks_outer]
mult_inner_count = [distributions.get(mask, pd.Series(dtype=float)).sum() for mask in masks_mult]
data_inner = [di + [mult_inner_count[idx]] for idx, di in enumerate(data_inner)]
if "None" in extra_inner:
masks_none_inner = tuple(False for x in available_inner)
masks_none = [m_o + masks_none_inner for m_o in masks_outer]
none_inner_count = [distributions.get(mask, pd.Series(dtype=float)).sum() for mask in masks_none]
data_inner = [di + [none_inner_count[idx]] for idx, di in enumerate(data_inner)]
if len(available_inner) > 0:
masks_none_inner = [masks_none_inner]
if "Multiple" in extra_outer:
masks_mult = [el for el in itertools.product((True, False), repeat=len(available_outer)) if el.count(True) > 1]
data_inner = data_inner + [[
sum([distributions.get(mask + mask_inner, pd.Series(dtype=float)).sum() for mask in masks_mult])
for mask_inner in (masks_inner + masks_mult_inner + masks_none_inner)
]]
if "None" in extra_outer:
none_mask_outer = tuple(False for x in available_outer)
data_inner = data_inner + [[distributions.get(none_mask_outer + mask, pd.Series(dtype=float)).sum() for mask in (masks_inner + masks_mult_inner + masks_none_inner)]]
fig, ax = plt.subplots()
title = "Distribution "
size = 0.3
cmap = plt.colormaps["Set3"]
cmap2 = plt.colormaps["Set2"]
outer_colors = cmap(np.arange(len(data_inner)))
inner_colors = cmap2(np.arange(len(data_inner[0])))
total_sum = sum(sum(data_inner, []))
data_outer = [sum(x) for x in data_inner]
all_fields_outer = sorted(clean_outer) + sorted(extra_outer)
clean_labels_outer = [f"{feature_names.get(c, c)}\n{round(data_outer[i] / total_sum * 100, 2)}%" for i, c in enumerate(all_fields_outer)]
clean_labels_inner = [feature_names[c] for c in sorted(clean_inner)]
ax.pie(data_outer, radius=1, labels=clean_labels_outer, colors=outer_colors,
wedgeprops=dict(width=size, edgecolor='w'))
patches, _ = ax.pie(list(itertools.chain(*data_inner)), radius=1 - size, colors=inner_colors,
wedgeprops=dict(width=size, edgecolor='w'))
ax.set(aspect="equal", title=title)
fig.legend(handles=patches, labels=clean_labels_inner + sorted(extra_inner), loc="lower left")
if plot is None:
return gr.Plot(plt)
else:
new_plot = plot.update(plt)
return new_plot
def select_new_base_plot(plot, disagg_check, disagg_by, artifacts):
if disagg_by == "Both":
disaggs = sorted(list(artifacts["disaggregators"]))
all_choices = sorted([[x for x in artifacts["data_fields"] if x.startswith(d)] for d in disaggs], key=len, reverse=True)
selected_choices = list(artifacts["data_fields"])
choices = selected_choices + [f"{disagg}.{extra}" for disagg in disaggs for extra in ["Multiple", "None"]]
# Map feature names to labels
choices = [artifacts["feature_names"].get(x, x) for x in choices]
selected_choices = [artifacts["feature_names"].get(x, x) for x in selected_choices]
# Choose new options
new_check = disagg_check.update(choices=sorted(choices), value=selected_choices)
# Generate plot
new_plot = create_nested_plot(
all_choices[0], all_choices[0],
all_choices[1], all_choices[1],
artifacts["distributions"],
artifacts["feature_names"],
plot=plot
)
return new_plot, new_check
else:
selected_choices = [field for field in artifacts["data_fields"] if field.startswith(disagg_by)]
choices = selected_choices + ["Multiple", "None"]
# Map feature names to labels
choices_for_check = [artifacts["feature_names"].get(x, x) for x in choices]
selected_choices_for_check = [artifacts["feature_names"].get(x, x) for x in selected_choices]
# Choose new options
new_check = disagg_check.update(choices=choices_for_check, value=selected_choices_for_check)
# Generate plot
new_plot = create_plot(
sorted(selected_choices), sorted(selected_choices), artifacts["distributions"], artifacts["feature_names"],
plot=plot
)
return new_plot, new_check
def select_new_sub_plot(plot, disagg_check, disagg_by, artifacts):
if disagg_by == "Both":
disaggs = sorted(list(artifacts["disaggregators"]))
all_choices = sorted([[x for x in artifacts["data_fields"] if x.startswith(d)] for d in disaggs], key=len, reverse=True)
choice1 = all_choices[0][0].split(".")[0]
choice2 = all_choices[1][0].split(".")[0]
check1 = [dc for dc in disagg_check if dc.startswith(choice1)]
check2 = [dc for dc in disagg_check if dc.startswith(choice2)]
check1 = ["Multiple" if c == f"{c.split('.')[0]}.Multiple" else c for c in check1]
check1 = ["None" if c == f"{c.split('.')[0]}.None" else c for c in check1]
check2 = ["Multiple" if c == f"{c.split('.')[0]}.Multiple" else c for c in check2]
check2 = ["None" if c == f"{c.split('.')[0]}.None" else c for c in check2]
new_plot = create_nested_plot(
check1, all_choices[0],
check2, all_choices[1],
artifacts["distributions"],
artifacts["feature_names"],
plot=plot
)
return new_plot
else:
selected_choices = [field for field in artifacts["data_fields"] if field.startswith(disagg_by)]
# Generate plot
new_plot = create_plot(
disagg_check, selected_choices, artifacts["distributions"], artifacts["feature_names"],
plot=plot
)
return new_plot
def visualization_filter(plot, artifacts, default_value, intersect=False):
def map_labels_to_fields(labels: List[str]):
return [list(artifacts["feature_names"].keys())[list(artifacts["feature_names"].values()).index(x)] if not any([extra in x for extra in ["Multiple", "None"]]) else x for x in labels]
def map_category_to_disaggregator(category: str): # e.g. Gender, Age, Gender + Age -> gender, age, Both
return list(artifacts["feature_names"].keys())[list(artifacts["feature_names"].values()).index(category)]
choices = sorted(list(artifacts["disaggregators"]))
if intersect:
choices = choices + ["Both"]
# Map categories to nice names
choices = [artifacts["feature_names"][c] for c in choices]
disagg_radio = gr.Radio(
label="Disaggregate by...",
choices=choices,
value=artifacts["feature_names"][default_value],
interactive=True
)
selected_choices = [field for field in artifacts["data_fields"] if field.startswith(f"{default_value}.")]
choices = selected_choices + ["Multiple", "None"]
# Map feature names to labels
choices = [artifacts["feature_names"].get(x, x) for x in choices]
selected_choices = [artifacts["feature_names"].get(x, x) for x in selected_choices]
disagg_check = gr.CheckboxGroup(
label="Features",
choices=choices,
interactive=True,
value=selected_choices,
)
disagg_radio.change(
lambda x: select_new_base_plot(plot, disagg_check, map_category_to_disaggregator(x), artifacts),
inputs=[disagg_radio],
outputs=[plot, disagg_check]
)
disagg_check.change(
lambda x, y: select_new_sub_plot(plot, map_labels_to_fields(x), map_category_to_disaggregator(y), artifacts),
inputs=[disagg_check, disagg_radio],
outputs=[plot]
)
def generate_components(dataset, artifacts, intersect=True):
gr.Markdown(f"### {artifacts['text']['title']}")
gr.Markdown(artifacts['text']['description'])
with gr.Accordion(label="💻 Click me to see the code!", open=False):
gr.Markdown(artifacts["text"]["code"])
with gr.Tab("Visualize"):
with gr.Row(elem_id="visualization-window"):
with gr.Column():
disagg_by = sorted(list(artifacts["disaggregators"]))[0]
selected_choices = [field for field in artifacts["data_fields"] if field.startswith(disagg_by)]
plot = create_plot(
sorted(selected_choices),
sorted(selected_choices),
artifacts["distributions"],
artifacts["feature_names"]
)
with gr.Column():
gr.Markdown("### Visualization")
gr.Markdown(artifacts["text"]["visualization"])
visualization_filter(plot, artifacts, disagg_by, intersect=intersect)
with gr.Tab("Inspect"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Data Inspector")
gr.Markdown("This tab lets you filter the disaggregated dataset and inspect individual elements. Set as many filters as you like, and then click \"Apply filters\" to fetch a random subset of rows that match *all* of the filters you've selected.")
filter_groups = gr.CheckboxGroup(choices=sorted(list(artifacts["data_fields"])), label="Filters")
fetch_subset = gr.Button("Apply filters")
sample_dataframe = gr.State(value=dataset.sample(10))
def fetch_new_samples(filters):
if len(filters) == 0:
new_dataset = dataset.sample(10)
else:
filter_query = " & ".join([f"`{f}`" for f in filters])
new_dataset = dataset.query(filter_query)
if new_dataset.shape[0] > 0:
new_dataset = new_dataset.sample(10)
new_samples = [[
x[1][artifacts["column"]],
", ".join([col for col in artifacts["data_fields"] if x[1][col]]),
] for x in new_dataset.iterrows()]
return sample_rows.update(samples=new_samples), new_dataset
sample_rows = gr.Dataset(
samples=[[
x[1][artifacts["column"]],
", ".join([col for col in artifacts["data_fields"] if x[1][col]]),
] for x in sample_dataframe.value.iterrows()],
components=[gr.Textbox(visible=False), gr.Textbox(visible=False)],
type="index"
)
with gr.Column(scale=1):
row_inspector = gr.DataFrame(
wrap=True,
visible=False
)
fetch_subset.click(
fetch_new_samples,
inputs=[filter_groups],
outputs=[sample_rows, sample_dataframe],
)
sample_rows.click(
lambda df, index: row_inspector.update(visible=True, value=df.iloc[index].reset_index()),
inputs=[sample_dataframe, sample_rows],
outputs=[row_inspector]
)
with gr.Blocks(css="#visualization-window {flex-direction: row-reverse;}") as demo:
gr.Markdown("# Exploring Disaggregated Data with 🤗 Disaggregators")
with gr.Accordion("About this demo 👀"):
gr.Markdown("## What's in your dataset?")
gr.Markdown("""
Addressing fairness and bias in machine learning models is [more important than ever](https://www.vice.com/en/article/bvm35w/this-tool-lets-anyone-see-the-bias-in-ai-image-generators)!
One form of fairness is equal performance across different groups or features.
To measure this, evaluation datasets must be disaggregated across the different groups of interest.
""")
gr.Markdown("The `disaggregators` library ([GitHub](https://github.com/huggingface/disaggregators)) provides an interface and a collection of modules to help you disaggregate datasets by different groups. Click through each of the tabs below to see it in action!")
gr.Markdown("""
After tinkering with the demo, you can install 🤗 Disaggregators with:
```bash
pip install disaggregators
```
""")
gr.Markdown("Each tab below will show you a feature of `disaggregators` used on a different dataset. First, you'll learn about using the built-in disaggregation modules. The second tab will show you how to override the configurations for the existing modules. Finally, the third tab will show you how to incorporate your own custom modules.")
with gr.Tab("🐊 LAION: Built-in Modules Example"):
generate_components(laion, cached_artifacts["laion"], intersect=False)
with gr.Tab("🔧 MedMCQA: Configuration Example"):
generate_components(medmcqa, cached_artifacts["medmcqa"])
with gr.Tab("🎡 The Stack: Custom Disaggregation Example"):
generate_components(stack, cached_artifacts["stack"])
with gr.Accordion(label="💡How is this calculated?", open=False):
gr.Markdown("""
## Continent
Continents are inferred by identifying geographic terms and their related countries using [geograpy3](https://github.com/somnathrakshit/geograpy3). The results are then mapped to [their respective continents](https://github.com/bigscience-workshop/data_sourcing/blob/master/sourcing_sprint/resources/country_regions.json).
## Age
Ages are inferred by using [spaCy](https://spacy.io) to seek "date" tokens in strings.
## Gender
Binary gender is inferred by checking for words against the [md_gender_bias](https://huggingface.co/datasets/md_gender_bias) dataset.
```
@inproceedings{dinan-etal-2020-multi,
title = "Multi-Dimensional Gender Bias Classification",
author = "Dinan, Emily and
Fan, Angela and
Wu, Ledell and
Weston, Jason and
Kiela, Douwe and
Williams, Adina",
year = "2020",
publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/2020.emnlp-main.23",
doi = "10.18653/v1/2020.emnlp-main.23",
```
## Learn more!
Visit the [GitHub repository](https://github.com/huggingface/disaggregators) to learn about using the `disaggregators` library and to leave feedback 🤗
""")
demo.launch()