amaye15 commited on
Commit
60283f6
·
1 Parent(s): 6ea28ef

webhook - complete

Browse files
Files changed (1) hide show
  1. app.py +28 -14
app.py CHANGED
@@ -6,8 +6,8 @@ from pathlib import Path
6
  from huggingface_hub import WebhooksServer, WebhookPayload
7
  from datasets import Dataset, load_dataset, disable_caching
8
  from fastapi import BackgroundTasks, Response, status
9
- from huggingface_hub.utils import build_hf_headers, get_session
10
 
 
11
  disable_caching()
12
 
13
  # Set up the logger
@@ -23,7 +23,7 @@ logger.addHandler(console_handler)
23
 
24
  # Environment variables
25
  DS_NAME = "amaye15/object-segmentation"
26
- DATA_DIR = "data"
27
  TARGET_REPO = "amaye15/object-segmentation-processed"
28
  WEBHOOK_SECRET = os.getenv("HF_WEBHOOK_SECRET")
29
 
@@ -31,10 +31,13 @@ WEBHOOK_SECRET = os.getenv("HF_WEBHOOK_SECRET")
31
  def get_data():
32
  """
33
  Generator function to stream data from the dataset.
 
 
 
34
  """
35
  ds = load_dataset(
36
  DS_NAME,
37
- cache_dir=os.path.join(os.getcwd(), DATA_DIR),
38
  streaming=True,
39
  download_mode="force_redownload",
40
  )
@@ -46,16 +49,18 @@ def get_data():
46
  def process_and_push_data():
47
  """
48
  Function to process and push new data to the target repository.
49
- """
50
- p = os.path.join(os.getcwd(), DATA_DIR)
51
-
52
- if os.path.exists(p):
53
- shutil.rmtree(p)
54
 
55
- os.mkdir(p)
 
 
 
 
 
56
 
 
57
  ds_processed = Dataset.from_generator(get_data)
58
  ds_processed.push_to_hub(TARGET_REPO)
 
59
  logger.info("Data processed and pushed to the hub.")
60
  gc.collect()
61
 
@@ -70,6 +75,9 @@ async def handle_repository_changes(
70
  ):
71
  """
72
  Webhook endpoint that triggers data processing when the dataset is updated.
 
 
 
73
  """
74
  logger.info(
75
  f"Webhook received from {payload.repo.name} indicating a repo {payload.event.action}"
@@ -79,13 +87,19 @@ async def handle_repository_changes(
79
 
80
 
81
  def _process_webhook():
82
- logger.info(f"Loading new dataset...")
83
- # dataset = load_dataset(DS_NAME)
84
- logger.info(f"Loaded new dataset")
 
 
 
 
 
 
85
 
86
- logger.info(f"Processing and updating dataset...")
87
  process_and_push_data()
88
- logger.info(f"Processing and updating dataset completed!")
89
 
90
 
91
  if __name__ == "__main__":
 
6
  from huggingface_hub import WebhooksServer, WebhookPayload
7
  from datasets import Dataset, load_dataset, disable_caching
8
  from fastapi import BackgroundTasks, Response, status
 
9
 
10
+ # Disable caching globally for Hugging Face datasets
11
  disable_caching()
12
 
13
  # Set up the logger
 
23
 
24
  # Environment variables
25
  DS_NAME = "amaye15/object-segmentation"
26
+ DATA_DIR = Path("data") # Use pathlib for path handling
27
  TARGET_REPO = "amaye15/object-segmentation-processed"
28
  WEBHOOK_SECRET = os.getenv("HF_WEBHOOK_SECRET")
29
 
 
31
  def get_data():
32
  """
33
  Generator function to stream data from the dataset.
34
+
35
+ Uses streaming to avoid loading the entire dataset into memory at once,
36
+ which is useful for handling large datasets.
37
  """
38
  ds = load_dataset(
39
  DS_NAME,
40
+ cache_dir=DATA_DIR,
41
  streaming=True,
42
  download_mode="force_redownload",
43
  )
 
49
  def process_and_push_data():
50
  """
51
  Function to process and push new data to the target repository.
 
 
 
 
 
52
 
53
+ Removes existing data directory if it exists, recreates it, processes
54
+ the dataset, and pushes the processed dataset to the hub.
55
+ """
56
+ if DATA_DIR.exists():
57
+ shutil.rmtree(DATA_DIR)
58
+ DATA_DIR.mkdir(parents=True, exist_ok=True)
59
 
60
+ # Process data using the generator and push it to the hub
61
  ds_processed = Dataset.from_generator(get_data)
62
  ds_processed.push_to_hub(TARGET_REPO)
63
+
64
  logger.info("Data processed and pushed to the hub.")
65
  gc.collect()
66
 
 
75
  ):
76
  """
77
  Webhook endpoint that triggers data processing when the dataset is updated.
78
+
79
+ Adds a task to the background task queue to process the dataset
80
+ asynchronously.
81
  """
82
  logger.info(
83
  f"Webhook received from {payload.repo.name} indicating a repo {payload.event.action}"
 
87
 
88
 
89
  def _process_webhook():
90
+ """
91
+ Private function to handle the processing of the dataset when a webhook
92
+ is triggered.
93
+
94
+ Loads the dataset, processes it, and pushes the processed data to the hub.
95
+ """
96
+ logger.info("Loading new dataset...")
97
+ # Dataset loading is handled inside process_and_push_data, no need to load here
98
+ logger.info("Loaded new dataset")
99
 
100
+ logger.info("Processing and updating dataset...")
101
  process_and_push_data()
102
+ logger.info("Processing and updating dataset completed!")
103
 
104
 
105
  if __name__ == "__main__":