Bobby commited on
Commit
f3ff2c1
·
1 Parent(s): ab8a1ed
app.py CHANGED
@@ -1,435 +1,451 @@
1
- prod = False
2
- port = 8080
3
- show_options = False
4
- if prod:
5
- port = 8081
6
- # show_options = False
7
-
8
- import os
9
- import gc
10
- import random
11
- import time
12
- import gradio as gr
13
- import numpy as np
14
- # import imageio
15
- import torch
16
- from PIL import Image
17
- from diffusers import (
18
- ControlNetModel,
19
- DPMSolverMultistepScheduler,
20
- StableDiffusionControlNetPipeline,
21
- )
22
- from diffusers.models.attention_processor import AttnProcessor2_0
23
- from preprocess import Preprocessor
24
- MAX_SEED = np.iinfo(np.int32).max
25
- API_KEY = os.environ.get("API_KEY", None)
26
-
27
- print("CUDA version:", torch.version.cuda)
28
- print("loading pipe")
29
- compiled = False
30
- # api = HfApi()
31
-
32
- import spaces
33
-
34
- preprocessor = Preprocessor()
35
- preprocessor.load("NormalBae")
36
-
37
- if gr.NO_RELOAD:
38
- torch.cuda.max_memory_allocated(device="cuda")
39
-
40
- # Controlnet Normal
41
- model_id = "lllyasviel/control_v11p_sd15_normalbae"
42
- print("initializing controlnet")
43
- controlnet = ControlNetModel.from_pretrained(
44
- model_id,
45
- torch_dtype=torch.float16,
46
- attn_implementation="flash_attention_2",
47
- ).to("cuda")
48
-
49
- # Scheduler
50
- scheduler = DPMSolverMultistepScheduler.from_pretrained(
51
- "runwayml/stable-diffusion-v1-5",
52
- solver_order=2,
53
- subfolder="scheduler",
54
- use_karras_sigmas=True,
55
- final_sigmas_type="sigma_min",
56
- algorithm_type="sde-dpmsolver++",
57
- prediction_type="epsilon",
58
- thresholding=False,
59
- denoise_final=True,
60
- device_map="cuda",
61
- torch_dtype=torch.float16,
62
- )
63
-
64
- # Stable Diffusion Pipeline URL
65
- # base_model_url = "https://huggingface.co/broyang/hentaidigitalart_v20/blob/main/realcartoon3d_v15.safetensors"
66
- base_model_url = "https://huggingface.co/Lykon/AbsoluteReality/blob/main/AbsoluteReality_1.8.1_pruned.safetensors"
67
-
68
- pipe = StableDiffusionControlNetPipeline.from_single_file(
69
- base_model_url,
70
- # safety_checker=None,
71
- # load_safety_checker=True,
72
- controlnet=controlnet,
73
- scheduler=scheduler,
74
- torch_dtype=torch.float16,
75
- )
76
-
77
- pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="EasyNegativeV2.safetensors", token="EasyNegativeV2",)
78
- pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="badhandv4.pt", token="badhandv4")
79
- pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="fcNeg-neg.pt", token="fcNeg-neg")
80
- pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Ahegao.pt", token="HDA_Ahegao")
81
- pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Bondage.pt", token="HDA_Bondage")
82
- pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_pet_play.pt", token="HDA_pet_play")
83
- pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_unconventional maid.pt", token="HDA_unconventional_maid")
84
- pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_NakedHoodie.pt", token="HDA_NakedHoodie")
85
- pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_NunDress.pt", token="HDA_NunDress")
86
- pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Shibari.pt", token="HDA_Shibari")
87
- pipe.to("cuda")
88
-
89
- # experimental speedup?
90
- # pipe.compile()
91
- # torch.cuda.empty_cache()
92
- # gc.collect()
93
- print("---------------Loaded controlnet pipeline---------------")
94
-
95
- @spaces.GPU(duration=12)
96
- def init(pipe):
97
- pipe.enable_xformers_memory_efficient_attention()
98
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
99
- pipe.unet.set_attn_processor(AttnProcessor2_0())
100
- print("Model Compiled!")
101
- init(pipe)
102
-
103
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
104
- if randomize_seed:
105
- seed = random.randint(0, MAX_SEED)
106
- return seed
107
-
108
- def get_additional_prompt():
109
- prompt = "hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
110
- top = ["tank top", "blouse", "button up shirt", "sweater", "corset top"]
111
- bottom = ["short skirt", "athletic shorts", "jean shorts", "pleated skirt", "short skirt", "leggings", "high-waisted shorts"]
112
- accessory = ["knee-high boots", "gloves", "Thigh-high stockings", "Garter belt", "choker", "necklace", "headband", "headphones"]
113
- return f"{prompt}, {random.choice(top)}, {random.choice(bottom)}, {random.choice(accessory)}, score_9"
114
- # outfit = ["schoolgirl outfit", "playboy outfit", "red dress", "gala dress", "cheerleader outfit", "nurse outfit", "Kimono"]
115
-
116
- def get_prompt(prompt, additional_prompt):
117
- interior = "design-style interior designed (interior space), captured with a DSLR camera using f/10 aperture, 1/60 sec shutter speed, ISO 400, 20mm focal length, tungsten white balance, (sharp focus), professional photography, high-resolution, 8k, Pulitzer Prize-winning"
118
- default = "hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
119
- default2 = f"professional 3d model {prompt},octane render,highly detailed,volumetric,dramatic lighting,hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
120
- randomize = get_additional_prompt()
121
- # nude = "NSFW,((nude)),medium bare breasts,hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
122
- # bodypaint = "((fully naked with no clothes)),nude naked seethroughxray,invisiblebodypaint,rating_newd,NSFW"
123
- lab_girl = "hyperrealistic photography, extremely detailed, shy assistant wearing minidress boots and gloves, laboratory background, score_9, 1girl"
124
- pet_play = "hyperrealistic photography, extremely detailed, playful, blush, glasses, collar, score_9, HDA_pet_play"
125
- bondage = "hyperrealistic photography, extremely detailed, submissive, glasses, score_9, HDA_Bondage"
126
- # ahegao = "((invisible clothing)), hyperrealistic photography,exposed vagina,sexy,nsfw,HDA_Ahegao"
127
- ahegao2 = "(invisiblebodypaint),rating_newd,HDA_Ahegao"
128
- athleisure = "hyperrealistic photography, extremely detailed, 1girl athlete, exhausted embarrassed sweaty,outdoors, ((athleisure clothing)), score_9"
129
- atompunk = "((atompunk world)), hyperrealistic photography, extremely detailed, short hair, bodysuit, glasses, neon cyberpunk background, score_9"
130
- maid = "hyperrealistic photography, extremely detailed, shy, blushing, score_9, pastel background, HDA_unconventional_maid"
131
- nundress = "hyperrealistic photography, extremely detailed, shy, blushing, fantasy background, score_9, HDA_NunDress"
132
- naked_hoodie = "hyperrealistic photography, extremely detailed, medium hair, cityscape, (neon lights), score_9, HDA_NakedHoodie"
133
- abg = "(1girl, asian body covered in words, words on body, tattoos of (words) on body),(masterpiece, best quality),medium breasts,(intricate details),unity 8k wallpaper,ultra detailed,(pastel colors),beautiful and aesthetic,see-through (clothes),detailed,solo"
134
- # shibari = "extremely detailed, hyperrealistic photography, earrings, blushing, lace choker, tattoo, medium hair, score_9, HDA_Shibari"
135
- shibari2 = "octane render, highly detailed, volumetric, HDA_Shibari"
136
-
137
- if prompt == "":
138
- girls = [randomize, pet_play, bondage, lab_girl, athleisure, atompunk, maid, nundress, naked_hoodie, abg, shibari2, ahegao2]
139
- prompts_nsfw = [abg, shibari2, ahegao2]
140
- prompt = f"{random.choice(girls)}"
141
- prompt = f"boho chic"
142
- # print(f"-------------{preset}-------------")
143
- else:
144
- prompt = f"Photo from Pinterest of {prompt} {interior}"
145
- # prompt = default2
146
- return f"{prompt} f{additional_prompt}"
147
-
148
- style_list = [
149
- {
150
- "name": "None",
151
- "prompt": ""
152
- },
153
- {
154
- "name": "Minimalistic",
155
- "prompt": "Minimalistic"
156
- },
157
- {
158
- "name": "Boho Chic",
159
- "prompt": "boho chic"
160
- },
161
- {
162
- "name": "Saudi Prince Gold",
163
- "prompt": "saudi prince gold",
164
- },
165
- {
166
- "name": "Modern Farmhouse",
167
- "prompt": "modern farmhouse",
168
- },
169
- {
170
- "name": "Neoclassical",
171
- "prompt": "Neoclassical",
172
- },
173
- {
174
- "name": "Eclectic",
175
- "prompt": "Eclectic",
176
- },
177
- {
178
- "name": "Parisian White",
179
- "prompt": "Parisian White",
180
- },
181
- {
182
- "name": "Hollywood Glam",
183
- "prompt": "Hollywood Glam",
184
- },
185
- {
186
- "name": "Scandinavian",
187
- "prompt": "Scandinavian",
188
- },
189
- {
190
- "name": "Japanese",
191
- "prompt": "Japanese",
192
- },
193
- {
194
- "name": "Texas Cowboy",
195
- "prompt": "Texas Cowboy",
196
- },
197
- ]
198
-
199
- styles = {k["name"]: (k["prompt"]) for k in style_list}
200
- STYLE_NAMES = list(styles.keys())
201
-
202
- def apply_style(style_name):
203
- if style_name in styles:
204
- p = styles.get(style_name, "boho chic")
205
- return p
206
-
207
-
208
- css = """
209
- h1 {
210
- text-align: center;
211
- display:block;
212
- }
213
- h2 {
214
- text-align: center;
215
- display:block;
216
- }
217
- h3 {
218
- text-align: center;
219
- display:block;
220
- }
221
- .gradio-container{max-width: 1200px !important}
222
- footer {visibility: hidden}
223
- """
224
- with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
225
- #############################################################################
226
- with gr.Row():
227
- with gr.Accordion("Advanced options", open=show_options, visible=show_options):
228
- num_images = gr.Slider(
229
- label="Images", minimum=1, maximum=4, value=1, step=1
230
- )
231
- image_resolution = gr.Slider(
232
- label="Image resolution",
233
- minimum=256,
234
- maximum=1024,
235
- value=768,
236
- step=256,
237
- )
238
- preprocess_resolution = gr.Slider(
239
- label="Preprocess resolution",
240
- minimum=128,
241
- maximum=1024,
242
- value=768,
243
- step=1,
244
- )
245
- num_steps = gr.Slider(
246
- label="Number of steps", minimum=1, maximum=100, value=12, step=1
247
- ) # 20/4.5 or 12 without lora, 4 with lora
248
- guidance_scale = gr.Slider(
249
- label="Guidance scale", minimum=0.1, maximum=30.0, value=5.5, step=0.1
250
- ) # 5 without lora, 2 with lora
251
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
252
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
253
- a_prompt = gr.Textbox(
254
- label="Additional prompt",
255
- value = "design-style interior designed (interior space), captured with a DSLR camera using f/10 aperture, 1/60 sec shutter speed, ISO 400, 20mm focal length, tungsten white balance, (sharp focus), professional photography, high-resolution, 8k, Pulitzer Prize-winning"
256
- )
257
- n_prompt = gr.Textbox(
258
- label="Negative prompt",
259
- value="EasyNegativeV2, fcNeg, (badhandv4:1.4), (worst quality, low quality, bad quality, normal quality:2.0), (bad hands, missing fingers, extra fingers:2.0)",
260
- )
261
- #############################################################################
262
- # input text
263
- with gr.Row():
264
- gr.Text(label="Interior Design Style Examples", value="Eclectic, Maximalist, Bohemian, Scandinavian, Minimalist, Rustic, Modern Farmhouse, Contemporary, Luxury, Airbnb, Boho Chic, Midcentury Modern, Art Deco, Zen, Beach, Neoclassical, Industrial, Biophilic, Eco-friendly, Hollywood Glam, Parisian White, Saudi Prince Gold, French Country, Monster Energy Drink, Cyberpunk, Vaporwave, Baroque, etc.\n\nPro tip: add a color to customize it! You can also describe the furniture type.")
265
- with gr.Column():
266
- prompt = gr.Textbox(
267
- label="Description",
268
- placeholder="boho chic",
269
- )
270
- with gr.Row(visible=True):
271
- style_selection = gr.Radio(
272
- show_label=True,
273
- container=True,
274
- interactive=True,
275
- choices=STYLE_NAMES,
276
- value="None",
277
- label="Design Styles",
278
- )
279
- # input image
280
- with gr.Row():
281
- with gr.Column():
282
- image = gr.Image(
283
- label="Input",
284
- sources=["upload"],
285
- show_label=True,
286
- mirror_webcam=True,
287
- format="webp",
288
- )
289
- # run button
290
- with gr.Column():
291
- run_button = gr.Button(value="Use this one", size=["lg"], visible=False)
292
- # output image
293
- with gr.Column():
294
- result = gr.Image(
295
- label="Output",
296
- interactive=False,
297
- format="webp",
298
- show_share_button= False,
299
- )
300
- # Use this image button
301
- with gr.Column():
302
- use_ai_button = gr.Button(value="Use this one", size=["lg"], visible=False)
303
- config = [
304
- image,
305
- style_selection,
306
- prompt,
307
- a_prompt,
308
- n_prompt,
309
- num_images,
310
- image_resolution,
311
- preprocess_resolution,
312
- num_steps,
313
- guidance_scale,
314
- seed,
315
- ]
316
-
317
- with gr.Row():
318
- helper_text = gr.Markdown("## Tap and hold (on mobile) to save the image.", visible=True)
319
-
320
- # image processing
321
- @gr.on(triggers=[image.upload, prompt.submit, run_button.click], inputs=config, outputs=result, show_progress="minimal")
322
- def auto_process_image(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
323
- return process_image(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed)
324
-
325
- # AI Image Processing
326
- @gr.on(triggers=[use_ai_button.click], inputs=config, outputs=result, show_progress="minimal")
327
- def submit(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
328
- return process_image(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed)
329
-
330
- # Change input to result
331
- @gr.on(triggers=[use_ai_button.click], inputs=None, outputs=image, show_progress="hidden")
332
- def update_input():
333
- try:
334
- print("Updating image to AI Temp Image")
335
- ai_temp_image = Image.open("temp_image.jpg")
336
- return ai_temp_image
337
- except FileNotFoundError:
338
- print("No AI Image Available")
339
- return None
340
-
341
- # Turn off buttons when processing
342
- @gr.on(triggers=[image.upload, use_ai_button.click, run_button.click], inputs=None, outputs=[run_button, use_ai_button], show_progress="hidden")
343
- def turn_buttons_off():
344
- return gr.update(visible=False), gr.update(visible=False)
345
-
346
- # Turn on buttons when processing is complete
347
- @gr.on(triggers=[result.change], inputs=None, outputs=[use_ai_button, run_button], show_progress="hidden")
348
- def turn_buttons_on():
349
- return gr.update(visible=True), gr.update(visible=True)
350
-
351
- @spaces.GPU(duration=10)
352
- @torch.inference_mode()
353
- def process_image(
354
- image,
355
- style_selection,
356
- prompt,
357
- a_prompt,
358
- n_prompt,
359
- num_images,
360
- image_resolution,
361
- preprocess_resolution,
362
- num_steps,
363
- guidance_scale,
364
- seed,
365
- progress=gr.Progress(track_tqdm=True)
366
- ):
367
- preprocess_start = time.time()
368
- print("processing image")
369
- preprocessor.load("NormalBae")
370
- # preprocessor.load("Canny") #20 steps, 9 guidance, 512, 512
371
-
372
- global compiled
373
- if not compiled:
374
- print("Not Compiled")
375
- compiled = True
376
-
377
- seed = random.randint(0, MAX_SEED)
378
- generator = torch.cuda.manual_seed(seed)
379
- control_image = preprocessor(
380
- image=image,
381
- image_resolution=image_resolution,
382
- detect_resolution=preprocess_resolution,
383
- )
384
- preprocess_time = time.time() - preprocess_start
385
- if style_selection is not None or style_selection != "None":
386
- prompt = "Photo from Pinterest of " + apply_style(style_selection) + " " + prompt + " " + a_prompt
387
- else:
388
- prompt=str(get_prompt(prompt, a_prompt))
389
- negative_prompt=str(n_prompt)
390
- print(prompt)
391
- start = time.time()
392
- results = pipe(
393
- prompt=prompt,
394
- negative_prompt=negative_prompt,
395
- guidance_scale=guidance_scale,
396
- num_images_per_prompt=num_images,
397
- num_inference_steps=num_steps,
398
- generator=generator,
399
- image=control_image,
400
- ).images[0]
401
- torch.cuda.empty_cache()
402
- print(f"\n-------------------------Preprocess done in: {preprocess_time:.2f} seconds-------------------------")
403
- print(f"\n-------------------------Inference done in: {time.time() - start:.2f} seconds-------------------------")
404
-
405
- # timestamp = int(time.time())
406
- #if not os.path.exists("./outputs"):
407
- # os.makedirs("./outputs")
408
- # img_path = f"./{timestamp}.jpg"
409
- # results_path = f"./{timestamp}_out_{prompt}.jpg"
410
- # imageio.imsave(img_path, image)
411
- # results.save(results_path)
412
- results.save("temp_image.jpg")
413
-
414
- # api.upload_file(
415
- # path_or_fileobj=img_path,
416
- # path_in_repo=img_path,
417
- # repo_id="broyang/anime-ai-outputs",
418
- # repo_type="dataset",
419
- # token=API_KEY,
420
- # run_as_future=True,
421
- # )
422
- # api.upload_file(
423
- # path_or_fileobj=results_path,
424
- # path_in_repo=results_path,
425
- # repo_id="broyang/anime-ai-outputs",
426
- # repo_type="dataset",
427
- # token=API_KEY,
428
- # run_as_future=True,
429
- # )
430
-
431
- return results
432
- if prod:
433
- demo.queue(max_size=20).launch(server_name="localhost", server_port=port)
434
- else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
  demo.queue(api_open=False).launch(show_api=False)
 
1
+ prod = False
2
+ port = 8080
3
+ show_options = False
4
+ if prod:
5
+ port = 8081
6
+ # show_options = False
7
+
8
+ import os
9
+ import gc
10
+ import random
11
+ import time
12
+ import gradio as gr
13
+ import numpy as np
14
+ # import imageio
15
+ import torch
16
+ from PIL import Image
17
+ from diffusers import (
18
+ ControlNetModel,
19
+ DPMSolverMultistepScheduler,
20
+ StableDiffusionControlNetPipeline,
21
+ AutoencoderKL,
22
+ )
23
+ from diffusers.models.attention_processor import AttnProcessor2_0
24
+ from preprocess import Preprocessor
25
+ MAX_SEED = np.iinfo(np.int32).max
26
+ API_KEY = os.environ.get("API_KEY", None)
27
+
28
+ print("CUDA version:", torch.version.cuda)
29
+ print("loading pipe")
30
+ compiled = False
31
+ # api = HfApi()
32
+
33
+ import spaces
34
+
35
+ preprocessor = Preprocessor()
36
+ preprocessor.load("NormalBae")
37
+
38
+ if gr.NO_RELOAD:
39
+ torch.cuda.max_memory_allocated(device="cuda")
40
+
41
+ # Controlnet Normal
42
+ model_id = "lllyasviel/control_v11p_sd15_normalbae"
43
+ print("initializing controlnet")
44
+ controlnet = ControlNetModel.from_pretrained(
45
+ model_id,
46
+ torch_dtype=torch.float16,
47
+ attn_implementation="flash_attention_2",
48
+ ).to("cuda")
49
+
50
+ # Scheduler
51
+ scheduler = DPMSolverMultistepScheduler.from_pretrained(
52
+ "runwayml/stable-diffusion-v1-5",
53
+ solver_order=2,
54
+ subfolder="scheduler",
55
+ use_karras_sigmas=True,
56
+ final_sigmas_type="sigma_min",
57
+ algorithm_type="sde-dpmsolver++",
58
+ prediction_type="epsilon",
59
+ thresholding=False,
60
+ denoise_final=True,
61
+ device_map="cuda",
62
+ torch_dtype=torch.float16,
63
+ )
64
+
65
+ # Stable Diffusion Pipeline URL
66
+ # base_model_url = "https://huggingface.co/broyang/hentaidigitalart_v20/blob/main/realcartoon3d_v15.safetensors"
67
+ base_model_url = "https://huggingface.co/Lykon/AbsoluteReality/blob/main/AbsoluteReality_1.8.1_pruned.safetensors"
68
+ vae_url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors"
69
+
70
+ vae = AutoencoderKL.from_single_file(vae_url, torch_dtype=torch.float16).to("cuda")
71
+ vae.to(memory_format=torch.channels_last)
72
+
73
+ pipe = StableDiffusionControlNetPipeline.from_single_file(
74
+ base_model_url,
75
+ # safety_checker=None,
76
+ # load_safety_checker=True,
77
+ controlnet=controlnet,
78
+ scheduler=scheduler,
79
+ vae=vae,
80
+ torch_dtype=torch.float16,
81
+ )
82
+
83
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="EasyNegativeV2.safetensors", token="EasyNegativeV2",)
84
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="badhandv4.pt", token="badhandv4")
85
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="fcNeg-neg.pt", token="fcNeg-neg")
86
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Ahegao.pt", token="HDA_Ahegao")
87
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Bondage.pt", token="HDA_Bondage")
88
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_pet_play.pt", token="HDA_pet_play")
89
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_unconventional maid.pt", token="HDA_unconventional_maid")
90
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_NakedHoodie.pt", token="HDA_NakedHoodie")
91
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_NunDress.pt", token="HDA_NunDress")
92
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Shibari.pt", token="HDA_Shibari")
93
+ pipe.to("cuda")
94
+
95
+ # experimental speedup?
96
+ # pipe.compile()
97
+ # torch.cuda.empty_cache()
98
+ # gc.collect()
99
+ print("---------------Loaded controlnet pipeline---------------")
100
+
101
+ @spaces.GPU(duration=12)
102
+ def init(pipe):
103
+ pipe.enable_xformers_memory_efficient_attention()
104
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
105
+ pipe.unet.set_attn_processor(AttnProcessor2_0())
106
+ print("Model Compiled!")
107
+ init(pipe)
108
+
109
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
110
+ if randomize_seed:
111
+ seed = random.randint(0, MAX_SEED)
112
+ return seed
113
+
114
+ def get_additional_prompt():
115
+ prompt = "hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
116
+ top = ["tank top", "blouse", "button up shirt", "sweater", "corset top"]
117
+ bottom = ["short skirt", "athletic shorts", "jean shorts", "pleated skirt", "short skirt", "leggings", "high-waisted shorts"]
118
+ accessory = ["knee-high boots", "gloves", "Thigh-high stockings", "Garter belt", "choker", "necklace", "headband", "headphones"]
119
+ return f"{prompt}, {random.choice(top)}, {random.choice(bottom)}, {random.choice(accessory)}, score_9"
120
+ # outfit = ["schoolgirl outfit", "playboy outfit", "red dress", "gala dress", "cheerleader outfit", "nurse outfit", "Kimono"]
121
+
122
+ def get_prompt(prompt, additional_prompt):
123
+ interior = "design-style interior designed (interior space), captured with a DSLR camera using f/10 aperture, 1/60 sec shutter speed, ISO 400, 20mm focal length, tungsten white balance, (sharp focus), professional photography, high-resolution, 8k, Pulitzer Prize-winning"
124
+ default = "hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
125
+ default2 = f"professional 3d model {prompt},octane render,highly detailed,volumetric,dramatic lighting,hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
126
+ randomize = get_additional_prompt()
127
+ # nude = "NSFW,((nude)),medium bare breasts,hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
128
+ # bodypaint = "((fully naked with no clothes)),nude naked seethroughxray,invisiblebodypaint,rating_newd,NSFW"
129
+ lab_girl = "hyperrealistic photography, extremely detailed, shy assistant wearing minidress boots and gloves, laboratory background, score_9, 1girl"
130
+ pet_play = "hyperrealistic photography, extremely detailed, playful, blush, glasses, collar, score_9, HDA_pet_play"
131
+ bondage = "hyperrealistic photography, extremely detailed, submissive, glasses, score_9, HDA_Bondage"
132
+ # ahegao = "((invisible clothing)), hyperrealistic photography,exposed vagina,sexy,nsfw,HDA_Ahegao"
133
+ ahegao2 = "(invisiblebodypaint),rating_newd,HDA_Ahegao"
134
+ athleisure = "hyperrealistic photography, extremely detailed, 1girl athlete, exhausted embarrassed sweaty,outdoors, ((athleisure clothing)), score_9"
135
+ atompunk = "((atompunk world)), hyperrealistic photography, extremely detailed, short hair, bodysuit, glasses, neon cyberpunk background, score_9"
136
+ maid = "hyperrealistic photography, extremely detailed, shy, blushing, score_9, pastel background, HDA_unconventional_maid"
137
+ nundress = "hyperrealistic photography, extremely detailed, shy, blushing, fantasy background, score_9, HDA_NunDress"
138
+ naked_hoodie = "hyperrealistic photography, extremely detailed, medium hair, cityscape, (neon lights), score_9, HDA_NakedHoodie"
139
+ abg = "(1girl, asian body covered in words, words on body, tattoos of (words) on body),(masterpiece, best quality),medium breasts,(intricate details),unity 8k wallpaper,ultra detailed,(pastel colors),beautiful and aesthetic,see-through (clothes),detailed,solo"
140
+ # shibari = "extremely detailed, hyperrealistic photography, earrings, blushing, lace choker, tattoo, medium hair, score_9, HDA_Shibari"
141
+ shibari2 = "octane render, highly detailed, volumetric, HDA_Shibari"
142
+
143
+ if prompt == "":
144
+ girls = [randomize, pet_play, bondage, lab_girl, athleisure, atompunk, maid, nundress, naked_hoodie, abg, shibari2, ahegao2]
145
+ prompts_nsfw = [abg, shibari2, ahegao2]
146
+ prompt = f"{random.choice(girls)}"
147
+ prompt = f"boho chic"
148
+ # print(f"-------------{preset}-------------")
149
+ else:
150
+ prompt = f"Photo from Pinterest of {prompt} {interior}"
151
+ # prompt = default2
152
+ return f"{prompt} f{additional_prompt}"
153
+
154
+ style_list = [
155
+ {
156
+ "name": "None",
157
+ "prompt": ""
158
+ },
159
+ {
160
+ "name": "Minimalistic",
161
+ "prompt": "Minimalistic"
162
+ },
163
+ {
164
+ "name": "Boho Chic",
165
+ "prompt": "boho chic"
166
+ },
167
+ {
168
+ "name": "Saudi Prince Gold",
169
+ "prompt": "saudi prince gold",
170
+ },
171
+ {
172
+ "name": "Modern Farmhouse",
173
+ "prompt": "modern farmhouse",
174
+ },
175
+ {
176
+ "name": "Neoclassical",
177
+ "prompt": "Neoclassical",
178
+ },
179
+ {
180
+ "name": "Eclectic",
181
+ "prompt": "Eclectic",
182
+ },
183
+ {
184
+ "name": "Parisian White",
185
+ "prompt": "Parisian White",
186
+ },
187
+ {
188
+ "name": "Hollywood Glam",
189
+ "prompt": "Hollywood Glam",
190
+ },
191
+ {
192
+ "name": "Scandinavian",
193
+ "prompt": "Scandinavian",
194
+ },
195
+ {
196
+ "name": "Japanese",
197
+ "prompt": "Japanese",
198
+ },
199
+ {
200
+ "name": "Texas Cowboy",
201
+ "prompt": "Texas Cowboy",
202
+ },
203
+ {
204
+ "name": "Midcentury Modern",
205
+ "prompt": "Midcentury Modern",
206
+ },
207
+ {
208
+ "name": "Beach",
209
+ "prompt": "Beach",
210
+ },
211
+ ]
212
+
213
+ styles = {k["name"]: (k["prompt"]) for k in style_list}
214
+ STYLE_NAMES = list(styles.keys())
215
+
216
+ def apply_style(style_name):
217
+ if style_name in styles:
218
+ p = styles.get(style_name, "boho chic")
219
+ return p
220
+
221
+
222
+ css = """
223
+ h1 {
224
+ text-align: center;
225
+ display:block;
226
+ }
227
+ h2 {
228
+ text-align: center;
229
+ display:block;
230
+ }
231
+ h3 {
232
+ text-align: center;
233
+ display:block;
234
+ }
235
+ .gradio-container{max-width: 1200px !important}
236
+ footer {visibility: hidden}
237
+ """
238
+ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
239
+ #############################################################################
240
+ with gr.Row():
241
+ with gr.Accordion("Advanced options", open=show_options, visible=show_options):
242
+ num_images = gr.Slider(
243
+ label="Images", minimum=1, maximum=4, value=1, step=1
244
+ )
245
+ image_resolution = gr.Slider(
246
+ label="Image resolution",
247
+ minimum=256,
248
+ maximum=1024,
249
+ value=512,
250
+ step=256,
251
+ )
252
+ preprocess_resolution = gr.Slider(
253
+ label="Preprocess resolution",
254
+ minimum=128,
255
+ maximum=1024,
256
+ value=512,
257
+ step=1,
258
+ )
259
+ num_steps = gr.Slider(
260
+ label="Number of steps", minimum=1, maximum=100, value=15, step=1
261
+ ) # 20/4.5 or 12 without lora, 4 with lora
262
+ guidance_scale = gr.Slider(
263
+ label="Guidance scale", minimum=0.1, maximum=30.0, value=5.5, step=0.1
264
+ ) # 5 without lora, 2 with lora
265
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
266
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
267
+ a_prompt = gr.Textbox(
268
+ label="Additional prompt",
269
+ value = "design-style interior designed (interior space), captured with a DSLR camera using f/10 aperture, 1/60 sec shutter speed, ISO 400, 20mm focal length, tungsten white balance, (sharp focus), professional photography, high-resolution, 8k, Pulitzer Prize-winning"
270
+ )
271
+ n_prompt = gr.Textbox(
272
+ label="Negative prompt",
273
+ value="EasyNegativeV2, fcNeg, (badhandv4:1.4), (worst quality, low quality, bad quality, normal quality:2.0), (bad hands, missing fingers, extra fingers:2.0)",
274
+ )
275
+ #############################################################################
276
+ # input text
277
+ with gr.Row():
278
+ gr.Text(label="Interior Design Style Examples", value="Eclectic, Maximalist, Bohemian, Scandinavian, Minimalist, Rustic, Modern Farmhouse, Contemporary, Luxury, Airbnb, Boho Chic, Midcentury Modern, Art Deco, Zen, Beach, Neoclassical, Industrial, Biophilic, Eco-friendly, Hollywood Glam, Parisian White, Saudi Prince Gold, French Country, Monster Energy Drink, Cyberpunk, Vaporwave, Baroque, etc.\n\nPro tip: add a color to customize it! You can also describe the furniture type.")
279
+ with gr.Column():
280
+ prompt = gr.Textbox(
281
+ label="Custom Prompt",
282
+ placeholder="boho chic",
283
+ )
284
+ with gr.Row(visible=True):
285
+ style_selection = gr.Radio(
286
+ show_label=True,
287
+ container=True,
288
+ interactive=True,
289
+ choices=STYLE_NAMES,
290
+ value="None",
291
+ label="Design Styles",
292
+ )
293
+ # input image
294
+ with gr.Row():
295
+ with gr.Column():
296
+ image = gr.Image(
297
+ label="Input",
298
+ sources=["upload"],
299
+ show_label=True,
300
+ mirror_webcam=True,
301
+ format="webp",
302
+ )
303
+ # run button
304
+ with gr.Column():
305
+ run_button = gr.Button(value="Use this one", size=["lg"], visible=False)
306
+ # output image
307
+ with gr.Column():
308
+ result = gr.Image(
309
+ label="Output",
310
+ interactive=False,
311
+ format="webp",
312
+ show_share_button= False,
313
+ )
314
+ # Use this image button
315
+ with gr.Column():
316
+ use_ai_button = gr.Button(value="Use this one", size=["lg"], visible=False)
317
+ config = [
318
+ image,
319
+ style_selection,
320
+ prompt,
321
+ a_prompt,
322
+ n_prompt,
323
+ num_images,
324
+ image_resolution,
325
+ preprocess_resolution,
326
+ num_steps,
327
+ guidance_scale,
328
+ seed,
329
+ ]
330
+
331
+ with gr.Row():
332
+ helper_text = gr.Markdown("## Tap and hold (on mobile) to save the image.", visible=True)
333
+
334
+ # image processing
335
+ @gr.on(triggers=[image.upload, prompt.submit, run_button.click], inputs=config, outputs=result, show_progress="minimal")
336
+ def auto_process_image(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
337
+ return process_image(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed)
338
+
339
+ # AI Image Processing
340
+ @gr.on(triggers=[use_ai_button.click], inputs=config, outputs=result, show_progress="minimal")
341
+ def submit(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
342
+ return process_image(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed)
343
+
344
+ # Change input to result
345
+ @gr.on(triggers=[use_ai_button.click], inputs=None, outputs=image, show_progress="hidden")
346
+ def update_input():
347
+ try:
348
+ print("Updating image to AI Temp Image")
349
+ ai_temp_image = Image.open("temp_image.jpg")
350
+ return ai_temp_image
351
+ except FileNotFoundError:
352
+ print("No AI Image Available")
353
+ return None
354
+
355
+ # Turn off buttons when processing
356
+ @gr.on(triggers=[image.upload, use_ai_button.click, run_button.click], inputs=None, outputs=[run_button, use_ai_button], show_progress="hidden")
357
+ def turn_buttons_off():
358
+ return gr.update(visible=False), gr.update(visible=False)
359
+
360
+ # Turn on buttons when processing is complete
361
+ @gr.on(triggers=[result.change], inputs=None, outputs=[use_ai_button, run_button], show_progress="hidden")
362
+ def turn_buttons_on():
363
+ return gr.update(visible=True), gr.update(visible=True)
364
+
365
+ @spaces.GPU(duration=10)
366
+ @torch.inference_mode()
367
+ def process_image(
368
+ image,
369
+ style_selection,
370
+ prompt,
371
+ a_prompt,
372
+ n_prompt,
373
+ num_images,
374
+ image_resolution,
375
+ preprocess_resolution,
376
+ num_steps,
377
+ guidance_scale,
378
+ seed,
379
+ progress=gr.Progress(track_tqdm=True)
380
+ ):
381
+ torch.cuda.synchronize()
382
+ preprocess_start = time.time()
383
+ print("processing image")
384
+ preprocessor.load("NormalBae")
385
+ # preprocessor.load("Canny") #20 steps, 9 guidance, 512, 512
386
+
387
+ global compiled
388
+ if not compiled:
389
+ print("Not Compiled")
390
+ compiled = True
391
+
392
+ seed = random.randint(0, MAX_SEED)
393
+ generator = torch.cuda.manual_seed(seed)
394
+ control_image = preprocessor(
395
+ image=image,
396
+ image_resolution=image_resolution,
397
+ detect_resolution=preprocess_resolution,
398
+ )
399
+ preprocess_time = time.time() - preprocess_start
400
+ if style_selection is not None or style_selection != "None":
401
+ prompt = "Photo from Pinterest of " + apply_style(style_selection) + " " + prompt + " " + a_prompt
402
+ else:
403
+ prompt=str(get_prompt(prompt, a_prompt))
404
+ negative_prompt=str(n_prompt)
405
+ print(prompt)
406
+ start = time.time()
407
+ results = pipe(
408
+ prompt=prompt,
409
+ negative_prompt=negative_prompt,
410
+ guidance_scale=guidance_scale,
411
+ num_images_per_prompt=num_images,
412
+ num_inference_steps=num_steps,
413
+ generator=generator,
414
+ image=control_image,
415
+ ).images[0]
416
+ torch.cuda.synchronize()
417
+ torch.cuda.empty_cache()
418
+ print(f"\n-------------------------Preprocess done in: {preprocess_time:.2f} seconds-------------------------")
419
+ print(f"\n-------------------------Inference done in: {time.time() - start:.2f} seconds-------------------------")
420
+
421
+ # timestamp = int(time.time())
422
+ #if not os.path.exists("./outputs"):
423
+ # os.makedirs("./outputs")
424
+ # img_path = f"./{timestamp}.jpg"
425
+ # results_path = f"./{timestamp}_out_{prompt}.jpg"
426
+ # imageio.imsave(img_path, image)
427
+ # results.save(results_path)
428
+ results.save("temp_image.jpg")
429
+
430
+ # api.upload_file(
431
+ # path_or_fileobj=img_path,
432
+ # path_in_repo=img_path,
433
+ # repo_id="broyang/anime-ai-outputs",
434
+ # repo_type="dataset",
435
+ # token=API_KEY,
436
+ # run_as_future=True,
437
+ # )
438
+ # api.upload_file(
439
+ # path_or_fileobj=results_path,
440
+ # path_in_repo=results_path,
441
+ # repo_id="broyang/anime-ai-outputs",
442
+ # repo_type="dataset",
443
+ # token=API_KEY,
444
+ # run_as_future=True,
445
+ # )
446
+
447
+ return results
448
+ if prod:
449
+ demo.queue(max_size=20).launch(server_name="localhost", server_port=port)
450
+ else:
451
  demo.queue(api_open=False).launch(show_api=False)
app.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e24bf5be8b15309a5f50ad7c94d94dfb5fab0bff4f0baea1f1a67af5cc3f925
3
+ size 13317
app/app.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prod = False
2
+ port = 8080
3
+ show_options = False
4
+ if prod:
5
+ port = 8081
6
+ # show_options = False
7
+
8
+ import os
9
+ import gc
10
+ import random
11
+ import time
12
+ import gradio as gr
13
+ import numpy as np
14
+ # import imageio
15
+ import torch
16
+ from PIL import Image
17
+ from diffusers import (
18
+ ControlNetModel,
19
+ DPMSolverMultistepScheduler,
20
+ StableDiffusionControlNetPipeline,
21
+ AutoencoderKL,
22
+ )
23
+ from diffusers.models.attention_processor import AttnProcessor2_0
24
+ from preprocess import Preprocessor
25
+ MAX_SEED = np.iinfo(np.int32).max
26
+ API_KEY = os.environ.get("API_KEY", None)
27
+
28
+ print("CUDA version:", torch.version.cuda)
29
+ print("loading pipe")
30
+ compiled = False
31
+ # api = HfApi()
32
+
33
+ import spaces
34
+
35
+ preprocessor = Preprocessor()
36
+ preprocessor.load("NormalBae")
37
+
38
+ if gr.NO_RELOAD:
39
+ torch.cuda.max_memory_allocated(device="cuda")
40
+
41
+ # Controlnet Normal
42
+ model_id = "lllyasviel/control_v11p_sd15_normalbae"
43
+ print("initializing controlnet")
44
+ controlnet = ControlNetModel.from_pretrained(
45
+ model_id,
46
+ torch_dtype=torch.float16,
47
+ attn_implementation="flash_attention_2",
48
+ ).to("cuda")
49
+
50
+ # Scheduler
51
+ scheduler = DPMSolverMultistepScheduler.from_pretrained(
52
+ "runwayml/stable-diffusion-v1-5",
53
+ solver_order=2,
54
+ subfolder="scheduler",
55
+ use_karras_sigmas=True,
56
+ final_sigmas_type="sigma_min",
57
+ algorithm_type="sde-dpmsolver++",
58
+ prediction_type="epsilon",
59
+ thresholding=False,
60
+ denoise_final=True,
61
+ device_map="cuda",
62
+ torch_dtype=torch.float16,
63
+ )
64
+
65
+ # Stable Diffusion Pipeline URL
66
+ # base_model_url = "https://huggingface.co/broyang/hentaidigitalart_v20/blob/main/realcartoon3d_v15.safetensors"
67
+ base_model_url = "https://huggingface.co/Lykon/AbsoluteReality/blob/main/AbsoluteReality_1.8.1_pruned.safetensors"
68
+ vae_url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors"
69
+
70
+ vae = AutoencoderKL.from_single_file(vae_url, torch_dtype=torch.float16).to("cuda")
71
+ vae.to(memory_format=torch.channels_last)
72
+
73
+ pipe = StableDiffusionControlNetPipeline.from_single_file(
74
+ base_model_url,
75
+ # safety_checker=None,
76
+ # load_safety_checker=True,
77
+ controlnet=controlnet,
78
+ scheduler=scheduler,
79
+ vae=vae,
80
+ torch_dtype=torch.float16,
81
+ )
82
+
83
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="EasyNegativeV2.safetensors", token="EasyNegativeV2",)
84
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="badhandv4.pt", token="badhandv4")
85
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="fcNeg-neg.pt", token="fcNeg-neg")
86
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Ahegao.pt", token="HDA_Ahegao")
87
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Bondage.pt", token="HDA_Bondage")
88
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_pet_play.pt", token="HDA_pet_play")
89
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_unconventional maid.pt", token="HDA_unconventional_maid")
90
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_NakedHoodie.pt", token="HDA_NakedHoodie")
91
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_NunDress.pt", token="HDA_NunDress")
92
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Shibari.pt", token="HDA_Shibari")
93
+ pipe.to("cuda")
94
+
95
+ # experimental speedup?
96
+ # pipe.compile()
97
+ # torch.cuda.empty_cache()
98
+ # gc.collect()
99
+ print("---------------Loaded controlnet pipeline---------------")
100
+
101
+ @spaces.GPU(duration=12)
102
+ def init(pipe):
103
+ pipe.enable_xformers_memory_efficient_attention()
104
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
105
+ pipe.unet.set_attn_processor(AttnProcessor2_0())
106
+ print("Model Compiled!")
107
+ init(pipe)
108
+
109
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
110
+ if randomize_seed:
111
+ seed = random.randint(0, MAX_SEED)
112
+ return seed
113
+
114
+ def get_additional_prompt():
115
+ prompt = "hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
116
+ top = ["tank top", "blouse", "button up shirt", "sweater", "corset top"]
117
+ bottom = ["short skirt", "athletic shorts", "jean shorts", "pleated skirt", "short skirt", "leggings", "high-waisted shorts"]
118
+ accessory = ["knee-high boots", "gloves", "Thigh-high stockings", "Garter belt", "choker", "necklace", "headband", "headphones"]
119
+ return f"{prompt}, {random.choice(top)}, {random.choice(bottom)}, {random.choice(accessory)}, score_9"
120
+ # outfit = ["schoolgirl outfit", "playboy outfit", "red dress", "gala dress", "cheerleader outfit", "nurse outfit", "Kimono"]
121
+
122
+ def get_prompt(prompt, additional_prompt):
123
+ interior = "design-style interior designed (interior space), captured with a DSLR camera using f/10 aperture, 1/60 sec shutter speed, ISO 400, 20mm focal length, tungsten white balance, (sharp focus), professional photography, high-resolution, 8k, Pulitzer Prize-winning"
124
+ default = "hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
125
+ default2 = f"professional 3d model {prompt},octane render,highly detailed,volumetric,dramatic lighting,hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
126
+ randomize = get_additional_prompt()
127
+ # nude = "NSFW,((nude)),medium bare breasts,hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
128
+ # bodypaint = "((fully naked with no clothes)),nude naked seethroughxray,invisiblebodypaint,rating_newd,NSFW"
129
+ lab_girl = "hyperrealistic photography, extremely detailed, shy assistant wearing minidress boots and gloves, laboratory background, score_9, 1girl"
130
+ pet_play = "hyperrealistic photography, extremely detailed, playful, blush, glasses, collar, score_9, HDA_pet_play"
131
+ bondage = "hyperrealistic photography, extremely detailed, submissive, glasses, score_9, HDA_Bondage"
132
+ # ahegao = "((invisible clothing)), hyperrealistic photography,exposed vagina,sexy,nsfw,HDA_Ahegao"
133
+ ahegao2 = "(invisiblebodypaint),rating_newd,HDA_Ahegao"
134
+ athleisure = "hyperrealistic photography, extremely detailed, 1girl athlete, exhausted embarrassed sweaty,outdoors, ((athleisure clothing)), score_9"
135
+ atompunk = "((atompunk world)), hyperrealistic photography, extremely detailed, short hair, bodysuit, glasses, neon cyberpunk background, score_9"
136
+ maid = "hyperrealistic photography, extremely detailed, shy, blushing, score_9, pastel background, HDA_unconventional_maid"
137
+ nundress = "hyperrealistic photography, extremely detailed, shy, blushing, fantasy background, score_9, HDA_NunDress"
138
+ naked_hoodie = "hyperrealistic photography, extremely detailed, medium hair, cityscape, (neon lights), score_9, HDA_NakedHoodie"
139
+ abg = "(1girl, asian body covered in words, words on body, tattoos of (words) on body),(masterpiece, best quality),medium breasts,(intricate details),unity 8k wallpaper,ultra detailed,(pastel colors),beautiful and aesthetic,see-through (clothes),detailed,solo"
140
+ # shibari = "extremely detailed, hyperrealistic photography, earrings, blushing, lace choker, tattoo, medium hair, score_9, HDA_Shibari"
141
+ shibari2 = "octane render, highly detailed, volumetric, HDA_Shibari"
142
+
143
+ if prompt == "":
144
+ girls = [randomize, pet_play, bondage, lab_girl, athleisure, atompunk, maid, nundress, naked_hoodie, abg, shibari2, ahegao2]
145
+ prompts_nsfw = [abg, shibari2, ahegao2]
146
+ prompt = f"{random.choice(girls)}"
147
+ prompt = f"boho chic"
148
+ # print(f"-------------{preset}-------------")
149
+ else:
150
+ prompt = f"Photo from Pinterest of {prompt} {interior}"
151
+ # prompt = default2
152
+ return f"{prompt} f{additional_prompt}"
153
+
154
+ style_list = [
155
+ {
156
+ "name": "None",
157
+ "prompt": ""
158
+ },
159
+ {
160
+ "name": "Minimalistic",
161
+ "prompt": "Minimalistic"
162
+ },
163
+ {
164
+ "name": "Boho Chic",
165
+ "prompt": "boho chic"
166
+ },
167
+ {
168
+ "name": "Saudi Prince Gold",
169
+ "prompt": "saudi prince gold",
170
+ },
171
+ {
172
+ "name": "Modern Farmhouse",
173
+ "prompt": "modern farmhouse",
174
+ },
175
+ {
176
+ "name": "Neoclassical",
177
+ "prompt": "Neoclassical",
178
+ },
179
+ {
180
+ "name": "Eclectic",
181
+ "prompt": "Eclectic",
182
+ },
183
+ {
184
+ "name": "Parisian White",
185
+ "prompt": "Parisian White",
186
+ },
187
+ {
188
+ "name": "Hollywood Glam",
189
+ "prompt": "Hollywood Glam",
190
+ },
191
+ {
192
+ "name": "Scandinavian",
193
+ "prompt": "Scandinavian",
194
+ },
195
+ {
196
+ "name": "Japanese",
197
+ "prompt": "Japanese",
198
+ },
199
+ {
200
+ "name": "Texas Cowboy",
201
+ "prompt": "Texas Cowboy",
202
+ },
203
+ {
204
+ "name": "Midcentury Modern",
205
+ "prompt": "Midcentury Modern",
206
+ },
207
+ {
208
+ "name": "Beach",
209
+ "prompt": "Beach",
210
+ },
211
+ ]
212
+
213
+ styles = {k["name"]: (k["prompt"]) for k in style_list}
214
+ STYLE_NAMES = list(styles.keys())
215
+
216
+ def apply_style(style_name):
217
+ if style_name in styles:
218
+ p = styles.get(style_name, "boho chic")
219
+ return p
220
+
221
+
222
+ css = """
223
+ h1 {
224
+ text-align: center;
225
+ display:block;
226
+ }
227
+ h2 {
228
+ text-align: center;
229
+ display:block;
230
+ }
231
+ h3 {
232
+ text-align: center;
233
+ display:block;
234
+ }
235
+ .gradio-container{max-width: 1200px !important}
236
+ footer {visibility: hidden}
237
+ """
238
+ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
239
+ #############################################################################
240
+ with gr.Row():
241
+ with gr.Accordion("Advanced options", open=show_options, visible=show_options):
242
+ num_images = gr.Slider(
243
+ label="Images", minimum=1, maximum=4, value=1, step=1
244
+ )
245
+ image_resolution = gr.Slider(
246
+ label="Image resolution",
247
+ minimum=256,
248
+ maximum=1024,
249
+ value=512,
250
+ step=256,
251
+ )
252
+ preprocess_resolution = gr.Slider(
253
+ label="Preprocess resolution",
254
+ minimum=128,
255
+ maximum=1024,
256
+ value=512,
257
+ step=1,
258
+ )
259
+ num_steps = gr.Slider(
260
+ label="Number of steps", minimum=1, maximum=100, value=15, step=1
261
+ ) # 20/4.5 or 12 without lora, 4 with lora
262
+ guidance_scale = gr.Slider(
263
+ label="Guidance scale", minimum=0.1, maximum=30.0, value=5.5, step=0.1
264
+ ) # 5 without lora, 2 with lora
265
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
266
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
267
+ a_prompt = gr.Textbox(
268
+ label="Additional prompt",
269
+ value = "design-style interior designed (interior space), captured with a DSLR camera using f/10 aperture, 1/60 sec shutter speed, ISO 400, 20mm focal length, tungsten white balance, (sharp focus), professional photography, high-resolution, 8k, Pulitzer Prize-winning"
270
+ )
271
+ n_prompt = gr.Textbox(
272
+ label="Negative prompt",
273
+ value="EasyNegativeV2, fcNeg, (badhandv4:1.4), (worst quality, low quality, bad quality, normal quality:2.0), (bad hands, missing fingers, extra fingers:2.0)",
274
+ )
275
+ #############################################################################
276
+ # input text
277
+ with gr.Row():
278
+ gr.Text(label="Interior Design Style Examples", value="Eclectic, Maximalist, Bohemian, Scandinavian, Minimalist, Rustic, Modern Farmhouse, Contemporary, Luxury, Airbnb, Boho Chic, Midcentury Modern, Art Deco, Zen, Beach, Neoclassical, Industrial, Biophilic, Eco-friendly, Hollywood Glam, Parisian White, Saudi Prince Gold, French Country, Monster Energy Drink, Cyberpunk, Vaporwave, Baroque, etc.\n\nPro tip: add a color to customize it! You can also describe the furniture type.")
279
+ with gr.Column():
280
+ prompt = gr.Textbox(
281
+ label="Custom Prompt",
282
+ placeholder="boho chic",
283
+ )
284
+ with gr.Row(visible=True):
285
+ style_selection = gr.Radio(
286
+ show_label=True,
287
+ container=True,
288
+ interactive=True,
289
+ choices=STYLE_NAMES,
290
+ value="None",
291
+ label="Design Styles",
292
+ )
293
+ # input image
294
+ with gr.Row():
295
+ with gr.Column():
296
+ image = gr.Image(
297
+ label="Input",
298
+ sources=["upload"],
299
+ show_label=True,
300
+ mirror_webcam=True,
301
+ format="webp",
302
+ )
303
+ # run button
304
+ with gr.Column():
305
+ run_button = gr.Button(value="Use this one", size=["lg"], visible=False)
306
+ # output image
307
+ with gr.Column():
308
+ result = gr.Image(
309
+ label="Output",
310
+ interactive=False,
311
+ format="webp",
312
+ show_share_button= False,
313
+ )
314
+ # Use this image button
315
+ with gr.Column():
316
+ use_ai_button = gr.Button(value="Use this one", size=["lg"], visible=False)
317
+ config = [
318
+ image,
319
+ style_selection,
320
+ prompt,
321
+ a_prompt,
322
+ n_prompt,
323
+ num_images,
324
+ image_resolution,
325
+ preprocess_resolution,
326
+ num_steps,
327
+ guidance_scale,
328
+ seed,
329
+ ]
330
+
331
+ with gr.Row():
332
+ helper_text = gr.Markdown("## Tap and hold (on mobile) to save the image.", visible=True)
333
+
334
+ # image processing
335
+ @gr.on(triggers=[image.upload, prompt.submit, run_button.click], inputs=config, outputs=result, show_progress="minimal")
336
+ def auto_process_image(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
337
+ return process_image(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed)
338
+
339
+ # AI Image Processing
340
+ @gr.on(triggers=[use_ai_button.click], inputs=config, outputs=result, show_progress="minimal")
341
+ def submit(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
342
+ return process_image(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed)
343
+
344
+ # Change input to result
345
+ @gr.on(triggers=[use_ai_button.click], inputs=None, outputs=image, show_progress="hidden")
346
+ def update_input():
347
+ try:
348
+ print("Updating image to AI Temp Image")
349
+ ai_temp_image = Image.open("temp_image.jpg")
350
+ return ai_temp_image
351
+ except FileNotFoundError:
352
+ print("No AI Image Available")
353
+ return None
354
+
355
+ # Turn off buttons when processing
356
+ @gr.on(triggers=[image.upload, use_ai_button.click, run_button.click], inputs=None, outputs=[run_button, use_ai_button], show_progress="hidden")
357
+ def turn_buttons_off():
358
+ return gr.update(visible=False), gr.update(visible=False)
359
+
360
+ # Turn on buttons when processing is complete
361
+ @gr.on(triggers=[result.change], inputs=None, outputs=[use_ai_button, run_button], show_progress="hidden")
362
+ def turn_buttons_on():
363
+ return gr.update(visible=True), gr.update(visible=True)
364
+
365
+ @spaces.GPU(duration=10)
366
+ @torch.inference_mode()
367
+ def process_image(
368
+ image,
369
+ style_selection,
370
+ prompt,
371
+ a_prompt,
372
+ n_prompt,
373
+ num_images,
374
+ image_resolution,
375
+ preprocess_resolution,
376
+ num_steps,
377
+ guidance_scale,
378
+ seed,
379
+ progress=gr.Progress(track_tqdm=True)
380
+ ):
381
+ torch.cuda.synchronize()
382
+ preprocess_start = time.time()
383
+ print("processing image")
384
+ preprocessor.load("NormalBae")
385
+ # preprocessor.load("Canny") #20 steps, 9 guidance, 512, 512
386
+
387
+ global compiled
388
+ if not compiled:
389
+ print("Not Compiled")
390
+ compiled = True
391
+
392
+ seed = random.randint(0, MAX_SEED)
393
+ generator = torch.cuda.manual_seed(seed)
394
+ control_image = preprocessor(
395
+ image=image,
396
+ image_resolution=image_resolution,
397
+ detect_resolution=preprocess_resolution,
398
+ )
399
+ preprocess_time = time.time() - preprocess_start
400
+ if style_selection is not None or style_selection != "None":
401
+ prompt = "Photo from Pinterest of " + apply_style(style_selection) + " " + prompt + " " + a_prompt
402
+ else:
403
+ prompt=str(get_prompt(prompt, a_prompt))
404
+ negative_prompt=str(n_prompt)
405
+ print(prompt)
406
+ start = time.time()
407
+ results = pipe(
408
+ prompt=prompt,
409
+ negative_prompt=negative_prompt,
410
+ guidance_scale=guidance_scale,
411
+ num_images_per_prompt=num_images,
412
+ num_inference_steps=num_steps,
413
+ generator=generator,
414
+ image=control_image,
415
+ ).images[0]
416
+ torch.cuda.synchronize()
417
+ torch.cuda.empty_cache()
418
+ print(f"\n-------------------------Preprocess done in: {preprocess_time:.2f} seconds-------------------------")
419
+ print(f"\n-------------------------Inference done in: {time.time() - start:.2f} seconds-------------------------")
420
+
421
+ # timestamp = int(time.time())
422
+ #if not os.path.exists("./outputs"):
423
+ # os.makedirs("./outputs")
424
+ # img_path = f"./{timestamp}.jpg"
425
+ # results_path = f"./{timestamp}_out_{prompt}.jpg"
426
+ # imageio.imsave(img_path, image)
427
+ # results.save(results_path)
428
+ results.save("temp_image.jpg")
429
+
430
+ # api.upload_file(
431
+ # path_or_fileobj=img_path,
432
+ # path_in_repo=img_path,
433
+ # repo_id="broyang/anime-ai-outputs",
434
+ # repo_type="dataset",
435
+ # token=API_KEY,
436
+ # run_as_future=True,
437
+ # )
438
+ # api.upload_file(
439
+ # path_or_fileobj=results_path,
440
+ # path_in_repo=results_path,
441
+ # repo_id="broyang/anime-ai-outputs",
442
+ # repo_type="dataset",
443
+ # token=API_KEY,
444
+ # run_as_future=True,
445
+ # )
446
+
447
+ return results
448
+ if prod:
449
+ demo.queue(max_size=20).launch(server_name="localhost", server_port=port)
450
+ else:
451
+ demo.queue(api_open=False).launch(show_api=False)
app/local_app.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prod = True
2
+ port = 8080
3
+ show_options = False
4
+ if prod:
5
+ port = 8081
6
+ # show_options = False
7
+
8
+ import os
9
+ import gc
10
+ import random
11
+ import time
12
+ import gradio as gr
13
+ import numpy as np
14
+ # import imageio
15
+ import torch
16
+ from PIL import Image
17
+ from diffusers import (
18
+ ControlNetModel,
19
+ DPMSolverMultistepScheduler,
20
+ StableDiffusionControlNetPipeline,
21
+ AutoencoderKL,
22
+ )
23
+ from diffusers.models.attention_processor import AttnProcessor2_0
24
+ from local_preprocess import Preprocessor
25
+ MAX_SEED = np.iinfo(np.int32).max
26
+ API_KEY = os.environ.get("API_KEY", None)
27
+
28
+ print("CUDA version:", torch.version.cuda)
29
+ print("loading pipe")
30
+ compiled = False
31
+
32
+ preprocessor = Preprocessor()
33
+ preprocessor.load("NormalBae")
34
+
35
+ if gr.NO_RELOAD:
36
+ # torch.cuda.max_memory_allocated(device="cuda")
37
+
38
+ # Controlnet Normal
39
+ model_id = "lllyasviel/control_v11p_sd15_normalbae"
40
+ print("initializing controlnet")
41
+ controlnet = ControlNetModel.from_pretrained(
42
+ model_id,
43
+ torch_dtype=torch.float16,
44
+ attn_implementation="flash_attention_2",
45
+ ).to("cuda")
46
+
47
+ # Scheduler
48
+ scheduler = DPMSolverMultistepScheduler.from_pretrained(
49
+ "stabilityai/stable-diffusion-xl-base-1.0",
50
+ subfolder="scheduler",
51
+ use_karras_sigmas=True,
52
+ # final_sigmas_type="sigma_min",
53
+ algorithm_type="sde-dpmsolver++",
54
+ # prediction_type="epsilon",
55
+ # thresholding=False,
56
+ denoise_final=True,
57
+ device_map="cuda",
58
+ attn_implementation="flash_attention_2",
59
+ )
60
+
61
+ # Stable Diffusion Pipeline URL
62
+ # base_model_url = "https://huggingface.co/broyang/hentaidigitalart_v20/blob/main/realcartoon3d_v15.safetensors"
63
+ base_model_url = "https://huggingface.co/Lykon/AbsoluteReality/blob/main/AbsoluteReality_1.8.1_pruned.safetensors"
64
+ base_model_id = "Lykon/absolute-reality-1.81"
65
+ vae_url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors"
66
+
67
+ vae = AutoencoderKL.from_single_file(vae_url, torch_dtype=torch.float16).to("cuda")
68
+ vae.to(memory_format=torch.channels_last)
69
+
70
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
71
+ base_model_id,
72
+ safety_checker=None,
73
+ controlnet=controlnet,
74
+ scheduler=scheduler,
75
+ vae=vae,
76
+ torch_dtype=torch.float16,
77
+ ).to("cuda")
78
+
79
+ # pipe = StableDiffusionControlNetPipeline.from_single_file(
80
+ # base_model_url,
81
+ # controlnet=controlnet,
82
+ # scheduler=scheduler,
83
+ # vae=vae,
84
+ # torch_dtype=torch.float16,
85
+ # )
86
+
87
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="EasyNegativeV2.safetensors", token="EasyNegativeV2",)
88
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="badhandv4.pt", token="badhandv4")
89
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="fcNeg-neg.pt", token="fcNeg-neg")
90
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Ahegao.pt", token="HDA_Ahegao")
91
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Bondage.pt", token="HDA_Bondage")
92
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_pet_play.pt", token="HDA_pet_play")
93
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_unconventional maid.pt", token="HDA_unconventional_maid")
94
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_NakedHoodie.pt", token="HDA_NakedHoodie")
95
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_NunDress.pt", token="HDA_NunDress")
96
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Shibari.pt", token="HDA_Shibari")
97
+ pipe.to("cuda")
98
+
99
+ # experimental speedup?
100
+ # pipe.compile()
101
+ # torch.cuda.empty_cache()
102
+ # gc.collect()
103
+ print("---------------Loaded controlnet pipeline---------------")
104
+
105
+ # @spaces.GPU(duration=12)
106
+ # pipe.enable_xformers_memory_efficient_attention()
107
+ # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
108
+ # pipe.unet.set_attn_processor(AttnProcessor2_0())
109
+ torch.cuda.empty_cache()
110
+ gc.collect()
111
+ print("Model Compiled!")
112
+
113
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
114
+ if randomize_seed:
115
+ seed = random.randint(0, MAX_SEED)
116
+ return seed
117
+
118
+ def get_additional_prompt():
119
+ prompt = "hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
120
+ top = ["tank top", "blouse", "button up shirt", "sweater", "corset top"]
121
+ bottom = ["short skirt", "athletic shorts", "jean shorts", "pleated skirt", "short skirt", "leggings", "high-waisted shorts"]
122
+ accessory = ["knee-high boots", "gloves", "Thigh-high stockings", "Garter belt", "choker", "necklace", "headband", "headphones"]
123
+ return f"{prompt}, {random.choice(top)}, {random.choice(bottom)}, {random.choice(accessory)}, score_9"
124
+ # outfit = ["schoolgirl outfit", "playboy outfit", "red dress", "gala dress", "cheerleader outfit", "nurse outfit", "Kimono"]
125
+
126
+ def get_prompt(prompt, additional_prompt):
127
+ interior = "design-style interior designed (interior space), captured with a DSLR camera using f/10 aperture, 1/60 sec shutter speed, ISO 400, 20mm focal length, tungsten white balance, (sharp focus), professional photography, high-resolution, 8k, Pulitzer Prize-winning"
128
+ default = "hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
129
+ default2 = f"professional 3d model {prompt},octane render,highly detailed,volumetric,dramatic lighting,hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
130
+ randomize = get_additional_prompt()
131
+ # nude = "NSFW,((nude)),medium bare breasts,hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
132
+ # bodypaint = "((fully naked with no clothes)),nude naked seethroughxray,invisiblebodypaint,rating_newd,NSFW"
133
+ lab_girl = "hyperrealistic photography, extremely detailed, shy assistant wearing minidress boots and gloves, laboratory background, score_9, 1girl"
134
+ pet_play = "hyperrealistic photography, extremely detailed, playful, blush, glasses, collar, score_9, HDA_pet_play"
135
+ bondage = "hyperrealistic photography, extremely detailed, submissive, glasses, score_9, HDA_Bondage"
136
+ # ahegao = "((invisible clothing)), hyperrealistic photography,exposed vagina,sexy,nsfw,HDA_Ahegao"
137
+ ahegao2 = "(invisiblebodypaint),rating_newd,HDA_Ahegao"
138
+ athleisure = "hyperrealistic photography, extremely detailed, 1girl athlete, exhausted embarrassed sweaty,outdoors, ((athleisure clothing)), score_9"
139
+ atompunk = "((atompunk world)), hyperrealistic photography, extremely detailed, short hair, bodysuit, glasses, neon cyberpunk background, score_9"
140
+ maid = "hyperrealistic photography, extremely detailed, shy, blushing, score_9, pastel background, HDA_unconventional_maid"
141
+ nundress = "hyperrealistic photography, extremely detailed, shy, blushing, fantasy background, score_9, HDA_NunDress"
142
+ naked_hoodie = "hyperrealistic photography, extremely detailed, medium hair, cityscape, (neon lights), score_9, HDA_NakedHoodie"
143
+ abg = "(1girl, asian body covered in words, words on body, tattoos of (words) on body),(masterpiece, best quality),medium breasts,(intricate details),unity 8k wallpaper,ultra detailed,(pastel colors),beautiful and aesthetic,see-through (clothes),detailed,solo"
144
+ # shibari = "extremely detailed, hyperrealistic photography, earrings, blushing, lace choker, tattoo, medium hair, score_9, HDA_Shibari"
145
+ shibari2 = "octane render, highly detailed, volumetric, HDA_Shibari"
146
+
147
+ if prompt == "":
148
+ girls = [randomize, pet_play, bondage, lab_girl, athleisure, atompunk, maid, nundress, naked_hoodie, abg, shibari2, ahegao2]
149
+ prompts_nsfw = [abg, shibari2, ahegao2]
150
+ prompt = f"{random.choice(girls)}"
151
+ prompt = f"boho chic"
152
+ # print(f"-------------{preset}-------------")
153
+ else:
154
+ prompt = f"Photo from Pinterest of {prompt} {interior}"
155
+ # prompt = default2
156
+ return f"{prompt} f{additional_prompt}"
157
+
158
+ style_list = [
159
+ {
160
+ "name": "None",
161
+ "prompt": ""
162
+ },
163
+ {
164
+ "name": "Minimalistic",
165
+ "prompt": "Minimalistic"
166
+ },
167
+ {
168
+ "name": "Boho Chic",
169
+ "prompt": "boho chic"
170
+ },
171
+ {
172
+ "name": "Saudi Prince Gold",
173
+ "prompt": "saudi prince gold",
174
+ },
175
+ {
176
+ "name": "Modern Farmhouse",
177
+ "prompt": "modern farmhouse",
178
+ },
179
+ {
180
+ "name": "Neoclassical",
181
+ "prompt": "Neoclassical",
182
+ },
183
+ {
184
+ "name": "Eclectic",
185
+ "prompt": "Eclectic",
186
+ },
187
+ {
188
+ "name": "Parisian White",
189
+ "prompt": "Parisian White",
190
+ },
191
+ {
192
+ "name": "Hollywood Glam",
193
+ "prompt": "Hollywood Glam",
194
+ },
195
+ {
196
+ "name": "Scandinavian",
197
+ "prompt": "Scandinavian",
198
+ },
199
+ {
200
+ "name": "Japanese",
201
+ "prompt": "Japanese",
202
+ },
203
+ {
204
+ "name": "Texas Cowboy",
205
+ "prompt": "Texas Cowboy",
206
+ },
207
+ {
208
+ "name": "Midcentury Modern",
209
+ "prompt": "Midcentury Modern",
210
+ },
211
+ {
212
+ "name": "Beach",
213
+ "prompt": "Beach",
214
+ },
215
+ ]
216
+
217
+ styles = {k["name"]: (k["prompt"]) for k in style_list}
218
+ STYLE_NAMES = list(styles.keys())
219
+
220
+ def apply_style(style_name):
221
+ if style_name in styles:
222
+ p = styles.get(style_name, "boho chic")
223
+ return p
224
+
225
+
226
+ css = """
227
+ h1 {
228
+ text-align: center;
229
+ display:block;
230
+ }
231
+ h2 {
232
+ text-align: center;
233
+ display:block;
234
+ }
235
+ h3 {
236
+ text-align: center;
237
+ display:block;
238
+ }
239
+ .gradio-container{max-width: 1200px !important}
240
+ footer {visibility: hidden}
241
+ """
242
+ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
243
+ #############################################################################
244
+ with gr.Row():
245
+ with gr.Accordion("Advanced options", open=show_options, visible=show_options):
246
+ num_images = gr.Slider(
247
+ label="Images", minimum=1, maximum=4, value=1, step=1
248
+ )
249
+ image_resolution = gr.Slider(
250
+ label="Image resolution",
251
+ minimum=256,
252
+ maximum=1024,
253
+ value=512,
254
+ step=256,
255
+ )
256
+ preprocess_resolution = gr.Slider(
257
+ label="Preprocess resolution",
258
+ minimum=128,
259
+ maximum=1024,
260
+ value=512,
261
+ step=1,
262
+ )
263
+ num_steps = gr.Slider(
264
+ label="Number of steps", minimum=1, maximum=100, value=15, step=1
265
+ ) # 20/4.5 or 12 without lora, 4 with lora
266
+ guidance_scale = gr.Slider(
267
+ label="Guidance scale", minimum=0.1, maximum=30.0, value=5.5, step=0.1
268
+ ) # 5 without lora, 2 with lora
269
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
270
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
271
+ a_prompt = gr.Textbox(
272
+ label="Additional prompt",
273
+ value = "design-style interior designed (interior space), captured with a DSLR camera using f/10 aperture, 1/60 sec shutter speed, ISO 400, 20mm focal length, tungsten white balance, (sharp focus), professional photography, high-resolution, 8k, Pulitzer Prize-winning"
274
+ )
275
+ n_prompt = gr.Textbox(
276
+ label="Negative prompt",
277
+ value="EasyNegativeV2, fcNeg, (badhandv4:1.4), (worst quality, low quality, bad quality, normal quality:2.0), (bad hands, missing fingers, extra fingers:2.0)",
278
+ )
279
+ #############################################################################
280
+ # input text
281
+ with gr.Row():
282
+ gr.Text(label="Interior Design Style Examples", value="Eclectic, Maximalist, Bohemian, Scandinavian, Minimalist, Rustic, Modern Farmhouse, Contemporary, Luxury, Airbnb, Boho Chic, Midcentury Modern, Art Deco, Zen, Beach, Neoclassical, Industrial, Biophilic, Eco-friendly, Hollywood Glam, Parisian White, Saudi Prince Gold, French Country, Monster Energy Drink, Cyberpunk, Vaporwave, Baroque, etc.\n\nPro tip: add a color to customize it! You can also describe the furniture type.")
283
+ with gr.Column():
284
+ prompt = gr.Textbox(
285
+ label="Custom Prompt",
286
+ placeholder="boho chic",
287
+ )
288
+ with gr.Row(visible=True):
289
+ style_selection = gr.Radio(
290
+ show_label=True,
291
+ container=True,
292
+ interactive=True,
293
+ choices=STYLE_NAMES,
294
+ value="None",
295
+ label="Design Styles",
296
+ )
297
+ # input image
298
+ with gr.Row():
299
+ with gr.Column():
300
+ image = gr.Image(
301
+ label="Input",
302
+ sources=["upload"],
303
+ show_label=True,
304
+ mirror_webcam=True,
305
+ format="webp",
306
+ )
307
+ # run button
308
+ with gr.Column():
309
+ run_button = gr.Button(value="Use this one", size=["lg"], visible=False)
310
+ # output image
311
+ with gr.Column():
312
+ result = gr.Image(
313
+ label="Output",
314
+ interactive=False,
315
+ format="webp",
316
+ show_share_button= False,
317
+ )
318
+ # Use this image button
319
+ with gr.Column():
320
+ use_ai_button = gr.Button(value="Use this one", size=["lg"], visible=False)
321
+ config = [
322
+ image,
323
+ style_selection,
324
+ prompt,
325
+ a_prompt,
326
+ n_prompt,
327
+ num_images,
328
+ image_resolution,
329
+ preprocess_resolution,
330
+ num_steps,
331
+ guidance_scale,
332
+ seed,
333
+ ]
334
+
335
+ with gr.Row():
336
+ helper_text = gr.Markdown("## Tap and hold (on mobile) to save the image.", visible=True)
337
+
338
+ # image processing
339
+ @gr.on(triggers=[image.upload, prompt.submit, run_button.click], inputs=config, outputs=result, show_progress="minimal")
340
+ def auto_process_image(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
341
+ return process_image(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed)
342
+
343
+ # AI Image Processing
344
+ @gr.on(triggers=[use_ai_button.click], inputs=config, outputs=result, show_progress="minimal")
345
+ def submit(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
346
+ return process_image(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed)
347
+
348
+ # Change input to result
349
+ @gr.on(triggers=[use_ai_button.click], inputs=None, outputs=image, show_progress="hidden")
350
+ def update_input():
351
+ try:
352
+ print("Updating image to AI Temp Image")
353
+ ai_temp_image = Image.open("temp_image.jpg")
354
+ return ai_temp_image
355
+ except FileNotFoundError:
356
+ print("No AI Image Available")
357
+ return None
358
+
359
+ # Turn off buttons when processing
360
+ @gr.on(triggers=[image.upload, use_ai_button.click, run_button.click], inputs=None, outputs=[run_button, use_ai_button], show_progress="hidden")
361
+ def turn_buttons_off():
362
+ return gr.update(visible=False), gr.update(visible=False)
363
+
364
+ # Turn on buttons when processing is complete
365
+ @gr.on(triggers=[result.change], inputs=None, outputs=[use_ai_button, run_button], show_progress="hidden")
366
+ def turn_buttons_on():
367
+ return gr.update(visible=True), gr.update(visible=True)
368
+
369
+ # @spaces.GPU(duration=12)
370
+ @torch.inference_mode()
371
+ def process_image(
372
+ image,
373
+ style_selection,
374
+ prompt,
375
+ a_prompt,
376
+ n_prompt,
377
+ num_images,
378
+ image_resolution,
379
+ preprocess_resolution,
380
+ num_steps,
381
+ guidance_scale,
382
+ seed,
383
+ progress=gr.Progress(track_tqdm=True)
384
+ ):
385
+ torch.cuda.synchronize()
386
+ preprocess_start = time.time()
387
+ print("processing image")
388
+ preprocessor.load("NormalBae")
389
+ # preprocessor.load("Canny") #20 steps, 9 guidance, 512, 512
390
+
391
+ global compiled
392
+ if not compiled:
393
+ print("Not Compiled")
394
+ compiled = True
395
+
396
+ seed = random.randint(0, MAX_SEED)
397
+ generator = torch.cuda.manual_seed(seed)
398
+ control_image = preprocessor(
399
+ image=image,
400
+ image_resolution=image_resolution,
401
+ detect_resolution=preprocess_resolution,
402
+ )
403
+ preprocess_time = time.time() - preprocess_start
404
+ if style_selection is not None or style_selection != "None":
405
+ prompt = "Photo from Pinterest of " + apply_style(style_selection) + " " + prompt + " " + a_prompt
406
+ else:
407
+ prompt=str(get_prompt(prompt, a_prompt))
408
+ negative_prompt=str(n_prompt)
409
+ print(prompt)
410
+ start = time.time()
411
+ results = pipe(
412
+ prompt=prompt,
413
+ negative_prompt=negative_prompt,
414
+ guidance_scale=guidance_scale,
415
+ num_images_per_prompt=num_images,
416
+ num_inference_steps=num_steps,
417
+ generator=generator,
418
+ image=control_image,
419
+ ).images[0]
420
+ torch.cuda.synchronize()
421
+ torch.cuda.empty_cache()
422
+ print(f"\n-------------------------Preprocess done in: {preprocess_time:.2f} seconds-------------------------")
423
+ print(f"\n-------------------------Inference done in: {time.time() - start:.2f} seconds-------------------------")
424
+
425
+ # timestamp = int(time.time())
426
+ #if not os.path.exists("./outputs"):
427
+ # os.makedirs("./outputs")
428
+ # img_path = f"./{timestamp}.jpg"
429
+ # results_path = f"./{timestamp}_out_{prompt}.jpg"
430
+ # imageio.imsave(img_path, image)
431
+ # results.save(results_path)
432
+ results.save("temp_image.jpg")
433
+
434
+ # api.upload_file(
435
+ # path_or_fileobj=img_path,
436
+ # path_in_repo=img_path,
437
+ # repo_id="broyang/anime-ai-outputs",
438
+ # repo_type="dataset",
439
+ # token=API_KEY,
440
+ # run_as_future=True,
441
+ # )
442
+ # api.upload_file(
443
+ # path_or_fileobj=results_path,
444
+ # path_in_repo=results_path,
445
+ # repo_id="broyang/anime-ai-outputs",
446
+ # repo_type="dataset",
447
+ # token=API_KEY,
448
+ # run_as_future=True,
449
+ # )
450
+
451
+ return results
452
+ if prod:
453
+ demo.queue(max_size=20).launch(server_name="localhost", server_port=port)
454
+ else:
455
+ demo.queue(api_open=False).launch(show_api=False)
app/local_preprocess.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import numpy as np
2
+ import PIL.Image
3
+ import torch
4
+ import gc
5
+ # from controlnet_aux_local import NormalBaeDetector#, CannyDetector
6
+ from controlnet_aux import NormalBaeDetector
7
+
8
+ # from controlnet_aux.util import HWC3
9
+ # import cv2
10
+ # from cv_utils import resize_image
11
+
12
+ class Preprocessor:
13
+ MODEL_ID = "lllyasviel/Annotators"
14
+
15
+ # def resize_image(input_image, resolution, interpolation=None):
16
+ # H, W, C = input_image.shape
17
+ # H = float(H)
18
+ # W = float(W)
19
+ # k = float(resolution) / max(H, W)
20
+ # H *= k
21
+ # W *= k
22
+ # H = int(np.round(H / 64.0)) * 64
23
+ # W = int(np.round(W / 64.0)) * 64
24
+ # if interpolation is None:
25
+ # interpolation = cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA
26
+ # img = cv2.resize(input_image, (W, H), interpolation=interpolation)
27
+ # return img
28
+
29
+
30
+ def __init__(self):
31
+ self.model = None
32
+ self.name = ""
33
+
34
+ def load(self, name: str) -> None:
35
+ if name == self.name:
36
+ return
37
+ elif name == "NormalBae":
38
+ print("Loading NormalBae")
39
+ self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID).to("cuda")
40
+ # elif name == "Canny":
41
+ # self.model = CannyDetector()
42
+ else:
43
+ raise ValueError
44
+ torch.cuda.empty_cache()
45
+ gc.collect()
46
+
47
+ self.name = name
48
+
49
+ def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image:
50
+ # if self.name == "Canny":
51
+ # if "detect_resolution" in kwargs:
52
+ # detect_resolution = kwargs.pop("detect_resolution")
53
+ # image = np.array(image)
54
+ # image = HWC3(image)
55
+ # image = resize_image(image, resolution=detect_resolution)
56
+ # image = self.model(image, **kwargs)
57
+ # return PIL.Image.fromarray(image)
58
+ # elif self.name == "Midas":
59
+ # detect_resolution = kwargs.pop("detect_resolution", 512)
60
+ # image_resolution = kwargs.pop("image_resolution", 512)
61
+ # image = np.array(image)
62
+ # image = HWC3(image)
63
+ # image = resize_image(image, resolution=detect_resolution)
64
+ # image = self.model(image, **kwargs)
65
+ # image = HWC3(image)
66
+ # image = resize_image(image, resolution=image_resolution)
67
+ # return PIL.Image.fromarray(image)
68
+ # else:
69
+ return self.model(image, **kwargs)
app/preprocess.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import numpy as np
2
+ import PIL.Image
3
+ # import torch
4
+ from controlnet_aux import NormalBaeDetector#, CannyDetector
5
+
6
+ # from controlnet_aux.util import HWC3
7
+ # import cv2
8
+ # from cv_utils import resize_image
9
+
10
+ class Preprocessor:
11
+ MODEL_ID = "lllyasviel/Annotators"
12
+
13
+ # def resize_image(input_image, resolution, interpolation=None):
14
+ # H, W, C = input_image.shape
15
+ # H = float(H)
16
+ # W = float(W)
17
+ # k = float(resolution) / max(H, W)
18
+ # H *= k
19
+ # W *= k
20
+ # H = int(np.round(H / 64.0)) * 64
21
+ # W = int(np.round(W / 64.0)) * 64
22
+ # if interpolation is None:
23
+ # interpolation = cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA
24
+ # img = cv2.resize(input_image, (W, H), interpolation=interpolation)
25
+ # return img
26
+
27
+
28
+ def __init__(self):
29
+ self.model = None
30
+ self.name = ""
31
+
32
+ def load(self, name: str) -> None:
33
+ if name == self.name:
34
+ return
35
+ elif name == "NormalBae":
36
+ print("Loading NormalBae")
37
+ self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID).to("cuda")
38
+ # elif name == "Canny":
39
+ # self.model = CannyDetector()
40
+ else:
41
+ raise ValueError
42
+ # torch.cuda.empty_cache()
43
+ # gc.collect()
44
+
45
+ self.name = name
46
+
47
+ def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image:
48
+ # if self.name == "Canny":
49
+ # if "detect_resolution" in kwargs:
50
+ # detect_resolution = kwargs.pop("detect_resolution")
51
+ # image = np.array(image)
52
+ # image = HWC3(image)
53
+ # image = resize_image(image, resolution=detect_resolution)
54
+ # image = self.model(image, **kwargs)
55
+ # return PIL.Image.fromarray(image)
56
+ # elif self.name == "Midas":
57
+ # detect_resolution = kwargs.pop("detect_resolution", 512)
58
+ # image_resolution = kwargs.pop("image_resolution", 512)
59
+ # image = np.array(image)
60
+ # image = HWC3(image)
61
+ # image = resize_image(image, resolution=detect_resolution)
62
+ # image = self.model(image, **kwargs)
63
+ # image = HWC3(image)
64
+ # image = resize_image(image, resolution=image_resolution)
65
+ # return PIL.Image.fromarray(image)
66
+ # else:
67
+ return self.model(image, **kwargs)
app/requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ diffusers
4
+ einops
5
+ huggingface-hub
6
+ mediapipe
7
+ opencv-python-headless
8
+ safetensors
9
+ transformers
10
+ xformers
11
+ accelerate
12
+ imageio
app/win.requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ torchaudio
4
+ --index-url https://download.pytorch.org/whl/cu121
5
+
6
+ diffusers
7
+ einops
8
+ gradio
9
+ gradio-client
10
+ mediapipe
11
+ opencv-python-headless
12
+ safetensors
13
+ transformers
14
+ xformers
15
+ accelerate
16
+ imageio
17
+ controlnet_aux
local_app.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prod = True
2
+ port = 8080
3
+ show_options = False
4
+ if prod:
5
+ port = 8081
6
+ # show_options = False
7
+
8
+ import os
9
+ import gc
10
+ import random
11
+ import time
12
+ import gradio as gr
13
+ import numpy as np
14
+ # import imageio
15
+ import torch
16
+ from PIL import Image
17
+ from diffusers import (
18
+ ControlNetModel,
19
+ DPMSolverMultistepScheduler,
20
+ StableDiffusionControlNetPipeline,
21
+ AutoencoderKL,
22
+ )
23
+ from diffusers.models.attention_processor import AttnProcessor2_0
24
+ from local_preprocess import Preprocessor
25
+ MAX_SEED = np.iinfo(np.int32).max
26
+ API_KEY = os.environ.get("API_KEY", None)
27
+
28
+ print("CUDA version:", torch.version.cuda)
29
+ print("loading pipe")
30
+ compiled = False
31
+
32
+ preprocessor = Preprocessor()
33
+ preprocessor.load("NormalBae")
34
+
35
+ if gr.NO_RELOAD:
36
+ # torch.cuda.max_memory_allocated(device="cuda")
37
+
38
+ # Controlnet Normal
39
+ model_id = "lllyasviel/control_v11p_sd15_normalbae"
40
+ print("initializing controlnet")
41
+ controlnet = ControlNetModel.from_pretrained(
42
+ model_id,
43
+ torch_dtype=torch.float16,
44
+ attn_implementation="flash_attention_2",
45
+ ).to("cuda")
46
+
47
+ # Scheduler
48
+ scheduler = DPMSolverMultistepScheduler.from_pretrained(
49
+ "stabilityai/stable-diffusion-xl-base-1.0",
50
+ subfolder="scheduler",
51
+ use_karras_sigmas=True,
52
+ # final_sigmas_type="sigma_min",
53
+ algorithm_type="sde-dpmsolver++",
54
+ # prediction_type="epsilon",
55
+ # thresholding=False,
56
+ denoise_final=True,
57
+ device_map="cuda",
58
+ attn_implementation="flash_attention_2",
59
+ )
60
+
61
+ # Stable Diffusion Pipeline URL
62
+ # base_model_url = "https://huggingface.co/broyang/hentaidigitalart_v20/blob/main/realcartoon3d_v15.safetensors"
63
+ base_model_url = "https://huggingface.co/Lykon/AbsoluteReality/blob/main/AbsoluteReality_1.8.1_pruned.safetensors"
64
+ base_model_id = "Lykon/absolute-reality-1.81"
65
+ vae_url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors"
66
+
67
+ vae = AutoencoderKL.from_single_file(vae_url, torch_dtype=torch.float16).to("cuda")
68
+ vae.to(memory_format=torch.channels_last)
69
+
70
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
71
+ base_model_id,
72
+ safety_checker=None,
73
+ controlnet=controlnet,
74
+ scheduler=scheduler,
75
+ vae=vae,
76
+ torch_dtype=torch.float16,
77
+ ).to("cuda")
78
+
79
+ # pipe = StableDiffusionControlNetPipeline.from_single_file(
80
+ # base_model_url,
81
+ # controlnet=controlnet,
82
+ # scheduler=scheduler,
83
+ # vae=vae,
84
+ # torch_dtype=torch.float16,
85
+ # )
86
+
87
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="EasyNegativeV2.safetensors", token="EasyNegativeV2",)
88
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="badhandv4.pt", token="badhandv4")
89
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="fcNeg-neg.pt", token="fcNeg-neg")
90
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Ahegao.pt", token="HDA_Ahegao")
91
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Bondage.pt", token="HDA_Bondage")
92
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_pet_play.pt", token="HDA_pet_play")
93
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_unconventional maid.pt", token="HDA_unconventional_maid")
94
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_NakedHoodie.pt", token="HDA_NakedHoodie")
95
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_NunDress.pt", token="HDA_NunDress")
96
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Shibari.pt", token="HDA_Shibari")
97
+ pipe.to("cuda")
98
+
99
+ # experimental speedup?
100
+ # pipe.compile()
101
+ # torch.cuda.empty_cache()
102
+ # gc.collect()
103
+ print("---------------Loaded controlnet pipeline---------------")
104
+
105
+ # @spaces.GPU(duration=12)
106
+ # pipe.enable_xformers_memory_efficient_attention()
107
+ # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
108
+ # pipe.unet.set_attn_processor(AttnProcessor2_0())
109
+ torch.cuda.empty_cache()
110
+ gc.collect()
111
+ print("Model Compiled!")
112
+
113
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
114
+ if randomize_seed:
115
+ seed = random.randint(0, MAX_SEED)
116
+ return seed
117
+
118
+ def get_additional_prompt():
119
+ prompt = "hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
120
+ top = ["tank top", "blouse", "button up shirt", "sweater", "corset top"]
121
+ bottom = ["short skirt", "athletic shorts", "jean shorts", "pleated skirt", "short skirt", "leggings", "high-waisted shorts"]
122
+ accessory = ["knee-high boots", "gloves", "Thigh-high stockings", "Garter belt", "choker", "necklace", "headband", "headphones"]
123
+ return f"{prompt}, {random.choice(top)}, {random.choice(bottom)}, {random.choice(accessory)}, score_9"
124
+ # outfit = ["schoolgirl outfit", "playboy outfit", "red dress", "gala dress", "cheerleader outfit", "nurse outfit", "Kimono"]
125
+
126
+ def get_prompt(prompt, additional_prompt):
127
+ interior = "design-style interior designed (interior space), captured with a DSLR camera using f/10 aperture, 1/60 sec shutter speed, ISO 400, 20mm focal length, tungsten white balance, (sharp focus), professional photography, high-resolution, 8k, Pulitzer Prize-winning"
128
+ default = "hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
129
+ default2 = f"professional 3d model {prompt},octane render,highly detailed,volumetric,dramatic lighting,hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
130
+ randomize = get_additional_prompt()
131
+ # nude = "NSFW,((nude)),medium bare breasts,hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
132
+ # bodypaint = "((fully naked with no clothes)),nude naked seethroughxray,invisiblebodypaint,rating_newd,NSFW"
133
+ lab_girl = "hyperrealistic photography, extremely detailed, shy assistant wearing minidress boots and gloves, laboratory background, score_9, 1girl"
134
+ pet_play = "hyperrealistic photography, extremely detailed, playful, blush, glasses, collar, score_9, HDA_pet_play"
135
+ bondage = "hyperrealistic photography, extremely detailed, submissive, glasses, score_9, HDA_Bondage"
136
+ # ahegao = "((invisible clothing)), hyperrealistic photography,exposed vagina,sexy,nsfw,HDA_Ahegao"
137
+ ahegao2 = "(invisiblebodypaint),rating_newd,HDA_Ahegao"
138
+ athleisure = "hyperrealistic photography, extremely detailed, 1girl athlete, exhausted embarrassed sweaty,outdoors, ((athleisure clothing)), score_9"
139
+ atompunk = "((atompunk world)), hyperrealistic photography, extremely detailed, short hair, bodysuit, glasses, neon cyberpunk background, score_9"
140
+ maid = "hyperrealistic photography, extremely detailed, shy, blushing, score_9, pastel background, HDA_unconventional_maid"
141
+ nundress = "hyperrealistic photography, extremely detailed, shy, blushing, fantasy background, score_9, HDA_NunDress"
142
+ naked_hoodie = "hyperrealistic photography, extremely detailed, medium hair, cityscape, (neon lights), score_9, HDA_NakedHoodie"
143
+ abg = "(1girl, asian body covered in words, words on body, tattoos of (words) on body),(masterpiece, best quality),medium breasts,(intricate details),unity 8k wallpaper,ultra detailed,(pastel colors),beautiful and aesthetic,see-through (clothes),detailed,solo"
144
+ # shibari = "extremely detailed, hyperrealistic photography, earrings, blushing, lace choker, tattoo, medium hair, score_9, HDA_Shibari"
145
+ shibari2 = "octane render, highly detailed, volumetric, HDA_Shibari"
146
+
147
+ if prompt == "":
148
+ girls = [randomize, pet_play, bondage, lab_girl, athleisure, atompunk, maid, nundress, naked_hoodie, abg, shibari2, ahegao2]
149
+ prompts_nsfw = [abg, shibari2, ahegao2]
150
+ prompt = f"{random.choice(girls)}"
151
+ prompt = f"boho chic"
152
+ # print(f"-------------{preset}-------------")
153
+ else:
154
+ prompt = f"Photo from Pinterest of {prompt} {interior}"
155
+ # prompt = default2
156
+ return f"{prompt} f{additional_prompt}"
157
+
158
+ style_list = [
159
+ {
160
+ "name": "None",
161
+ "prompt": ""
162
+ },
163
+ {
164
+ "name": "Minimalistic",
165
+ "prompt": "Minimalistic"
166
+ },
167
+ {
168
+ "name": "Boho Chic",
169
+ "prompt": "boho chic"
170
+ },
171
+ {
172
+ "name": "Saudi Prince Gold",
173
+ "prompt": "saudi prince gold",
174
+ },
175
+ {
176
+ "name": "Modern Farmhouse",
177
+ "prompt": "modern farmhouse",
178
+ },
179
+ {
180
+ "name": "Neoclassical",
181
+ "prompt": "Neoclassical",
182
+ },
183
+ {
184
+ "name": "Eclectic",
185
+ "prompt": "Eclectic",
186
+ },
187
+ {
188
+ "name": "Parisian White",
189
+ "prompt": "Parisian White",
190
+ },
191
+ {
192
+ "name": "Hollywood Glam",
193
+ "prompt": "Hollywood Glam",
194
+ },
195
+ {
196
+ "name": "Scandinavian",
197
+ "prompt": "Scandinavian",
198
+ },
199
+ {
200
+ "name": "Japanese",
201
+ "prompt": "Japanese",
202
+ },
203
+ {
204
+ "name": "Texas Cowboy",
205
+ "prompt": "Texas Cowboy",
206
+ },
207
+ {
208
+ "name": "Midcentury Modern",
209
+ "prompt": "Midcentury Modern",
210
+ },
211
+ {
212
+ "name": "Beach",
213
+ "prompt": "Beach",
214
+ },
215
+ ]
216
+
217
+ styles = {k["name"]: (k["prompt"]) for k in style_list}
218
+ STYLE_NAMES = list(styles.keys())
219
+
220
+ def apply_style(style_name):
221
+ if style_name in styles:
222
+ p = styles.get(style_name, "boho chic")
223
+ return p
224
+
225
+
226
+ css = """
227
+ h1 {
228
+ text-align: center;
229
+ display:block;
230
+ }
231
+ h2 {
232
+ text-align: center;
233
+ display:block;
234
+ }
235
+ h3 {
236
+ text-align: center;
237
+ display:block;
238
+ }
239
+ .gradio-container{max-width: 1200px !important}
240
+ footer {visibility: hidden}
241
+ """
242
+ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
243
+ #############################################################################
244
+ with gr.Row():
245
+ with gr.Accordion("Advanced options", open=show_options, visible=show_options):
246
+ num_images = gr.Slider(
247
+ label="Images", minimum=1, maximum=4, value=1, step=1
248
+ )
249
+ image_resolution = gr.Slider(
250
+ label="Image resolution",
251
+ minimum=256,
252
+ maximum=1024,
253
+ value=512,
254
+ step=256,
255
+ )
256
+ preprocess_resolution = gr.Slider(
257
+ label="Preprocess resolution",
258
+ minimum=128,
259
+ maximum=1024,
260
+ value=512,
261
+ step=1,
262
+ )
263
+ num_steps = gr.Slider(
264
+ label="Number of steps", minimum=1, maximum=100, value=15, step=1
265
+ ) # 20/4.5 or 12 without lora, 4 with lora
266
+ guidance_scale = gr.Slider(
267
+ label="Guidance scale", minimum=0.1, maximum=30.0, value=5.5, step=0.1
268
+ ) # 5 without lora, 2 with lora
269
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
270
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
271
+ a_prompt = gr.Textbox(
272
+ label="Additional prompt",
273
+ value = "design-style interior designed (interior space), captured with a DSLR camera using f/10 aperture, 1/60 sec shutter speed, ISO 400, 20mm focal length, tungsten white balance, (sharp focus), professional photography, high-resolution, 8k, Pulitzer Prize-winning"
274
+ )
275
+ n_prompt = gr.Textbox(
276
+ label="Negative prompt",
277
+ value="EasyNegativeV2, fcNeg, (badhandv4:1.4), (worst quality, low quality, bad quality, normal quality:2.0), (bad hands, missing fingers, extra fingers:2.0)",
278
+ )
279
+ #############################################################################
280
+ # input text
281
+ with gr.Row():
282
+ gr.Text(label="Interior Design Style Examples", value="Eclectic, Maximalist, Bohemian, Scandinavian, Minimalist, Rustic, Modern Farmhouse, Contemporary, Luxury, Airbnb, Boho Chic, Midcentury Modern, Art Deco, Zen, Beach, Neoclassical, Industrial, Biophilic, Eco-friendly, Hollywood Glam, Parisian White, Saudi Prince Gold, French Country, Monster Energy Drink, Cyberpunk, Vaporwave, Baroque, etc.\n\nPro tip: add a color to customize it! You can also describe the furniture type.")
283
+ with gr.Column():
284
+ prompt = gr.Textbox(
285
+ label="Custom Prompt",
286
+ placeholder="boho chic",
287
+ )
288
+ with gr.Row(visible=True):
289
+ style_selection = gr.Radio(
290
+ show_label=True,
291
+ container=True,
292
+ interactive=True,
293
+ choices=STYLE_NAMES,
294
+ value="None",
295
+ label="Design Styles",
296
+ )
297
+ # input image
298
+ with gr.Row():
299
+ with gr.Column():
300
+ image = gr.Image(
301
+ label="Input",
302
+ sources=["upload"],
303
+ show_label=True,
304
+ mirror_webcam=True,
305
+ format="webp",
306
+ )
307
+ # run button
308
+ with gr.Column():
309
+ run_button = gr.Button(value="Use this one", size=["lg"], visible=False)
310
+ # output image
311
+ with gr.Column():
312
+ result = gr.Image(
313
+ label="Output",
314
+ interactive=False,
315
+ format="webp",
316
+ show_share_button= False,
317
+ )
318
+ # Use this image button
319
+ with gr.Column():
320
+ use_ai_button = gr.Button(value="Use this one", size=["lg"], visible=False)
321
+ config = [
322
+ image,
323
+ style_selection,
324
+ prompt,
325
+ a_prompt,
326
+ n_prompt,
327
+ num_images,
328
+ image_resolution,
329
+ preprocess_resolution,
330
+ num_steps,
331
+ guidance_scale,
332
+ seed,
333
+ ]
334
+
335
+ with gr.Row():
336
+ helper_text = gr.Markdown("## Tap and hold (on mobile) to save the image.", visible=True)
337
+
338
+ # image processing
339
+ @gr.on(triggers=[image.upload, prompt.submit, run_button.click], inputs=config, outputs=result, show_progress="minimal")
340
+ def auto_process_image(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
341
+ return process_image(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed)
342
+
343
+ # AI Image Processing
344
+ @gr.on(triggers=[use_ai_button.click], inputs=config, outputs=result, show_progress="minimal")
345
+ def submit(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
346
+ return process_image(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed)
347
+
348
+ # Change input to result
349
+ @gr.on(triggers=[use_ai_button.click], inputs=None, outputs=image, show_progress="hidden")
350
+ def update_input():
351
+ try:
352
+ print("Updating image to AI Temp Image")
353
+ ai_temp_image = Image.open("temp_image.jpg")
354
+ return ai_temp_image
355
+ except FileNotFoundError:
356
+ print("No AI Image Available")
357
+ return None
358
+
359
+ # Turn off buttons when processing
360
+ @gr.on(triggers=[image.upload, use_ai_button.click, run_button.click], inputs=None, outputs=[run_button, use_ai_button], show_progress="hidden")
361
+ def turn_buttons_off():
362
+ return gr.update(visible=False), gr.update(visible=False)
363
+
364
+ # Turn on buttons when processing is complete
365
+ @gr.on(triggers=[result.change], inputs=None, outputs=[use_ai_button, run_button], show_progress="hidden")
366
+ def turn_buttons_on():
367
+ return gr.update(visible=True), gr.update(visible=True)
368
+
369
+ # @spaces.GPU(duration=12)
370
+ @torch.inference_mode()
371
+ def process_image(
372
+ image,
373
+ style_selection,
374
+ prompt,
375
+ a_prompt,
376
+ n_prompt,
377
+ num_images,
378
+ image_resolution,
379
+ preprocess_resolution,
380
+ num_steps,
381
+ guidance_scale,
382
+ seed,
383
+ progress=gr.Progress(track_tqdm=True)
384
+ ):
385
+ torch.cuda.synchronize()
386
+ preprocess_start = time.time()
387
+ print("processing image")
388
+ preprocessor.load("NormalBae")
389
+ # preprocessor.load("Canny") #20 steps, 9 guidance, 512, 512
390
+
391
+ global compiled
392
+ if not compiled:
393
+ print("Not Compiled")
394
+ compiled = True
395
+
396
+ seed = random.randint(0, MAX_SEED)
397
+ generator = torch.cuda.manual_seed(seed)
398
+ control_image = preprocessor(
399
+ image=image,
400
+ image_resolution=image_resolution,
401
+ detect_resolution=preprocess_resolution,
402
+ )
403
+ preprocess_time = time.time() - preprocess_start
404
+ if style_selection is not None or style_selection != "None":
405
+ prompt = "Photo from Pinterest of " + apply_style(style_selection) + " " + prompt + " " + a_prompt
406
+ else:
407
+ prompt=str(get_prompt(prompt, a_prompt))
408
+ negative_prompt=str(n_prompt)
409
+ print(prompt)
410
+ start = time.time()
411
+ results = pipe(
412
+ prompt=prompt,
413
+ negative_prompt=negative_prompt,
414
+ guidance_scale=guidance_scale,
415
+ num_images_per_prompt=num_images,
416
+ num_inference_steps=num_steps,
417
+ generator=generator,
418
+ image=control_image,
419
+ ).images[0]
420
+ torch.cuda.synchronize()
421
+ torch.cuda.empty_cache()
422
+ print(f"\n-------------------------Preprocess done in: {preprocess_time:.2f} seconds-------------------------")
423
+ print(f"\n-------------------------Inference done in: {time.time() - start:.2f} seconds-------------------------")
424
+
425
+ # timestamp = int(time.time())
426
+ #if not os.path.exists("./outputs"):
427
+ # os.makedirs("./outputs")
428
+ # img_path = f"./{timestamp}.jpg"
429
+ # results_path = f"./{timestamp}_out_{prompt}.jpg"
430
+ # imageio.imsave(img_path, image)
431
+ # results.save(results_path)
432
+ results.save("temp_image.jpg")
433
+
434
+ # api.upload_file(
435
+ # path_or_fileobj=img_path,
436
+ # path_in_repo=img_path,
437
+ # repo_id="broyang/anime-ai-outputs",
438
+ # repo_type="dataset",
439
+ # token=API_KEY,
440
+ # run_as_future=True,
441
+ # )
442
+ # api.upload_file(
443
+ # path_or_fileobj=results_path,
444
+ # path_in_repo=results_path,
445
+ # repo_id="broyang/anime-ai-outputs",
446
+ # repo_type="dataset",
447
+ # token=API_KEY,
448
+ # run_as_future=True,
449
+ # )
450
+
451
+ return results
452
+ if prod:
453
+ demo.queue(max_size=20).launch(server_name="localhost", server_port=port)
454
+ else:
455
+ demo.queue(api_open=False).launch(show_api=False)
local_preprocess.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import numpy as np
2
+ import PIL.Image
3
+ import torch
4
+ import gc
5
+ # from controlnet_aux_local import NormalBaeDetector#, CannyDetector
6
+ from controlnet_aux import NormalBaeDetector
7
+
8
+ # from controlnet_aux.util import HWC3
9
+ # import cv2
10
+ # from cv_utils import resize_image
11
+
12
+ class Preprocessor:
13
+ MODEL_ID = "lllyasviel/Annotators"
14
+
15
+ # def resize_image(input_image, resolution, interpolation=None):
16
+ # H, W, C = input_image.shape
17
+ # H = float(H)
18
+ # W = float(W)
19
+ # k = float(resolution) / max(H, W)
20
+ # H *= k
21
+ # W *= k
22
+ # H = int(np.round(H / 64.0)) * 64
23
+ # W = int(np.round(W / 64.0)) * 64
24
+ # if interpolation is None:
25
+ # interpolation = cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA
26
+ # img = cv2.resize(input_image, (W, H), interpolation=interpolation)
27
+ # return img
28
+
29
+
30
+ def __init__(self):
31
+ self.model = None
32
+ self.name = ""
33
+
34
+ def load(self, name: str) -> None:
35
+ if name == self.name:
36
+ return
37
+ elif name == "NormalBae":
38
+ print("Loading NormalBae")
39
+ self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID).to("cuda")
40
+ # elif name == "Canny":
41
+ # self.model = CannyDetector()
42
+ else:
43
+ raise ValueError
44
+ torch.cuda.empty_cache()
45
+ gc.collect()
46
+
47
+ self.name = name
48
+
49
+ def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image:
50
+ # if self.name == "Canny":
51
+ # if "detect_resolution" in kwargs:
52
+ # detect_resolution = kwargs.pop("detect_resolution")
53
+ # image = np.array(image)
54
+ # image = HWC3(image)
55
+ # image = resize_image(image, resolution=detect_resolution)
56
+ # image = self.model(image, **kwargs)
57
+ # return PIL.Image.fromarray(image)
58
+ # elif self.name == "Midas":
59
+ # detect_resolution = kwargs.pop("detect_resolution", 512)
60
+ # image_resolution = kwargs.pop("image_resolution", 512)
61
+ # image = np.array(image)
62
+ # image = HWC3(image)
63
+ # image = resize_image(image, resolution=detect_resolution)
64
+ # image = self.model(image, **kwargs)
65
+ # image = HWC3(image)
66
+ # image = resize_image(image, resolution=image_resolution)
67
+ # return PIL.Image.fromarray(image)
68
+ # else:
69
+ return self.model(image, **kwargs)
preprocess.py CHANGED
@@ -64,4 +64,4 @@ class Preprocessor:
64
  # image = resize_image(image, resolution=image_resolution)
65
  # return PIL.Image.fromarray(image)
66
  # else:
67
- return self.model(image, **kwargs)
 
64
  # image = resize_image(image, resolution=image_resolution)
65
  # return PIL.Image.fromarray(image)
66
  # else:
67
+ return self.model(image, **kwargs)
requirements.txt CHANGED
@@ -1,12 +1,12 @@
1
- torch
2
- torchvision
3
- diffusers
4
- einops
5
- huggingface-hub
6
- mediapipe
7
- opencv-python-headless
8
- safetensors
9
- transformers
10
- xformers
11
- accelerate
12
  imageio
 
1
+ torch
2
+ torchvision
3
+ diffusers
4
+ einops
5
+ huggingface-hub
6
+ mediapipe
7
+ opencv-python-headless
8
+ safetensors
9
+ transformers
10
+ xformers
11
+ accelerate
12
  imageio
win.requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ torchaudio
4
+ --index-url https://download.pytorch.org/whl/cu121
5
+
6
+ diffusers
7
+ einops
8
+ gradio
9
+ gradio-client
10
+ mediapipe
11
+ opencv-python-headless
12
+ safetensors
13
+ transformers
14
+ xformers
15
+ accelerate
16
+ imageio
17
+ controlnet_aux