File size: 6,016 Bytes
4120479
 
 
a8a382e
 
4120479
 
 
1475e41
3eb8dac
4120479
 
 
 
 
 
 
20706a7
4120479
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f40f84
 
 
4120479
ce04d24
 
 
 
4120479
ce04d24
4120479
ce04d24
 
 
 
4120479
 
84e6fa0
3eb8dac
 
 
4120479
3eb8dac
ce04d24
1d5c6d0
2f40f84
4120479
2f40f84
4120479
1d5c6d0
4120479
 
3eb8dac
 
 
 
 
4120479
 
 
 
ce04d24
4120479
 
ea9cf0a
4120479
 
 
 
 
 
 
 
a56c826
 
4120479
a56c826
 
4120479
 
 
 
 
 
 
 
ce04d24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4120479
 
 
 
 
 
 
a56c826
 
4120479
 
 
a56c826
4120479
ce04d24
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import gradio as gr
from time import sleep
from diffusers import DiffusionPipeline
from huggingface_hub import hf_hub_download

import torch
import json
import random
import copy
import gc

lora_list = hf_hub_download(repo_id="multimodalart/LoraTheExplorer", filename="sdxl_loras.json", repo_type="space")

with open(lora_list, "r") as file:
    data = json.load(file)
    sdxl_loras = [
        {
            "image": item["image"] if item["image"].startswith("https://") else f'https://huggingface.co/spaces/multimodalart/LoraTheExplorer/resolve/main/{item["image"]}',
            "title": item["title"],
            "repo": item["repo"],
            "trigger_word": item["trigger_word"],
            "weights": item["weights"],
            "is_compatible": item["is_compatible"],
            "is_pivotal": item.get("is_pivotal", False),
            "text_embedding_weights": item.get("text_embedding_weights", None),
            "is_nc": item.get("is_nc", False)
        }
        for item in data
    ]

saved_names = [
    hf_hub_download(item["repo"], item["weights"]) for item in sdxl_loras
]

for item, saved_name in zip(sdxl_loras, saved_names):
    item["saved_name"] = saved_name

css = '''
#title{text-align:center;}
#title h1{font-size: 250%}
.plus_column{align-self: center}
.plus_button{font-size: 235% !important; text-align: center;margin-bottom: 19px}
#prompt input{width: calc(100% - 160px);border-top-right-radius: 0px;border-bottom-right-radius: 0px;}
#run_button{position:absolute;margin-top: 12px;right: 0;margin-right: 1.5em;border-bottom-left-radius: 0px;
    border-top-left-radius: 0px;}
.random_column{align-self: center}
@media (max-width: 1024px) {
.roulette_group{flex-direction: column}
}
'''

#@spaces.GPU
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
original_pipe = copy.deepcopy(pipe)

def merge_and_run(prompt, negative_prompt, shuffled_items, lora_1_scale=0.5, lora_2_scale=0.5, progress=gr.Progress(track_tqdm=True)):
  pipe = copy.deepcopy(original_pipe)
  pipe.to("cuda")  
  print("Loading LoRAs")
  pipe.load_lora_weights(shuffled_items[0]['saved_name'])
  pipe.fuse_lora(lora_1_scale)
  pipe.load_lora_weights(shuffled_items[1]['saved_name'])
  pipe.fuse_lora(lora_2_scale)
  
  if negative_prompt == "":
    negative_prompt = False
      
  image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=25).images[0]
  del pipe
  gc.collect()
  torch.cuda.empty_cache()
  return image

def get_description(item):
      trigger_word = item["trigger_word"]
      return f"Trigger: `{trigger_word}`" if trigger_word else "No trigger word, will be applied automatically", trigger_word
    
def shuffle_images():
    compatible_items = [item for item in sdxl_loras if item['is_compatible']]
    random.shuffle(compatible_items)
    two_shuffled_items = compatible_items[:2]
    title_1  = gr.update(label=two_shuffled_items[0]['title'], value=two_shuffled_items[0]['image'])
    title_2 = gr.update(label=two_shuffled_items[1]['title'], value=two_shuffled_items[1]['image'])

    description_1, trigger_word_1 = get_description(two_shuffled_items[0])
    description_2, trigger_word_2 = get_description(two_shuffled_items[1])
    
    prompt_description_1 = gr.update(value=description_1, visible=True)
    prompt_description_2 = gr.update(value=description_2, visible=True)
    prompt = gr.update(value=f"{trigger_word_1} {trigger_word_2}")
    
    return title_1, prompt_description_1, title_2, prompt_description_2, prompt, two_shuffled_items

with gr.Blocks(css=css) as demo:
  shuffled_items = gr.State()
  title = gr.HTML(
        '''<h1>LoRA Roulette 🎲</h1>
        ''',
        elem_id="title"
  )
  with gr.Row(elem_classes="roulette_group"):
    with gr.Column(min_width=10, scale=16, elem_classes="plus_column"):
        gr.HTML("<p>This 2 random LoRAs are loaded to SDXL, find a fun way to combine them 🎨</p>")
        with gr.Row():
            with gr.Column(min_width=10, scale=8, elem_classes="random_column"):
              lora_1 = gr.Image(interactive=False, height=263)
              lora_1_prompt = gr.Markdown(visible=False)
            with gr.Column(min_width=10, scale=1, elem_classes="plus_column"):
              plus = gr.HTML("+", elem_classes="plus_button")
            with gr.Column(min_width=10, scale=8, elem_classes="random_column"):
              lora_2 = gr.Image(interactive=False, height=263)
              lora_2_prompt = gr.Markdown(visible=False)
            with gr.Column(min_width=10, scale=1, elem_classes="plus_column"):
               equal = gr.HTML("=", elem_classes="plus_button")
        shuffle_button = gr.Button("Reshuffle!", visible=False)
    with gr.Column(min_width=10, scale=14):
        with gr.Box():
            with gr.Row():
                prompt = gr.Textbox(label="Your prompt", show_label=False, interactive=True, elem_id="prompt")
                run_btn = gr.Button("Run", elem_id="run_button")
            output_image = gr.Image(label="Output", height=355)
  
  
  
  
  with gr.Accordion("Advanced settings", open=False):
    negative_prompt = gr.Textbox(label="Negative prompt")
    with gr.Row():
      lora_1_scale = gr.Slider(label="LoRA 1 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
      lora_2_scale = gr.Slider(label="LoRa 2 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
  
  demo.load(shuffle_images, inputs=[], outputs=[lora_1, lora_1_prompt, lora_2, lora_2_prompt, prompt, shuffled_items], queue=False, show_progress="hidden")
  shuffle_button.click(shuffle_images, outputs=[lora_1, lora_1_prompt, lora_2, lora_2_prompt, prompt, shuffled_items], queue=False, show_progress="hidden")

  run_btn.click(merge_and_run, inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale], outputs=[output_image])
  prompt.submit(merge_and_run, inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale], outputs=[output_image])

demo.queue()
demo.launch(share=True)