DJStomp commited on
Commit
b62da8b
·
verified ·
1 Parent(s): b95cb12

Give ZeroGPU longer to finish

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -39,11 +39,12 @@ TORCH_DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
39
 
40
  CONTROLNET = FluxControlNetModel.from_pretrained(CONTROLNET_MODEL, torch_dtype=TORCH_DTYPE)
41
  PIPE = FluxControlNetPipeline.from_pretrained(MODEL_DIR, controlnet=CONTROLNET, torch_dtype=TORCH_DTYPE)
 
42
  PIPE = PIPE.to(DEVICE)
43
 
44
  MAX_SEED = np.iinfo(np.int32).max
45
 
46
- @spaces.GPU
47
  def infer(
48
  prompt,
49
  control_image_path,
@@ -84,10 +85,11 @@ with gr.Blocks(css=CSS) as demo:
84
  type="filepath",
85
  label="Control Image (LineArt)"
86
  )
87
- prompt = gr.Textbox(
88
  label="Prompt",
89
  placeholder="Enter your prompt",
90
  max_lines=1,
 
91
  )
92
  run_button = gr.Button("Generate", variant="primary")
93
  result = gr.Image(label="Result", show_label=False)
@@ -130,9 +132,10 @@ with gr.Blocks(css=CSS) as demo:
130
  ],
131
  inputs=[prompt]
132
  )
133
-
134
- run_button.click(
135
- infer,
 
136
  inputs=[
137
  prompt,
138
  control_image,
@@ -142,7 +145,7 @@ with gr.Blocks(css=CSS) as demo:
142
  seed,
143
  randomize_seed
144
  ],
145
- outputs=[result, seed]
146
  )
147
 
148
  if __name__ == "__main__":
 
39
 
40
  CONTROLNET = FluxControlNetModel.from_pretrained(CONTROLNET_MODEL, torch_dtype=TORCH_DTYPE)
41
  PIPE = FluxControlNetPipeline.from_pretrained(MODEL_DIR, controlnet=CONTROLNET, torch_dtype=TORCH_DTYPE)
42
+ torch.cuda.empty_cache()
43
  PIPE = PIPE.to(DEVICE)
44
 
45
  MAX_SEED = np.iinfo(np.int32).max
46
 
47
+ @spaces.GPU(duration=140)
48
  def infer(
49
  prompt,
50
  control_image_path,
 
85
  type="filepath",
86
  label="Control Image (LineArt)"
87
  )
88
+ prompt = gr.Text(
89
  label="Prompt",
90
  placeholder="Enter your prompt",
91
  max_lines=1,
92
+ container=False
93
  )
94
  run_button = gr.Button("Generate", variant="primary")
95
  result = gr.Image(label="Result", show_label=False)
 
132
  ],
133
  inputs=[prompt]
134
  )
135
+
136
+ gr.on(
137
+ triggers=[run_button.click, prompt.submit],
138
+ fn = infer,
139
  inputs=[
140
  prompt,
141
  control_image,
 
145
  seed,
146
  randomize_seed
147
  ],
148
+ outputs = [result, seed]
149
  )
150
 
151
  if __name__ == "__main__":