aldan.creo commited on
Commit
4db55cd
·
1 Parent(s): b448895
Files changed (2) hide show
  1. app.py +46 -22
  2. requirements.txt +0 -1
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import json
2
  import logging
3
  import os
 
4
  from functools import partial
5
 
6
  import gradio as gr
@@ -14,9 +15,11 @@ logger.setLevel(logging.INFO)
14
  load_dotenv()
15
 
16
  # dataset = load_dataset("detection-datasets/coco")
17
- it_dataset = load_dataset(
18
- "imagenet-1k", split="train", streaming=True, trust_remote_code=True
19
- ).shuffle(42)
 
 
20
 
21
 
22
  def gen_from_iterable_dataset(iterable_ds):
@@ -26,6 +29,10 @@ def gen_from_iterable_dataset(iterable_ds):
26
  yield from iterable_ds
27
 
28
 
 
 
 
 
29
  # imagenet_categories_data.json is a JSON file containing a hierarchy of ImageNet categories.
30
  # We want to take all categories under "artifact, artefact".
31
  # Each node has this structure:
@@ -58,14 +65,17 @@ def filter_imgs_by_label(x):
58
  """
59
  Filter out the images that have label -1
60
  """
 
61
  return x["label"] in artifact_categories
62
 
63
 
64
- it_dataset = it_dataset.take(1000).filter(filter_imgs_by_label)
65
- dataset = Dataset.from_generator(
66
- partial(gen_from_iterable_dataset, it_dataset), features=it_dataset.features
67
- )
68
- dataset_iterable = iter(dataset)
 
 
69
 
70
 
71
  def get_user_prompt():
@@ -74,15 +84,11 @@ def get_user_prompt():
74
  machine_labels = []
75
  human_labels = []
76
  for i in range(3):
77
- data = next(dataset_iterable)
78
- logger.info(f"Data: {data}")
79
  images.append(data["image"])
80
  # Get the label as a human readable string
81
  machine_labels.append(data["label"])
82
- logger.info(dataset)
83
- human_label = dataset.features["label"].int2str(data["label"]) + str(
84
- data["label"]
85
- )
86
  human_labels.append(human_label)
87
  return {
88
  "images": images,
@@ -94,7 +100,7 @@ def get_user_prompt():
94
  hf_writer = gr.HuggingFaceDatasetSaver(
95
  hf_token=os.environ["HF_TOKEN"], dataset_name="acmc/maker-faire-bot", private=True
96
  )
97
- csv_writer = gr.CSVLogger(simplify_file_data=True)
98
 
99
  theme = gr.themes.Default(primary_hue="cyan", secondary_hue="fuchsia")
100
 
@@ -137,15 +143,14 @@ with gr.Blocks(theme=theme) as demo:
137
  btn = gr.Button("Change", variant="secondary")
138
 
139
  def change_image(user_prompt):
140
- data = next(dataset_iterable)
141
- logger.info(user_prompt)
142
  user_prompt = user_prompt.copy()
143
  user_prompt["images"][i] = data["image"]
144
  user_prompt["machine_labels"][i] = data["label"]
145
  user_prompt["human_labels"][i] = dataset.features["label"].int2str(
146
  data["label"]
147
  )
148
- logger.info(user_prompt)
149
  return (
150
  user_prompt,
151
  user_prompt["images"][i],
@@ -192,20 +197,39 @@ with gr.Blocks(theme=theme) as demo:
192
  submit_btn = gr.Button("Submit", variant="primary")
193
 
194
  def log_results(prompt, object, explanation):
195
- csv_writer.flag([prompt, object, explanation])
196
- hf_writer.flag([prompt, object, explanation])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  submit_btn.click(
199
  log_results,
200
  inputs=[user_prompt, user_answer_object, user_answer_explanation],
201
- preprocess=False,
202
  )
203
 
204
  new_prompt_btn = gr.Button("New Prompt", variant="secondary")
205
  new_prompt_btn.click(
206
  get_user_prompt,
207
  outputs=[user_prompt],
208
- preprocess=False,
209
  )
210
 
211
  gr.Markdown(
 
1
  import json
2
  import logging
3
  import os
4
+ import random
5
  from functools import partial
6
 
7
  import gradio as gr
 
15
  load_dotenv()
16
 
17
  # dataset = load_dataset("detection-datasets/coco")
18
+ it_dataset = (
19
+ load_dataset("imagenet-1k", split="train", streaming=True, trust_remote_code=True)
20
+ .shuffle(42)
21
+ .take(1000)
22
+ )
23
 
24
 
25
  def gen_from_iterable_dataset(iterable_ds):
 
29
  yield from iterable_ds
30
 
31
 
32
+ dataset = Dataset.from_generator(
33
+ partial(gen_from_iterable_dataset, it_dataset), features=it_dataset.features
34
+ )
35
+
36
  # imagenet_categories_data.json is a JSON file containing a hierarchy of ImageNet categories.
37
  # We want to take all categories under "artifact, artefact".
38
  # Each node has this structure:
 
65
  """
66
  Filter out the images that have label -1
67
  """
68
+ logger.info(f'label: {x["label"]} (present: {x["label"] in artifact_categories})')
69
  return x["label"] in artifact_categories
70
 
71
 
72
+ dataset = dataset.filter(filter_imgs_by_label)
73
+
74
+ logging.basicConfig(level=logging.INFO)
75
+ logger = logging.getLogger(__name__)
76
+ logger.setLevel(logging.INFO)
77
+
78
+ load_dotenv()
79
 
80
 
81
  def get_user_prompt():
 
84
  machine_labels = []
85
  human_labels = []
86
  for i in range(3):
87
+ data = dataset[random.randint(0, len(dataset) - 1)]
 
88
  images.append(data["image"])
89
  # Get the label as a human readable string
90
  machine_labels.append(data["label"])
91
+ human_label = dataset.features["label"].int2str(data["label"])
 
 
 
92
  human_labels.append(human_label)
93
  return {
94
  "images": images,
 
100
  hf_writer = gr.HuggingFaceDatasetSaver(
101
  hf_token=os.environ["HF_TOKEN"], dataset_name="acmc/maker-faire-bot", private=True
102
  )
103
+ csv_writer = gr.CSVLogger()
104
 
105
  theme = gr.themes.Default(primary_hue="cyan", secondary_hue="fuchsia")
106
 
 
143
  btn = gr.Button("Change", variant="secondary")
144
 
145
  def change_image(user_prompt):
146
+ logger.info(f"Current user prompt: {user_prompt}")
147
+ data = dataset[random.randint(0, len(dataset) - 1)]
148
  user_prompt = user_prompt.copy()
149
  user_prompt["images"][i] = data["image"]
150
  user_prompt["machine_labels"][i] = data["label"]
151
  user_prompt["human_labels"][i] = dataset.features["label"].int2str(
152
  data["label"]
153
  )
 
154
  return (
155
  user_prompt,
156
  user_prompt["images"][i],
 
197
  submit_btn = gr.Button("Submit", variant="primary")
198
 
199
  def log_results(prompt, object, explanation):
200
+ logger.info(f"logging - Prompt: {prompt}")
201
+ csv_writer.flag(
202
+ [
203
+ {
204
+ "machine_labels": prompt["machine_labels"],
205
+ "human_labels": prompt["human_labels"],
206
+ },
207
+ object,
208
+ explanation,
209
+ ]
210
+ )
211
+ hf_writer.flag(
212
+ [
213
+ {
214
+ "machine_labels": prompt["machine_labels"],
215
+ "human_labels": prompt["human_labels"],
216
+ },
217
+ object,
218
+ explanation,
219
+ ]
220
+ )
221
 
222
  submit_btn.click(
223
  log_results,
224
  inputs=[user_prompt, user_answer_object, user_answer_explanation],
225
+ preprocess=True,
226
  )
227
 
228
  new_prompt_btn = gr.Button("New Prompt", variant="secondary")
229
  new_prompt_btn.click(
230
  get_user_prompt,
231
  outputs=[user_prompt],
232
+ # preprocess=True,
233
  )
234
 
235
  gr.Markdown(
requirements.txt CHANGED
@@ -1,3 +1,2 @@
1
  datasets==2.19.0
2
- gradio==4.28.0
3
  python-dotenv==1.0.1
 
1
  datasets==2.19.0
 
2
  python-dotenv==1.0.1