mrfakename commited on
Commit
ef90edf
·
verified ·
1 Parent(s): d9c8497

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

Files changed (2) hide show
  1. README_REPO.md +21 -0
  2. model/trainer.py +24 -20
README_REPO.md CHANGED
@@ -72,6 +72,27 @@ An initial guidance on Finetuning [#57](https://github.com/SWivid/F5-TTS/discuss
72
 
73
  Gradio UI finetuning with `finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  ## Inference
76
 
77
  The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS) and [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), or automatically downloaded with `inference-cli` and `gradio_app`.
 
72
 
73
  Gradio UI finetuning with `finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
74
 
75
+ ## Wandb Logging
76
+
77
+ By default, the training script does NOT use logging (assuming you didn't manually log in using `wandb login`).
78
+
79
+ To turn on wandb logging, you can either:
80
+
81
+ 1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login)
82
+ 2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/site/ and set the environment variable as follows:
83
+
84
+ On Mac & Linux:
85
+
86
+ ```
87
+ export WANDB_API_KEY=<YOUR WANDB API KEY>
88
+ ```
89
+
90
+ On Windows:
91
+
92
+ ```
93
+ set WANDB_API_KEY=<YOUR WANDB API KEY>
94
+ ```
95
+
96
  ## Inference
97
 
98
  The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS) and [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), or automatically downloaded with `inference-cli` and `gradio_app`.
model/trainer.py CHANGED
@@ -50,31 +50,35 @@ class Trainer:
50
 
51
  ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
52
 
 
 
 
53
  self.accelerator = Accelerator(
54
- log_with = "wandb",
55
  kwargs_handlers = [ddp_kwargs],
56
  gradient_accumulation_steps = grad_accumulation_steps,
57
  **accelerate_kwargs
58
  )
59
-
60
- if exists(wandb_resume_id):
61
- init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name, 'id': wandb_resume_id}}
62
- else:
63
- init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name}}
64
- self.accelerator.init_trackers(
65
- project_name = wandb_project,
66
- init_kwargs=init_kwargs,
67
- config={"epochs": epochs,
68
- "learning_rate": learning_rate,
69
- "num_warmup_updates": num_warmup_updates,
70
- "batch_size": batch_size,
71
- "batch_size_type": batch_size_type,
72
- "max_samples": max_samples,
73
- "grad_accumulation_steps": grad_accumulation_steps,
74
- "max_grad_norm": max_grad_norm,
75
- "gpus": self.accelerator.num_processes,
76
- "noise_scheduler": noise_scheduler}
77
- )
 
78
 
79
  self.model = model
80
 
 
50
 
51
  ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
52
 
53
+ logger = "wandb" if wandb.api.api_key else None
54
+ print(f"Using logger: {logger}")
55
+
56
  self.accelerator = Accelerator(
57
+ log_with = logger,
58
  kwargs_handlers = [ddp_kwargs],
59
  gradient_accumulation_steps = grad_accumulation_steps,
60
  **accelerate_kwargs
61
  )
62
+
63
+ if logger == "wandb":
64
+ if exists(wandb_resume_id):
65
+ init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name, 'id': wandb_resume_id}}
66
+ else:
67
+ init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name}}
68
+ self.accelerator.init_trackers(
69
+ project_name = wandb_project,
70
+ init_kwargs=init_kwargs,
71
+ config={"epochs": epochs,
72
+ "learning_rate": learning_rate,
73
+ "num_warmup_updates": num_warmup_updates,
74
+ "batch_size": batch_size,
75
+ "batch_size_type": batch_size_type,
76
+ "max_samples": max_samples,
77
+ "grad_accumulation_steps": grad_accumulation_steps,
78
+ "max_grad_norm": max_grad_norm,
79
+ "gpus": self.accelerator.num_processes,
80
+ "noise_scheduler": noise_scheduler}
81
+ )
82
 
83
  self.model = model
84