jadechoghari
commited on
Update pipeline.py
Browse files- pipeline.py +61 -35
pipeline.py
CHANGED
@@ -1,68 +1,94 @@
|
|
1 |
from diffusers import DiffusionPipeline
|
2 |
-
import torch
|
3 |
import os
|
4 |
-
import
|
5 |
-
from
|
6 |
-
from .
|
7 |
|
8 |
-
class
|
9 |
|
10 |
-
def __init__(self, config_yaml, list_inference, reload_from_ckpt=None):
|
11 |
"""
|
12 |
-
Initialize the MOS Diffusion pipeline.
|
13 |
|
14 |
Args:
|
15 |
config_yaml (str): Path to the YAML configuration file.
|
16 |
list_inference (str): Path to the file containing inference prompts.
|
17 |
reload_from_ckpt (str, optional): Checkpoint path to reload from.
|
|
|
18 |
"""
|
19 |
super().__init__()
|
20 |
|
21 |
-
|
|
|
|
|
|
|
22 |
self.config_yaml = config_yaml
|
23 |
self.list_inference = list_inference
|
24 |
self.reload_from_ckpt = reload_from_ckpt
|
25 |
-
|
26 |
-
# we load the yaml config
|
27 |
config_yaml_path = os.path.join(self.config_yaml)
|
28 |
-
self.configs =
|
29 |
-
|
30 |
-
# override checkpoint if provided--
|
31 |
if self.reload_from_ckpt is not None:
|
32 |
self.configs["reload_from_ckpt"] = self.reload_from_ckpt
|
33 |
|
34 |
-
self.dataset_key = build_dataset_json_from_list(self.list_inference)
|
35 |
self.exp_name = os.path.basename(self.config_yaml.split(".")[0])
|
36 |
self.exp_group_name = os.path.basename(os.path.dirname(self.config_yaml))
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
@torch.no_grad()
|
39 |
def __call__(self, *args, **kwargs):
|
40 |
"""
|
41 |
Run the MOS Diffusion Pipeline. This method calls the infer function from infer_mos5.py.
|
42 |
-
|
43 |
-
Args:
|
44 |
-
*args: Additional arguments.
|
45 |
-
**kwargs: Keyword arguments that may contain overrides for configurations.
|
46 |
-
|
47 |
-
Returns:
|
48 |
-
None. Inference is performed and samples are generated.
|
49 |
"""
|
50 |
-
|
|
|
51 |
infer(
|
52 |
-
dataset_key=self.dataset_key,
|
53 |
-
configs=self.configs,
|
54 |
-
config_yaml_path=self.config_yaml,
|
55 |
-
exp_group_name=self.exp_group_name,
|
56 |
exp_name=self.exp_name
|
57 |
)
|
58 |
|
59 |
-
#
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
66 |
|
67 |
-
#
|
68 |
-
|
|
|
1 |
from diffusers import DiffusionPipeline
|
|
|
2 |
import os
|
3 |
+
import sys
|
4 |
+
from huggingface_hub import HfApi, hf_hub_download
|
5 |
+
from .tools import build_dataset_json_from_list
|
6 |
|
7 |
+
class MOSDiffusionPipeline(DiffusionPipeline):
|
8 |
|
9 |
+
def __init__(self, config_yaml, list_inference, reload_from_ckpt=None, base_folder=None):
|
10 |
"""
|
11 |
+
Initialize the MOS Diffusion pipeline and download the necessary files/folders.
|
12 |
|
13 |
Args:
|
14 |
config_yaml (str): Path to the YAML configuration file.
|
15 |
list_inference (str): Path to the file containing inference prompts.
|
16 |
reload_from_ckpt (str, optional): Checkpoint path to reload from.
|
17 |
+
base_folder (str, optional): Base folder to store downloaded files. Defaults to the current working directory.
|
18 |
"""
|
19 |
super().__init__()
|
20 |
|
21 |
+
|
22 |
+
self.base_folder = base_folder if base_folder else os.getcwd()
|
23 |
+
self.repo_id = "jadechoghari/qa-mdt"
|
24 |
+
self.download_required_folders()
|
25 |
self.config_yaml = config_yaml
|
26 |
self.list_inference = list_inference
|
27 |
self.reload_from_ckpt = reload_from_ckpt
|
|
|
|
|
28 |
config_yaml_path = os.path.join(self.config_yaml)
|
29 |
+
self.configs = self.load_yaml(config_yaml_path)
|
|
|
|
|
30 |
if self.reload_from_ckpt is not None:
|
31 |
self.configs["reload_from_ckpt"] = self.reload_from_ckpt
|
32 |
|
33 |
+
self.dataset_key = self.build_dataset_json_from_list(self.list_inference)
|
34 |
self.exp_name = os.path.basename(self.config_yaml.split(".")[0])
|
35 |
self.exp_group_name = os.path.basename(os.path.dirname(self.config_yaml))
|
36 |
|
37 |
+
def download_required_folders(self):
|
38 |
+
"""
|
39 |
+
Downloads the necessary folders from the Hugging Face Hub if they are not already available locally.
|
40 |
+
"""
|
41 |
+
api = HfApi()
|
42 |
+
|
43 |
+
files = api.list_repo_files(repo_id=self.repo_id)
|
44 |
+
|
45 |
+
required_folders = ["audioldm_train", "checkpoints", "infer", "log", "taming", "test_prompts"]
|
46 |
+
|
47 |
+
files_to_download = [f for f in files if any(f.startswith(folder) for folder in required_folders)]
|
48 |
+
|
49 |
+
for file in files_to_download:
|
50 |
+
local_file_path = os.path.join(self.base_folder, file)
|
51 |
+
if not os.path.exists(local_file_path):
|
52 |
+
downloaded_file = hf_hub_download(repo_id=self.repo_id, filename=file)
|
53 |
+
|
54 |
+
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
|
55 |
+
|
56 |
+
os.rename(downloaded_file, local_file_path)
|
57 |
+
|
58 |
+
sys.path.append(self.base_folder)
|
59 |
+
|
60 |
+
def load_yaml(self, yaml_path):
|
61 |
+
"""
|
62 |
+
Helper method to load the YAML configuration.
|
63 |
+
"""
|
64 |
+
import yaml
|
65 |
+
with open(yaml_path, "r") as f:
|
66 |
+
return yaml.safe_load(f)
|
67 |
+
|
68 |
+
|
69 |
@torch.no_grad()
|
70 |
def __call__(self, *args, **kwargs):
|
71 |
"""
|
72 |
Run the MOS Diffusion Pipeline. This method calls the infer function from infer_mos5.py.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
"""
|
74 |
+
from infer_mos5 import infer
|
75 |
+
|
76 |
infer(
|
77 |
+
dataset_key=self.dataset_key,
|
78 |
+
configs=self.configs,
|
79 |
+
config_yaml_path=self.config_yaml,
|
80 |
+
exp_group_name=self.exp_group_name,
|
81 |
exp_name=self.exp_name
|
82 |
)
|
83 |
|
84 |
+
# Example of how to use the pipeline
|
85 |
+
if __name__ == "__main__":
|
86 |
+
pipeline = MOSDiffusionPipeline(
|
87 |
+
config_yaml="audioldm_train/config/mos_as_token/qa_mdt.yaml",
|
88 |
+
list_inference="test_prompts/good_prompts_1.lst",
|
89 |
+
reload_from_ckpt="checkpoints/checkpoint_389999.ckpt",
|
90 |
+
base_folder=None
|
91 |
+
)
|
92 |
|
93 |
+
# Run the pipeline
|
94 |
+
pipeline()
|