John6666 commited on
Commit
57302a8
·
verified ·
1 Parent(s): bfbd1cc

Upload 5 files

Browse files
Files changed (5) hide show
  1. README.md +1 -1
  2. app.py +120 -79
  3. live_preview_helpers.py +2 -1
  4. modutils.py +66 -22
  5. requirements.txt +1 -1
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🏆😻
4
  colorFrom: red
5
  colorTo: pink
6
  sdk: gradio
7
- sdk_version: 5.7.0
8
  app_file: app.py
9
  pinned: true
10
  license: mit
 
4
  colorFrom: red
5
  colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 5.6.0
8
  app_file: app.py
9
  pinned: true
10
  license: mit
app.py CHANGED
@@ -13,6 +13,7 @@ import random
13
  import time
14
  import requests
15
  import pandas as pd
 
16
  from pathlib import Path
17
 
18
  from env import models, num_loras, num_cns, HF_TOKEN, single_file_base_models
@@ -41,10 +42,16 @@ controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union'
41
  #controlnet_model_union_repo = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro'
42
  dtype = torch.bfloat16
43
  #dtype = torch.float8_e4m3fn
44
- #device = "cuda" if torch.cuda.is_available() else "cpu"
 
45
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype, token=HF_TOKEN)
46
  good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype, token=HF_TOKEN)
47
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1, token=HF_TOKEN)
 
 
 
 
 
48
  pipe_i2i = AutoPipelineForImage2Image.from_pretrained(base_model, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
49
  tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
50
  pipe_ip = AutoPipelineForInpainting.from_pretrained(base_model, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
@@ -66,6 +73,7 @@ def unload_lora():
66
  pipe.unload_lora_weights()
67
  #pipe_i2i.unfuse_lora()
68
  pipe_i2i.unload_lora_weights()
 
69
  pipe_ip.unload_lora_weights()
70
  except Exception as e:
71
  print(e)
@@ -104,12 +112,19 @@ def change_base_model(repo_id: str, cn_on: bool, disable_model_cache: bool, mode
104
  if ".safetensors" in repo_id:
105
  safetensors_file = download_file_mod(repo_id)
106
  transformer = FluxTransformer2DModel.from_single_file(safetensors_file, subfolder="transformer", torch_dtype=dtype, config=single_file_base_model)
107
- pipe = FluxControlNetPipeline.from_pretrained(single_file_base_model, transformer=transformer, controlnet=controlnet, torch_dtype=dtype, token=HF_TOKEN)
 
 
 
108
  pipe_i2i = FluxControlNetImg2ImgPipeline.from_pretrained(single_file_base_model, controlnet=controlnet, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
109
  tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
110
  pipe_ip = FluxControlNetInpaintPipeline.from_pretrained(single_file_base_model, controlnet=controlnet, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
111
  tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
112
  else:
 
 
 
 
113
  pipe = FluxControlNetPipeline.from_pretrained(repo_id, controlnet=controlnet, torch_dtype=dtype, token=HF_TOKEN)
114
  pipe_i2i = FluxControlNetImg2ImgPipeline.from_pretrained(repo_id, controlnet=controlnet, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
115
  tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
@@ -124,14 +139,21 @@ def change_base_model(repo_id: str, cn_on: bool, disable_model_cache: bool, mode
124
  print(f"Loading model: {repo_id}")
125
  if ".safetensors" in repo_id:
126
  safetensors_file = download_file_mod(repo_id)
127
- transformer = FluxTransformer2DModel.from_single_file(safetensors_file, subfolder="transformer", torch_dtype=dtype, config=single_file_base_model)
128
- pipe = DiffusionPipeline.from_pretrained(single_file_base_model, transformer=transformer, torch_dtype=dtype, token=HF_TOKEN)
 
 
 
129
  pipe_i2i = AutoPipelineForImage2Image.from_pretrained(single_file_base_model, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
130
  tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
131
  pipe_ip = AutoPipelineForInpainting.from_pretrained(single_file_base_model, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
132
  tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
133
  else:
134
- pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=dtype, token=HF_TOKEN)
 
 
 
 
135
  pipe_i2i = AutoPipelineForImage2Image.from_pretrained(repo_id, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
136
  tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
137
  pipe_ip = AutoPipelineForInpainting.from_pretrained(repo_id, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
@@ -158,6 +180,11 @@ def is_repo_public(repo_id: str):
158
  print(f"Error: Failed to connect {repo_id}. {e}")
159
  return False
160
 
 
 
 
 
 
161
  class calculateDuration:
162
  def __init__(self, activity_name=""):
163
  self.activity_name = activity_name
@@ -414,18 +441,19 @@ def remove_custom_lora(selected_indices, current_loras, gallery):
414
 
415
  @spaces.GPU(duration=70)
416
  @torch.inference_mode()
417
- def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, cn_on, progress=gr.Progress(track_tqdm=True)):
418
  global pipe, taef1, good_vae, controlnet, controlnet_union
419
  try:
420
- good_vae.to("cuda")
421
- taef1.to("cuda")
422
- generator = torch.Generator(device="cuda").manual_seed(int(float(seed)))
 
423
 
424
  with calculateDuration("Generating image"):
425
  # Generate image
426
  modes, images, scales = get_control_params()
427
  if not cn_on or len(modes) == 0:
428
- pipe.to("cuda")
429
  pipe.vae = taef1
430
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
431
  progress(0, desc="Start Inference.")
@@ -439,13 +467,14 @@ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, cn_on, pr
439
  joint_attention_kwargs={"scale": 1.0},
440
  output_type="pil",
441
  good_vae=good_vae,
 
442
  ):
443
  yield img
444
  else:
445
- pipe.to("cuda")
446
  pipe.vae = good_vae
447
- if controlnet_union is not None: controlnet_union.to("cuda")
448
- if controlnet is not None: controlnet.to("cuda")
449
  pipe.enable_model_cpu_offload()
450
  progress(0, desc="Start Inference with ControlNet.")
451
  for img in pipe(
@@ -459,6 +488,7 @@ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, cn_on, pr
459
  controlnet_conditioning_scale=scales,
460
  generator=generator,
461
  joint_attention_kwargs={"scale": 1.0},
 
462
  ).images:
463
  yield img
464
  except Exception as e:
@@ -467,20 +497,22 @@ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, cn_on, pr
467
 
468
  @spaces.GPU(duration=70)
469
  @torch.inference_mode()
470
- def generate_image_to_image(prompt_mash, image_input_path_dict, image_strength, is_inpaint, blur_mask, blur_factor, steps, cfg_scale, width, height, seed, cn_on, progress=gr.Progress(track_tqdm=True)):
 
471
  global pipe_i2i, pipe_ip, good_vae, controlnet, controlnet_union
472
  try:
473
- good_vae.to("cuda")
474
- generator = torch.Generator(device="cuda").manual_seed(int(float(seed)))
475
  image_input_path = image_input_path_dict['background']
476
  mask_path = image_input_path_dict['layers'][0]
 
477
 
478
  with calculateDuration("Generating image"):
479
  # Generate image
480
  modes, images, scales = get_control_params()
481
  if not cn_on or len(modes) == 0:
482
  if is_inpaint: # Inpainting
483
- pipe_ip.to("cuda")
484
  pipe_ip.vae = good_vae
485
  image_input = load_image(image_input_path)
486
  mask_input = load_image(mask_path)
@@ -498,10 +530,11 @@ def generate_image_to_image(prompt_mash, image_input_path_dict, image_strength,
498
  generator=generator,
499
  joint_attention_kwargs={"scale": 1.0},
500
  output_type="pil",
 
501
  ).images[0]
502
  return final_image
503
  else:
504
- pipe_i2i.to("cuda")
505
  pipe_i2i.vae = good_vae
506
  image_input = load_image(image_input_path)
507
  progress(0, desc="Start I2I Inference.")
@@ -516,17 +549,18 @@ def generate_image_to_image(prompt_mash, image_input_path_dict, image_strength,
516
  generator=generator,
517
  joint_attention_kwargs={"scale": 1.0},
518
  output_type="pil",
 
519
  ).images[0]
520
  return final_image
521
  else:
522
  if is_inpaint: # Inpainting
523
- pipe_ip.to("cuda")
524
  pipe_ip.vae = good_vae
525
  image_input = load_image(image_input_path)
526
  mask_input = load_image(mask_path)
527
  if blur_mask: mask_input = pipe_ip.mask_processor.blur(mask_input, blur_factor=blur_factor)
528
- if controlnet_union is not None: controlnet_union.to("cuda")
529
- if controlnet is not None: controlnet.to("cuda")
530
  pipe_ip.enable_model_cpu_offload()
531
  progress(0, desc="Start Inpainting Inference with ControlNet.")
532
  final_image = pipe_ip(
@@ -544,14 +578,15 @@ def generate_image_to_image(prompt_mash, image_input_path_dict, image_strength,
544
  generator=generator,
545
  joint_attention_kwargs={"scale": 1.0},
546
  output_type="pil",
 
547
  ).images[0]
548
  return final_image
549
  else:
550
- pipe_i2i.to("cuda")
551
  pipe_i2i.vae = good_vae
552
  image_input = load_image(image_input_path['background'])
553
- if controlnet_union is not None: controlnet_union.to("cuda")
554
- if controlnet is not None: controlnet.to("cuda")
555
  pipe_i2i.enable_model_cpu_offload()
556
  progress(0, desc="Start I2I Inference with ControlNet.")
557
  final_image = pipe_i2i(
@@ -568,6 +603,7 @@ def generate_image_to_image(prompt_mash, image_input_path_dict, image_strength,
568
  generator=generator,
569
  joint_attention_kwargs={"scale": 1.0},
570
  output_type="pil",
 
571
  ).images[0]
572
  return final_image
573
  except Exception as e:
@@ -575,7 +611,7 @@ def generate_image_to_image(prompt_mash, image_input_path_dict, image_strength,
575
  raise gr.Error(f"I2I Inference Error: {e}") from e
576
 
577
  def run_lora(prompt, image_input, image_strength, task_type, blur_mask, blur_factor, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2,
578
- randomize_seed, seed, width, height, loras_state, lora_json, cn_on, translate_on, progress=gr.Progress(track_tqdm=True)):
579
  global pipe, pipe_i2i, pipe_ip
580
  if not selected_indices and not is_valid_lora(lora_json):
581
  gr.Info("LoRA isn't selected.")
@@ -690,10 +726,11 @@ def run_lora(prompt, image_input, image_strength, task_type, blur_mask, blur_fac
690
  # Generate image
691
  progress(0, desc="Running Inference.")
692
  if is_i2i:
693
- final_image = generate_image_to_image(prompt_mash, image_input, image_strength, is_inpaint, blur_mask, blur_factor, steps, cfg_scale, width, height, seed, cn_on)
 
694
  yield save_image(final_image, None, last_model, prompt_mash, height, width, steps, cfg_scale, seed), seed, gr.update(visible=False)
695
  else:
696
- image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, cn_on)
697
  # Consume the generator to get the final image
698
  final_image = None
699
  step_counter = 0
@@ -856,43 +893,47 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', fill_width=True, css=css, delete_ca
856
  progress_bar = gr.Markdown(elem_id="progress",visible=False)
857
  result = gr.Image(label="Generated Image", format="png", type="filepath", show_share_button=False, interactive=False)
858
  with gr.Accordion("History", open=False):
 
859
  history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False, format="png",
860
  show_share_button=False, show_download_button=True)
861
- history_files = gr.Files(interactive=False, visible=False)
862
  history_clear_button = gr.Button(value="Clear History", variant="secondary")
863
  history_clear_button.click(lambda: ([], []), None, [history_gallery, history_files], queue=False, show_api=False)
864
- with gr.Group():
865
- with gr.Row():
866
- model_name = gr.Dropdown(label="Base Model", info="You can enter a huggingface model repo_id or path of single safetensors file to want to use.",
867
- choices=models, value=models[0], allow_custom_value=True, min_width=320, scale=5)
868
- model_type = gr.Radio(label="Model type", info="Model type of single safetensors file",
869
- choices=list(single_file_base_models.keys()), value=list(single_file_base_models.keys())[0], scale=1)
870
- model_info = gr.Markdown(elem_classes="info")
871
-
872
  with gr.Row():
873
- with gr.Accordion("Advanced Settings", open=False):
874
- with gr.Row():
875
- with gr.Column():
876
- #input_image = gr.Image(label="Input image", type="filepath", height=256, sources=["upload", "clipboard"], show_share_button=False)
877
- input_image = gr.ImageEditor(label='Input image', type='filepath', sources=["upload", "clipboard"], image_mode='RGB', show_share_button=False, show_fullscreen_button=False,
878
- layers=False, brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed", default_size=32), value=None,
879
- canvas_size=(384, 384), width=384, height=512)
880
  with gr.Column():
881
- task_type = gr.Radio(label="Task", choices=["Text-to-Image", "Image-to-Image", "Inpainting"], value="Text-to-Image")
882
- image_strength = gr.Slider(label="Strength", info="Lower means more image influence in I2I, opposite in Inpaint", minimum=0.01, maximum=1.0, step=0.01, value=0.75)
883
- blur_mask = gr.Checkbox(label="Blur mask", value=False)
884
- blur_factor = gr.Slider(label="Blur factor", minimum=0, maximum=50, step=1, value=33)
885
- input_image_preprocess = gr.Checkbox(True, label="Preprocess Input image")
886
- with gr.Column():
887
- with gr.Row():
888
- width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
889
- height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
890
- cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
891
- steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
 
 
 
 
 
 
 
892
  with gr.Row():
893
- randomize_seed = gr.Checkbox(True, label="Randomize seed")
894
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
895
- disable_model_cache = gr.Checkbox(False, label="Disable model caching")
 
 
 
 
 
 
 
 
 
896
  with gr.Accordion("External LoRA", open=True):
897
  with gr.Column():
898
  deselect_lora_button = gr.Button("Remove External LoRAs", variant="secondary")
@@ -939,27 +980,27 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', fill_width=True, css=css, delete_ca
939
  lora_download = [None] * num_loras
940
  for i in range(num_loras):
941
  lora_download[i] = gr.Button(f"Get and set LoRA to {int(i+1)}")
942
- with gr.Accordion("ControlNet (extremely slow)", open=True, visible=False):
943
- with gr.Column():
944
- cn_on = gr.Checkbox(False, label="Use ControlNet")
945
- cn_mode = [None] * num_cns
946
- cn_scale = [None] * num_cns
947
- cn_image = [None] * num_cns
948
- cn_image_ref = [None] * num_cns
949
- cn_res = [None] * num_cns
950
- cn_num = [None] * num_cns
951
- with gr.Row():
952
- for i in range(num_cns):
953
- with gr.Column():
954
- cn_mode[i] = gr.Radio(label=f"ControlNet {int(i+1)} Mode", choices=get_control_union_mode(), value=get_control_union_mode()[0])
955
- with gr.Row():
956
- cn_scale[i] = gr.Slider(label=f"ControlNet {int(i+1)} Weight", minimum=0.0, maximum=1.0, step=0.01, value=0.75)
957
- cn_res[i] = gr.Slider(label=f"ControlNet {int(i+1)} Preprocess resolution", minimum=128, maximum=512, value=384, step=1)
958
- cn_num[i] = gr.Number(i, visible=False)
959
- with gr.Row():
960
- cn_image_ref[i] = gr.Image(label="Image Reference", type="pil", format="png", height=256, sources=["upload", "clipboard"], show_share_button=False)
961
- cn_image[i] = gr.Image(label="Control Image", type="pil", format="png", height=256, show_share_button=False, interactive=False)
962
-
963
  gallery.select(
964
  update_selection,
965
  inputs=[selected_indices, loras_state, width, height],
@@ -1000,7 +1041,7 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', fill_width=True, css=css, delete_ca
1000
  ).success(
1001
  fn=run_lora,
1002
  inputs=[prompt, input_image, image_strength, task_type, blur_mask, blur_factor, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2,
1003
- randomize_seed, seed, width, height, loras_state, lora_repo_json, cn_on, auto_trans],
1004
  outputs=[result, seed, progress_bar],
1005
  queue=True,
1006
  show_api=True,
 
13
  import time
14
  import requests
15
  import pandas as pd
16
+ import numpy as np
17
  from pathlib import Path
18
 
19
  from env import models, num_loras, num_cns, HF_TOKEN, single_file_base_models
 
42
  #controlnet_model_union_repo = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro'
43
  dtype = torch.bfloat16
44
  #dtype = torch.float8_e4m3fn
45
+ CACHE_MODEL = False
46
+ device = "cuda" if torch.cuda.is_available() else "cpu"
47
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype, token=HF_TOKEN)
48
  good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype, token=HF_TOKEN)
49
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1, token=HF_TOKEN)
50
+ if CACHE_MODEL:
51
+ taef1.to(device)
52
+ good_vae.to(device)
53
+ pipe.to(device)
54
+ pipe.transformer.to("cpu")
55
  pipe_i2i = AutoPipelineForImage2Image.from_pretrained(base_model, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
56
  tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
57
  pipe_ip = AutoPipelineForInpainting.from_pretrained(base_model, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
 
73
  pipe.unload_lora_weights()
74
  #pipe_i2i.unfuse_lora()
75
  pipe_i2i.unload_lora_weights()
76
+ #pipe_ip.unfuse_lora()
77
  pipe_ip.unload_lora_weights()
78
  except Exception as e:
79
  print(e)
 
112
  if ".safetensors" in repo_id:
113
  safetensors_file = download_file_mod(repo_id)
114
  transformer = FluxTransformer2DModel.from_single_file(safetensors_file, subfolder="transformer", torch_dtype=dtype, config=single_file_base_model)
115
+ if CACHE_MODEL:
116
+ pipe = FluxControlNetPipeline.from_pretrained(single_file_base_model, controlnet=controlnet, vae=good_vae, transformer=transformer, text_encoder=pipe.text_encoder,
117
+ tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
118
+ else: pipe = FluxControlNetPipeline.from_pretrained(single_file_base_model, transformer=transformer, controlnet=controlnet, torch_dtype=dtype, token=HF_TOKEN)
119
  pipe_i2i = FluxControlNetImg2ImgPipeline.from_pretrained(single_file_base_model, controlnet=controlnet, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
120
  tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
121
  pipe_ip = FluxControlNetInpaintPipeline.from_pretrained(single_file_base_model, controlnet=controlnet, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
122
  tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
123
  else:
124
+ if CACHE_MODEL:
125
+ transformer = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", torch_dtype=dtype, token=HF_TOKEN)
126
+ pipe = FluxControlNetPipeline.from_pretrained(repo_id, controlnet=controlnet, vae=good_vae, transformer=transformer, text_encoder=pipe.text_encoder,
127
+ tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
128
  pipe = FluxControlNetPipeline.from_pretrained(repo_id, controlnet=controlnet, torch_dtype=dtype, token=HF_TOKEN)
129
  pipe_i2i = FluxControlNetImg2ImgPipeline.from_pretrained(repo_id, controlnet=controlnet, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
130
  tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
 
139
  print(f"Loading model: {repo_id}")
140
  if ".safetensors" in repo_id:
141
  safetensors_file = download_file_mod(repo_id)
142
+ transformer = FluxTransformer2DModel.from_single_file(safetensors_file, subfolder="transformer", torch_dtype=dtype, config=single_file_base_model, token=HF_TOKEN)
143
+ if CACHE_MODEL:
144
+ pipe = DiffusionPipeline.from_pretrained(single_file_base_model, vae=taef1, transformer=transformer, text_encoder=pipe.text_encoder,
145
+ tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
146
+ else: pipe = DiffusionPipeline.from_pretrained(single_file_base_model, transformer=transformer, torch_dtype=dtype, token=HF_TOKEN)
147
  pipe_i2i = AutoPipelineForImage2Image.from_pretrained(single_file_base_model, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
148
  tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
149
  pipe_ip = AutoPipelineForInpainting.from_pretrained(single_file_base_model, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
150
  tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
151
  else:
152
+ if CACHE_MODEL:
153
+ transformer = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", torch_dtype=dtype, token=HF_TOKEN)
154
+ pipe = DiffusionPipeline.from_pretrained(repo_id, vae=taef1, transformer=transformer, text_encoder=pipe.text_encoder,
155
+ tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
156
+ else: pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=dtype, token=HF_TOKEN)
157
  pipe_i2i = AutoPipelineForImage2Image.from_pretrained(repo_id, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
158
  tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
159
  pipe_ip = AutoPipelineForInpainting.from_pretrained(repo_id, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
 
180
  print(f"Error: Failed to connect {repo_id}. {e}")
181
  return False
182
 
183
+ def calc_sigmas(num_inference_steps: int, sigmas_factor: float):
184
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
185
+ sigmas = sigmas * sigmas_factor
186
+ return sigmas
187
+
188
  class calculateDuration:
189
  def __init__(self, activity_name=""):
190
  self.activity_name = activity_name
 
441
 
442
  @spaces.GPU(duration=70)
443
  @torch.inference_mode()
444
+ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, sigmas_factor, cn_on, progress=gr.Progress(track_tqdm=True)):
445
  global pipe, taef1, good_vae, controlnet, controlnet_union
446
  try:
447
+ good_vae.to(device)
448
+ taef1.to(device)
449
+ generator = torch.Generator(device=device).manual_seed(int(float(seed)))
450
+ sigmas = calc_sigmas(steps, sigmas_factor)
451
 
452
  with calculateDuration("Generating image"):
453
  # Generate image
454
  modes, images, scales = get_control_params()
455
  if not cn_on or len(modes) == 0:
456
+ pipe.to(device)
457
  pipe.vae = taef1
458
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
459
  progress(0, desc="Start Inference.")
 
467
  joint_attention_kwargs={"scale": 1.0},
468
  output_type="pil",
469
  good_vae=good_vae,
470
+ sigmas=sigmas,
471
  ):
472
  yield img
473
  else:
474
+ pipe.to(device)
475
  pipe.vae = good_vae
476
+ if controlnet_union is not None: controlnet_union.to(device)
477
+ if controlnet is not None: controlnet.to(device)
478
  pipe.enable_model_cpu_offload()
479
  progress(0, desc="Start Inference with ControlNet.")
480
  for img in pipe(
 
488
  controlnet_conditioning_scale=scales,
489
  generator=generator,
490
  joint_attention_kwargs={"scale": 1.0},
491
+ sigmas=sigmas,
492
  ).images:
493
  yield img
494
  except Exception as e:
 
497
 
498
  @spaces.GPU(duration=70)
499
  @torch.inference_mode()
500
+ def generate_image_to_image(prompt_mash, image_input_path_dict, image_strength, is_inpaint, blur_mask, blur_factor, steps, cfg_scale, width, height,
501
+ sigmas_factor, seed, cn_on, progress=gr.Progress(track_tqdm=True)):
502
  global pipe_i2i, pipe_ip, good_vae, controlnet, controlnet_union
503
  try:
504
+ good_vae.to(device)
505
+ generator = torch.Generator(device=device).manual_seed(int(float(seed)))
506
  image_input_path = image_input_path_dict['background']
507
  mask_path = image_input_path_dict['layers'][0]
508
+ sigmas = calc_sigmas(steps, sigmas_factor)
509
 
510
  with calculateDuration("Generating image"):
511
  # Generate image
512
  modes, images, scales = get_control_params()
513
  if not cn_on or len(modes) == 0:
514
  if is_inpaint: # Inpainting
515
+ pipe_ip.to(device)
516
  pipe_ip.vae = good_vae
517
  image_input = load_image(image_input_path)
518
  mask_input = load_image(mask_path)
 
530
  generator=generator,
531
  joint_attention_kwargs={"scale": 1.0},
532
  output_type="pil",
533
+ #sigmas=sigmas,
534
  ).images[0]
535
  return final_image
536
  else:
537
+ pipe_i2i.to(device)
538
  pipe_i2i.vae = good_vae
539
  image_input = load_image(image_input_path)
540
  progress(0, desc="Start I2I Inference.")
 
549
  generator=generator,
550
  joint_attention_kwargs={"scale": 1.0},
551
  output_type="pil",
552
+ #sigmas=sigmas,
553
  ).images[0]
554
  return final_image
555
  else:
556
  if is_inpaint: # Inpainting
557
+ pipe_ip.to(device)
558
  pipe_ip.vae = good_vae
559
  image_input = load_image(image_input_path)
560
  mask_input = load_image(mask_path)
561
  if blur_mask: mask_input = pipe_ip.mask_processor.blur(mask_input, blur_factor=blur_factor)
562
+ if controlnet_union is not None: controlnet_union.to(device)
563
+ if controlnet is not None: controlnet.to(device)
564
  pipe_ip.enable_model_cpu_offload()
565
  progress(0, desc="Start Inpainting Inference with ControlNet.")
566
  final_image = pipe_ip(
 
578
  generator=generator,
579
  joint_attention_kwargs={"scale": 1.0},
580
  output_type="pil",
581
+ #sigmas=sigmas,
582
  ).images[0]
583
  return final_image
584
  else:
585
+ pipe_i2i.to(device)
586
  pipe_i2i.vae = good_vae
587
  image_input = load_image(image_input_path['background'])
588
+ if controlnet_union is not None: controlnet_union.to(device)
589
+ if controlnet is not None: controlnet.to(device)
590
  pipe_i2i.enable_model_cpu_offload()
591
  progress(0, desc="Start I2I Inference with ControlNet.")
592
  final_image = pipe_i2i(
 
603
  generator=generator,
604
  joint_attention_kwargs={"scale": 1.0},
605
  output_type="pil",
606
+ #sigmas=sigmas,
607
  ).images[0]
608
  return final_image
609
  except Exception as e:
 
611
  raise gr.Error(f"I2I Inference Error: {e}") from e
612
 
613
  def run_lora(prompt, image_input, image_strength, task_type, blur_mask, blur_factor, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2,
614
+ randomize_seed, seed, width, height, sigmas_factor, loras_state, lora_json, cn_on, translate_on, progress=gr.Progress(track_tqdm=True)):
615
  global pipe, pipe_i2i, pipe_ip
616
  if not selected_indices and not is_valid_lora(lora_json):
617
  gr.Info("LoRA isn't selected.")
 
726
  # Generate image
727
  progress(0, desc="Running Inference.")
728
  if is_i2i:
729
+ final_image = generate_image_to_image(prompt_mash, image_input, image_strength, is_inpaint, blur_mask, blur_factor,
730
+ steps, cfg_scale, width, height, sigmas_factor, seed, cn_on)
731
  yield save_image(final_image, None, last_model, prompt_mash, height, width, steps, cfg_scale, seed), seed, gr.update(visible=False)
732
  else:
733
+ image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, sigmas_factor, cn_on)
734
  # Consume the generator to get the final image
735
  final_image = None
736
  step_counter = 0
 
893
  progress_bar = gr.Markdown(elem_id="progress",visible=False)
894
  result = gr.Image(label="Generated Image", format="png", type="filepath", show_share_button=False, interactive=False)
895
  with gr.Accordion("History", open=False):
896
+ history_files = gr.Files(interactive=False, visible=False)
897
  history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False, format="png",
898
  show_share_button=False, show_download_button=True)
 
899
  history_clear_button = gr.Button(value="Clear History", variant="secondary")
900
  history_clear_button.click(lambda: ([], []), None, [history_gallery, history_files], queue=False, show_api=False)
901
+
 
 
 
 
 
 
 
902
  with gr.Row():
903
+ with gr.Accordion("Advanced Settings", open=True):
904
+ with gr.Tab("Generation Settings"):
 
 
 
 
 
905
  with gr.Column():
906
+ with gr.Group():
907
+ with gr.Row():
908
+ model_name = gr.Dropdown(label="Base Model", info="You can enter a huggingface model repo_id or path of single safetensors file to want to use.",
909
+ choices=models, value=models[0], allow_custom_value=True, min_width=320, scale=5)
910
+ model_type = gr.Radio(label="Model type", info="Model type of single safetensors file",
911
+ choices=list(single_file_base_models.keys()), value=list(single_file_base_models.keys())[0], scale=1)
912
+ model_info = gr.Markdown(elem_classes="info")
913
+ with gr.Row():
914
+ width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
915
+ height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
916
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
917
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
918
+ with gr.Row():
919
+ randomize_seed = gr.Checkbox(True, label="Randomize seed")
920
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
921
+ sigmas_factor = gr.Slider(label="Sigmas factor", minimum=0.01, maximum=1.00, step=0.01, value=0.95)
922
+ disable_model_cache = gr.Checkbox(False, label="Disable model caching")
923
+ with gr.Tab("Image-to-Image"):
924
  with gr.Row():
925
+ with gr.Column():
926
+ #input_image = gr.Image(label="Input image", type="filepath", height=256, sources=["upload", "clipboard"], show_share_button=False)
927
+ input_image = gr.ImageEditor(label='Input image', type='filepath', sources=["upload", "clipboard"], image_mode='RGB', show_share_button=False, show_fullscreen_button=False,
928
+ layers=False, brush=gr.Brush(colors=["white"], color_mode="fixed", default_size=32), eraser=gr.Eraser(default_size="32"), value=None,
929
+ canvas_size=(384, 384), width=384, height=512)
930
+ with gr.Column():
931
+ task_type = gr.Radio(label="Task", choices=["Text-to-Image", "Image-to-Image", "Inpainting"], value="Text-to-Image")
932
+ image_strength = gr.Slider(label="Strength", info="Lower means more image influence in I2I, opposite in Inpaint", minimum=0.01, maximum=1.0, step=0.01, value=0.75)
933
+ blur_mask = gr.Checkbox(label="Blur mask", value=False)
934
+ blur_factor = gr.Slider(label="Blur factor", minimum=0, maximum=50, step=1, value=33)
935
+ input_image_preprocess = gr.Checkbox(True, label="Preprocess Input image")
936
+ with gr.Tab("More LoRA"):
937
  with gr.Accordion("External LoRA", open=True):
938
  with gr.Column():
939
  deselect_lora_button = gr.Button("Remove External LoRAs", variant="secondary")
 
980
  lora_download = [None] * num_loras
981
  for i in range(num_loras):
982
  lora_download[i] = gr.Button(f"Get and set LoRA to {int(i+1)}")
983
+ with gr.Tab("ControlNet", visible=False):
984
+ with gr.Column():
985
+ cn_on = gr.Checkbox(False, label="Use ControlNet")
986
+ cn_mode = [None] * num_cns
987
+ cn_scale = [None] * num_cns
988
+ cn_image = [None] * num_cns
989
+ cn_image_ref = [None] * num_cns
990
+ cn_res = [None] * num_cns
991
+ cn_num = [None] * num_cns
992
+ with gr.Row():
993
+ for i in range(num_cns):
994
+ with gr.Column():
995
+ cn_mode[i] = gr.Radio(label=f"ControlNet {int(i+1)} Mode", choices=get_control_union_mode(), value=get_control_union_mode()[0])
996
+ with gr.Row():
997
+ cn_scale[i] = gr.Slider(label=f"ControlNet {int(i+1)} Weight", minimum=0.0, maximum=1.0, step=0.01, value=0.75)
998
+ cn_res[i] = gr.Slider(label=f"ControlNet {int(i+1)} Preprocess resolution", minimum=128, maximum=512, value=384, step=1)
999
+ cn_num[i] = gr.Number(i, visible=False)
1000
+ with gr.Row():
1001
+ cn_image_ref[i] = gr.Image(label="Image Reference", type="pil", format="png", height=256, sources=["upload", "clipboard"], show_share_button=False)
1002
+ cn_image[i] = gr.Image(label="Control Image", type="pil", format="png", height=256, show_share_button=False, interactive=False)
1003
+
1004
  gallery.select(
1005
  update_selection,
1006
  inputs=[selected_indices, loras_state, width, height],
 
1041
  ).success(
1042
  fn=run_lora,
1043
  inputs=[prompt, input_image, image_strength, task_type, blur_mask, blur_factor, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2,
1044
+ randomize_seed, seed, width, height, sigmas_factor, loras_state, lora_repo_json, cn_on, auto_trans],
1045
  outputs=[result, seed, progress_bar],
1046
  queue=True,
1047
  show_api=True,
live_preview_helpers.py CHANGED
@@ -60,6 +60,7 @@ def flux_pipe_call_that_returns_an_iterable_of_images(
60
  joint_attention_kwargs: Optional[Dict[str, Any]] = None,
61
  max_sequence_length: int = 512,
62
  good_vae: Optional[Any] = None,
 
63
  ):
64
  height = height or self.default_sample_size * self.vae_scale_factor
65
  width = width or self.default_sample_size * self.vae_scale_factor
@@ -108,7 +109,7 @@ def flux_pipe_call_that_returns_an_iterable_of_images(
108
  latents,
109
  )
110
  # 5. Prepare timesteps
111
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
112
  image_seq_len = latents.shape[1]
113
  mu = calculate_shift(
114
  image_seq_len,
 
60
  joint_attention_kwargs: Optional[Dict[str, Any]] = None,
61
  max_sequence_length: int = 512,
62
  good_vae: Optional[Any] = None,
63
+ sigmas: Optional[List[float]] = None, # MOD
64
  ):
65
  height = height or self.default_sample_size * self.vae_scale_factor
66
  width = width or self.default_sample_size * self.vae_scale_factor
 
109
  latents,
110
  )
111
  # 5. Prepare timesteps
112
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas # MOD
113
  image_seq_len = latents.shape[1]
114
  mu = calculate_shift(
115
  image_seq_len,
modutils.py CHANGED
@@ -172,7 +172,7 @@ class ModelInformation:
172
  self.download_url = json_data.get("downloadUrl", "")
173
  self.model_url = f"https://civitai.com/models/{self.model_id}?modelVersionId={self.model_version_id}"
174
  self.filename_url = next(
175
- (v.get("name", "") for v in json_data.get("files", []) if str(self.model_version_id) in v.get("downloadUrl", "")), ""
176
  )
177
  self.filename_url = self.filename_url if self.filename_url else ""
178
  self.description = json_data.get("description", "")
@@ -302,6 +302,10 @@ def safe_float(input):
302
  return output
303
 
304
 
 
 
 
 
305
  def save_images(images: list[Image.Image], metadatas: list[str]):
306
  from PIL import PngImagePlugin
307
  import uuid
@@ -566,7 +570,8 @@ private_lora_model_list = get_private_lora_model_lists()
566
 
567
  def get_civitai_info(path):
568
  global civitai_not_exists_list
569
- if path in set(civitai_not_exists_list): return ["", "", "", "", ""]
 
570
  if not Path(path).exists(): return None
571
  user_agent = get_user_agent()
572
  headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
@@ -584,12 +589,12 @@ def get_civitai_info(path):
584
  r = session.get(url, params=params, headers=headers, stream=True, timeout=(3.0, 15))
585
  except Exception as e:
586
  print(e)
587
- return ["", "", "", "", ""]
588
  if not r.ok: return None
589
  json = r.json()
590
  if not 'baseModel' in json:
591
  civitai_not_exists_list.append(path)
592
- return ["", "", "", "", ""]
593
  items = []
594
  items.append(" / ".join(json['trainedWords']))
595
  items.append(json['baseModel'])
@@ -690,7 +695,7 @@ def copy_lora(path: str, new_path: str):
690
  return None
691
 
692
 
693
- def download_my_lora(dl_urls: str, lora1: str, lora2: str, lora3: str, lora4: str, lora5: str):
694
  path = download_lora(dl_urls)
695
  if path:
696
  if not lora1 or lora1 == "None":
@@ -703,9 +708,13 @@ def download_my_lora(dl_urls: str, lora1: str, lora2: str, lora3: str, lora4: st
703
  lora4 = path
704
  elif not lora5 or lora5 == "None":
705
  lora5 = path
 
 
 
 
706
  choices = get_all_lora_tupled_list()
707
  return gr.update(value=lora1, choices=choices), gr.update(value=lora2, choices=choices), gr.update(value=lora3, choices=choices),\
708
- gr.update(value=lora4, choices=choices), gr.update(value=lora5, choices=choices)
709
 
710
 
711
  def get_valid_lora_name(query: str, model_name: str):
@@ -745,25 +754,31 @@ def get_valid_lora_wt(prompt: str, lora_path: str, lora_wt: float):
745
  return wt
746
 
747
 
748
- def set_prompt_loras(prompt, prompt_syntax, model_name, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt):
749
- if not "Classic" in str(prompt_syntax): return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt
750
  lora1 = get_valid_lora_name(lora1, model_name)
751
  lora2 = get_valid_lora_name(lora2, model_name)
752
  lora3 = get_valid_lora_name(lora3, model_name)
753
  lora4 = get_valid_lora_name(lora4, model_name)
754
  lora5 = get_valid_lora_name(lora5, model_name)
755
- if not "<lora" in prompt: return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt
 
 
756
  lora1_wt = get_valid_lora_wt(prompt, lora1, lora1_wt)
757
  lora2_wt = get_valid_lora_wt(prompt, lora2, lora2_wt)
758
  lora3_wt = get_valid_lora_wt(prompt, lora3, lora3_wt)
759
  lora4_wt = get_valid_lora_wt(prompt, lora4, lora4_wt)
760
  lora5_wt = get_valid_lora_wt(prompt, lora5, lora5_wt)
 
 
761
  on1, label1, tag1, md1 = get_lora_info(lora1)
762
  on2, label2, tag2, md2 = get_lora_info(lora2)
763
  on3, label3, tag3, md3 = get_lora_info(lora3)
764
  on4, label4, tag4, md4 = get_lora_info(lora4)
765
  on5, label5, tag5, md5 = get_lora_info(lora5)
766
- lora_paths = [lora1, lora2, lora3, lora4, lora5]
 
 
767
  prompts = prompt.split(",") if prompt else []
768
  for p in prompts:
769
  p = str(p).strip()
@@ -780,30 +795,40 @@ def set_prompt_loras(prompt, prompt_syntax, model_name, lora1, lora1_wt, lora2,
780
  continue
781
  elif not on1:
782
  lora1 = path
783
- lora_paths = [lora1, lora2, lora3, lora4, lora5]
784
  lora1_wt = safe_float(wt)
785
  on1 = True
786
  elif not on2:
787
  lora2 = path
788
- lora_paths = [lora1, lora2, lora3, lora4, lora5]
789
  lora2_wt = safe_float(wt)
790
  on2 = True
791
  elif not on3:
792
  lora3 = path
793
- lora_paths = [lora1, lora2, lora3, lora4, lora5]
794
  lora3_wt = safe_float(wt)
795
  on3 = True
796
  elif not on4:
797
  lora4 = path
798
- lora_paths = [lora1, lora2, lora3, lora4, lora5]
799
  lora4_wt = safe_float(wt)
800
  on4 = True
801
  elif not on5:
802
  lora5 = path
803
- lora_paths = [lora1, lora2, lora3, lora4, lora5]
804
  lora5_wt = safe_float(wt)
805
  on5 = True
806
- return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt
 
 
 
 
 
 
 
 
 
 
807
 
808
 
809
  def get_lora_info(lora_path: str):
@@ -864,13 +889,15 @@ def apply_lora_prompt(prompt: str = "", lora_info: str = ""):
864
  return gr.update(value=prompt)
865
 
866
 
867
- def update_loras(prompt, prompt_syntax, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt):
868
  on1, label1, tag1, md1 = get_lora_info(lora1)
869
  on2, label2, tag2, md2 = get_lora_info(lora2)
870
  on3, label3, tag3, md3 = get_lora_info(lora3)
871
  on4, label4, tag4, md4 = get_lora_info(lora4)
872
  on5, label5, tag5, md5 = get_lora_info(lora5)
873
- lora_paths = [lora1, lora2, lora3, lora4, lora5]
 
 
874
 
875
  output_prompt = prompt
876
  if "Classic" in str(prompt_syntax):
@@ -895,6 +922,8 @@ def update_loras(prompt, prompt_syntax, lora1, lora1_wt, lora2, lora2_wt, lora3,
895
  if on3: lora_prompts.append(f"<lora:{to_lora_key(lora3)}:{lora3_wt:.2f}>")
896
  if on4: lora_prompts.append(f"<lora:{to_lora_key(lora4)}:{lora4_wt:.2f}>")
897
  if on5: lora_prompts.append(f"<lora:{to_lora_key(lora5)}:{lora5_wt:.2f}>")
 
 
898
  output_prompt = ", ".join(list_uniq(output_prompts + lora_prompts + [""]))
899
  choices = get_all_lora_tupled_list()
900
 
@@ -907,7 +936,11 @@ def update_loras(prompt, prompt_syntax, lora1, lora1_wt, lora2, lora2_wt, lora3,
907
  gr.update(value=lora4, choices=choices), gr.update(value=lora4_wt),\
908
  gr.update(value=tag4, label=label4, visible=on4), gr.update(visible=on4), gr.update(value=md4, visible=on4),\
909
  gr.update(value=lora5, choices=choices), gr.update(value=lora5_wt),\
910
- gr.update(value=tag5, label=label5, visible=on5), gr.update(visible=on5), gr.update(value=md5, visible=on5)
 
 
 
 
911
 
912
 
913
  def get_my_lora(link_url, romanize):
@@ -926,7 +959,6 @@ def get_my_lora(link_url, romanize):
926
  path.resolve().rename(new_path.resolve())
927
  update_lora_dict(str(new_path))
928
  l_path = str(new_path)
929
- new_lora_model_list = get_lora_model_list()
930
  new_lora_tupled_list = get_all_lora_tupled_list()
931
  msg_lora = "Downloaded"
932
  if l_name:
@@ -943,6 +975,10 @@ def get_my_lora(link_url, romanize):
943
  choices=new_lora_tupled_list
944
  ), gr.update(
945
  choices=new_lora_tupled_list
 
 
 
 
946
  ), gr.update(
947
  value=msg_lora
948
  )
@@ -975,12 +1011,19 @@ def move_file_lora(filepaths):
975
  choices=new_lora_tupled_list
976
  ), gr.update(
977
  choices=new_lora_tupled_list
 
 
 
 
978
  )
979
 
980
 
981
- CIVITAI_SORT = ["Highest Rated", "Most Downloaded", "Newest"]
982
  CIVITAI_PERIOD = ["AllTime", "Year", "Month", "Week", "Day"]
983
- CIVITAI_BASEMODEL = ["Pony", "Illustrious", "SDXL 1.0", "SD 1.5", "Flux.1 D", "Flux.1 S"]
 
 
 
984
 
985
 
986
  def get_civitai_info(path):
@@ -1025,6 +1068,7 @@ def search_lora_on_civitai(query: str, allow_model: list[str] = ["Pony", "SDXL 1
1025
  sort: str = "Highest Rated", period: str = "AllTime", tag: str = "", user: str = "", page: int = 1):
1026
  user_agent = get_user_agent()
1027
  headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
 
1028
  base_url = 'https://civitai.com/api/v1/models'
1029
  params = {'types': ['LORA'], 'sort': sort, 'period': period, 'limit': limit, 'page': int(page), 'nsfw': 'true'}
1030
  if query: params["query"] = query
 
172
  self.download_url = json_data.get("downloadUrl", "")
173
  self.model_url = f"https://civitai.com/models/{self.model_id}?modelVersionId={self.model_version_id}"
174
  self.filename_url = next(
175
+ (v.get("name", "") for v in reversed(json_data.get("files", [])) if str(self.model_version_id) in v.get("downloadUrl", "")), ""
176
  )
177
  self.filename_url = self.filename_url if self.filename_url else ""
178
  self.description = json_data.get("description", "")
 
302
  return output
303
 
304
 
305
+ def valid_model_name(model_name: str):
306
+ return model_name.split(" ")[0]
307
+
308
+
309
  def save_images(images: list[Image.Image], metadatas: list[str]):
310
  from PIL import PngImagePlugin
311
  import uuid
 
570
 
571
  def get_civitai_info(path):
572
  global civitai_not_exists_list
573
+ default = ["", "", "", "", ""]
574
+ if path in set(civitai_not_exists_list): return default
575
  if not Path(path).exists(): return None
576
  user_agent = get_user_agent()
577
  headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
 
589
  r = session.get(url, params=params, headers=headers, stream=True, timeout=(3.0, 15))
590
  except Exception as e:
591
  print(e)
592
+ return default
593
  if not r.ok: return None
594
  json = r.json()
595
  if not 'baseModel' in json:
596
  civitai_not_exists_list.append(path)
597
+ return default
598
  items = []
599
  items.append(" / ".join(json['trainedWords']))
600
  items.append(json['baseModel'])
 
695
  return None
696
 
697
 
698
+ def download_my_lora(dl_urls: str, lora1: str, lora2: str, lora3: str, lora4: str, lora5: str, lora6: str, lora7: str):
699
  path = download_lora(dl_urls)
700
  if path:
701
  if not lora1 or lora1 == "None":
 
708
  lora4 = path
709
  elif not lora5 or lora5 == "None":
710
  lora5 = path
711
+ #elif not lora6 or lora6 == "None":
712
+ # lora6 = path
713
+ #elif not lora7 or lora7 == "None":
714
+ # lora7 = path
715
  choices = get_all_lora_tupled_list()
716
  return gr.update(value=lora1, choices=choices), gr.update(value=lora2, choices=choices), gr.update(value=lora3, choices=choices),\
717
+ gr.update(value=lora4, choices=choices), gr.update(value=lora5, choices=choices), gr.update(value=lora6, choices=choices), gr.update(value=lora7, choices=choices)
718
 
719
 
720
  def get_valid_lora_name(query: str, model_name: str):
 
754
  return wt
755
 
756
 
757
+ def set_prompt_loras(prompt, prompt_syntax, model_name, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt, lora6, lora6_wt, lora7, lora7_wt):
758
+ if not "Classic" in str(prompt_syntax): return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt, lora6, lora6_wt, lora7, lora7_wt
759
  lora1 = get_valid_lora_name(lora1, model_name)
760
  lora2 = get_valid_lora_name(lora2, model_name)
761
  lora3 = get_valid_lora_name(lora3, model_name)
762
  lora4 = get_valid_lora_name(lora4, model_name)
763
  lora5 = get_valid_lora_name(lora5, model_name)
764
+ #lora6 = get_valid_lora_name(lora6, model_name)
765
+ #lora7 = get_valid_lora_name(lora7, model_name)
766
+ if not "<lora" in prompt: return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt, lora6, lora6_wt, lora7, lora7_wt
767
  lora1_wt = get_valid_lora_wt(prompt, lora1, lora1_wt)
768
  lora2_wt = get_valid_lora_wt(prompt, lora2, lora2_wt)
769
  lora3_wt = get_valid_lora_wt(prompt, lora3, lora3_wt)
770
  lora4_wt = get_valid_lora_wt(prompt, lora4, lora4_wt)
771
  lora5_wt = get_valid_lora_wt(prompt, lora5, lora5_wt)
772
+ #lora6_wt = get_valid_lora_wt(prompt, lora6, lora5_wt)
773
+ #lora7_wt = get_valid_lora_wt(prompt, lora7, lora5_wt)
774
  on1, label1, tag1, md1 = get_lora_info(lora1)
775
  on2, label2, tag2, md2 = get_lora_info(lora2)
776
  on3, label3, tag3, md3 = get_lora_info(lora3)
777
  on4, label4, tag4, md4 = get_lora_info(lora4)
778
  on5, label5, tag5, md5 = get_lora_info(lora5)
779
+ #on6, label6, tag6, md6 = get_lora_info(lora6)
780
+ #on7, label7, tag7, md7 = get_lora_info(lora7)
781
+ lora_paths = [lora1, lora2, lora3, lora4, lora5, lora6, lora7]
782
  prompts = prompt.split(",") if prompt else []
783
  for p in prompts:
784
  p = str(p).strip()
 
795
  continue
796
  elif not on1:
797
  lora1 = path
798
+ lora_paths = [lora1, lora2, lora3, lora4, lora5, lora6, lora7]
799
  lora1_wt = safe_float(wt)
800
  on1 = True
801
  elif not on2:
802
  lora2 = path
803
+ lora_paths = [lora1, lora2, lora3, lora4, lora5, lora6, lora7]
804
  lora2_wt = safe_float(wt)
805
  on2 = True
806
  elif not on3:
807
  lora3 = path
808
+ lora_paths = [lora1, lora2, lora3, lora4, lora5, lora6, lora7]
809
  lora3_wt = safe_float(wt)
810
  on3 = True
811
  elif not on4:
812
  lora4 = path
813
+ lora_paths = [lora1, lora2, lora3, lora4, lora5, lora6, lora7]
814
  lora4_wt = safe_float(wt)
815
  on4 = True
816
  elif not on5:
817
  lora5 = path
818
+ lora_paths = [lora1, lora2, lora3, lora4, lora5, lora6, lora7]
819
  lora5_wt = safe_float(wt)
820
  on5 = True
821
+ #elif not on6:
822
+ # lora6 = path
823
+ # lora_paths = [lora1, lora2, lora3, lora4, lora5, lora6, lora7]
824
+ # lora6_wt = safe_float(wt)
825
+ # on6 = True
826
+ #elif not on7:
827
+ # lora7 = path
828
+ # lora_paths = [lora1, lora2, lora3, lora4, lora5, lora6, lora7]
829
+ # lora7_wt = safe_float(wt)
830
+ # on7 = True
831
+ return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt, lora6, lora6_wt, lora7, lora7_wt
832
 
833
 
834
  def get_lora_info(lora_path: str):
 
889
  return gr.update(value=prompt)
890
 
891
 
892
+ def update_loras(prompt, prompt_syntax, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt, lora6, lora6_wt, lora7, lora7_wt):
893
  on1, label1, tag1, md1 = get_lora_info(lora1)
894
  on2, label2, tag2, md2 = get_lora_info(lora2)
895
  on3, label3, tag3, md3 = get_lora_info(lora3)
896
  on4, label4, tag4, md4 = get_lora_info(lora4)
897
  on5, label5, tag5, md5 = get_lora_info(lora5)
898
+ on6, label6, tag6, md6 = get_lora_info(lora6)
899
+ on7, label7, tag7, md7 = get_lora_info(lora7)
900
+ lora_paths = [lora1, lora2, lora3, lora4, lora5, lora6, lora7]
901
 
902
  output_prompt = prompt
903
  if "Classic" in str(prompt_syntax):
 
922
  if on3: lora_prompts.append(f"<lora:{to_lora_key(lora3)}:{lora3_wt:.2f}>")
923
  if on4: lora_prompts.append(f"<lora:{to_lora_key(lora4)}:{lora4_wt:.2f}>")
924
  if on5: lora_prompts.append(f"<lora:{to_lora_key(lora5)}:{lora5_wt:.2f}>")
925
+ #if on6: lora_prompts.append(f"<lora:{to_lora_key(lora6)}:{lora6_wt:.2f}>")
926
+ #if on7: lora_prompts.append(f"<lora:{to_lora_key(lora7)}:{lora7_wt:.2f}>")
927
  output_prompt = ", ".join(list_uniq(output_prompts + lora_prompts + [""]))
928
  choices = get_all_lora_tupled_list()
929
 
 
936
  gr.update(value=lora4, choices=choices), gr.update(value=lora4_wt),\
937
  gr.update(value=tag4, label=label4, visible=on4), gr.update(visible=on4), gr.update(value=md4, visible=on4),\
938
  gr.update(value=lora5, choices=choices), gr.update(value=lora5_wt),\
939
+ gr.update(value=tag5, label=label5, visible=on5), gr.update(visible=on5), gr.update(value=md5, visible=on5),\
940
+ gr.update(value=lora6, choices=choices), gr.update(value=lora6_wt),\
941
+ gr.update(value=tag6, label=label6, visible=on6), gr.update(visible=on6), gr.update(value=md6, visible=on6),\
942
+ gr.update(value=lora7, choices=choices), gr.update(value=lora7_wt),\
943
+ gr.update(value=tag7, label=label7, visible=on7), gr.update(visible=on7), gr.update(value=md7, visible=on7)
944
 
945
 
946
  def get_my_lora(link_url, romanize):
 
959
  path.resolve().rename(new_path.resolve())
960
  update_lora_dict(str(new_path))
961
  l_path = str(new_path)
 
962
  new_lora_tupled_list = get_all_lora_tupled_list()
963
  msg_lora = "Downloaded"
964
  if l_name:
 
975
  choices=new_lora_tupled_list
976
  ), gr.update(
977
  choices=new_lora_tupled_list
978
+ ), gr.update(
979
+ choices=new_lora_tupled_list
980
+ ), gr.update(
981
+ choices=new_lora_tupled_list
982
  ), gr.update(
983
  value=msg_lora
984
  )
 
1011
  choices=new_lora_tupled_list
1012
  ), gr.update(
1013
  choices=new_lora_tupled_list
1014
+ ), gr.update(
1015
+ choices=new_lora_tupled_list
1016
+ ), gr.update(
1017
+ choices=new_lora_tupled_list
1018
  )
1019
 
1020
 
1021
+ CIVITAI_SORT = ["Highest Rated", "Most Downloaded", "Most Liked", "Most Discussed", "Most Collected", "Most Buzz", "Newest"]
1022
  CIVITAI_PERIOD = ["AllTime", "Year", "Month", "Week", "Day"]
1023
+ CIVITAI_BASEMODEL = ["Pony", "Illustrious", "SDXL 1.0", "SD 1.5", "Flux.1 D", "Flux.1 S"] # , "SD 3.5"
1024
+ CIVITAI_TYPE = ["Checkpoint", "TextualInversion", "Hypernetwork", "AestheticGradient", "LORA", "LoCon", "DoRA",
1025
+ "Controlnet", "Upscaler", "MotionModule", "VAE", "Poses", "Wildcards", "Workflows", "Other"]
1026
+ CIVITAI_FILETYPE = ["Model", "VAE", "Config", "Training Data"]
1027
 
1028
 
1029
  def get_civitai_info(path):
 
1068
  sort: str = "Highest Rated", period: str = "AllTime", tag: str = "", user: str = "", page: int = 1):
1069
  user_agent = get_user_agent()
1070
  headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
1071
+ if CIVITAI_API_KEY: headers['Authorization'] = f'Bearer {{{CIVITAI_API_KEY}}}'
1072
  base_url = 'https://civitai.com/api/v1/models'
1073
  params = {'types': ['LORA'], 'sort': sort, 'period': period, 'limit': limit, 'page': int(page), 'nsfw': 'true'}
1074
  if query: params["query"] = query
requirements.txt CHANGED
@@ -14,6 +14,6 @@ numpy<2
14
  opencv-python
15
  deepspeed
16
  mediapipe
17
- openai==1.37.0
18
  translatepy
19
  unidecode
 
14
  opencv-python
15
  deepspeed
16
  mediapipe
17
+ openai>=1.37.0
18
  translatepy
19
  unidecode