aiqtech commited on
Commit
64806f8
·
verified ·
1 Parent(s): 1bb836a

Create app-backup.py

Browse files
Files changed (1) hide show
  1. app-backup.py +956 -0
app-backup.py ADDED
@@ -0,0 +1,956 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ import time
3
+ from collections.abc import Sequence
4
+ from typing import Any, cast
5
+ import os
6
+ from huggingface_hub import login, hf_hub_download
7
+
8
+ import gradio as gr
9
+ import numpy as np
10
+ import pillow_heif
11
+ import spaces
12
+ import torch
13
+ from gradio_image_annotation import image_annotator
14
+ from gradio_imageslider import ImageSlider
15
+ from PIL import Image
16
+ from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
17
+ from refiners.fluxion.utils import no_grad
18
+ from refiners.solutions import BoxSegmenter
19
+ from transformers import GroundingDinoForObjectDetection, GroundingDinoProcessor
20
+ from diffusers import FluxPipeline
21
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
22
+ import gc
23
+
24
+ from PIL import Image, ImageDraw, ImageFont
25
+ from PIL import Image
26
+ from gradio_client import Client, handle_file
27
+ import uuid
28
+
29
+
30
+ def clear_memory():
31
+ """메모리 정리 함수"""
32
+ gc.collect()
33
+ try:
34
+ if torch.cuda.is_available():
35
+ with torch.cuda.device(0): # 명시적으로 device 0 사용
36
+ torch.cuda.empty_cache()
37
+ except:
38
+ pass
39
+
40
+ # GPU 설정
41
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 명시적으로 cuda:0 지정
42
+
43
+ # GPU 설정을 try-except로 감싸기
44
+ if torch.cuda.is_available():
45
+ try:
46
+ with torch.cuda.device(0):
47
+ torch.cuda.empty_cache()
48
+ torch.backends.cudnn.benchmark = True
49
+ torch.backends.cuda.matmul.allow_tf32 = True
50
+ except:
51
+ print("Warning: Could not configure CUDA settings")
52
+
53
+ # 번역 모델 초기화
54
+ model_name = "Helsinki-NLP/opus-mt-ko-en"
55
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
56
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to('cpu')
57
+ translator = pipeline("translation", model=model, tokenizer=tokenizer, device=-1)
58
+
59
+ def translate_to_english(text: str) -> str:
60
+ """한글 텍스트를 영어로 번역"""
61
+ try:
62
+ if any(ord('가') <= ord(char) <= ord('힣') for char in text):
63
+ translated = translator(text, max_length=128)[0]['translation_text']
64
+ print(f"Translated '{text}' to '{translated}'")
65
+ return translated
66
+ return text
67
+ except Exception as e:
68
+ print(f"Translation error: {str(e)}")
69
+ return text
70
+
71
+ BoundingBox = tuple[int, int, int, int]
72
+
73
+ pillow_heif.register_heif_opener()
74
+ pillow_heif.register_avif_opener()
75
+
76
+ # HF 토큰 설정
77
+ HF_TOKEN = os.getenv("HF_TOKEN")
78
+ if HF_TOKEN is None:
79
+ raise ValueError("Please set the HF_TOKEN environment variable")
80
+
81
+ try:
82
+ login(token=HF_TOKEN)
83
+ except Exception as e:
84
+ raise ValueError(f"Failed to login to Hugging Face: {str(e)}")
85
+
86
+ # 모델 초기화
87
+ segmenter = BoxSegmenter(device="cpu")
88
+ segmenter.device = device
89
+ segmenter.model = segmenter.model.to(device=segmenter.device)
90
+
91
+ gd_model_path = "IDEA-Research/grounding-dino-base"
92
+ gd_processor = GroundingDinoProcessor.from_pretrained(gd_model_path)
93
+ gd_model = GroundingDinoForObjectDetection.from_pretrained(gd_model_path, torch_dtype=torch.float32)
94
+ gd_model = gd_model.to(device=device)
95
+ assert isinstance(gd_model, GroundingDinoForObjectDetection)
96
+
97
+ # FLUX 파이프라인 초기화
98
+ pipe = FluxPipeline.from_pretrained(
99
+ "black-forest-labs/FLUX.1-dev",
100
+ torch_dtype=torch.float16,
101
+ use_auth_token=HF_TOKEN
102
+ )
103
+ pipe.enable_attention_slicing(slice_size="auto")
104
+
105
+ # LoRA 가중치 로드
106
+ pipe.load_lora_weights(
107
+ hf_hub_download(
108
+ "ByteDance/Hyper-SD",
109
+ "Hyper-FLUX.1-dev-8steps-lora.safetensors",
110
+ use_auth_token=HF_TOKEN
111
+ )
112
+ )
113
+ pipe.fuse_lora(lora_scale=0.125)
114
+
115
+ # GPU 설정을 try-except로 감싸기
116
+ try:
117
+ if torch.cuda.is_available():
118
+ pipe = pipe.to("cuda:0") # 명시적으로 cuda:0 지정
119
+ except Exception as e:
120
+ print(f"Warning: Could not move pipeline to CUDA: {str(e)}")
121
+
122
+ client = Client("NabeelShar/BiRefNet_for_text_writing")
123
+
124
+ class timer:
125
+ def __init__(self, method_name="timed process"):
126
+ self.method = method_name
127
+ def __enter__(self):
128
+ self.start = time.time()
129
+ print(f"{self.method} starts")
130
+ def __exit__(self, exc_type, exc_val, exc_tb):
131
+ end = time.time()
132
+ print(f"{self.method} took {str(round(end - self.start, 2))}s")
133
+
134
+ def bbox_union(bboxes: Sequence[list[int]]) -> BoundingBox | None:
135
+ if not bboxes:
136
+ return None
137
+ for bbox in bboxes:
138
+ assert len(bbox) == 4
139
+ assert all(isinstance(x, int) for x in bbox)
140
+ return (
141
+ min(bbox[0] for bbox in bboxes),
142
+ min(bbox[1] for bbox in bboxes),
143
+ max(bbox[2] for bbox in bboxes),
144
+ max(bbox[3] for bbox in bboxes),
145
+ )
146
+
147
+ def corners_to_pixels_format(bboxes: torch.Tensor, width: int, height: int) -> torch.Tensor:
148
+ x1, y1, x2, y2 = bboxes.round().to(torch.int32).unbind(-1)
149
+ return torch.stack((x1.clamp_(0, width), y1.clamp_(0, height), x2.clamp_(0, width), y2.clamp_(0, height)), dim=-1)
150
+
151
+ def gd_detect(img: Image.Image, prompt: str) -> BoundingBox | None:
152
+ inputs = gd_processor(images=img, text=f"{prompt}.", return_tensors="pt").to(device=device)
153
+ with no_grad():
154
+ outputs = gd_model(**inputs)
155
+ width, height = img.size
156
+ results: dict[str, Any] = gd_processor.post_process_grounded_object_detection(
157
+ outputs,
158
+ inputs["input_ids"],
159
+ target_sizes=[(height, width)],
160
+ )[0]
161
+ assert "boxes" in results and isinstance(results["boxes"], torch.Tensor)
162
+ bboxes = corners_to_pixels_format(results["boxes"].cpu(), width, height)
163
+ return bbox_union(bboxes.numpy().tolist())
164
+
165
+ def apply_mask(img: Image.Image, mask_img: Image.Image, defringe: bool = True) -> Image.Image:
166
+ assert img.size == mask_img.size
167
+ img = img.convert("RGB")
168
+ mask_img = mask_img.convert("L")
169
+ if defringe:
170
+ rgb, alpha = np.asarray(img) / 255.0, np.asarray(mask_img) / 255.0
171
+ foreground = cast(np.ndarray[Any, np.dtype[np.uint8]], estimate_foreground_ml(rgb, alpha))
172
+ img = Image.fromarray((foreground * 255).astype("uint8"))
173
+ result = Image.new("RGBA", img.size)
174
+ result.paste(img, (0, 0), mask_img)
175
+ return result
176
+
177
+
178
+ def adjust_size_to_multiple_of_8(width: int, height: int) -> tuple[int, int]:
179
+ """이미지 크기를 8의 배수로 조정하는 함수"""
180
+ new_width = ((width + 7) // 8) * 8
181
+ new_height = ((height + 7) // 8) * 8
182
+ return new_width, new_height
183
+
184
+ def calculate_dimensions(aspect_ratio: str, base_size: int = 512) -> tuple[int, int]:
185
+ """선택된 비율에 따라 이미지 크기 계산"""
186
+ if aspect_ratio == "1:1":
187
+ return base_size, base_size
188
+ elif aspect_ratio == "16:9":
189
+ return base_size * 16 // 9, base_size
190
+ elif aspect_ratio == "9:16":
191
+ return base_size, base_size * 16 // 9
192
+ elif aspect_ratio == "4:3":
193
+ return base_size * 4 // 3, base_size
194
+ return base_size, base_size
195
+
196
+ @spaces.GPU(duration=20) # 40초에서 20초로 감소
197
+ def generate_background(prompt: str, aspect_ratio: str) -> Image.Image:
198
+ try:
199
+ width, height = calculate_dimensions(aspect_ratio)
200
+ width, height = adjust_size_to_multiple_of_8(width, height)
201
+
202
+ max_size = 768
203
+ if width > max_size or height > max_size:
204
+ ratio = max_size / max(width, height)
205
+ width = int(width * ratio)
206
+ height = int(height * ratio)
207
+ width, height = adjust_size_to_multiple_of_8(width, height)
208
+
209
+ with timer("Background generation"):
210
+ try:
211
+ with torch.inference_mode():
212
+ image = pipe(
213
+ prompt=prompt,
214
+ width=width,
215
+ height=height,
216
+ num_inference_steps=8,
217
+ guidance_scale=4.0
218
+ ).images[0]
219
+ except Exception as e:
220
+ print(f"Pipeline error: {str(e)}")
221
+ return Image.new('RGB', (width, height), 'white')
222
+
223
+ return image
224
+ except Exception as e:
225
+ print(f"Background generation error: {str(e)}")
226
+ return Image.new('RGB', (512, 512), 'white')
227
+
228
+ def create_position_grid():
229
+ return """
230
+ <div class="position-grid" style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 10px; width: 150px; margin: auto;">
231
+ <button class="position-btn" data-pos="top-left">↖</button>
232
+ <button class="position-btn" data-pos="top-center">↑</button>
233
+ <button class="position-btn" data-pos="top-right">↗</button>
234
+ <button class="position-btn" data-pos="middle-left">←</button>
235
+ <button class="position-btn" data-pos="middle-center">•</button>
236
+ <button class="position-btn" data-pos="middle-right">→</button>
237
+ <button class="position-btn" data-pos="bottom-left">↙</button>
238
+ <button class="position-btn" data-pos="bottom-center" data-default="true">↓</button>
239
+ <button class="position-btn" data-pos="bottom-right">↘</button>
240
+ </div>
241
+ """
242
+
243
+ def calculate_object_position(position: str, bg_size: tuple[int, int], obj_size: tuple[int, int]) -> tuple[int, int]:
244
+ """오브젝트의 위치 계산"""
245
+ bg_width, bg_height = bg_size
246
+ obj_width, obj_height = obj_size
247
+
248
+ positions = {
249
+ "top-left": (0, 0),
250
+ "top-center": ((bg_width - obj_width) // 2, 0),
251
+ "top-right": (bg_width - obj_width, 0),
252
+ "middle-left": (0, (bg_height - obj_height) // 2),
253
+ "middle-center": ((bg_width - obj_width) // 2, (bg_height - obj_height) // 2),
254
+ "middle-right": (bg_width - obj_width, (bg_height - obj_height) // 2),
255
+ "bottom-left": (0, bg_height - obj_height),
256
+ "bottom-center": ((bg_width - obj_width) // 2, bg_height - obj_height),
257
+ "bottom-right": (bg_width - obj_width, bg_height - obj_height)
258
+ }
259
+
260
+ return positions.get(position, positions["bottom-center"])
261
+
262
+ def resize_object(image: Image.Image, scale_percent: float) -> Image.Image:
263
+ """오브젝트 크기 조정"""
264
+ width = int(image.width * scale_percent / 100)
265
+ height = int(image.height * scale_percent / 100)
266
+ return image.resize((width, height), Image.Resampling.LANCZOS)
267
+
268
+ def combine_with_background(foreground: Image.Image, background: Image.Image,
269
+ position: str = "bottom-center", scale_percent: float = 100) -> Image.Image:
270
+ """전경과 배경 합성 함수"""
271
+ print(f"Combining with position: {position}, scale: {scale_percent}")
272
+
273
+ result = background.convert('RGBA')
274
+ scaled_foreground = resize_object(foreground, scale_percent)
275
+
276
+ x, y = calculate_object_position(position, result.size, scaled_foreground.size)
277
+ print(f"Calculated position coordinates: ({x}, {y})")
278
+
279
+ result.paste(scaled_foreground, (x, y), scaled_foreground)
280
+ return result
281
+
282
+ @spaces.GPU(duration=30) # 120초에서 30초로 감소
283
+ def _gpu_process(img: Image.Image, prompt: str | BoundingBox | None) -> tuple[Image.Image, BoundingBox | None, list[str]]:
284
+ time_log: list[str] = []
285
+ try:
286
+ if isinstance(prompt, str):
287
+ t0 = time.time()
288
+ bbox = gd_detect(img, prompt)
289
+ time_log.append(f"detect: {time.time() - t0}")
290
+ if not bbox:
291
+ print(time_log[0])
292
+ raise gr.Error("No object detected")
293
+ else:
294
+ bbox = prompt
295
+ t0 = time.time()
296
+ mask = segmenter(img, bbox)
297
+ time_log.append(f"segment: {time.time() - t0}")
298
+ return mask, bbox, time_log
299
+ except Exception as e:
300
+ print(f"GPU process error: {str(e)}")
301
+ raise
302
+
303
+ def _process(img: Image.Image, prompt: str | BoundingBox | None, bg_prompt: str | None = None, aspect_ratio: str = "1:1") -> tuple[tuple[Image.Image, Image.Image, Image.Image], gr.DownloadButton]:
304
+ try:
305
+ # 입력 이미지 크기 제한
306
+ max_size = 1024
307
+ if img.width > max_size or img.height > max_size:
308
+ ratio = max_size / max(img.width, img.height)
309
+ new_size = (int(img.width * ratio), int(img.height * ratio))
310
+ img = img.resize(new_size, Image.LANCZOS)
311
+
312
+ # CUDA 메모리 관리 수정
313
+ try:
314
+ if torch.cuda.is_available():
315
+ current_device = torch.cuda.current_device()
316
+ with torch.cuda.device(current_device):
317
+ torch.cuda.empty_cache()
318
+ except Exception as e:
319
+ print(f"CUDA memory management failed: {e}")
320
+
321
+ with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
322
+ mask, bbox, time_log = _gpu_process(img, prompt)
323
+ masked_alpha = apply_mask(img, mask, defringe=True)
324
+
325
+ if bg_prompt:
326
+ background = generate_background(bg_prompt, aspect_ratio)
327
+ combined = background
328
+ else:
329
+ combined = Image.alpha_composite(Image.new("RGBA", masked_alpha.size, "white"), masked_alpha)
330
+
331
+ clear_memory()
332
+
333
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp:
334
+ combined.save(temp.name)
335
+ return (img, combined, masked_alpha), gr.DownloadButton(value=temp.name, interactive=True)
336
+ except Exception as e:
337
+ clear_memory()
338
+ print(f"Processing error: {str(e)}")
339
+ raise gr.Error(f"Processing failed: {str(e)}")
340
+
341
+ def on_change_bbox(prompts: dict[str, Any] | None):
342
+ return gr.update(interactive=prompts is not None)
343
+
344
+
345
+ def on_change_prompt(img: Image.Image | None, prompt: str | None, bg_prompt: str | None = None):
346
+ return gr.update(interactive=bool(img and prompt))
347
+
348
+
349
+ def process_prompt(img: Image.Image, prompt: str, bg_prompt: str | None = None,
350
+ aspect_ratio: str = "1:1", position: str = "bottom-center",
351
+ scale_percent: float = 100) -> tuple[Image.Image, Image.Image]:
352
+ try:
353
+ if img is None or prompt.strip() == "":
354
+ raise gr.Error("Please provide both image and prompt")
355
+
356
+ print(f"Processing with position: {position}, scale: {scale_percent}") # 디버깅용
357
+
358
+ try:
359
+ prompt = translate_to_english(prompt)
360
+ if bg_prompt:
361
+ bg_prompt = translate_to_english(bg_prompt)
362
+ except Exception as e:
363
+ print(f"Translation error (continuing with original text): {str(e)}")
364
+
365
+ results, _ = _process(img, prompt, bg_prompt, aspect_ratio)
366
+
367
+ if bg_prompt:
368
+ try:
369
+ print(f"Using position: {position}") # 디버깅용
370
+ # 위치 값 검증
371
+ valid_positions = ["top-left", "top-center", "top-right",
372
+ "middle-left", "middle-center", "middle-right",
373
+ "bottom-left", "bottom-center", "bottom-right"]
374
+ if position not in valid_positions:
375
+ position = "bottom-center"
376
+ print(f"Invalid position, using default: {position}")
377
+
378
+ combined = combine_with_background(
379
+ foreground=results[2],
380
+ background=results[1],
381
+ position=position,
382
+ scale_percent=scale_percent
383
+ )
384
+ return combined, results[2]
385
+ except Exception as e:
386
+ print(f"Combination error: {str(e)}")
387
+ return results[1], results[2]
388
+
389
+ return results[1], results[2] # 기본 반환 추가
390
+ except Exception as e:
391
+ print(f"Error in process_prompt: {str(e)}")
392
+ raise gr.Error(str(e))
393
+ finally:
394
+ clear_memory()
395
+
396
+
397
+ def process_bbox(img: Image.Image, box_input: str) -> tuple[Image.Image, Image.Image]:
398
+ try:
399
+ if img is None or box_input.strip() == "":
400
+ raise gr.Error("Please provide both image and bounding box coordinates")
401
+
402
+ try:
403
+ coords = eval(box_input)
404
+ if not isinstance(coords, list) or len(coords) != 4:
405
+ raise ValueError("Invalid box format")
406
+ bbox = tuple(int(x) for x in coords)
407
+ except:
408
+ raise gr.Error("Invalid box format. Please provide [xmin, ymin, xmax, ymax]")
409
+
410
+ # Process the image
411
+ results, _ = _process(img, bbox)
412
+
413
+ # 합성된 이미지와 추출된 이미지만 반환
414
+ return results[1], results[2]
415
+ except Exception as e:
416
+ raise gr.Error(str(e))
417
+
418
+ # Event handler functions 수정
419
+ def update_process_button(img, prompt):
420
+ return gr.update(
421
+ interactive=bool(img and prompt),
422
+ variant="primary" if bool(img and prompt) else "secondary"
423
+ )
424
+
425
+ def update_box_button(img, box_input):
426
+ try:
427
+ if img and box_input:
428
+ coords = eval(box_input)
429
+ if isinstance(coords, list) and len(coords) == 4:
430
+ return gr.update(interactive=True, variant="primary")
431
+ return gr.update(interactive=False, variant="secondary")
432
+ except:
433
+ return gr.update(interactive=False, variant="secondary")
434
+
435
+
436
+ css = """
437
+ footer {display: none}
438
+ .main-title {
439
+ text-align: center;
440
+ margin: 1em 0;
441
+ padding: 1.5em;
442
+ background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
443
+ border-radius: 15px;
444
+ box-shadow: 0 4px 6px rgba(0,0,0,0.1);
445
+ }
446
+ .main-title h1 {
447
+ color: #2196F3;
448
+ font-size: 2.8em;
449
+ margin-bottom: 0.3em;
450
+ font-weight: 700;
451
+ }
452
+ .main-title p {
453
+ color: #555;
454
+ font-size: 1.3em;
455
+ line-height: 1.4;
456
+ }
457
+ .container {
458
+ max-width: 1200px;
459
+ margin: auto;
460
+ padding: 20px;
461
+ }
462
+ .input-panel, .output-panel {
463
+ background: white;
464
+ padding: 1.5em;
465
+ border-radius: 12px;
466
+ box-shadow: 0 2px 8px rgba(0,0,0,0.08);
467
+ margin-bottom: 1em;
468
+ }
469
+ .controls-panel {
470
+ background: #f8f9fa;
471
+ padding: 1em;
472
+ border-radius: 8px;
473
+ margin: 1em 0;
474
+ }
475
+ .image-display {
476
+ min-height: 512px;
477
+ display: flex;
478
+ align-items: center;
479
+ justify-content: center;
480
+ background: #fafafa;
481
+ border-radius: 8px;
482
+ margin: 1em 0;
483
+ }
484
+ .example-section {
485
+ text-align: center;
486
+ padding: 2em;
487
+ background: #f5f5f5;
488
+ border-radius: 12px;
489
+ margin-top: 2em;
490
+ }
491
+ .example-section img {
492
+ max-width: 100%;
493
+ border-radius: 8px;
494
+ box-shadow: 0 4px 8px rgba(0,0,0,0.1);
495
+ }
496
+ .accordion {
497
+ border: 1px solid #e0e0e0;
498
+ border-radius: 8px;
499
+ margin: 1em 0;
500
+ }
501
+ .accordion-header {
502
+ padding: 1em;
503
+ background: #f5f5f5;
504
+ cursor: pointer;
505
+ }
506
+ .accordion-content {
507
+ padding: 1em;
508
+ display: none;
509
+ }
510
+ .accordion.open .accordion-content {
511
+ display: block;
512
+ }
513
+ .position-grid {
514
+ display: grid;
515
+ grid-template-columns: repeat(3, 1fr);
516
+ gap: 8px;
517
+ margin: 1em 0;
518
+ }
519
+ .position-btn {
520
+ padding: 10px;
521
+ border: 1px solid #ddd;
522
+ border-radius: 4px;
523
+ background: white;
524
+ cursor: pointer;
525
+ transition: all 0.3s ease;
526
+ width: 40px;
527
+ height: 40px;
528
+ display: flex;
529
+ align-items: center;
530
+ justify-content: center;
531
+ }
532
+ .position-btn:hover {
533
+ background: #e3f2fd;
534
+ }
535
+ .position-btn.selected {
536
+ background-color: #2196F3;
537
+ color: white;
538
+ border-color: #1976D2;
539
+ }
540
+ """
541
+
542
+
543
+ def add_text_with_stroke(draw, text, x, y, font, text_color, stroke_width):
544
+ """Helper function to draw text with stroke"""
545
+ # Draw the stroke/outline
546
+ for adj_x in range(-stroke_width, stroke_width + 1):
547
+ for adj_y in range(-stroke_width, stroke_width + 1):
548
+ draw.text((x + adj_x, y + adj_y), text, font=font, fill=text_color)
549
+
550
+ def remove_background(image):
551
+ # Save the image to a specific location
552
+ filename = f"image_{uuid.uuid4()}.png" # Generates a universally unique identifier (UUID) for the filename
553
+ image.save(filename)
554
+ # Call gradio client for background removal
555
+ result = client.predict(images=handle_file(filename), api_name="/image")
556
+ return Image.open(result[0])
557
+
558
+ def superimpose(image_with_text, overlay_image):
559
+ # Open image as RGBA to handle transparency
560
+ overlay_image = overlay_image.convert("RGBA")
561
+ # Paste overlay on the background
562
+ image_with_text.paste(overlay_image, (0, 0), overlay_image)
563
+ # Save the final image
564
+ # image_with_text.save("output_image.png")
565
+ return image_with_text
566
+
567
+ def add_text_to_image(
568
+ input_image,
569
+ text,
570
+ font_size,
571
+ color,
572
+ opacity,
573
+ x_position,
574
+ y_position,
575
+ thickness,
576
+ text_position_type,
577
+ font_choice
578
+ ):
579
+ try:
580
+ if input_image is None or text.strip() == "":
581
+ return input_image
582
+
583
+ # PIL Image 객체로 변환
584
+ if not isinstance(input_image, Image.Image):
585
+ if isinstance(input_image, np.ndarray):
586
+ image = Image.fromarray(input_image)
587
+ else:
588
+ raise ValueError("Unsupported image type")
589
+ else:
590
+ image = input_image.copy()
591
+
592
+ # 이미지를 RGBA 모드로 변환
593
+ if image.mode != 'RGBA':
594
+ image = image.convert('RGBA')
595
+
596
+ # 폰트 설정
597
+ font_files = {
598
+ "Default": "DejaVuSans.ttf",
599
+ "Korean Regular": "ko-Regular.ttf"
600
+ }
601
+
602
+ try:
603
+ font_file = font_files.get(font_choice, "DejaVuSans.ttf")
604
+ font = ImageFont.truetype(font_file, int(font_size))
605
+ except Exception as e:
606
+ print(f"Font loading error ({font_choice}): {str(e)}")
607
+ font = ImageFont.load_default()
608
+
609
+ # 색상 설정
610
+ color_map = {
611
+ 'White': (255, 255, 255),
612
+ 'Black': (0, 0, 0),
613
+ 'Red': (255, 0, 0),
614
+ 'Green': (0, 255, 0),
615
+ 'Blue': (0, 0, 255),
616
+ 'Yellow': (255, 255, 0),
617
+ 'Purple': (128, 0, 128)
618
+ }
619
+ rgb_color = color_map.get(color, (255, 255, 255))
620
+
621
+ # 임시 Draw 객체 생성하여 텍스트 크기 계산
622
+ temp_draw = ImageDraw.Draw(image)
623
+ text_bbox = temp_draw.textbbox((0, 0), text, font=font)
624
+ text_width = text_bbox[2] - text_bbox[0]
625
+ text_height = text_bbox[3] - text_bbox[1]
626
+
627
+ # 위치 계산
628
+ actual_x = int((image.width - text_width) * (x_position / 100))
629
+ actual_y = int((image.height - text_height) * (y_position / 100))
630
+
631
+ # 텍스트 색상 설정
632
+ text_color = (*rgb_color, int(opacity))
633
+
634
+ if text_position_type == "Text Behind Image":
635
+ try:
636
+ # 원본 이미지에서 전경 객체만 추출
637
+ foreground = remove_background(image)
638
+
639
+ # 배경 이미지 생성 (원본 이미지 복사)
640
+ background = image.copy()
641
+
642
+ # 텍스트를 그릴 임시 레이어 생성
643
+ text_layer = Image.new('RGBA', image.size, (255, 255, 255, 0))
644
+ draw_text = ImageDraw.Draw(text_layer)
645
+
646
+ # 텍스트 그리기
647
+ add_text_with_stroke(
648
+ draw_text,
649
+ text,
650
+ actual_x,
651
+ actual_y,
652
+ font,
653
+ text_color,
654
+ int(thickness)
655
+ )
656
+
657
+ # 배경에 텍스트 합성
658
+ background = Image.alpha_composite(background, text_layer)
659
+
660
+ # 텍스트가 있는 배경 위에 전경 객체 합성
661
+ output_image = Image.alpha_composite(background, foreground)
662
+ except Exception as e:
663
+ print(f"Error in Text Behind Image processing: {str(e)}")
664
+ return input_image
665
+ else:
666
+ # 텍스트 오버레이 생성
667
+ txt_overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
668
+ draw = ImageDraw.Draw(txt_overlay)
669
+
670
+ # 텍스트를 이미지 위에 그리기
671
+ add_text_with_stroke(
672
+ draw,
673
+ text,
674
+ actual_x,
675
+ actual_y,
676
+ font,
677
+ text_color,
678
+ int(thickness)
679
+ )
680
+ output_image = Image.alpha_composite(image, txt_overlay)
681
+
682
+ # RGB로 변환
683
+ output_image = output_image.convert('RGB')
684
+
685
+ return output_image
686
+
687
+ except Exception as e:
688
+ print(f"Error in add_text_to_image: {str(e)}")
689
+ return input_image
690
+
691
+
692
+ def update_position(new_position):
693
+ """위치 업데이트 함수"""
694
+ print(f"Position updated to: {new_position}")
695
+ return new_position
696
+
697
+ def update_controls(bg_prompt):
698
+ """배경 프롬프트 입력 여부에 따라 컨트롤 표시 업데이트"""
699
+ is_visible = bool(bg_prompt)
700
+ return [
701
+ gr.update(visible=is_visible), # aspect_ratio
702
+ gr.update(visible=is_visible), # object_controls
703
+ ]
704
+
705
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
706
+ position = gr.State(value="bottom-center") # 여기로 이동
707
+
708
+ gr.HTML("""
709
+ <div class="main-title">
710
+ <h1>🎨 GiniGen Canvas-o3</h1>
711
+ <p>Remove background of specified objects, generate new backgrounds, and insert text over or behind images with prompts.</p>
712
+ </div>
713
+ """)
714
+
715
+ with gr.Row(equal_height=True):
716
+ # 왼쪽 패널 (입력)
717
+ with gr.Column(scale=1):
718
+ with gr.Group(elem_classes="input-panel"):
719
+ input_image = gr.Image(
720
+ type="pil",
721
+ label="Upload Image",
722
+ interactive=True,
723
+ height=400
724
+ )
725
+ text_prompt = gr.Textbox(
726
+ label="Object to Extract",
727
+ placeholder="Enter what you want to extract...",
728
+ interactive=True
729
+ )
730
+ with gr.Row():
731
+ bg_prompt = gr.Textbox(
732
+ label="Background Prompt (optional)",
733
+ placeholder="Describe the background...",
734
+ interactive=True,
735
+ scale=3
736
+ )
737
+ aspect_ratio = gr.Dropdown(
738
+ choices=["1:1", "16:9", "9:16", "4:3"],
739
+ value="1:1",
740
+ label="Aspect Ratio",
741
+ interactive=True,
742
+ visible=True,
743
+ scale=1
744
+ )
745
+
746
+ with gr.Group(elem_classes="controls-panel", visible=False) as object_controls:
747
+ with gr.Column(scale=1):
748
+ position = gr.State(value="bottom-center") # 초기값 설정
749
+ with gr.Row():
750
+ btn_top_left = gr.Button("↖", elem_classes="position-btn")
751
+ btn_top_center = gr.Button("↑", elem_classes="position-btn")
752
+ btn_top_right = gr.Button("↗", elem_classes="position-btn")
753
+ with gr.Row():
754
+ btn_middle_left = gr.Button("←", elem_classes="position-btn")
755
+ btn_middle_center = gr.Button("•", elem_classes="position-btn")
756
+ btn_middle_right = gr.Button("→", elem_classes="position-btn")
757
+ with gr.Row():
758
+ btn_bottom_left = gr.Button("↙", elem_classes="position-btn")
759
+ btn_bottom_center = gr.Button("↓", elem_classes="position-btn", value="selected")
760
+ btn_bottom_right = gr.Button("↘", elem_classes="position-btn")
761
+ with gr.Column(scale=1):
762
+ scale_slider = gr.Slider(
763
+ minimum=10,
764
+ maximum=200,
765
+ value=50,
766
+ step=5,
767
+ label="Object Size (%)"
768
+ )
769
+
770
+ process_btn = gr.Button(
771
+ "Process",
772
+ variant="primary",
773
+ interactive=False,
774
+ size="lg"
775
+ )
776
+
777
+ # 오른쪽 패널 (출력)
778
+ with gr.Column(scale=1):
779
+ with gr.Group(elem_classes="output-panel"):
780
+ with gr.Tab("Result"):
781
+ combined_image = gr.Image(
782
+ label="Combined Result",
783
+ show_download_button=True,
784
+ type="pil",
785
+ height=400
786
+ )
787
+
788
+ # 텍스트 삽입 옵션을 Accordion으로 변경
789
+ with gr.Accordion("Text Insertion Options", open=False):
790
+ with gr.Group():
791
+ with gr.Row():
792
+ text_input = gr.Textbox(
793
+ label="Text Content",
794
+ placeholder="Enter text to add..."
795
+ )
796
+ text_position_type = gr.Radio(
797
+ choices=["Text Over Image", "Text Behind Image"],
798
+ value="Text Over Image",
799
+ label="Text Position"
800
+ )
801
+
802
+ with gr.Row():
803
+ with gr.Column(scale=1):
804
+ font_choice = gr.Dropdown(
805
+ choices=["Default", "Korean Regular"], # "Korean Son" 제거
806
+
807
+ value="Default",
808
+ label="Font Selection",
809
+ interactive=True
810
+ )
811
+
812
+
813
+ font_size = gr.Slider(
814
+ minimum=10,
815
+ maximum=200,
816
+ value=40,
817
+ step=5,
818
+ label="Font Size"
819
+ )
820
+ color_dropdown = gr.Dropdown(
821
+ choices=["White", "Black", "Red", "Green", "Blue", "Yellow", "Purple"],
822
+ value="White",
823
+ label="Text Color"
824
+ )
825
+ thickness = gr.Slider(
826
+ minimum=0,
827
+ maximum=10,
828
+ value=1,
829
+ step=1,
830
+ label="Text Thickness"
831
+ )
832
+ with gr.Column(scale=1):
833
+ opacity_slider = gr.Slider(
834
+ minimum=0,
835
+ maximum=255,
836
+ value=255,
837
+ step=1,
838
+ label="Opacity"
839
+ )
840
+ x_position = gr.Slider(
841
+ minimum=0,
842
+ maximum=100,
843
+ value=50,
844
+ step=1,
845
+ label="Left(0%)~Right(100%)"
846
+ )
847
+ y_position = gr.Slider(
848
+ minimum=0,
849
+ maximum=100,
850
+ value=50,
851
+ step=1,
852
+ label="High(0%)~Low(100%)"
853
+ )
854
+ add_text_btn = gr.Button("Apply Text", variant="primary")
855
+
856
+ extracted_image = gr.Image(
857
+ label="Extracted Object",
858
+ show_download_button=True,
859
+ type="pil",
860
+ height=200
861
+ )
862
+
863
+ # CSS 클래스를 위한 스타일 추가
864
+ gr.HTML("""
865
+ <style>
866
+ .position-btn.selected {
867
+ background-color: #2196F3 !important;
868
+ color: white !important;
869
+ }
870
+ </style>
871
+ """)
872
+
873
+ # 버튼 클릭 이벤트 바인딩
874
+ position_mapping = {
875
+ btn_top_left: "top-left",
876
+ btn_top_center: "top-center",
877
+ btn_top_right: "top-right",
878
+ btn_middle_left: "middle-left",
879
+ btn_middle_center: "middle-center",
880
+ btn_middle_right: "middle-right",
881
+ btn_bottom_left: "bottom-left",
882
+ btn_bottom_center: "bottom-center",
883
+ btn_bottom_right: "bottom-right"
884
+ }
885
+
886
+ for btn, pos in position_mapping.items():
887
+ btn.click(
888
+ fn=lambda pos=pos: update_position(pos), # 클로저 문제 해결을 위해 수정
889
+ outputs=position
890
+ )
891
+
892
+
893
+ # 이벤트 바인딩
894
+ bg_prompt.change(
895
+ fn=update_controls,
896
+ inputs=bg_prompt,
897
+ outputs=[aspect_ratio, object_controls],
898
+ queue=False
899
+ )
900
+
901
+ input_image.change(
902
+ fn=update_process_button,
903
+ inputs=[input_image, text_prompt],
904
+ outputs=process_btn,
905
+ queue=False
906
+ )
907
+
908
+ text_prompt.change(
909
+ fn=update_process_button,
910
+ inputs=[input_image, text_prompt],
911
+ outputs=process_btn,
912
+ queue=False
913
+ )
914
+
915
+ process_btn.click(
916
+ fn=process_prompt,
917
+ inputs=[
918
+ input_image,
919
+ text_prompt,
920
+ bg_prompt,
921
+ aspect_ratio,
922
+ position,
923
+ scale_slider
924
+ ],
925
+ outputs=[combined_image, extracted_image],
926
+ queue=True
927
+ )
928
+
929
+ # 이벤트 바인딩 부분에서
930
+ add_text_btn.click(
931
+ fn=add_text_to_image,
932
+ inputs=[
933
+ combined_image, # 첫 번째 인자로 이미지
934
+ text_input, # 두 번째 인자로 텍스트
935
+ font_size,
936
+ color_dropdown,
937
+ opacity_slider,
938
+ x_position,
939
+ y_position,
940
+ thickness,
941
+ text_position_type,
942
+ font_choice
943
+ ],
944
+ outputs=combined_image,
945
+ api_name="add_text" # API 이름 추가
946
+ )
947
+
948
+
949
+
950
+ demo.queue(max_size=5)
951
+ demo.launch(
952
+ server_name="0.0.0.0",
953
+ server_port=7860,
954
+ share=False,
955
+ max_threads=2
956
+ )