kyleledbetter commited on
Commit
7bd4255
·
1 Parent(s): e6d89e2

feat(): gradio gui

Browse files
Files changed (1) hide show
  1. app.py +91 -71
app.py CHANGED
@@ -1,88 +1,108 @@
 
1
  import requests
2
  import json
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
 
4
  import plotly.graph_objects as go
5
  import plotly.express as px
6
  import pandas as pd
7
  from sklearn.metrics import confusion_matrix
8
- from datasets import load_dataset
9
-
10
 
11
  def load_model(endpoint: str):
12
- tokenizer = AutoTokenizer.from_pretrained(endpoint)
13
- model = AutoModelForSequenceClassification.from_pretrained(endpoint)
14
- return tokenizer, model
15
 
16
 
17
  def test_model(tokenizer, model, test_data: list, label_map: dict):
18
- results = []
19
- for text, true_label in test_data:
20
- inputs = tokenizer(text, return_tensors="pt",
21
- truncation=True, padding=True)
22
- outputs = model(**inputs)
23
- pred_label = label_map[int(outputs.logits.argmax(dim=-1))]
24
- results.append((text, true_label, pred_label))
25
- return results
26
-
27
-
28
- def generate_report_card(results, label_map):
29
- true_labels = [r[1] for r in results]
30
- pred_labels = [r[2] for r in results]
31
-
32
- cm = confusion_matrix(true_labels, pred_labels,
33
- labels=list(label_map.values()))
34
-
35
- fig = go.Figure(
36
- data=go.Heatmap(
37
- z=cm,
38
- x=list(label_map.values()),
39
- y=list(label_map.values()),
40
- colorscale='Viridis',
41
- colorbar=dict(title='Number of Samples')
42
- ),
43
- layout=go.Layout(
44
- title='Confusion Matrix',
45
- xaxis=dict(title='Predicted Labels'),
46
- yaxis=dict(title='True Labels', autorange='reversed')
47
- )
48
- )
49
-
50
- fig.show()
51
 
 
 
 
 
52
 
53
- def load_sst2_data(split="test"):
54
- dataset = load_dataset("glue", "sst2", split=split)
55
- data = [(item["sentence"], "positive" if item["label"] == 1 else "negative")
56
- for item in dataset]
57
- return data
58
 
59
-
60
- # Define your model endpoint and label map
61
- # model_endpoint = "your-model-endpoint"
62
-
63
- # Modify this according to your model's labels
64
- # label_map = {0: "label0", 1: "label1"}
65
-
66
- model_endpoint = "distilbert-base-uncased-finetuned-sst-2-english"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  label_map = {0: "negative", 1: "positive"}
68
 
69
- # Load the model and tokenizer
70
- tokenizer, model = load_model(model_endpoint)
71
-
72
- # Prepare your test data (list of tuples containing text and true label)
73
- #test_data = [
74
- # ("Sample text 1", "label0"),
75
- # ("Sample text 2", "label1"),
76
- # # Add more test samples here
77
- #]
78
-
79
- # Load the test data from the SST-2 dataset
80
- test_data = load_sst2_data()
81
- # Use a smaller subset of test_data for a quicker demonstration (optional)
82
- test_data = test_data[:100]
83
-
84
- # Test the model and generate results
85
- results = test_model(tokenizer, model, test_data, label_map)
86
-
87
- # Generate the visual report card
88
- generate_report_card(results, label_map)
 
1
+ import gradio as gr
2
  import requests
3
  import json
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+ from datasets import load_dataset
6
+ import plotly.io as pio
7
  import plotly.graph_objects as go
8
  import plotly.express as px
9
  import pandas as pd
10
  from sklearn.metrics import confusion_matrix
 
 
11
 
12
  def load_model(endpoint: str):
13
+ tokenizer = AutoTokenizer.from_pretrained(endpoint)
14
+ model = AutoModelForSequenceClassification.from_pretrained(endpoint)
15
+ return tokenizer, model
16
 
17
 
18
  def test_model(tokenizer, model, test_data: list, label_map: dict):
19
+ results = []
20
+ for text, true_label in test_data:
21
+ inputs = tokenizer(text, return_tensors="pt",
22
+ truncation=True, padding=True)
23
+ outputs = model(**inputs)
24
+ pred_label = label_map[int(outputs.logits.argmax(dim=-1))]
25
+ results.append((text, true_label, pred_label))
26
+ return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ def generate_label_map(dataset):
29
+ num_labels = len(dataset.features["label"].names)
30
+ label_map = {i: label for i, label in enumerate(dataset.features["label"].names)}
31
+ return label_map
32
 
 
 
 
 
 
33
 
34
+ def generate_report_card(results, label_map):
35
+ true_labels = [r[1] for r in results]
36
+ pred_labels = [r[2] for r in results]
37
+
38
+ cm = confusion_matrix(true_labels, pred_labels,
39
+ labels=list(label_map.values()))
40
+
41
+ fig = go.Figure(
42
+ data=go.Heatmap(
43
+ z=cm,
44
+ x=list(label_map.values()),
45
+ y=list(label_map.values()),
46
+ colorscale='Viridis',
47
+ colorbar=dict(title='Number of Samples')
48
+ ),
49
+ layout=go.Layout(
50
+ title='Confusion Matrix',
51
+ xaxis=dict(title='Predicted Labels'),
52
+ yaxis=dict(title='True Labels', autorange='reversed')
53
+ )
54
+ )
55
+
56
+ fig.update_layout(height=600, width=800)
57
+
58
+ # return fig in new window
59
+ # fig.show() # uncomment this line to show the plot in a new window
60
+
61
+ # Convert the Plotly figure to an HTML string < i was trying this bc i couldn't get Plot() to work before
62
+ # plot_html = pio.to_html(fig, full_html=True, include_plotlyjs=True, config={
63
+ # "displayModeBar": False, "responsive": True})
64
+ #return plot_html
65
+ return fig
66
+
67
+ def app(model_endpoint: str, dataset_name: str, config_name: str, dataset_split: str, num_samples: int):
68
+ tokenizer, model = load_model(model_endpoint)
69
+
70
+ # Load the dataset
71
+ num_samples = int(num_samples) # Add this line to cast num_samples to an integer
72
+ dataset = load_dataset(
73
+ dataset_name, config_name, split=f"{dataset_split}[:{num_samples}]")
74
+ test_data = [(item["sentence"], dataset.features["label"].names[item["label"]])
75
+ for item in dataset]
76
+
77
+ label_map = generate_label_map(dataset)
78
+
79
+ results = test_model(tokenizer, model, test_data, label_map)
80
+ report_card = generate_report_card(results, label_map)
81
+
82
+ return report_card
83
+
84
+ interface = gr.Interface(
85
+ fn=app,
86
+ inputs=[
87
+ gr.inputs.Textbox(lines=1, label="Model Endpoint",
88
+ placeholder="ex: distilbert-base-uncased-finetuned-sst-2-english"),
89
+ gr.inputs.Textbox(lines=1, label="Dataset Name",
90
+ placeholder="ex: glue"),
91
+ gr.inputs.Textbox(lines=1, label="Config Name",
92
+ placeholder="ex: sst2"),
93
+ gr.inputs.Dropdown(
94
+ choices=["train", "validation", "test"], label="Dataset Split"),
95
+ gr.inputs.Number(default=100, label="Number of Samples"),
96
+ ],
97
+ # outputs=gr.outputs.Plotly(),
98
+ # outputs=gr.outputs.HTML(),
99
+ outputs=gr.Plot(),
100
+ title="Fairness and Bias Testing",
101
+ description="Enter a model endpoint and dataset to test for fairness and bias.",
102
+ )
103
+
104
+ # Define the label map globally
105
  label_map = {0: "negative", 1: "positive"}
106
 
107
+ if __name__ == "__main__":
108
+ interface.launch()