ginipick commited on
Commit
39b272a
·
verified ·
1 Parent(s): b02e794

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -49
app.py CHANGED
@@ -9,8 +9,8 @@ from diffusers import DiffusionPipeline
9
  from custom_pipeline import FLUXPipelineWithIntermediateOutputs
10
  from transformers import pipeline
11
 
12
- # Translation model loading
13
- translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
14
 
15
  # Constants
16
  MAX_SEED = np.iinfo(np.int32).max
@@ -18,13 +18,19 @@ MAX_IMAGE_SIZE = 2048
18
  DEFAULT_WIDTH = 1024
19
  DEFAULT_HEIGHT = 1024
20
  DEFAULT_INFERENCE_STEPS = 1
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- # Device and model setup
23
- dtype = torch.float16
24
- pipe = FLUXPipelineWithIntermediateOutputs.from_pretrained(
25
- "black-forest-labs/FLUX.1-schnell", torch_dtype=dtype
26
- ).to("cuda")
27
- torch.cuda.empty_cache()
28
 
29
  # Menu labels dictionary
30
  english_labels = {
@@ -41,36 +47,67 @@ english_labels = {
41
  }
42
 
43
  def translate_if_korean(text):
44
- if any('\u3131' <= char <= '\u3163' or '\uac00' <= char <= '\ud7a3' for char in text):
45
- return translator(text)[0]['translation_text']
46
- return text
47
-
48
- # Modified inference function to always use random seed for examples
49
- @spaces.GPU(duration=25)
50
- def generate_image(prompt, seed=None, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT, randomize_seed=True, num_inference_steps=DEFAULT_INFERENCE_STEPS):
51
- prompt = translate_if_korean(prompt)
52
-
53
- # Always generate a random seed if none provided or randomize_seed is True
54
- if seed is None or randomize_seed:
55
- seed = random.randint(0, MAX_SEED)
56
- generator = torch.Generator().manual_seed(seed)
57
-
58
- start_time = time.time()
59
-
60
- for img in pipe.generate_images(
61
- prompt=prompt,
62
- guidance_scale=0,
63
- num_inference_steps=num_inference_steps,
64
- width=width,
65
- height=height,
66
- generator=generator
67
- ):
68
- latency = f"Processing Time: {(time.time()-start_time):.2f} seconds"
69
- yield img, seed, latency
70
-
71
- # Function specifically for examples that always uses random seeds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  def generate_example_image(prompt):
73
- return generate_image(prompt, randomize_seed=True)
 
 
 
 
74
 
75
  # Example prompts
76
  examples = [
@@ -88,12 +125,14 @@ footer {
88
  }
89
  """
90
 
91
- # --- Gradio UI ---
92
  with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
93
  with gr.Column(elem_id="app-container"):
94
  with gr.Row():
95
  with gr.Column(scale=3):
96
- result = gr.Image(label=english_labels["Generated Image"], show_label=False, interactive=False)
 
 
97
  with gr.Column(scale=1):
98
  prompt = gr.Text(
99
  label=english_labels["Prompt"],
@@ -108,25 +147,53 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
108
  with gr.Row():
109
  latency = gr.Text(show_label=False)
110
  with gr.Row():
111
- seed = gr.Number(label=english_labels["Seed"], value=42, precision=0)
112
- randomize_seed = gr.Checkbox(label=english_labels["Randomize Seed"], value=True)
 
 
 
 
 
 
 
 
 
 
113
  with gr.Row():
114
- width = gr.Slider(label=english_labels["Width"], minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=DEFAULT_WIDTH)
115
- height = gr.Slider(label=english_labels["Height"], minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=DEFAULT_HEIGHT)
116
- num_inference_steps = gr.Slider(label=english_labels["Inference Steps"], minimum=1, maximum=4, step=1, value=DEFAULT_INFERENCE_STEPS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  with gr.Row():
119
  gr.Markdown(f"### 🌟 {english_labels['Inspiration Gallery']}")
120
  with gr.Row():
121
  gr.Examples(
122
  examples=examples,
123
- fn=generate_example_image, # Use the example-specific function
124
  inputs=[prompt],
125
  outputs=[result, seed],
126
- cache_examples=False # Disable caching to ensure new generation each time
127
  )
128
 
129
- # Event handling
130
  enhanceBtn.click(
131
  fn=generate_image,
132
  inputs=[prompt, seed, width, height],
@@ -136,9 +203,17 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
136
  queue=False
137
  )
138
 
 
 
 
 
 
 
 
 
139
  gr.on(
140
  triggers=[prompt.input, width.input, height.input, num_inference_steps.input],
141
- fn=generate_image,
142
  inputs=[prompt, seed, width, height, randomize_seed, num_inference_steps],
143
  outputs=[result, seed, latency],
144
  show_progress="hidden",
@@ -147,4 +222,5 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
147
  queue=False
148
  )
149
 
150
- demo.launch()
 
 
9
  from custom_pipeline import FLUXPipelineWithIntermediateOutputs
10
  from transformers import pipeline
11
 
12
+ # Translation model loading with device specification
13
+ translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cpu")
14
 
15
  # Constants
16
  MAX_SEED = np.iinfo(np.int32).max
 
18
  DEFAULT_WIDTH = 1024
19
  DEFAULT_HEIGHT = 1024
20
  DEFAULT_INFERENCE_STEPS = 1
21
+ GPU_DURATION = 15 # Reduced from 25 to stay within quota
22
+
23
+ # Device and model setup with memory optimization
24
+ def setup_model():
25
+ dtype = torch.float16
26
+ pipe = FLUXPipelineWithIntermediateOutputs.from_pretrained(
27
+ "black-forest-labs/FLUX.1-schnell",
28
+ torch_dtype=dtype,
29
+ device_map="auto" # Enable model parallelism
30
+ )
31
+ return pipe
32
 
33
+ pipe = setup_model()
 
 
 
 
 
34
 
35
  # Menu labels dictionary
36
  english_labels = {
 
47
  }
48
 
49
  def translate_if_korean(text):
50
+ """Safely translate Korean text to English."""
51
+ try:
52
+ if any('\u3131' <= char <= '\u3163' or '\uac00' <= char <= '\ud7a3' for char in text):
53
+ return translator(text)[0]['translation_text']
54
+ return text
55
+ except Exception as e:
56
+ print(f"Translation error: {e}")
57
+ return text
58
+
59
+ # Modified inference function with error handling and memory management
60
+ @spaces.GPU(duration=GPU_DURATION)
61
+ def generate_image(prompt, seed=None, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT,
62
+ randomize_seed=True, num_inference_steps=DEFAULT_INFERENCE_STEPS):
63
+ try:
64
+ # Input validation
65
+ if not isinstance(seed, (int, type(None))):
66
+ seed = None
67
+ randomize_seed = True
68
+
69
+ prompt = translate_if_korean(prompt)
70
+
71
+ if seed is None or randomize_seed:
72
+ seed = random.randint(0, MAX_SEED)
73
+
74
+ # Ensure valid dimensions
75
+ width = min(max(256, width), MAX_IMAGE_SIZE)
76
+ height = min(max(256, height), MAX_IMAGE_SIZE)
77
+
78
+ generator = torch.Generator().manual_seed(seed)
79
+
80
+ start_time = time.time()
81
+
82
+ with torch.cuda.amp.autocast(): # Enable automatic mixed precision
83
+ for img in pipe.generate_images(
84
+ prompt=prompt,
85
+ guidance_scale=0,
86
+ num_inference_steps=num_inference_steps,
87
+ width=width,
88
+ height=height,
89
+ generator=generator
90
+ ):
91
+ latency = f"Processing Time: {(time.time()-start_time):.2f} seconds"
92
+
93
+ # Clear CUDA cache after generation
94
+ if torch.cuda.is_available():
95
+ torch.cuda.empty_cache()
96
+
97
+ yield img, seed, latency
98
+
99
+ except Exception as e:
100
+ print(f"Error in generate_image: {e}")
101
+ # Return a blank image or error message
102
+ yield None, seed, f"Error: {str(e)}"
103
+
104
+ # Example generator with error handling
105
  def generate_example_image(prompt):
106
+ try:
107
+ return next(generate_image(prompt, randomize_seed=True))
108
+ except Exception as e:
109
+ print(f"Error in example generation: {e}")
110
+ return None, None, f"Error: {str(e)}"
111
 
112
  # Example prompts
113
  examples = [
 
125
  }
126
  """
127
 
128
+ # --- Gradio UI with improved error handling ---
129
  with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
130
  with gr.Column(elem_id="app-container"):
131
  with gr.Row():
132
  with gr.Column(scale=3):
133
+ result = gr.Image(label=english_labels["Generated Image"],
134
+ show_label=False,
135
+ interactive=False)
136
  with gr.Column(scale=1):
137
  prompt = gr.Text(
138
  label=english_labels["Prompt"],
 
147
  with gr.Row():
148
  latency = gr.Text(show_label=False)
149
  with gr.Row():
150
+ # Modified Number component with proper validation
151
+ seed = gr.Number(
152
+ label=english_labels["Seed"],
153
+ value=42,
154
+ precision=0,
155
+ minimum=0,
156
+ maximum=MAX_SEED
157
+ )
158
+ randomize_seed = gr.Checkbox(
159
+ label=english_labels["Randomize Seed"],
160
+ value=True
161
+ )
162
  with gr.Row():
163
+ width = gr.Slider(
164
+ label=english_labels["Width"],
165
+ minimum=256,
166
+ maximum=MAX_IMAGE_SIZE,
167
+ step=32,
168
+ value=DEFAULT_WIDTH
169
+ )
170
+ height = gr.Slider(
171
+ label=english_labels["Height"],
172
+ minimum=256,
173
+ maximum=MAX_IMAGE_SIZE,
174
+ step=32,
175
+ value=DEFAULT_HEIGHT
176
+ )
177
+ num_inference_steps = gr.Slider(
178
+ label=english_labels["Inference Steps"],
179
+ minimum=1,
180
+ maximum=4,
181
+ step=1,
182
+ value=DEFAULT_INFERENCE_STEPS
183
+ )
184
 
185
  with gr.Row():
186
  gr.Markdown(f"### 🌟 {english_labels['Inspiration Gallery']}")
187
  with gr.Row():
188
  gr.Examples(
189
  examples=examples,
190
+ fn=generate_example_image,
191
  inputs=[prompt],
192
  outputs=[result, seed],
193
+ cache_examples=False
194
  )
195
 
196
+ # Event handling with improved error handling
197
  enhanceBtn.click(
198
  fn=generate_image,
199
  inputs=[prompt, seed, width, height],
 
203
  queue=False
204
  )
205
 
206
+ # Modified event handler with proper input validation
207
+ def validated_generate(*args):
208
+ try:
209
+ return next(generate_image(*args))
210
+ except Exception as e:
211
+ print(f"Error in validated_generate: {e}")
212
+ return None, args[1], f"Error: {str(e)}"
213
+
214
  gr.on(
215
  triggers=[prompt.input, width.input, height.input, num_inference_steps.input],
216
+ fn=validated_generate,
217
  inputs=[prompt, seed, width, height, randomize_seed, num_inference_steps],
218
  outputs=[result, seed, latency],
219
  show_progress="hidden",
 
222
  queue=False
223
  )
224
 
225
+ if __name__ == "__main__":
226
+ demo.launch()