File size: 4,579 Bytes
a937644
 
b95cb12
 
e8e7d81
 
f0a89f2
 
 
 
a937644
2683ee0
e8e7d81
a937644
2683ee0
 
 
a937644
 
 
 
 
 
 
2683ee0
 
 
 
a937644
 
2683ee0
 
a937644
2683ee0
a937644
 
2683ee0
 
e8e7d81
2683ee0
 
b62da8b
2683ee0
e8e7d81
 
 
b62da8b
e8e7d81
 
f0a89f2
 
e8e7d81
 
f0a89f2
 
e8e7d81
b95cb12
 
 
 
e8e7d81
 
 
f0a89f2
 
e8e7d81
f0a89f2
2683ee0
e8e7d81
f0a89f2
 
e8e7d81
f0a89f2
e8e7d81
 
 
f0a89f2
e8e7d81
2683ee0
e8e7d81
377c553
a937644
 
cc327b5
f0a89f2
a937644
 
b62da8b
a937644
 
 
b62da8b
a937644
 
 
 
f0a89f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8e7d81
 
 
 
 
 
f0a89f2
e8e7d81
f0a89f2
 
 
 
cc327b5
2683ee0
 
f0a89f2
 
 
b62da8b
 
 
 
e8e7d81
 
f0a89f2
 
e8e7d81
 
f0a89f2
 
e8e7d81
b62da8b
e8e7d81
 
 
f0a89f2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import os
import random

import spaces
import gradio as gr
import torch
from diffusers.utils import load_image
from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline
from diffusers.models.controlnet_flux import FluxControlNetModel
import numpy as np
from huggingface_hub import login, snapshot_download


# Configuration
BASE_MODEL = 'black-forest-labs/FLUX.1-dev'
CONTROLNET_MODEL = 'promeai/FLUX.1-controlnet-lineart-promeai'
CSS = """
#col-container {
    margin: 0 auto;
    max-width: 640px;
}
"""

# Setup
AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")
if AUTH_TOKEN:
    login(AUTH_TOKEN)
else:
    raise ValueError("Hugging Face auth token not found. Please set HF_AUTH_TOKEN in the environment.")

MODEL_DIR = snapshot_download(
    repo_id=BASE_MODEL,
    revision="main",
    use_auth_token=AUTH_TOKEN
)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32

CONTROLNET = FluxControlNetModel.from_pretrained(CONTROLNET_MODEL, torch_dtype=TORCH_DTYPE)
PIPE = FluxControlNetPipeline.from_pretrained(MODEL_DIR, controlnet=CONTROLNET, torch_dtype=TORCH_DTYPE)
torch.cuda.empty_cache()
PIPE = PIPE.to(DEVICE)

MAX_SEED = np.iinfo(np.int32).max

@spaces.GPU(duration=140)
def infer(
    prompt,
    control_image_path,
    controlnet_conditioning_scale,
    guidance_scale,
    num_inference_steps,
    seed,
    randomize_seed,
):
    global DEVICE, TORCH_DTYPE
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    TORCH_DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
    print(f"Inference: using device: {DEVICE} (torch_dtype={TORCH_DTYPE})")
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)

    generator = torch.manual_seed(seed)
    control_image = load_image(control_image_path) if control_image_path else None

    # Generate image
    result = PIPE(
        prompt=prompt,
        control_image=control_image,
        controlnet_conditioning_scale=controlnet_conditioning_scale,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        generator=generator,
    ).images[0]

    return result, seed

with gr.Blocks(css=CSS) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("# Flux.1[dev] LineArt")
        gr.Markdown("### Zero-shot Partial Style Transfer for Line Art Images, Powered by FLUX.1")
        control_image = gr.Image(
                sources=['upload', 'webcam', 'clipboard'],
                type="filepath",
                label="Control Image (LineArt)"
        )
        prompt = gr.Text(
            label="Prompt",
            placeholder="Enter your prompt",
            max_lines=1,
            container=False
        )
        run_button = gr.Button("Generate", variant="primary")
        result = gr.Image(label="Result", show_label=False)
        with gr.Accordion("Advanced Settings", open=False):
            controlnet_conditioning_scale = gr.Slider(
                label="ControlNet Conditioning Scale",
                minimum=0.0,
                maximum=1.0,
                value=0.6,
                step=0.1
            )
            guidance_scale = gr.Slider(
                label="Guidance Scale",
                minimum=1.0,
                maximum=10.0,
                value=3.5,
                step=0.1
            )
            num_inference_steps = gr.Slider(
                label="Number of Inference Steps",
                minimum=1,
                maximum=100,
                value=28,
                step=1
            )
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=0
            )
            randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)

        gr.Examples(
            examples=[
                "Shiba Inu wearing dinosaur costume riding skateboard",
                "Victorian style mansion interior with candlelight",
                "Loading screen for Grand Theft Otter: Clam Andreas"
            ],
            inputs=[prompt]
        )
        
    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn = infer,
        inputs=[
            prompt,
            control_image,
            controlnet_conditioning_scale,
            guidance_scale,
            num_inference_steps,
            seed,
            randomize_seed
        ],
        outputs = [result, seed]
    )

if __name__ == "__main__":
    demo.launch()