jacklangerman commited on
Commit
edc7860
·
1 Parent(s): e5605f2

add streaming support

Browse files
Files changed (1) hide show
  1. hoho/hoho.py +30 -5
hoho/hoho.py CHANGED
@@ -3,8 +3,13 @@ import json
3
  import shutil
4
  from pathlib import Path
5
  from typing import Dict
 
6
 
7
  from PIL import ImageFile
 
 
 
 
8
  ImageFile.LOAD_TRUNCATED_IMAGES = True
9
 
10
  LOCAL_DATADIR = None
@@ -29,11 +34,11 @@ def setup(local_dir='./data/usm-training-data/data'):
29
  else:
30
  LOCAL_DATADIR = local_val_datadir
31
  print(f"Using {LOCAL_DATADIR} as the data directory (we are running locally)")
32
-
33
- # os.system("ls -lahtr")
34
- # os.system(f"ls -lahtr {LOCAL_DATADIR}")
35
 
36
- assert LOCAL_DATADIR.exists(), f"Data directory {LOCAL_DATADIR} does not exist"
 
 
 
37
  return LOCAL_DATADIR
38
 
39
 
@@ -286,7 +291,9 @@ def get_params():
286
  import webdataset as wds
287
  import numpy as np
288
 
289
- def get_dataset(decode='pil', proc=proc, split='train', dataset_type='webdataset'):
 
 
290
  if LOCAL_DATADIR is None:
291
  raise ValueError('LOCAL_DATADIR is not set. Please run setup() first.')
292
 
@@ -295,8 +302,24 @@ def get_dataset(decode='pil', proc=proc, split='train', dataset_type='webdataset
295
  local_dir = local_dir / split
296
 
297
  paths = [str(p) for p in local_dir.rglob('*.tar.gz')]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
 
299
  dataset = wds.WebDataset(paths)
 
300
  if decode is not None:
301
  dataset = dataset.decode(decode)
302
  else:
@@ -315,6 +338,8 @@ def get_dataset(decode='pil', proc=proc, split='train', dataset_type='webdataset
315
  return datasets.IterableDataset.from_generator(lambda: dataset.iterator())
316
  elif split == 'val':
317
  return datasets.IterableDataset.from_generator(lambda: dataset.iterator())
 
 
318
 
319
 
320
 
 
3
  import shutil
4
  from pathlib import Path
5
  from typing import Dict
6
+ import warnings
7
 
8
  from PIL import ImageFile
9
+
10
+ from huggingface_hub.utils._headers import build_hf_headers # note: using _headers
11
+
12
+
13
  ImageFile.LOAD_TRUNCATED_IMAGES = True
14
 
15
  LOCAL_DATADIR = None
 
34
  else:
35
  LOCAL_DATADIR = local_val_datadir
36
  print(f"Using {LOCAL_DATADIR} as the data directory (we are running locally)")
 
 
 
37
 
38
+ if not LOCAL_DATADIR.exists():
39
+ warnings.warn(f"Data directory {LOCAL_DATADIR} does not exist: creating it...")
40
+ LOCAL_DATADIR.mkdir(parents=True)
41
+
42
  return LOCAL_DATADIR
43
 
44
 
 
291
  import webdataset as wds
292
  import numpy as np
293
 
294
+
295
+ SHARD_IDS = {'train': (0, 25), 'val': (25, 26), 'public': (26, 27), 'private': (27, 32)}
296
+ def get_dataset(decode='pil', proc=proc, split='train', dataset_type='webdataset', stream=True):
297
  if LOCAL_DATADIR is None:
298
  raise ValueError('LOCAL_DATADIR is not set. Please run setup() first.')
299
 
 
302
  local_dir = local_dir / split
303
 
304
  paths = [str(p) for p in local_dir.rglob('*.tar.gz')]
305
+ msg = f'no tarfiles found in {local_dir}.'
306
+ if len(paths) == 0:
307
+ if stream:
308
+ if split=='all': split = 'train'
309
+ warnings.warn('streaming isn\'t using with \'all\': changing `split` to \'train\'')
310
+ warnings.warn(msg)
311
+ if split == 'val':
312
+ names = [f'data/val/inputs/hoho_v3_{i:03}-of-032.tar.gz' for i in range(*SHARD_IDS[split])]
313
+ elif split == 'train':
314
+ names = [f'data/train/hoho_v3_{i:03}-of-032.tar.gz' for i in range(*SHARD_IDS[split])]
315
+
316
+ auth = build_hf_headers()['authorization']
317
+ paths = [f"pipe:curl -L -s https://huggingface.co/datasets/usm3d/hoho-train-set/resolve/main/{name} -H 'Authorization: {auth}'" for name in names]
318
+ else:
319
+ raise FileNotFoundError(msg)
320
 
321
  dataset = wds.WebDataset(paths)
322
+
323
  if decode is not None:
324
  dataset = dataset.decode(decode)
325
  else:
 
338
  return datasets.IterableDataset.from_generator(lambda: dataset.iterator())
339
  elif split == 'val':
340
  return datasets.IterableDataset.from_generator(lambda: dataset.iterator())
341
+ else:
342
+ raise NotImplementedError('only train and val are implemented as hf datasets')
343
 
344
 
345