import os from typing import Optional import requests from tqdm.auto import tqdm MODEL_PATHS = { "clevr": "https://www.dropbox.com/s/bqpc3ymstz9m05z/clevr_model.pt", "celebahq": "https://www.dropbox.com/s/687wuamoud4cs9x/celeb_model.pt" } DATA_PATHS = { "clevr": "", "clevr_toy": "" } def download_model( dataset: str, cache_dir: Optional[str] = None, chunk_size: int = 4096, ) -> str: if dataset not in MODEL_PATHS: raise ValueError( f"Unknown dataset name {dataset}. Known names are: {MODEL_PATHS.keys()}." ) if cache_dir is None: cache_dir = './' url = MODEL_PATHS[dataset] os.makedirs(cache_dir, exist_ok=True) local_path = os.path.join(cache_dir, url.split("/")[-1]) if os.path.exists(local_path.replace('?dl=0', '')): return local_path.replace('?dl=0', '') headers = {'user-agent': 'Wget/1.16 (linux-gnu)'} r = requests.get(url, stream=True, headers=headers) size = int(r.headers.get("content-length", "0")) with open(local_path, 'wb') as f: pbar = tqdm(total=size, unit="iB", unit_scale=True) for chunk in r.iter_content(chunk_size=chunk_size): if chunk: pbar.update(len(chunk)) f.write(chunk) os.rename(local_path, local_path.replace('?dl=0', '')) return local_path.replace('?dl=0', '')