Spaces:
Sleeping
Sleeping
include changes to track iteration callbacks
Browse files
main.py
CHANGED
@@ -14,7 +14,7 @@ from rewards import get_reward_losses
|
|
14 |
from training import LatentNoiseTrainer, get_optimizer
|
15 |
|
16 |
|
17 |
-
def main(args):
|
18 |
seed_everything(args.seed)
|
19 |
bf.makedirs(f"{args.save_dir}/logs/{args.task}")
|
20 |
# Set up logging and name settings
|
@@ -99,9 +99,10 @@ def main(args):
|
|
99 |
save_dir = f"{args.save_dir}/{args.task}/{settings}/{args.prompt}"
|
100 |
os.makedirs(f"{save_dir}", exist_ok=True)
|
101 |
best_image, total_init_rewards, total_best_rewards = trainer.train(
|
102 |
-
latents, args.prompt, optimizer, save_dir
|
103 |
)
|
104 |
best_image.save(f"{save_dir}/best_image.png")
|
|
|
105 |
elif args.task == "example-prompts":
|
106 |
fo = open("assets/example_prompts.txt", "r")
|
107 |
prompts = fo.readlines()
|
@@ -271,4 +272,4 @@ def main(args):
|
|
271 |
|
272 |
if __name__ == "__main__":
|
273 |
args = parse_args()
|
274 |
-
main(args)
|
|
|
14 |
from training import LatentNoiseTrainer, get_optimizer
|
15 |
|
16 |
|
17 |
+
def main(args, progress_callback=None):
|
18 |
seed_everything(args.seed)
|
19 |
bf.makedirs(f"{args.save_dir}/logs/{args.task}")
|
20 |
# Set up logging and name settings
|
|
|
99 |
save_dir = f"{args.save_dir}/{args.task}/{settings}/{args.prompt}"
|
100 |
os.makedirs(f"{save_dir}", exist_ok=True)
|
101 |
best_image, total_init_rewards, total_best_rewards = trainer.train(
|
102 |
+
latents, args.prompt, optimizer, save_dir, progress_callback=progress_callback
|
103 |
)
|
104 |
best_image.save(f"{save_dir}/best_image.png")
|
105 |
+
return best_image, total_init_rewards, total_best_rewards
|
106 |
elif args.task == "example-prompts":
|
107 |
fo = open("assets/example_prompts.txt", "r")
|
108 |
prompts = fo.readlines()
|
|
|
272 |
|
273 |
if __name__ == "__main__":
|
274 |
args = parse_args()
|
275 |
+
main(args)
|