yonishafir commited on
Commit
4012979
·
1 Parent(s): bf5b182

allow lora to be loaded w/o loading eveyr time

Browse files
Files changed (1) hide show
  1. app.py +45 -19
app.py CHANGED
@@ -163,6 +163,9 @@ def make_canny_condition(image, min_val=100, max_val=200, w_bilateral=True):
163
 
164
  default_negative_prompt = "Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers"
165
 
 
 
 
166
  # Load face detection and recognition package
167
  app = FaceAnalysis(name='antelopev2', root='./', providers=['CPUExecutionProvider'])
168
  app.prepare(ctx_id=0, det_size=(640, 640))
@@ -229,11 +232,9 @@ Loras_dict = {
229
  "":"",
230
  "Vangogh_Vanilla": "bold, dramatic brush strokes, vibrant colors, swirling patterns, intense, emotionally charged paintings of",
231
  "Avatar_internlm": "2d anime sketch avatar of",
232
- # "Tomer_Hanuka_V3": "Fluid lines",
233
- "Storyboards": "Illustration style for storyboarding",
234
- "3D_illustration": "3D object illustration, abstract",
235
- # "beetl_general_death_style_v2": "a pale, dead, unnatural color face with dark circles around the eyes",
236
- "Characters": "gaming vector Art"
237
  }
238
 
239
  lora_names = Loras_dict.keys()
@@ -302,22 +303,47 @@ def generate_image(image_path, prompt, num_steps, guidance_scale, seed, num_imag
302
 
303
  generator = torch.Generator(device=device).manual_seed(seed)
304
 
305
- if lora_name != "":
306
- lora_path = os.path.join(lora_base_path, lora_name, "pytorch_lora_weights.safetensors")
307
- pipe.load_lora_weights(lora_path)
308
- pipe.fuse_lora(lora_scale)
309
- pipe.enable_lora()
 
 
 
 
 
 
 
310
 
311
- lora_prefix = Loras_dict[lora_name]
 
 
 
 
 
312
 
313
- prompt = f"{lora_prefix} {prompt}"
 
 
 
 
314
 
315
- print("Using LoRA: ", lora_name)
316
-
 
 
 
 
 
 
 
 
 
317
 
318
  print("Start inference...")
319
  images = pipe(
320
- prompt = prompt,
321
  negative_prompt = default_negative_prompt,
322
  image_embeds = face_emb,
323
  image = [face_kps, canny_img] if canny_scale>0.0 else face_kps,
@@ -332,10 +358,10 @@ def generate_image(image_path, prompt, num_steps, guidance_scale, seed, num_imag
332
  num_images_per_prompt=num_images,
333
  ).images #[0]
334
 
335
- if lora_name != "":
336
- pipe.disable_lora()
337
- pipe.unfuse_lora()
338
- pipe.unload_lora_weights()
339
 
340
  gc.collect()
341
  torch.cuda.empty_cache()
 
163
 
164
  default_negative_prompt = "Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers"
165
 
166
+ # Global variable to track the currently loaded LoRA
167
+ current_lora_name = None
168
+
169
  # Load face detection and recognition package
170
  app = FaceAnalysis(name='antelopev2', root='./', providers=['CPUExecutionProvider'])
171
  app.prepare(ctx_id=0, det_size=(640, 640))
 
232
  "":"",
233
  "Vangogh_Vanilla": "bold, dramatic brush strokes, vibrant colors, swirling patterns, intense, emotionally charged paintings of",
234
  "Avatar_internlm": "2d anime sketch avatar of",
235
+ "Storyboards": "Illustration style for storyboarding.",
236
+ "3D_illustration": "3D object illustration, abstract.",
237
+ "Characters": "gaming vector Art."
 
 
238
  }
239
 
240
  lora_names = Loras_dict.keys()
 
303
 
304
  generator = torch.Generator(device=device).manual_seed(seed)
305
 
306
+ # if lora_name != "":
307
+ # lora_path = os.path.join(lora_base_path, lora_name, "pytorch_lora_weights.safetensors")
308
+ # pipe.load_lora_weights(lora_path)
309
+ # pipe.fuse_lora(lora_scale)
310
+ # pipe.enable_lora()
311
+
312
+ # lora_prefix = Loras_dict[lora_name]
313
+
314
+ # prompt = f"{lora_prefix} {prompt}"
315
+
316
+ # print("Using LoRA: ", lora_name)
317
+
318
 
319
+ if lora_name != current_lora_name: # Check if LoRA needs to be changed
320
+ if current_lora_name is not None: # If a LoRA is already loaded, unload it
321
+ pipe.disable_lora()
322
+ pipe.unfuse_lora()
323
+ pipe.unload_lora_weights()
324
+ print(f"Unloaded LoRA: {current_lora_name}")
325
 
326
+ if lora_name != "": # Load the new LoRA if specified
327
+ lora_path = os.path.join(lora_base_path, lora_name, "pytorch_lora_weights.safetensors")
328
+ pipe.load_lora_weights(lora_path)
329
+ pipe.fuse_lora(lora_scale)
330
+ pipe.enable_lora()
331
 
332
+ lora_prefix = Loras_dict[lora_name]
333
+
334
+ print(f"Loaded new LoRA: {lora_name}")
335
+
336
+ # Update the current LoRA name
337
+ current_lora_name = lora_name
338
+
339
+ if lora_name != "":
340
+ full_prompt = f"{lora_prefix} {prompt}"
341
+ else:
342
+ full_prompt = prompt
343
 
344
  print("Start inference...")
345
  images = pipe(
346
+ prompt = full_prompt,
347
  negative_prompt = default_negative_prompt,
348
  image_embeds = face_emb,
349
  image = [face_kps, canny_img] if canny_scale>0.0 else face_kps,
 
358
  num_images_per_prompt=num_images,
359
  ).images #[0]
360
 
361
+ # if lora_name != "":
362
+ # pipe.disable_lora()
363
+ # pipe.unfuse_lora()
364
+ # pipe.unload_lora_weights()
365
 
366
  gc.collect()
367
  torch.cuda.empty_cache()