import os import gc import shutil import logging from pathlib import Path from huggingface_hub import WebhooksServer, WebhookPayload from datasets import Dataset, load_dataset, disable_caching from fastapi import BackgroundTasks, Response, status # Disable caching globally for Hugging Face datasets disable_caching() # Set up the logger logger = logging.getLogger("basic_logger") logger.setLevel(logging.INFO) # Set up the console handler with a simple format console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") console_handler.setFormatter(formatter) logger.addHandler(console_handler) # Environment variables DS_NAME = "amaye15/object-segmentation" DATA_DIR = Path("data") # Use pathlib for path handling TARGET_REPO = "amaye15/object-segmentation-processed" WEBHOOK_SECRET = os.getenv("HF_WEBHOOK_SECRET") def get_data(): """ Generator function to stream data from the dataset. Uses streaming to avoid loading the entire dataset into memory at once, which is useful for handling large datasets. """ ds = load_dataset( DS_NAME, cache_dir=DATA_DIR, streaming=True, download_mode="force_redownload", ) for row in ds["train"]: yield row gc.collect() def process_and_push_data(): """ Function to process and push new data to the target repository. Removes existing data directory if it exists, recreates it, processes the dataset, and pushes the processed dataset to the hub. """ if DATA_DIR.exists(): shutil.rmtree(DATA_DIR) DATA_DIR.mkdir(parents=True, exist_ok=True) # Process data using the generator and push it to the hub ds_processed = Dataset.from_generator(get_data) ds_processed.push_to_hub(TARGET_REPO) logger.info("Data processed and pushed to the hub.") gc.collect() # Initialize the WebhooksServer with Gradio interface (if needed) app = WebhooksServer(webhook_secret=WEBHOOK_SECRET) @app.add_webhook("/dataset_repo") async def handle_repository_changes( payload: WebhookPayload, task_queue: BackgroundTasks ): """ Webhook endpoint that triggers data processing when the dataset is updated. Adds a task to the background task queue to process the dataset asynchronously. """ logger.info( f"Webhook received from {payload.repo.name} indicating a repo {payload.event.action}" ) task_queue.add_task(_process_webhook) return Response("Task scheduled.", status_code=status.HTTP_202_ACCEPTED) def _process_webhook(): """ Private function to handle the processing of the dataset when a webhook is triggered. Loads the dataset, processes it, and pushes the processed data to the hub. """ logger.info("Loading new dataset...") # Dataset loading is handled inside process_and_push_data, no need to load here logger.info("Loaded new dataset") logger.info("Processing and updating dataset...") process_and_push_data() logger.info("Processing and updating dataset completed!") if __name__ == "__main__": app.launch(server_name="0.0.0.0", show_error=True, server_port=7860)