fffiloni commited on
Commit
21cbd99
·
verified ·
1 Parent(s): feb3c15

include changes to track iteration callbacks

Browse files
Files changed (1) hide show
  1. main.py +4 -3
main.py CHANGED
@@ -14,7 +14,7 @@ from rewards import get_reward_losses
14
  from training import LatentNoiseTrainer, get_optimizer
15
 
16
 
17
- def main(args):
18
  seed_everything(args.seed)
19
  bf.makedirs(f"{args.save_dir}/logs/{args.task}")
20
  # Set up logging and name settings
@@ -99,9 +99,10 @@ def main(args):
99
  save_dir = f"{args.save_dir}/{args.task}/{settings}/{args.prompt}"
100
  os.makedirs(f"{save_dir}", exist_ok=True)
101
  best_image, total_init_rewards, total_best_rewards = trainer.train(
102
- latents, args.prompt, optimizer, save_dir
103
  )
104
  best_image.save(f"{save_dir}/best_image.png")
 
105
  elif args.task == "example-prompts":
106
  fo = open("assets/example_prompts.txt", "r")
107
  prompts = fo.readlines()
@@ -271,4 +272,4 @@ def main(args):
271
 
272
  if __name__ == "__main__":
273
  args = parse_args()
274
- main(args)
 
14
  from training import LatentNoiseTrainer, get_optimizer
15
 
16
 
17
+ def main(args, progress_callback=None):
18
  seed_everything(args.seed)
19
  bf.makedirs(f"{args.save_dir}/logs/{args.task}")
20
  # Set up logging and name settings
 
99
  save_dir = f"{args.save_dir}/{args.task}/{settings}/{args.prompt}"
100
  os.makedirs(f"{save_dir}", exist_ok=True)
101
  best_image, total_init_rewards, total_best_rewards = trainer.train(
102
+ latents, args.prompt, optimizer, save_dir, progress_callback=progress_callback
103
  )
104
  best_image.save(f"{save_dir}/best_image.png")
105
+ return best_image, total_init_rewards, total_best_rewards
106
  elif args.task == "example-prompts":
107
  fo = open("assets/example_prompts.txt", "r")
108
  prompts = fo.readlines()
 
272
 
273
  if __name__ == "__main__":
274
  args = parse_args()
275
+ main(args)