John6666 commited on
Commit
09fa6ac
·
verified ·
1 Parent(s): efd9993

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +6 -28
  2. mod.py +52 -0
app.py CHANGED
@@ -8,15 +8,13 @@ from diffusers import DiffusionPipeline
8
  import copy
9
  import random
10
  import time
 
11
 
12
  # Load LoRAs from JSON file
13
  with open('loras.json', 'r') as f:
14
  loras = json.load(f)
15
 
16
  # Initialize the base model
17
- models = ["camenduru/FLUX.1-dev-diffusers", "black-forest-labs/FLUX.1-schnell",
18
- "sayakpaul/FLUX.1-merged", "John6666/blue-pencil-flux1-v001-fp8-flux",
19
- "John6666/fluxunchained-artfulnsfw-fut516xfp8e4m3fnv11-fp8-flux", "John6666/nepotism-fuxdevschnell-v3aio-flux"]
20
  base_model = models[0]
21
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
22
 
@@ -79,7 +77,8 @@ def generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height,
79
 
80
  def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
81
  lora_scale, lora_repo, lora_weights, lora_trigger, progress=gr.Progress(track_tqdm=True)):
82
- #if selected_index is None and not lora_repo:
 
83
  # raise gr.Error("You must select a LoRA before proceeding.")
84
 
85
  if selected_index is not None and not lora_repo:
@@ -110,33 +109,12 @@ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, wid
110
  image = generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, progress)
111
  pipe.to("cpu")
112
  if selected_index is not None or lora_repo: pipe.unload_lora_weights()
 
113
  return image, seed
114
 
115
  run_lora.zerogpu = True
116
 
117
- def get_repo_safetensors(repo_id: str):
118
- from huggingface_hub import HfApi
119
- api = HfApi()
120
- try:
121
- if " " in repo_id or not api.repo_exists(repo_id): return gr.update(value="", choices=[])
122
- files = api.list_repo_files(repo_id=repo_id)
123
- except Exception as e:
124
- print(f"Error: Failed to get {repo_id}'s info. ")
125
- print(e)
126
- return gr.update(choices=[])
127
- files = [f for f in files if f.endswith(".safetensors")]
128
- if len(files) == 0: return gr.update(value="", choices=[])
129
- else: return gr.update(value=files[0], choices=files)
130
-
131
- def change_base_model(repo_id: str):
132
- from huggingface_hub import HfApi
133
- global pipe
134
- api = HfApi()
135
- try:
136
- if " " in repo_id or not api.repo_exists(repo_id): return
137
- pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16)
138
- except Exception as e:
139
- print(e)
140
 
141
  css = '''
142
  #gen_btn{height: 100%}
@@ -147,7 +125,7 @@ css = '''
147
  '''
148
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
149
  title = gr.HTML(
150
- """<h1><img src="https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer/resolve/main/flux_lora.png" alt="LoRA"> FLUX LoRA the Explorer</h1>""",
151
  elem_id="title",
152
  )
153
  selected_index = gr.State(None)
 
8
  import copy
9
  import random
10
  import time
11
+ from mod import models, clear_cache, get_repo_safetensors, change_base_model
12
 
13
  # Load LoRAs from JSON file
14
  with open('loras.json', 'r') as f:
15
  loras = json.load(f)
16
 
17
  # Initialize the base model
 
 
 
18
  base_model = models[0]
19
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
20
 
 
77
 
78
  def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
79
  lora_scale, lora_repo, lora_weights, lora_trigger, progress=gr.Progress(track_tqdm=True)):
80
+ if selected_index is None and not lora_repo:
81
+ gr.Info("LoRA isn't selected.")
82
  # raise gr.Error("You must select a LoRA before proceeding.")
83
 
84
  if selected_index is not None and not lora_repo:
 
109
  image = generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, progress)
110
  pipe.to("cpu")
111
  if selected_index is not None or lora_repo: pipe.unload_lora_weights()
112
+ clear_cache()
113
  return image, seed
114
 
115
  run_lora.zerogpu = True
116
 
117
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  css = '''
120
  #gen_btn{height: 100%}
 
125
  '''
126
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
127
  title = gr.HTML(
128
+ """<h1><img src="https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer/resolve/main/flux_lora.png" alt="LoRA"> FLUX LoRA the Explorer Mod</h1>""",
129
  elem_id="title",
130
  )
131
  selected_index = gr.State(None)
mod.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import spaces
4
+ from diffusers import DiffusionPipeline
5
+ import gc
6
+ import subprocess
7
+
8
+
9
+ subprocess.run('pip cache purge', shell=True)
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ torch.set_grad_enabled(False)
12
+
13
+
14
+ models = ["camenduru/FLUX.1-dev-diffusers",
15
+ "black-forest-labs/FLUX.1-schnell",
16
+ "sayakpaul/FLUX.1-merged",
17
+ "John6666/blue-pencil-flux1-v001-fp8-flux",
18
+ "John6666/fluxunchained-artfulnsfw-fut516xfp8e4m3fnv11-fp8-flux",
19
+ "John6666/nepotism-fuxdevschnell-v3aio-flux"
20
+ ]
21
+
22
+
23
+ def clear_cache():
24
+ torch.cuda.empty_cache()
25
+ gc.collect()
26
+
27
+
28
+ def get_repo_safetensors(repo_id: str):
29
+ from huggingface_hub import HfApi
30
+ api = HfApi()
31
+ try:
32
+ if " " in repo_id or not api.repo_exists(repo_id): return gr.update(value="", choices=[])
33
+ files = api.list_repo_files(repo_id=repo_id)
34
+ except Exception as e:
35
+ print(f"Error: Failed to get {repo_id}'s info. ")
36
+ print(e)
37
+ return gr.update(choices=[])
38
+ files = [f for f in files if f.endswith(".safetensors")]
39
+ if len(files) == 0: return gr.update(value="", choices=[])
40
+ else: return gr.update(value=files[0], choices=files)
41
+
42
+
43
+ def change_base_model(repo_id: str):
44
+ from huggingface_hub import HfApi
45
+ global pipe
46
+ api = HfApi()
47
+ try:
48
+ if " " in repo_id or not api.repo_exists(repo_id): return
49
+ clear_cache()
50
+ pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16)
51
+ except Exception as e:
52
+ print(e)