yonishafir commited on
Commit
2e398f7
·
verified ·
1 Parent(s): e21dae7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +457 -0
app.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import gradio as gr
4
+
5
+
6
+ import cv2
7
+ import torch
8
+ import numpy as np
9
+ from PIL import Image
10
+
11
+ from transformers import CLIPVisionModelWithProjection
12
+ from diffusers.utils import load_image
13
+ from diffusers.models import ControlNetModel
14
+ # from diffusers.image_processor import IPAdapterMaskProcessor
15
+ from insightface.app import FaceAnalysis
16
+ # import sys
17
+ # import glob
18
+ # import os
19
+ import io
20
+ import spaces
21
+
22
+ from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps
23
+
24
+ import pandas as pd
25
+ import json
26
+ import requests
27
+ from PIL import Image
28
+ from io import BytesIO
29
+
30
+
31
+ def resize_img(input_image, max_side=1280, min_side=1024, size=None,
32
+ pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64):
33
+
34
+ w, h = input_image.size
35
+ if size is not None:
36
+ w_resize_new, h_resize_new = size
37
+ else:
38
+ ratio = min_side / min(h, w)
39
+ w, h = round(ratio*w), round(ratio*h)
40
+ ratio = max_side / max(h, w)
41
+ input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode)
42
+ w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
43
+ h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
44
+ input_image = input_image.resize([w_resize_new, h_resize_new], mode)
45
+
46
+ if pad_to_max_side:
47
+ res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
48
+ offset_x = (max_side - w_resize_new) // 2
49
+ offset_y = (max_side - h_resize_new) // 2
50
+ res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image)
51
+ input_image = Image.fromarray(res)
52
+ return input_image
53
+
54
+ def process_image_by_bbox_larger(input_image, bbox_xyxy, min_bbox_ratio=0.2):
55
+ """
56
+ Process an image based on a bounding box, cropping and resizing as necessary.
57
+
58
+ Parameters:
59
+ - input_image: PIL Image object.
60
+ - bbox_xyxy: Tuple (x1, y1, x2, y2) representing the bounding box coordinates.
61
+
62
+ Returns:
63
+ - A processed image cropped and resized to 1024x1024 if the bounding box is valid,
64
+ or None if the bounding box does not meet the required size criteria.
65
+ """
66
+ # Constants
67
+ target_size = 1024
68
+ # min_bbox_ratio = 0.2 # Bounding box should be at least 20% of the crop
69
+
70
+ # Extract bounding box coordinates
71
+ x1, y1, x2, y2 = bbox_xyxy
72
+ bbox_w = x2 - x1
73
+ bbox_h = y2 - y1
74
+
75
+ # Calculate the area of the bounding box
76
+ bbox_area = bbox_w * bbox_h
77
+
78
+ # Start with the smallest square crop that allows bbox to be at least 20% of the crop area
79
+ crop_size = max(bbox_w, bbox_h)
80
+ initial_crop_area = crop_size * crop_size
81
+ while (bbox_area / initial_crop_area) < min_bbox_ratio:
82
+ crop_size += 10 # Gradually increase until bbox is at least 20% of the area
83
+ initial_crop_area = crop_size * crop_size
84
+
85
+ # Once the minimum condition is satisfied, try to expand the crop further
86
+ max_possible_crop_size = min(input_image.width, input_image.height)
87
+ while crop_size < max_possible_crop_size:
88
+ # Calculate a potential new area
89
+ new_crop_size = crop_size + 10
90
+ new_crop_area = new_crop_size * new_crop_size
91
+ if (bbox_area / new_crop_area) < min_bbox_ratio:
92
+ break # Stop if expanding further violates the 20% rule
93
+ crop_size = new_crop_size
94
+
95
+ # Determine the center of the bounding box
96
+ center_x = (x1 + x2) // 2
97
+ center_y = (y1 + y2) // 2
98
+
99
+ # Calculate the crop coordinates centered around the bounding box
100
+ crop_x1 = max(0, center_x - crop_size // 2)
101
+ crop_y1 = max(0, center_y - crop_size // 2)
102
+ crop_x2 = min(input_image.width, crop_x1 + crop_size)
103
+ crop_y2 = min(input_image.height, crop_y1 + crop_size)
104
+
105
+ # Ensure the crop is square, adjust if it goes out of image bounds
106
+ if crop_x2 - crop_x1 != crop_y2 - crop_y1:
107
+ side_length = min(crop_x2 - crop_x1, crop_y2 - crop_y1)
108
+ crop_x2 = crop_x1 + side_length
109
+ crop_y2 = crop_y1 + side_length
110
+
111
+ # Crop the image
112
+ cropped_image = input_image.crop((crop_x1, crop_y1, crop_x2, crop_y2))
113
+
114
+ # Resize the cropped image to 1024x1024
115
+ resized_image = cropped_image.resize((target_size, target_size), Image.LANCZOS)
116
+
117
+ return resized_image
118
+
119
+ def calc_emb_cropped(image, app):
120
+ face_image = image.copy()
121
+
122
+ face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
123
+
124
+ face_info = face_info[0]
125
+
126
+ cropped_face_image = process_image_by_bbox_larger(face_image, face_info["bbox"], min_bbox_ratio=0.2)
127
+
128
+ return cropped_face_image
129
+
130
+ def process_benchmark_csv(banchmark_csv_path):
131
+ # Reading the first CSV file into a DataFrame
132
+ df = pd.read_csv(banchmark_csv_path)
133
+
134
+ # Drop any unnamed columns
135
+ df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
136
+
137
+ # Drop columns with all NaN values
138
+ df.dropna(axis=1, how='all', inplace=True)
139
+
140
+ # Drop rows with all NaN values
141
+ df.dropna(axis=0, how='all', inplace=True)
142
+
143
+ df = df.loc[df['High resolution'] == 1]
144
+
145
+ df.reset_index(drop=True, inplace=True)
146
+
147
+ return df
148
+
149
+ def make_canny_condition(image, min_val=100, max_val=200, w_bilateral=True):
150
+ if w_bilateral:
151
+ image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
152
+ bilateral_filtered_image = cv2.bilateralFilter(image, d=9, sigmaColor=75, sigmaSpace=75)
153
+ image = cv2.Canny(bilateral_filtered_image, min_val, max_val)
154
+ else:
155
+ image = np.array(image)
156
+ image = cv2.Canny(image, min_val, max_val)
157
+ image = image[:, :, None]
158
+ image = np.concatenate([image, image, image], axis=2)
159
+ image = Image.fromarray(image)
160
+ return image
161
+
162
+
163
+ default_negative_prompt = "Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers"
164
+
165
+ # Load face detection and recognition package
166
+ app = FaceAnalysis(name='antelopev2', root='./', providers=['CPUExecutionProvider'])
167
+ app.prepare(ctx_id=0, det_size=(640, 640))
168
+
169
+ base_dir = "./instantID_ckpt/checkpoint_174000"
170
+ face_adapter = f'{base_dir}/pytorch_model.bin'
171
+ controlnet_path = f'{base_dir}/controlnet'
172
+ base_model_path = f'briaai/BRIA-2.3'
173
+ resolution = 1024
174
+
175
+ controlnet_lnmks = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
176
+
177
+ controlnet_canny = ControlNetModel.from_pretrained("briaai/BRIA-2.3-ControlNet-Canny",
178
+ torch_dtype=torch.float16)
179
+
180
+ controlnet = [controlnet_lnmks, controlnet_canny]
181
+
182
+ device = "cuda" if torch.cuda.is_available() else "cpu"
183
+
184
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
185
+ '/home/ubuntu/BRIA-2.3-InstantID/ip_adapter/image_encoder',
186
+ torch_dtype=torch.float16,
187
+ )
188
+
189
+ pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
190
+ base_model_path,
191
+ controlnet=controlnet,
192
+ torch_dtype=torch.float16,
193
+ image_encoder=image_encoder # For compatibility issues - needs to be there
194
+ )
195
+
196
+ pipe = pipe.to(device)
197
+
198
+ use_native_ip_adapter = True
199
+ pipe.use_native_ip_adapter=use_native_ip_adapter
200
+
201
+ pipe.load_ip_adapter_instantid(face_adapter)
202
+
203
+ clip_embeds=None
204
+
205
+
206
+ Loras_dict = {
207
+ "":"",
208
+ "Vangogh_Vanilla": "bold, dramatic brush strokes, vibrant colors, swirling patterns, intense, emotionally charged paintings of",
209
+ "Avatar_internlm": "2d anime sketch avatar of",
210
+ # "Tomer_Hanuka_V3": "Fluid lines",
211
+ "Storyboards": "Illustration style for storyboarding",
212
+ "3D_illustration": "3D object illustration, abstract",
213
+ # "beetl_general_death_style_v2": "a pale, dead, unnatural color face with dark circles around the eyes",
214
+ "Characters": "gaming vector Art"
215
+ }
216
+
217
+ lora_names = Loras_dict.keys()
218
+
219
+ lora_base_path = "./LoRAs"
220
+
221
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
222
+ if randomize_seed:
223
+ seed = random.randint(0, 99999999)
224
+ return seed
225
+
226
+
227
+ @spaces.GPU
228
+ def generate_image(image_path, prompt, num_steps, guidance_scale, seed, num_images, ip_adapter_scale=0.8, kps_scale=0.6, canny_scale=0.4, lora_name="", lora_scale=0.7, progress=gr.Progress(track_tqdm=True)):
229
+ if image_path is None:
230
+ raise gr.Error(f"Cannot find any input face image! Please upload a face image.")
231
+
232
+ # img = np.array(Image.open(image_path))[:,:,::-1]
233
+ img = Image.open(image_path)
234
+
235
+ face_image_orig = img #Image.open(BytesIO(response.content))
236
+ face_image_cropped = calc_emb_cropped(face_image_orig, app)
237
+ face_image = resize_img(face_image_cropped, max_side=resolution, min_side=resolution)
238
+ # face_image_padded = resize_img(face_image_cropped, max_side=resolution, min_side=resolution, pad_to_max_side=True)
239
+ face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
240
+ face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face
241
+ face_emb = face_info['embedding']
242
+ face_kps = draw_kps(face_image, face_info['kps'])
243
+
244
+ if canny_scale>0.0:
245
+ # Convert PIL image to a file-like object
246
+ image_file = io.BytesIO()
247
+ face_image_cropped.save(image_file, format='JPEG') # Save in the desired format (e.g., 'JPEG' or 'PNG')
248
+ image_file.seek(0) # Move to the start of the BytesIO stream
249
+
250
+ url = "https://engine.prod.bria-api.com/v1/background/remove"
251
+
252
+ payload = {}
253
+ files = [
254
+ ('file', ('image_name.jpeg', image_file, 'image/jpeg')) # Specify file name, file-like object, and MIME type
255
+ ]
256
+ headers = {
257
+ 'api_token': 'a10d6386dd6a11ebba800242ac130004'
258
+ }
259
+
260
+ response = requests.request("POST", url, headers=headers, data=payload, files=files)
261
+
262
+ print(response.text)
263
+
264
+ response_json = json.loads(response.content.decode('utf-8'))
265
+
266
+ img = requests.get(response_json['result_url'])
267
+
268
+ processed_image = Image.open(io.BytesIO(img.content))
269
+
270
+ # Assuming `processed_image` is the RGBA image returned
271
+ if processed_image.mode == 'RGBA':
272
+ # Create a white background image
273
+ white_background = Image.new("RGB", processed_image.size, (255, 255, 255))
274
+ # Composite the RGBA image over the white background
275
+ face_image = Image.alpha_composite(white_background.convert('RGBA'), processed_image).convert('RGB')
276
+ else:
277
+ face_image = processed_image.convert('RGB') # If already RGB, just ensure mode is correct
278
+
279
+ canny_img = make_canny_condition(face_image, min_val=20, max_val=40, w_bilateral=True)
280
+
281
+ generator = torch.Generator(device=device).manual_seed(seed)
282
+
283
+ if lora_name != "":
284
+ lora_path = os.path.join(lora_base_path, lora_name, "pytorch_lora_weights.safetensors")
285
+ pipe.load_lora_weights(lora_path)
286
+ pipe.fuse_lora(lora_scale)
287
+ pipe.enable_lora()
288
+
289
+ lora_prefix = Loras_dict[lora_name]
290
+
291
+ prompt = f"{lora_prefix} {prompt}"
292
+
293
+
294
+ print("Start inference...")
295
+ images = pipe(
296
+ prompt = prompt,
297
+ negative_prompt = default_negative_prompt,
298
+ image_embeds = face_emb,
299
+ image = [face_kps, canny_img] if canny_scale>0.0 else face_kps,
300
+ controlnet_conditioning_scale = [kps_scale, canny_scale] if canny_scale>0.0 else kps_scale,
301
+ control_guidance_end = [1.0, 1.0] if canny_scale>0.0 else 1.0,
302
+ ip_adapter_scale = ip_adapter_scale,
303
+ num_inference_steps = num_steps,
304
+ guidance_scale = guidance_scale,
305
+ generator = generator,
306
+ visual_prompt_embds = clip_embeds,
307
+ cross_attention_kwargs = None,
308
+ num_images_per_prompt=num_images,
309
+ ).images #[0]
310
+
311
+ if lora_name != "":
312
+ pipe.disable_lora()
313
+ pipe.unfuse_lora()
314
+ pipe.unload_lora_weights()
315
+
316
+ return images
317
+
318
+ ### Description
319
+ title = r"""
320
+ <h1>Bria-2.3 ID preservation</h1>
321
+ """
322
+
323
+ description = r"""
324
+ <b>🤗 Gradio demo</b> for bria ID preservation.<br>
325
+
326
+ Steps:<br>
327
+ 1. Upload an image with a face. If multiple faces are detected, we use the largest one. For images with already tightly cropped faces, detection may fail, try images with a larger margin.
328
+ 2. Click <b>Submit</b> to generate new images of the subject.
329
+ """
330
+
331
+ Footer = r"""
332
+ Enjoy
333
+ """
334
+
335
+ css = '''
336
+ .gradio-container {width: 85% !important}
337
+ '''
338
+ with gr.Blocks(css=css) as demo:
339
+
340
+ # description
341
+ gr.Markdown(title)
342
+ gr.Markdown(description)
343
+
344
+ with gr.Row():
345
+ with gr.Column():
346
+
347
+ # upload face image
348
+ img_file = gr.Image(label="Upload a photo with a face", type="filepath")
349
+
350
+ # Textbox for entering a prompt
351
+ prompt = gr.Textbox(
352
+ label="Prompt",
353
+ placeholder="Enter your prompt here",
354
+ info="Describe what you want to generate or modify in the image."
355
+ )
356
+
357
+ lora_name = gr.Dropdown(choices=lora_names, label="LoRA", value="", info="Select a LoRA name from the list, not selecting any will disable LoRA.")
358
+
359
+ submit = gr.Button("Submit", variant="primary")
360
+
361
+ # use_lcm = gr.Checkbox(
362
+ # label="Use LCM-LoRA to accelerate sampling", value=False,
363
+ # info="Reduces sampling steps significantly, but may decrease quality.",
364
+ # )
365
+
366
+ with gr.Accordion(open=False, label="Advanced Options"):
367
+ num_steps = gr.Slider(
368
+ label="Number of sample steps",
369
+ minimum=1,
370
+ maximum=100,
371
+ step=1,
372
+ value=30,
373
+ )
374
+ guidance_scale = gr.Slider(
375
+ label="Guidance scale",
376
+ minimum=0.1,
377
+ maximum=10.0,
378
+ step=0.1,
379
+ value=5.0,
380
+ )
381
+ num_images = gr.Slider(
382
+ label="Number of output images",
383
+ minimum=1,
384
+ maximum=3,
385
+ step=1,
386
+ value=1,
387
+ )
388
+ ip_adapter_scale = gr.Slider(
389
+ label="ip adapter scale",
390
+ minimum=0.0,
391
+ maximum=1.0,
392
+ step=0.01,
393
+ value=0.8,
394
+ )
395
+ kps_scale = gr.Slider(
396
+ label="kps control scale",
397
+ minimum=0.0,
398
+ maximum=1.0,
399
+ step=0.01,
400
+ value=0.6,
401
+ )
402
+ canny_scale = gr.Slider(
403
+ label="canny control scale",
404
+ minimum=0.0,
405
+ maximum=1.0,
406
+ step=0.01,
407
+ value=0.4,
408
+ )
409
+ seed = gr.Slider(
410
+ label="Seed",
411
+ minimum=0,
412
+ maximum=99999999,
413
+ step=1,
414
+ value=0,
415
+ )
416
+ seed = gr.Slider(
417
+ label="lora_scale",
418
+ minimum=0.0,
419
+ maximum=1.0,
420
+ step=0.01,
421
+ value=0.7,
422
+ )
423
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
424
+
425
+ with gr.Column():
426
+ gallery = gr.Gallery(label="Generated Images")
427
+
428
+ submit.click(
429
+ fn=randomize_seed_fn,
430
+ inputs=[seed, randomize_seed],
431
+ outputs=seed,
432
+ queue=False,
433
+ api_name=False,
434
+ ).then(
435
+ fn=generate_image,
436
+ inputs=[img_file, prompt, num_steps, guidance_scale, seed, num_images, ip_adapter_scale, kps_scale, canny_scale, lora_name],
437
+ outputs=[gallery]
438
+ )
439
+
440
+ # use_lcm.input(
441
+ # fn=toggle_lcm_ui,
442
+ # inputs=[use_lcm],
443
+ # outputs=[num_steps, guidance_scale],
444
+ # queue=False,
445
+ # )
446
+
447
+ # gr.Examples(
448
+ # examples=get_example(),
449
+ # inputs=[img_file],
450
+ # run_on_click=True,
451
+ # fn=run_example,
452
+ # outputs=[gallery],
453
+ # )
454
+
455
+ gr.Markdown(Footer)
456
+
457
+ demo.launch()