File size: 2,652 Bytes
a378000
e921d65
 
4daa30f
246a775
a378000
e921d65
 
955daea
e921d65
 
0f0204b
 
 
 
 
 
e921d65
a378000
e921d65
 
 
 
 
 
 
 
 
 
a378000
 
4daa30f
a49fee8
4daa30f
 
955daea
 
4daa30f
955daea
 
246a775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efae727
246a775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import time
import requests
from io import BytesIO
from urllib.parse import quote
from dataclasses import dataclass
import pandas as pd
from PIL import Image
import gradio as gr
from huggingface_hub import get_token


def check_image(image):
    """Check image."""
    if image is None:
        raise gr.Error("Oops! It looks like you forgot to upload an image.")


def load_image_from_url(url):
    """Load image from URL."""
    if not url:  # empty or None
        return gr.Image(interactive=True)
    try:
        response = requests.get(url, timeout=5)
        image = Image.open(BytesIO(response.content))
    except Exception as e:
        raise gr.Error("Unable to load image from URL") from e
    return image.convert("RGB")


def load_badges(n):
    """Load badges."""
    badges = [
        "https://img.shields.io/badge/version-beta-blue",
        f"https://img.shields.io/badge/{quote('🖼️')}{quote('🚩')}-{n}-green",
    ]
    return f"""
        <p style="display: flex">
        {"&nbsp".join([f'<img alt="" src="{badge}">' for badge in badges])}
        </p>
        """


@dataclass
class FlaggedCounter:
    """Count flagged images in dataset."""

    dataset_name: str
    headers: dict = None

    def __post_init__(self):
        self.API_URL = (
            f"https://datasets-server.huggingface.co/size?dataset={self.dataset_name}"
        )
        self.trials = 10
        if self.headers is None:
            self.headers = {"Authorization": f"Bearer {get_token()}"}

    def query(self):
        """Query API."""
        response = requests.get(self.API_URL, headers=self.headers, timeout=5)
        return response.json()

    def from_query(self, data):
        """Count flagged images via API. Might be slow."""
        for i in range(self.trials):
            try:
                data = self.query()
                if "error" not in data and data["size"]["dataset"]["num_rows"] > 0:
                    print(f"[{i+1}/{self.trials}] {data}")
                    return data["size"]["dataset"]["num_rows"]
            except requests.exceptions.RequestException:
                pass
            print(f"[{i+1}/{self.trials}] {data}")
            time.sleep(5)

        return 0

    def from_csv(self):
        """Count flagged images from CSV. Fast but relies on local files."""
        dataset_name = self.dataset_name.split("/")[-1]
        df = pd.read_csv(f"./flagged/{dataset_name}/data.csv")
        return len(df)

    def count(self):
        """Count flagged images."""
        try:
            return self.from_csv()
        except FileNotFoundError:
            return self.from_query(self.query())