import os from urllib.parse import urlparse aria2 = os.getenv('COMFYUI_MANAGER_ARIA2_SERVER') HF_ENDPOINT = os.getenv('HF_ENDPOINT') if aria2 is not None: secret = os.getenv('COMFYUI_MANAGER_ARIA2_SECRET') url = urlparse(aria2) port = url.port host = url.scheme + '://' + url.hostname import aria2p aria2 = aria2p.API(aria2p.Client(host=host, port=port, secret=secret)) def download_url(model_url: str, model_dir: str, filename: str): if aria2: return aria2_download_url(model_url, model_dir, filename) else: from torchvision.datasets.utils import download_url as torchvision_download_url return torchvision_download_url(model_url, model_dir, filename) def aria2_find_task(dir: str, filename: str): target = os.path.join(dir, filename) downloads = aria2.get_downloads() for download in downloads: for file in download.files: if file.is_metadata: continue if str(file.path) == target: return download def aria2_download_url(model_url: str, model_dir: str, filename: str): import manager_core as core import tqdm import time if model_dir.startswith(core.comfy_path): model_dir = model_dir[len(core.comfy_path) :] if HF_ENDPOINT: model_url = model_url.replace('https://huggingface.co', HF_ENDPOINT) download_dir = model_dir if model_dir.startswith('/') else os.path.join('/models', model_dir) download = aria2_find_task(download_dir, filename) if download is None: options = {'dir': download_dir, 'out': filename} download = aria2.add(model_url, options)[0] if download.is_active: with tqdm.tqdm( total=download.total_length, bar_format='{l_bar}{bar}{r_bar}', desc=filename, unit='B', unit_scale=True, ) as progress_bar: while download.is_active: if progress_bar.total == 0 and download.total_length != 0: progress_bar.reset(download.total_length) progress_bar.update(download.completed_length - progress_bar.n) time.sleep(1) download.update()