File size: 4,561 Bytes
80e6c51
b15c679
57ec10d
80e6c51
 
020ca85
 
 
80e6c51
b15c679
 
1a22f37
b15c679
 
80e6c51
 
020ca85
80e6c51
 
 
b15c679
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80e6c51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b15c679
80e6c51
 
020ca85
b15c679
 
80e6c51
b15c679
 
80e6c51
66a1189
 
80e6c51
 
 
b15c679
80e6c51
b15c679
 
80e6c51
 
b15c679
66a1189
80e6c51
020ca85
80e6c51
57ec10d
80e6c51
 
 
 
 
 
 
66a1189
57ec10d
80e6c51
57ec10d
 
 
80e6c51
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import spaces
import torch
import gc


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def load_pipeline():
    from diffusers import DiffusionPipeline
    pipe = DiffusionPipeline.from_pretrained(
        "John6666/rae-diffusion-xl-v2-sdxl-spo-pcm",
        custom_pipeline="lpw_stable_diffusion_xl",
        #custom_pipeline="nyanko7/sdxl_smoothed_energy_guidance",
        torch_dtype=torch.float16,
    )
    pipe.to("cpu")
    return pipe


def token_auto_concat_embeds(pipe, positive, negative):
    max_length = pipe.tokenizer.model_max_length
    positive_length = pipe.tokenizer(positive, return_tensors="pt").input_ids.shape[-1]
    negative_length = pipe.tokenizer(negative, return_tensors="pt").input_ids.shape[-1]
    
    print(f'Token length is model maximum: {max_length}, positive length: {positive_length}, negative length: {negative_length}.')
    if max_length < positive_length or max_length < negative_length:
        print('Concatenated embedding.')
        if positive_length > negative_length:
            positive_ids = pipe.tokenizer(positive, return_tensors="pt").input_ids.to("cuda")
            negative_ids = pipe.tokenizer(negative, truncation=False, padding="max_length", max_length=positive_ids.shape[-1], return_tensors="pt").input_ids.to("cuda")
        else:
            negative_ids = pipe.tokenizer(negative, return_tensors="pt").input_ids.to("cuda")  
            positive_ids = pipe.tokenizer(positive, truncation=False, padding="max_length", max_length=negative_ids.shape[-1],  return_tensors="pt").input_ids.to("cuda")
    else:
        positive_ids = pipe.tokenizer(positive, truncation=False, padding="max_length", max_length=max_length,  return_tensors="pt").input_ids.to("cuda")
        negative_ids = pipe.tokenizer(negative, truncation=False, padding="max_length", max_length=max_length, return_tensors="pt").input_ids.to("cuda")
    
    positive_concat_embeds = []
    negative_concat_embeds = []
    for i in range(0, positive_ids.shape[-1], max_length):
        positive_concat_embeds.append(pipe.text_encoder(positive_ids[:, i: i + max_length])[0])
        negative_concat_embeds.append(pipe.text_encoder(negative_ids[:, i: i + max_length])[0])
    
    positive_prompt_embeds = torch.cat(positive_concat_embeds, dim=1)
    negative_prompt_embeds = torch.cat(negative_concat_embeds, dim=1)
    return positive_prompt_embeds, negative_prompt_embeds


def save_image(image, metadata, output_dir):
    import os
    import uuid
    import json
    from PIL import PngImagePlugin
    filename = str(uuid.uuid4()) + ".png"
    os.makedirs(output_dir, exist_ok=True)
    filepath = os.path.join(output_dir, filename)
    metadata_str = json.dumps(metadata)
    info = PngImagePlugin.PngInfo()
    info.add_text("metadata", metadata_str)
    image.save(filepath, "PNG", pnginfo=info)
    return filepath


pipe = load_pipeline()


@torch.inference_mode()
@spaces.GPU
def generate_image(prompt, neg_prompt):
    pipe.to(device)
    prompt += ", anime, masterpiece, best quality, very aesthetic, absurdres"
    neg_prompt += ", bad hands, bad feet, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract], photo, deformed, disfigured, low contrast, photo, deformed, disfigured, low contrast"
    metadata = {
        "prompt": prompt,
        "negative_prompt": neg_prompt,
        "resolution": f"{1024} x {1024}",
        "guidance_scale": 7.0,
        "num_inference_steps": 28,
        "sampler": "Euler",
    }
    try: 
        #positive_embeds, negative_embeds = token_auto_concat_embeds(pipe, prompt, neg_prompt)
        images = pipe(
            prompt=prompt,
            negative_prompt=neg_prompt,
            width=1024,
            height=1024,
            guidance_scale=7.0,# seg_scale=3.0, seg_applied_layers=["mid"],
            num_inference_steps=28,
            output_type="pil",
            clip_skip=2,
        ).images
        pipe.to("cpu")
        if images:
            image_paths = [
                save_image(image, metadata, "./outputs")
                for image in images
            ]
        return image_paths
    except Exception as e:
        print(e)
        pipe.to("cpu")
        return []
    finally:
        torch.cuda.empty_cache()
        gc.collect()