kevinconka commited on
Commit
246a775
·
1 Parent(s): a378000

Refactor flagged image counting logic

Browse files
Files changed (2) hide show
  1. app.py +5 -9
  2. utils.py +50 -31
app.py CHANGED
@@ -6,7 +6,7 @@ from utils import (
6
  load_image_from_url,
7
  inference,
8
  load_badges,
9
- count_flagged_images_from_csv,
10
  )
11
  from flagging import myHuggingFaceDatasetSaver
12
 
@@ -42,16 +42,12 @@ model.agnostic = True # NMS class-agnostic
42
  # Flagging
43
  dataset_name = "SEA-AI/crowdsourced-sea-images"
44
  hf_writer = myHuggingFaceDatasetSaver(get_token(), dataset_name)
45
-
46
-
47
- def get_flagged_count():
48
- """Count flagged images in dataset."""
49
- return count_flagged_images_from_csv(dataset_name)
50
 
51
 
52
  theme = gr.themes.Default(primary_hue=gr.themes.colors.indigo)
53
  with gr.Blocks(theme=theme, css=css) as demo:
54
- badges = gr.HTML(load_badges(get_flagged_count()))
55
  title = gr.HTML(TITLE)
56
 
57
  with gr.Row():
@@ -115,11 +111,11 @@ with gr.Blocks(theme=theme, css=css) as demo:
115
  preprocess=False,
116
  show_api=False,
117
  ).then(
118
- lambda: load_badges(get_flagged_count()), [], badges, show_api=False
119
  )
120
 
121
  # called during initial load in browser
122
- demo.load(lambda: load_badges(get_flagged_count()), [], badges, show_api=False)
123
 
124
  if __name__ == "__main__":
125
  demo.queue().launch() # show_api=False)
 
6
  load_image_from_url,
7
  inference,
8
  load_badges,
9
+ FlaggedCounter,
10
  )
11
  from flagging import myHuggingFaceDatasetSaver
12
 
 
42
  # Flagging
43
  dataset_name = "SEA-AI/crowdsourced-sea-images"
44
  hf_writer = myHuggingFaceDatasetSaver(get_token(), dataset_name)
45
+ flagged_counter = FlaggedCounter(dataset_name)
 
 
 
 
46
 
47
 
48
  theme = gr.themes.Default(primary_hue=gr.themes.colors.indigo)
49
  with gr.Blocks(theme=theme, css=css) as demo:
50
+ badges = gr.HTML(load_badges(flagged_counter.count()))
51
  title = gr.HTML(TITLE)
52
 
53
  with gr.Row():
 
111
  preprocess=False,
112
  show_api=False,
113
  ).then(
114
+ lambda: load_badges(flagged_counter.count()), [], badges, show_api=False
115
  )
116
 
117
  # called during initial load in browser
118
+ demo.load(lambda: load_badges(flagged_counter.count()), [], badges, show_api=False)
119
 
120
  if __name__ == "__main__":
121
  demo.queue().launch() # show_api=False)
utils.py CHANGED
@@ -1,6 +1,7 @@
1
  import time
2
  import requests
3
  from io import BytesIO
 
4
  import numpy as np
5
  import pandas as pd
6
  from PIL import Image
@@ -40,37 +41,6 @@ def inference(model, image):
40
  return annotator.im
41
 
42
 
43
- def count_flagged_images_via_api(dataset_name, trials=10):
44
- """Count flagged images via API. Might be slow."""
45
-
46
- headers = {"Authorization": f"Bearer {get_token()}"}
47
- API_URL = f"https://datasets-server.huggingface.co/size?dataset={dataset_name}"
48
-
49
- def query():
50
- response = requests.get(API_URL, headers=headers, timeout=5)
51
- return response.json()
52
-
53
- for i in range(trials):
54
- try:
55
- data = query()
56
- if "error" not in data and data["size"]["dataset"]["num_rows"] > 0:
57
- print(f"[{i+1}/{trials}] {data}")
58
- return data["size"]["dataset"]["num_rows"]
59
- except Exception:
60
- pass
61
- print(f"[{i+1}/{trials}] {data}")
62
- time.sleep(5)
63
-
64
- return 0
65
-
66
-
67
- def count_flagged_images_from_csv(dataset_name):
68
- """Count flagged images from CSV. Fast but relies on local files."""
69
- dataset_name = dataset_name.split("/")[-1]
70
- df = pd.read_csv(f"./flagged/{dataset_name}/data.csv")
71
- return len(df)
72
-
73
-
74
  def load_badges(n):
75
  """Load badges."""
76
  return f"""
@@ -80,3 +50,52 @@ def load_badges(n):
80
  <img alt="" src="https://img.shields.io/badge/%F0%9F%96%BC%EF%B8%8F-{n}-green">
81
  </p>
82
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import time
2
  import requests
3
  from io import BytesIO
4
+ from dataclasses import dataclass
5
  import numpy as np
6
  import pandas as pd
7
  from PIL import Image
 
41
  return annotator.im
42
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def load_badges(n):
45
  """Load badges."""
46
  return f"""
 
50
  <img alt="" src="https://img.shields.io/badge/%F0%9F%96%BC%EF%B8%8F-{n}-green">
51
  </p>
52
  """
53
+
54
+
55
+ @dataclass
56
+ class FlaggedCounter:
57
+ """Count flagged images in dataset."""
58
+
59
+ dataset_name: str
60
+ headers: dict = None
61
+
62
+ def __post_init__(self):
63
+ self.API_URL = (
64
+ f"https://datasets-server.huggingface.co/size?dataset={self.dataset_name}"
65
+ )
66
+ self.trials = 10
67
+ if self.headers is None:
68
+ self.headers = {"Authorization": f"Bearer {get_token()}"}
69
+
70
+ def query(self):
71
+ """Query API."""
72
+ response = requests.get(self.API_URL, headers=self.headers, timeout=5)
73
+ return response.json()
74
+
75
+ def from_query(self, data):
76
+ """Count flagged images via API. Might be slow."""
77
+ for i in range(self.trials):
78
+ try:
79
+ data = self.query()
80
+ if "error" not in data and data["size"]["dataset"]["num_rows"] > 0:
81
+ print(f"[{i+1}/{self.trials}] {data}")
82
+ return data["size"]["dataset"]["num_rows"]
83
+ except Exception:
84
+ pass
85
+ print(f"[{i+1}/{self.trials}] {data}")
86
+ time.sleep(5)
87
+
88
+ return 0
89
+
90
+ def from_csv(self):
91
+ """Count flagged images from CSV. Fast but relies on local files."""
92
+ dataset_name = self.dataset_name.split("/")[-1]
93
+ df = pd.read_csv(f"./flagged/{dataset_name}/data.csv")
94
+ return len(df)
95
+
96
+ def count(self):
97
+ """Count flagged images."""
98
+ try:
99
+ return self.from_csv()
100
+ except FileNotFoundError:
101
+ return self.from_query(self.query())