yonishafir commited on
Commit
2c91a3f
·
verified ·
1 Parent(s): acfcea0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -21
app.py CHANGED
@@ -163,11 +163,17 @@ app.prepare(ctx_id=0, det_size=(640, 640))
163
 
164
  # download checkpoints
165
  print("Downloading checkpoints")
166
- hf_hub_download(repo_id="briaai/ID_preservation_2.3_auraFaceEnc", filename="checkpoint_105000/controlnet/config.json", local_dir="./checkpoints")
167
- hf_hub_download(repo_id="briaai/ID_preservation_2.3_auraFaceEnc", filename="checkpoint_105000/controlnet/diffusion_pytorch_model.safetensors", local_dir="./checkpoints")
168
- hf_hub_download(repo_id="briaai/ID_preservation_2.3_auraFaceEnc", filename="checkpoint_105000/ip-adapter.bin", local_dir="./checkpoints")
169
- hf_hub_download(repo_id="briaai/ID_preservation_2.3_auraFaceEnc", filename="image_encoder/pytorch_model.bin", local_dir="./checkpoints")
170
- hf_hub_download(repo_id="briaai/ID_preservation_2.3_auraFaceEnc", filename="image_encoder/config.json", local_dir="./checkpoints")
 
 
 
 
 
 
171
 
172
 
173
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -176,6 +182,8 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
176
  face_adapter = f"./checkpoints/checkpoint_105000/ip-adapter.bin"
177
  controlnet_path = f"./checkpoints/checkpoint_105000/controlnet"
178
  base_model_path = f'briaai/BRIA-2.3'
 
 
179
  resolution = 1024
180
 
181
  # Load ControlNet models
@@ -206,13 +214,19 @@ pipe.load_ip_adapter_instantid(face_adapter)
206
 
207
  clip_embeds=None
208
 
 
 
 
 
 
 
209
 
210
-
211
 
212
  @spaces.GPU
213
- # def generate_image(image_path, prompt, num_steps, guidance_scale, seed, num_images, ip_adapter_scale, kps_scale, canny_scale, lora_name, lora_scale, progress=gr.Progress(track_tqdm=True)):
214
- def generate_image(image_path, prompt, num_steps, guidance_scale, seed, num_images, ip_adapter_scale, kps_scale, canny_scale, progress=gr.Progress(track_tqdm=True)):
215
- # global CURRENT_LORA_NAME # Use the global variable to track LoRA
216
 
217
  if image_path is None:
218
  raise gr.Error(f"Cannot find any input face image! Please upload a face image.")
@@ -239,9 +253,6 @@ def generate_image(image_path, prompt, num_steps, guidance_scale, seed, num_imag
239
  files = [
240
  ('file', ('image_name.jpeg', image_file, 'image/jpeg')) # Specify file name, file-like object, and MIME type
241
  ]
242
- # headers = {
243
- # 'api_token': 'a10d6386dd6a11ebba800242ac130004'
244
- # }
245
  headers = {
246
  'api_token': os.getenv('BRIA_RMBG_TOKEN') # Securely retrieve the token
247
  }
@@ -269,7 +280,32 @@ def generate_image(image_path, prompt, num_steps, guidance_scale, seed, num_imag
269
 
270
  generator = torch.Generator(device=device).manual_seed(seed)
271
 
272
- full_prompt = prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
  print("Start inference...")
275
  images = pipe(
@@ -341,19 +377,21 @@ with gr.Blocks(css=css) as demo:
341
  placeholder="Enter your prompt here",
342
  info="Describe what you want to generate or modify in the image."
343
  )
344
-
 
 
345
  submit = gr.Button("Submit", variant="primary")
346
 
347
  with gr.Accordion(open=False, label="Advanced Options"):
348
  num_steps = gr.Slider(
349
- label="Number of sample steps",
350
  minimum=1,
351
  maximum=100,
352
  step=1,
353
  value=30,
354
  )
355
  guidance_scale = gr.Slider(
356
- label="Guidance scale",
357
  minimum=0.1,
358
  maximum=10.0,
359
  step=0.1,
@@ -367,27 +405,33 @@ with gr.Blocks(css=css) as demo:
367
  value=1,
368
  )
369
  ip_adapter_scale = gr.Slider(
370
- label="ip adapter scale",
371
  minimum=0.0,
372
  maximum=1.0,
373
  step=0.01,
374
  value=0.8,
375
  )
376
  kps_scale = gr.Slider(
377
- label="kps control scale",
378
  minimum=0.0,
379
  maximum=1.0,
380
  step=0.01,
381
  value=0.6,
382
  )
383
  canny_scale = gr.Slider(
384
- label="canny control scale",
385
  minimum=0.0,
386
  maximum=1.0,
387
  step=0.01,
388
  value=0.4,
389
  )
390
-
 
 
 
 
 
 
391
  seed = gr.Slider(
392
  label="Seed",
393
  minimum=0,
@@ -409,7 +453,8 @@ with gr.Blocks(css=css) as demo:
409
  api_name=False,
410
  ).then(
411
  fn=generate_image,
412
- inputs=[img_file, prompt, num_steps, guidance_scale, seed, num_images, ip_adapter_scale, kps_scale, canny_scale],
 
413
  # outputs=[gallery]
414
  outputs=gallery
415
  )
 
163
 
164
  # download checkpoints
165
  print("Downloading checkpoints")
166
+ hf_hub_download(repo_id="briaai/BRIA-2.3-ID_Preservation", filename="checkpoint_105000/controlnet/config.json", local_dir="./checkpoints")
167
+ hf_hub_download(repo_id="briaai/BRIA-2.3-ID_Preservation", filename="checkpoint_105000/controlnet/diffusion_pytorch_model.safetensors", local_dir="./checkpoints")
168
+ hf_hub_download(repo_id="briaai/BRIA-2.3-ID_Preservation", filename="checkpoint_105000/ip-adapter.bin", local_dir="./checkpoints")
169
+
170
+ hf_hub_download(repo_id="briaai/BRIA-2.3-ID_Preservation", filename="image_encoder/pytorch_model.bin", local_dir="./checkpoints")
171
+ hf_hub_download(repo_id="briaai/BRIA-2.3-ID_Preservation", filename="image_encoder/config.json", local_dir="./checkpoints")
172
+
173
+ # Download Lora weights
174
+ hf_hub_download(repo_id="briaai/BRIA-2.3-ID_Preservation", filename="LoRAs/3D_avatar/pytorch_lora_weights.safetensors", local_dir=".")
175
+ hf_hub_download(repo_id="briaai/BRIA-2.3-ID_Preservation", filename="LoRAs/coloringbook/pytorch_lora_weights.safetensors", local_dir=".")
176
+ hf_hub_download(repo_id="briaai/BRIA-2.3-ID_Preservation", filename="LoRAs/One_line_portraits_Light/pytorch_lora_weights.safetensors", local_dir=".")
177
 
178
 
179
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
182
  face_adapter = f"./checkpoints/checkpoint_105000/ip-adapter.bin"
183
  controlnet_path = f"./checkpoints/checkpoint_105000/controlnet"
184
  base_model_path = f'briaai/BRIA-2.3'
185
+ lora_base_path = f"./LoRAs"
186
+
187
  resolution = 1024
188
 
189
  # Load ControlNet models
 
214
 
215
  clip_embeds=None
216
 
217
+ Loras_dict = {
218
+ "":"",
219
+ "One_line_portraits_Light": "An illustration of ",
220
+ "3D_avatar": "An illustration of ",
221
+ "coloringbook": "An illustration of "
222
+ }
223
 
224
+ lora_names = Loras_dict.keys()
225
 
226
  @spaces.GPU
227
+ def generate_image(image_path, prompt, num_steps, guidance_scale, seed, num_images, ip_adapter_scale, kps_scale, canny_scale, lora_name, lora_scale, progress=gr.Progress(track_tqdm=True)):
228
+ # def generate_image(image_path, prompt, num_steps, guidance_scale, seed, num_images, ip_adapter_scale, kps_scale, canny_scale, progress=gr.Progress(track_tqdm=True)):
229
+ global CURRENT_LORA_NAME # Use the global variable to track LoRA
230
 
231
  if image_path is None:
232
  raise gr.Error(f"Cannot find any input face image! Please upload a face image.")
 
253
  files = [
254
  ('file', ('image_name.jpeg', image_file, 'image/jpeg')) # Specify file name, file-like object, and MIME type
255
  ]
 
 
 
256
  headers = {
257
  'api_token': os.getenv('BRIA_RMBG_TOKEN') # Securely retrieve the token
258
  }
 
280
 
281
  generator = torch.Generator(device=device).manual_seed(seed)
282
 
283
+ # full_prompt = prompt
284
+ if lora_name != CURRENT_LORA_NAME: # Check if LoRA needs to be changed
285
+ if CURRENT_LORA_NAME is not None: # If a LoRA is already loaded, unload it
286
+ pipe.disable_lora()
287
+ pipe.unfuse_lora()
288
+ pipe.unload_lora_weights()
289
+ print(f"Unloaded LoRA: {CURRENT_LORA_NAME}")
290
+
291
+ if lora_name != "": # Load the new LoRA if specified
292
+ # pipe.enable_model_cpu_offload()
293
+ lora_path = os.path.join(lora_base_path, lora_name, "pytorch_lora_weights.safetensors")
294
+ pipe.load_lora_weights(lora_path)
295
+ pipe.fuse_lora(lora_scale)
296
+ pipe.enable_lora()
297
+
298
+ # lora_prefix = Loras_dict[lora_name]
299
+
300
+ print(f"Loaded new LoRA: {lora_name}")
301
+
302
+ # Update the current LoRA name
303
+ CURRENT_LORA_NAME = lora_name
304
+
305
+ if lora_name != "":
306
+ full_prompt = f"{Loras_dict[lora_name]} + " " + {prompt}"
307
+ else:
308
+ full_prompt = prompt
309
 
310
  print("Start inference...")
311
  images = pipe(
 
377
  placeholder="Enter your prompt here",
378
  info="Describe what you want to generate or modify in the image."
379
  )
380
+
381
+ lora_name = gr.Dropdown(choices=lora_names, label="LoRA", value="", info="Select a LoRA name from the list, not selecting any will disable LoRA.")
382
+
383
  submit = gr.Button("Submit", variant="primary")
384
 
385
  with gr.Accordion(open=False, label="Advanced Options"):
386
  num_steps = gr.Slider(
387
+ label="Number of diffusion steps",
388
  minimum=1,
389
  maximum=100,
390
  step=1,
391
  value=30,
392
  )
393
  guidance_scale = gr.Slider(
394
+ label="cfg scale",
395
  minimum=0.1,
396
  maximum=10.0,
397
  step=0.1,
 
405
  value=1,
406
  )
407
  ip_adapter_scale = gr.Slider(
408
+ label="ID Adapter scale",
409
  minimum=0.0,
410
  maximum=1.0,
411
  step=0.01,
412
  value=0.8,
413
  )
414
  kps_scale = gr.Slider(
415
+ label="lnmks ControlNet scale",
416
  minimum=0.0,
417
  maximum=1.0,
418
  step=0.01,
419
  value=0.6,
420
  )
421
  canny_scale = gr.Slider(
422
+ label="canny ControlNet scale",
423
  minimum=0.0,
424
  maximum=1.0,
425
  step=0.01,
426
  value=0.4,
427
  )
428
+ lora_scale = gr.Slider(
429
+ label="LoRA scale",
430
+ minimum=0.0,
431
+ maximum=1.0,
432
+ step=0.01,
433
+ value=0.7,
434
+ )
435
  seed = gr.Slider(
436
  label="Seed",
437
  minimum=0,
 
453
  api_name=False,
454
  ).then(
455
  fn=generate_image,
456
+ # inputs=[img_file, prompt, num_steps, guidance_scale, seed, num_images, ip_adapter_scale, kps_scale, canny_scale],
457
+ inputs=[img_file, prompt, num_steps, guidance_scale, seed, num_images, ip_adapter_scale, kps_scale, canny_scale, lora_name, lora_scale],
458
  # outputs=[gallery]
459
  outputs=gallery
460
  )