jadechoghari commited on
Commit
893807d
·
verified ·
1 Parent(s): 0f80d94

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +61 -35
pipeline.py CHANGED
@@ -1,68 +1,94 @@
1
  from diffusers import DiffusionPipeline
2
- import torch
3
  import os
4
- import yaml
5
- from .audioldm_train.utilities.tools import build_dataset_json_from_list
6
- from .infer.infer_mos5 import infer # Importing the infer function
7
 
8
- class QAMDTModel(DiffusionPipeline):
9
 
10
- def __init__(self, config_yaml, list_inference, reload_from_ckpt=None):
11
  """
12
- Initialize the MOS Diffusion pipeline.
13
 
14
  Args:
15
  config_yaml (str): Path to the YAML configuration file.
16
  list_inference (str): Path to the file containing inference prompts.
17
  reload_from_ckpt (str, optional): Checkpoint path to reload from.
 
18
  """
19
  super().__init__()
20
 
21
- # we load and process the yaml config
 
 
 
22
  self.config_yaml = config_yaml
23
  self.list_inference = list_inference
24
  self.reload_from_ckpt = reload_from_ckpt
25
-
26
- # we load the yaml config
27
  config_yaml_path = os.path.join(self.config_yaml)
28
- self.configs = yaml.load(open(config_yaml_path, "r"), Loader=yaml.FullLoader)
29
-
30
- # override checkpoint if provided--
31
  if self.reload_from_ckpt is not None:
32
  self.configs["reload_from_ckpt"] = self.reload_from_ckpt
33
 
34
- self.dataset_key = build_dataset_json_from_list(self.list_inference)
35
  self.exp_name = os.path.basename(self.config_yaml.split(".")[0])
36
  self.exp_group_name = os.path.basename(os.path.dirname(self.config_yaml))
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  @torch.no_grad()
39
  def __call__(self, *args, **kwargs):
40
  """
41
  Run the MOS Diffusion Pipeline. This method calls the infer function from infer_mos5.py.
42
-
43
- Args:
44
- *args: Additional arguments.
45
- **kwargs: Keyword arguments that may contain overrides for configurations.
46
-
47
- Returns:
48
- None. Inference is performed and samples are generated.
49
  """
50
- # here call the infer function to perform the inference
 
51
  infer(
52
- dataset_key=self.dataset_key,
53
- configs=self.configs,
54
- config_yaml_path=self.config_yaml,
55
- exp_group_name=self.exp_group_name,
56
  exp_name=self.exp_name
57
  )
58
 
59
- # # This is an example of how to use the pipeline
60
- # if __name__ == "__main__":
61
- # pipeline = MOSDiffusionPipeline(
62
- # config_yaml="audioldm_train/config/mos_as_token/qa_mdt.yaml",
63
- # list_inference="/content/qa-mdt/test_prompts/good_prompts_1.lst",
64
- # reload_from_ckpt="/content/qa-mdt/checkpoint_389999.ckpt"
65
- # )
 
66
 
67
- # # Run the pipeline
68
- # pipeline()
 
1
  from diffusers import DiffusionPipeline
 
2
  import os
3
+ import sys
4
+ from huggingface_hub import HfApi, hf_hub_download
5
+ from .tools import build_dataset_json_from_list
6
 
7
+ class MOSDiffusionPipeline(DiffusionPipeline):
8
 
9
+ def __init__(self, config_yaml, list_inference, reload_from_ckpt=None, base_folder=None):
10
  """
11
+ Initialize the MOS Diffusion pipeline and download the necessary files/folders.
12
 
13
  Args:
14
  config_yaml (str): Path to the YAML configuration file.
15
  list_inference (str): Path to the file containing inference prompts.
16
  reload_from_ckpt (str, optional): Checkpoint path to reload from.
17
+ base_folder (str, optional): Base folder to store downloaded files. Defaults to the current working directory.
18
  """
19
  super().__init__()
20
 
21
+
22
+ self.base_folder = base_folder if base_folder else os.getcwd()
23
+ self.repo_id = "jadechoghari/qa-mdt"
24
+ self.download_required_folders()
25
  self.config_yaml = config_yaml
26
  self.list_inference = list_inference
27
  self.reload_from_ckpt = reload_from_ckpt
 
 
28
  config_yaml_path = os.path.join(self.config_yaml)
29
+ self.configs = self.load_yaml(config_yaml_path)
 
 
30
  if self.reload_from_ckpt is not None:
31
  self.configs["reload_from_ckpt"] = self.reload_from_ckpt
32
 
33
+ self.dataset_key = self.build_dataset_json_from_list(self.list_inference)
34
  self.exp_name = os.path.basename(self.config_yaml.split(".")[0])
35
  self.exp_group_name = os.path.basename(os.path.dirname(self.config_yaml))
36
 
37
+ def download_required_folders(self):
38
+ """
39
+ Downloads the necessary folders from the Hugging Face Hub if they are not already available locally.
40
+ """
41
+ api = HfApi()
42
+
43
+ files = api.list_repo_files(repo_id=self.repo_id)
44
+
45
+ required_folders = ["audioldm_train", "checkpoints", "infer", "log", "taming", "test_prompts"]
46
+
47
+ files_to_download = [f for f in files if any(f.startswith(folder) for folder in required_folders)]
48
+
49
+ for file in files_to_download:
50
+ local_file_path = os.path.join(self.base_folder, file)
51
+ if not os.path.exists(local_file_path):
52
+ downloaded_file = hf_hub_download(repo_id=self.repo_id, filename=file)
53
+
54
+ os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
55
+
56
+ os.rename(downloaded_file, local_file_path)
57
+
58
+ sys.path.append(self.base_folder)
59
+
60
+ def load_yaml(self, yaml_path):
61
+ """
62
+ Helper method to load the YAML configuration.
63
+ """
64
+ import yaml
65
+ with open(yaml_path, "r") as f:
66
+ return yaml.safe_load(f)
67
+
68
+
69
  @torch.no_grad()
70
  def __call__(self, *args, **kwargs):
71
  """
72
  Run the MOS Diffusion Pipeline. This method calls the infer function from infer_mos5.py.
 
 
 
 
 
 
 
73
  """
74
+ from infer_mos5 import infer
75
+
76
  infer(
77
+ dataset_key=self.dataset_key,
78
+ configs=self.configs,
79
+ config_yaml_path=self.config_yaml,
80
+ exp_group_name=self.exp_group_name,
81
  exp_name=self.exp_name
82
  )
83
 
84
+ # Example of how to use the pipeline
85
+ if __name__ == "__main__":
86
+ pipeline = MOSDiffusionPipeline(
87
+ config_yaml="audioldm_train/config/mos_as_token/qa_mdt.yaml",
88
+ list_inference="test_prompts/good_prompts_1.lst",
89
+ reload_from_ckpt="checkpoints/checkpoint_389999.ckpt",
90
+ base_folder=None
91
+ )
92
 
93
+ # Run the pipeline
94
+ pipeline()