DJStomp commited on
Commit
a937644
·
verified ·
1 Parent(s): 1e2293c

Update app.py

Browse files

Fix gated repo issue with Flux

Files changed (1) hide show
  1. app.py +39 -32
app.py CHANGED
@@ -1,24 +1,42 @@
 
 
1
  import gradio as gr
2
  import torch
3
  from diffusers.utils import load_image
4
  from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline
5
  from diffusers.models.controlnet_flux import FluxControlNetModel
6
- import random
7
  import numpy as np
 
8
 
9
- import os
10
- from huggingface_hub import login
11
-
12
- login(os.getenv("hfapikey"))
13
-
14
- # Initialize models
15
  base_model = 'black-forest-labs/FLUX.1-dev'
16
  controlnet_model = 'promeai/FLUX.1-controlnet-lineart-promeai'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
  torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
 
19
 
20
  controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch_dtype)
21
- pipe = FluxControlNetPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch_dtype)
22
  pipe = pipe.to(device)
23
 
24
  MAX_SEED = np.iinfo(np.int32).max
@@ -50,31 +68,23 @@ def infer(
50
 
51
  return result, seed
52
 
53
- css = """
54
- #col-container {
55
- margin: 0 auto;
56
- max-width: 640px;
57
- }
58
- """
59
-
60
  with gr.Blocks(css=css) as demo:
61
  with gr.Column(elem_id="col-container"):
62
- gr.Markdown("## Zero-shot Partial Style Transfer for Line Art Images, Powered by FLUX.1")
63
-
64
- with gr.Row():
65
- prompt = gr.Textbox(
66
- label="Prompt",
67
- placeholder="Enter your prompt",
68
- max_lines=1,
69
- )
70
- run_button = gr.Button("Generate", variant="primary")
71
-
72
- with gr.Accordion("Advanced Settings", open=True):
73
- control_image = gr.Image(
74
  sources=['upload', 'webcam', 'clipboard'],
75
  type="filepath",
76
- label="Control Image (Line Art)"
77
- )
 
 
 
 
 
 
 
 
78
  controlnet_conditioning_scale = gr.Slider(
79
  label="ControlNet Conditioning Scale",
80
  minimum=0.0,
@@ -105,9 +115,6 @@ with gr.Blocks(css=css) as demo:
105
  )
106
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
107
 
108
-
109
- result = gr.Image(label="Result", show_label=False)
110
-
111
  gr.Examples(
112
  examples=[
113
  "Shiba Inu wearing dinosaur costume riding skateboard",
 
1
+ import os
2
+ import random
3
  import gradio as gr
4
  import torch
5
  from diffusers.utils import load_image
6
  from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline
7
  from diffusers.models.controlnet_flux import FluxControlNetModel
 
8
  import numpy as np
9
+ from huggingface_hub import login, snapshot_download
10
 
11
+ # Configuration
 
 
 
 
 
12
  base_model = 'black-forest-labs/FLUX.1-dev'
13
  controlnet_model = 'promeai/FLUX.1-controlnet-lineart-promeai'
14
+ css = """
15
+ #col-container {
16
+ margin: 0 auto;
17
+ max-width: 640px;
18
+ }
19
+ """
20
+
21
+ # Setup
22
+ auth_token = os.getenv("HF_AUTH_TOKEN")
23
+ if not auth_token:
24
+ raise ValueError("Hugging Face auth token not found. Please set HF_AUTH_TOKEN in the environment.")
25
+
26
+ login(auth_token)
27
+
28
+ model_dir = snapshot_download(
29
+ repo_id=base_model,
30
+ revision="main",
31
+ use_auth_token=auth_token
32
+ )
33
+
34
  device = "cuda" if torch.cuda.is_available() else "cpu"
35
  torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
36
+ print(f"Using device: {device} (torch_dtype={torch_dtype})")
37
 
38
  controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch_dtype)
39
+ pipe = FluxControlNetPipeline.from_pretrained(model_dir, controlnet=controlnet, torch_dtype=torch_dtype)
40
  pipe = pipe.to(device)
41
 
42
  MAX_SEED = np.iinfo(np.int32).max
 
68
 
69
  return result, seed
70
 
 
 
 
 
 
 
 
71
  with gr.Blocks(css=css) as demo:
72
  with gr.Column(elem_id="col-container"):
73
+ gr.Markdown("Flux.1[dev] LineArt")
74
+ gr.Markdown("### Zero-shot Partial Style Transfer for Line Art Images, Powered by FLUX.1")
75
+ control_image = gr.Image(
 
 
 
 
 
 
 
 
 
76
  sources=['upload', 'webcam', 'clipboard'],
77
  type="filepath",
78
+ label="Control Image (LineArt)"
79
+ )
80
+ prompt = gr.Textbox(
81
+ label="Prompt",
82
+ placeholder="Enter your prompt",
83
+ max_lines=1,
84
+ )
85
+ run_button = gr.Button("Generate", variant="primary")
86
+ result = gr.Image(label="Result", show_label=False)
87
+ with gr.Accordion("Advanced Settings", open=False):
88
  controlnet_conditioning_scale = gr.Slider(
89
  label="ControlNet Conditioning Scale",
90
  minimum=0.0,
 
115
  )
116
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
117
 
 
 
 
118
  gr.Examples(
119
  examples=[
120
  "Shiba Inu wearing dinosaur costume riding skateboard",