freQuensy23's picture
Init
3a6b931
from gradio import FlaggingCallback, utils
import csv
import datetime
import os
import re
import secrets
from pathlib import Path
from typing import Any, List, Union
class CSVLogger(FlaggingCallback):
"""
The default implementation of the FlaggingCallback abstract class. Each flagged
sample (both the input and output data) is logged to a CSV file with headers on the machine running the gradio app.
Example:
import gradio as gr
def image_classifier(inp):
return {'cat': 0.3, 'dog': 0.7}
demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
flagging_callback=CSVLogger())
Guides: using_flagging
"""
def __init__(self):
pass
def setup(
self,
components: List[Any],
flagging_dir: Union[str, Path],
):
self.components = components
self.flagging_dir = flagging_dir
os.makedirs(flagging_dir, exist_ok=True)
def flag(
self,
flag_data: List[Any],
flag_option: str = "",
username: Union[str, None] = None,
filename="log.csv",
) -> int:
flagging_dir = self.flagging_dir
filename = re.sub(r"[/\\?%*:|\"<>\x7F\x00-\x1F]", "-", filename)
log_filepath = Path(flagging_dir) / filename
is_new = not Path(log_filepath).exists()
headers = [
getattr(component, "label", None) or f"component {idx}"
for idx, component in enumerate(self.components)
] + [
"flag",
"username",
"timestamp",
]
csv_data = []
for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
save_dir = Path(
flagging_dir
) / (
getattr(component, "label", None) or f"component {idx}"
)
if utils.is_update(sample):
csv_data.append(str(sample))
else:
csv_data.append(
component.deserialize(sample, save_dir=save_dir)
if sample is not None
else ""
)
csv_data.append(flag_option)
csv_data.append(username if username is not None else "")
csv_data.append(str(datetime.datetime.now()))
try:
with open(log_filepath, "a", newline="", encoding="utf-8") as csvfile:
writer = csv.writer(csvfile)
if is_new:
writer.writerow(utils.sanitize_list_for_csv(headers))
writer.writerow(utils.sanitize_list_for_csv(csv_data))
except Exception as e:
# workaround "OSError: [Errno 95] Operation not supported" with open(log_filepath, "a") on some cloud mounted directory
random_hex = secrets.token_hex(16)
tmp_log_filepath = str(log_filepath) + f".tmp_{random_hex}"
with open(tmp_log_filepath, "a", newline="", encoding="utf-8") as csvfile:
writer = csv.writer(csvfile)
if is_new:
writer.writerow(utils.sanitize_list_for_csv(headers))
writer.writerow(utils.sanitize_list_for_csv(csv_data))
os.system(f"mv '{log_filepath}' '{log_filepath}.old_{random_hex}'")
os.system(f"cat '{log_filepath}.old_{random_hex}' '{tmp_log_filepath}' > '{log_filepath}'")
os.system(f"rm '{tmp_log_filepath}'")
os.system(f"rm '{log_filepath}.old_{random_hex}'")
with open(log_filepath, "r", encoding="utf-8") as csvfile:
line_count = len([None for row in csv.reader(csvfile)]) - 1
return line_count