Deadmon commited on
Commit
11d79a9
·
verified ·
1 Parent(s): 69fb279

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -24
app.py CHANGED
@@ -27,21 +27,6 @@ device = torch.device("cuda")
27
  offload = False
28
  is_schnell = name == "flux-schnell"
29
 
30
- model, ae, t5, clip, controlnet = None, None, None, None, None
31
-
32
- def load_models():
33
- global model, ae, t5, clip, controlnet
34
- t5 = load_t5(device, max_length=256 if is_schnell else 512)
35
- clip = load_clip(device)
36
- model = load_flow_model(name, device=device)
37
- ae = load_ae(name, device=device)
38
- controlnet = load_controlnet(name, device).to(device).to(torch.bfloat16)
39
-
40
- checkpoint = load_safetensors(model_path)
41
- controlnet.load_state_dict(checkpoint, strict=False)
42
-
43
- load_models()
44
-
45
  def preprocess_image(image, target_width, target_height, crop=True):
46
  if crop:
47
  image = c_crop(image) # Crop the image to square
@@ -78,11 +63,16 @@ def generate_image(prompt, control_image, num_steps=50, guidance=4, width=512, h
78
 
79
  torch_device = torch.device("cuda")
80
 
81
- model.to(torch_device)
82
- t5.to(torch_device)
83
- clip.to(torch_device)
84
- ae.to(torch_device)
85
- controlnet.to(torch_device)
 
 
 
 
 
86
 
87
  width = 16 * width // 16
88
  height = 16 * height // 16
@@ -116,8 +106,8 @@ interface = gr.Interface(
116
  gr.Image(type="pil", label="Control Image"),
117
  gr.Slider(step=1, minimum=1, maximum=64, value=28, label="Num Steps"),
118
  gr.Slider(minimum=0.1, maximum=10, value=4, label="Guidance"),
119
- gr.Slider(minimum=128, maximum=2048, step=128, value=1024, label="Width"),
120
- gr.Slider(minimum=128, maximum=2048, step=128, value=1024, label="Height"),
121
  gr.Number(value=42, label="Seed"),
122
  gr.Checkbox(label="Random Seed")
123
  ],
@@ -127,5 +117,4 @@ interface = gr.Interface(
127
  )
128
 
129
  if __name__ == "__main__":
130
- interface.launch()
131
-
 
27
  offload = False
28
  is_schnell = name == "flux-schnell"
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def preprocess_image(image, target_width, target_height, crop=True):
31
  if crop:
32
  image = c_crop(image) # Crop the image to square
 
63
 
64
  torch_device = torch.device("cuda")
65
 
66
+ torch.cuda.empty_cache() # Clear GPU cache
67
+
68
+ model = load_flow_model(name, device=torch_device)
69
+ t5 = load_t5(torch_device, max_length=256 if is_schnell else 512)
70
+ clip = load_clip(torch_device)
71
+ ae = load_ae(name, device=torch_device)
72
+ controlnet = load_controlnet(name, torch_device).to(torch_device).to(torch.bfloat16)
73
+
74
+ checkpoint = load_safetensors(model_path)
75
+ controlnet.load_state_dict(checkpoint, strict=False)
76
 
77
  width = 16 * width // 16
78
  height = 16 * height // 16
 
106
  gr.Image(type="pil", label="Control Image"),
107
  gr.Slider(step=1, minimum=1, maximum=64, value=28, label="Num Steps"),
108
  gr.Slider(minimum=0.1, maximum=10, value=4, label="Guidance"),
109
+ gr.Slider(minimum=128, maximum=1024, step=128, value=512, label="Width"),
110
+ gr.Slider(minimum=128, maximum=1024, step=128, value=512, label="Height"),
111
  gr.Number(value=42, label="Seed"),
112
  gr.Checkbox(label="Random Seed")
113
  ],
 
117
  )
118
 
119
  if __name__ == "__main__":
120
+ interface.launch()