File size: 4,554 Bytes
f303715
 
 
 
045ff63
f303715
 
 
 
 
 
73aa562
f303715
 
 
 
73aa562
f303715
7d8193f
f303715
 
 
 
 
 
dceaee5
f303715
3b994bf
 
 
 
 
 
 
 
 
 
 
 
 
f303715
3b994bf
 
 
 
 
 
 
f303715
3b994bf
f303715
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dceaee5
f303715
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73aa562
f303715
 
 
73aa562
63f4ca3
f303715
 
 
 
 
 
 
 
73aa562
f303715
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import gradio as gr
import os
hf_token = os.environ.get("HF_TOKEN")
import spaces
from diffusers import DiffusionPipeline, UNet2DConditionModel, EulerAncestralDiscreteScheduler, AutoencoderKL
import torch
import time

class Dummy():
    pass

resolutions = ["1536 1536","1728 1280","1856 1280","1920 1088", "1088 1920","1280 1856","1280 1728" ] 

# Load pipeline 

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
unet = UNet2DConditionModel.from_pretrained("briaai/BRIA-2.2-HD", torch_dtype=torch.float16)
pipe = DiffusionPipeline.from_pretrained("briaai/BRIA-2.2", torch_dtype=torch.float16, unet=unet, vae=vae)
assert type(pipe.scheduler) == EulerAncestralDiscreteScheduler
pipe.to('cuda')
del unet
del vae


pipe.force_zeros_for_empty_prompt = False
negative_prompt= "Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers"

# print("Optimizing BRIA 2.2 HD - this could take a while")
# t=time.time()
# pipe.unet = torch.compile(
#     pipe.unet, mode="reduce-overhead", fullgraph=True # 600 secs compilation
# )
# with torch.no_grad():
#     outputs = pipe(
#         prompt="an apple",
#         num_inference_steps=30,
#         width=1536,
#         height=1536,
#         negative_prompt=negative_prompt
#     )

#     # This will avoid future compilations on different shapes
#     unet_compiled = torch._dynamo.run(pipe.unet)
#     unet_compiled.config=pipe.unet.config
#     unet_compiled.add_embedding = Dummy()
#     unet_compiled.add_embedding.linear_1 = Dummy()
#     unet_compiled.add_embedding.linear_1.in_features = pipe.unet.add_embedding.linear_1.in_features
#     pipe.unet = unet_compiled

# print(f"Optimizing finished successfully after {time.time()-t} secs")

@spaces.GPU(enable_queue=True)
def infer(prompt,seed,resolution):
    print(f"""
    —/n
    {prompt}
    """)
    
    # generator = torch.Generator("cuda").manual_seed(555)
    t=time.time()

    if seed=="-1":
        generator=None
    else:
        try:
            seed=int(seed)
            generator = torch.Generator("cuda").manual_seed(seed)
        except:
            generator=None

    w,h = resolution.split()
    w,h = int(w),int(h)
    image = pipe(prompt,num_inference_steps=30,generator=generator,width=w,height=h,negative_prompt=negative_prompt).images[0]
    print(f'gen time is {time.time()-t} secs')
    
    # Future
    # Add amound of steps
    # if nsfw:
    #     raise gr.Error("Generated image is NSFW")
    
    return image

css = """
#col-container{
    margin: 0 auto;
    max-width: 580px;
}
"""
with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("## BRIA 2.2 HD")
        gr.HTML('''
          <p style="margin-bottom: 10px; font-size: 94%">
            This is a demo for 
            <a href="https://huggingface.co/briaai/BRIA-2.2-HD" target="_blank">BRIA 2.2 HD </a>. 
            This is a high resolution version of BRIA 2.2 text-to-image model, still trained on licensed data, and so provide full legal liability coverage for copyright and privacy infringement.
          </p>
        ''')
        with gr.Group():
            with gr.Column():
                prompt_in = gr.Textbox(label="Prompt", value="A smiling man with wavy brown hair and a trimmed beard")
                resolution = gr.Dropdown(value=resolutions[0], show_label=True, label="Resolution", choices=resolutions)
                seed = gr.Textbox(label="Seed", value=-1)
                submit_btn = gr.Button("Generate")
        result = gr.Image(label="BRIA 2.2 HD Result")

        # gr.Examples(
        #     examples = [ 
        #         "Dragon, digital art, by Greg Rutkowski",
        #         "Armored knight holding sword",
        #         "A flat roof villa near a river with black walls and huge windows",
        #         "A calm and peaceful office",
        #         "Pirate guinea pig"
        #     ],
        #     fn = infer, 
        #     inputs = [
        #         prompt_in
        #     ],
        #     outputs = [
        #         result
        #     ]
        # )

    submit_btn.click(
        fn = infer,
        inputs = [
            prompt_in,
            seed,
            resolution
        ],
        outputs = [
            result
        ]
    )

demo.queue().launch(show_api=False)