Spaces:
Runtime error
Runtime error
RohitGandikota
commited on
Commit
Β·
655863b
1
Parent(s):
3c5bd3b
fixing training
Browse files- app.py +11 -11
- trainscripts/textsliders/demotrain.py +1 -1
app.py
CHANGED
@@ -227,16 +227,16 @@ class Demo:
|
|
227 |
)
|
228 |
|
229 |
def train(self, target_concept,positive_prompt, negative_prompt, rank, iterations_input, lr_input, attributes_input, is_person, pbar = gr.Progress(track_tqdm=True)):
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
print(target_concept, positive_prompt, negative_prompt, attributes_input, is_person)
|
241 |
|
242 |
randn = torch.randint(1, 10000000, (1,)).item()
|
@@ -253,7 +253,7 @@ class Demo:
|
|
253 |
attributes = 'white, black, asian, hispanic, indian, male, female'
|
254 |
|
255 |
self.training = True
|
256 |
-
train_xl(target=target_concept,
|
257 |
self.training = False
|
258 |
|
259 |
torch.cuda.empty_cache()
|
|
|
227 |
)
|
228 |
|
229 |
def train(self, target_concept,positive_prompt, negative_prompt, rank, iterations_input, lr_input, attributes_input, is_person, pbar = gr.Progress(track_tqdm=True)):
|
230 |
+
# if target_concept is None:
|
231 |
+
# target_concept = ''
|
232 |
+
# if positive_prompt is None:
|
233 |
+
# positive_prompt = ''
|
234 |
+
# if negative_prompt is None:
|
235 |
+
# negative_prompt = ''
|
236 |
+
# if is_person is None:
|
237 |
+
# is_person = False
|
238 |
+
# else:
|
239 |
+
# is_person = True
|
240 |
print(target_concept, positive_prompt, negative_prompt, attributes_input, is_person)
|
241 |
|
242 |
randn = torch.randint(1, 10000000, (1,)).item()
|
|
|
253 |
attributes = 'white, black, asian, hispanic, indian, male, female'
|
254 |
|
255 |
self.training = True
|
256 |
+
train_xl(target=target_concept, positive=positive_prompt, negative=negative_prompt, lr=lr_input, iterations=iterations_input, config_file='trainscripts/textsliders/data/config-xl.yaml', rank=rank, device=self.device, attributes=attributes, save_name=save_name)
|
257 |
self.training = False
|
258 |
|
259 |
torch.cuda.empty_cache()
|
trainscripts/textsliders/demotrain.py
CHANGED
@@ -411,7 +411,7 @@ def train(
|
|
411 |
# train(config, prompts, device)
|
412 |
|
413 |
|
414 |
-
def train_xl(target,
|
415 |
|
416 |
config = config_util.load_config_from_yaml(config_file)
|
417 |
randn = torch.randint(1, 10000000, (1,)).item()
|
|
|
411 |
# train(config, prompts, device)
|
412 |
|
413 |
|
414 |
+
def train_xl(target, positive, negative, lr, iterations, config_file, rank, device, attributes,save_name):
|
415 |
|
416 |
config = config_util.load_config_from_yaml(config_file)
|
417 |
randn = torch.randint(1, 10000000, (1,)).item()
|