RohitGandikota commited on
Commit
655863b
Β·
1 Parent(s): 3c5bd3b

fixing training

Browse files
Files changed (2) hide show
  1. app.py +11 -11
  2. trainscripts/textsliders/demotrain.py +1 -1
app.py CHANGED
@@ -227,16 +227,16 @@ class Demo:
227
  )
228
 
229
  def train(self, target_concept,positive_prompt, negative_prompt, rank, iterations_input, lr_input, attributes_input, is_person, pbar = gr.Progress(track_tqdm=True)):
230
- if target_concept is None:
231
- target_concept = ''
232
- if positive_prompt is None:
233
- positive_prompt = ''
234
- if negative_prompt is None:
235
- negative_prompt = ''
236
- if is_person is None:
237
- is_person = False
238
- else:
239
- is_person = True
240
  print(target_concept, positive_prompt, negative_prompt, attributes_input, is_person)
241
 
242
  randn = torch.randint(1, 10000000, (1,)).item()
@@ -253,7 +253,7 @@ class Demo:
253
  attributes = 'white, black, asian, hispanic, indian, male, female'
254
 
255
  self.training = True
256
- train_xl(target=target_concept, postive=positive_prompt, negative=negative_prompt, lr=lr_input, iterations=iterations_input, config_file='trainscripts/textsliders/data/config-xl.yaml', rank=rank, device=self.device, attributes=attributes, save_name=save_name)
257
  self.training = False
258
 
259
  torch.cuda.empty_cache()
 
227
  )
228
 
229
  def train(self, target_concept,positive_prompt, negative_prompt, rank, iterations_input, lr_input, attributes_input, is_person, pbar = gr.Progress(track_tqdm=True)):
230
+ # if target_concept is None:
231
+ # target_concept = ''
232
+ # if positive_prompt is None:
233
+ # positive_prompt = ''
234
+ # if negative_prompt is None:
235
+ # negative_prompt = ''
236
+ # if is_person is None:
237
+ # is_person = False
238
+ # else:
239
+ # is_person = True
240
  print(target_concept, positive_prompt, negative_prompt, attributes_input, is_person)
241
 
242
  randn = torch.randint(1, 10000000, (1,)).item()
 
253
  attributes = 'white, black, asian, hispanic, indian, male, female'
254
 
255
  self.training = True
256
+ train_xl(target=target_concept, positive=positive_prompt, negative=negative_prompt, lr=lr_input, iterations=iterations_input, config_file='trainscripts/textsliders/data/config-xl.yaml', rank=rank, device=self.device, attributes=attributes, save_name=save_name)
257
  self.training = False
258
 
259
  torch.cuda.empty_cache()
trainscripts/textsliders/demotrain.py CHANGED
@@ -411,7 +411,7 @@ def train(
411
  # train(config, prompts, device)
412
 
413
 
414
- def train_xl(target, postive, negative, lr, iterations, config_file, rank, device, attributes,save_name):
415
 
416
  config = config_util.load_config_from_yaml(config_file)
417
  randn = torch.randint(1, 10000000, (1,)).item()
 
411
  # train(config, prompts, device)
412
 
413
 
414
+ def train_xl(target, positive, negative, lr, iterations, config_file, rank, device, attributes,save_name):
415
 
416
  config = config_util.load_config_from_yaml(config_file)
417
  randn = torch.randint(1, 10000000, (1,)).item()