Upload 3 files
Browse files- engine.py +97 -0
- requirements.txt +18 -0
- utils.py +97 -0
engine.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchvision import transforms
|
3 |
+
from torchvision.transforms.functional import to_pil_image, to_tensor
|
4 |
+
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
|
5 |
+
from diffusers import StableDiffusionInpaintPipeline, EulerDiscreteScheduler
|
6 |
+
from PIL import Image
|
7 |
+
import numpy as np
|
8 |
+
import cv2
|
9 |
+
|
10 |
+
class SegmentAnythingModel:
|
11 |
+
def __init__(self, sam_checkpoint, model_type, device):
|
12 |
+
self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
13 |
+
self.sam.to(device=device)
|
14 |
+
self.mask_generator = SamAutomaticMaskGenerator(
|
15 |
+
model=self.sam,
|
16 |
+
points_per_side=32,
|
17 |
+
pred_iou_thresh=0.99,
|
18 |
+
stability_score_thresh=0.92,
|
19 |
+
crop_n_layers=1,
|
20 |
+
crop_n_points_downscale_factor=2,
|
21 |
+
min_mask_region_area=100
|
22 |
+
)
|
23 |
+
self.target_size = (512, 512)
|
24 |
+
|
25 |
+
def preprocess_image(self, image):
|
26 |
+
"""Resize image to 512x512"""
|
27 |
+
if isinstance(image, str):
|
28 |
+
image = Image.open(image)
|
29 |
+
elif isinstance(image, np.ndarray):
|
30 |
+
image = Image.fromarray(image)
|
31 |
+
|
32 |
+
# Get current dimensions
|
33 |
+
width, height = image.size
|
34 |
+
|
35 |
+
# Resize to 512x512 directly
|
36 |
+
image = image.resize(self.target_size, Image.Resampling.LANCZOS)
|
37 |
+
return np.array(image)
|
38 |
+
|
39 |
+
def generate_masks(self, image):
|
40 |
+
processed_image = self.preprocess_image(image)
|
41 |
+
return self.mask_generator.generate(processed_image)
|
42 |
+
|
43 |
+
class StableDiffusionInpaintingPipeline:
|
44 |
+
def __init__(self, model_dir):
|
45 |
+
# Initialize the scheduler first
|
46 |
+
self.scheduler = EulerDiscreteScheduler.from_pretrained(model_dir, subfolder="scheduler")
|
47 |
+
|
48 |
+
# Initialize the pipeline with the scheduler
|
49 |
+
self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
50 |
+
model_dir,
|
51 |
+
scheduler=self.scheduler,
|
52 |
+
revision="fp16",
|
53 |
+
torch_dtype=torch.float16
|
54 |
+
)
|
55 |
+
self.pipe = self.pipe.to("cuda")
|
56 |
+
self.pipe.enable_xformers_memory_efficient_attention()
|
57 |
+
self.target_size = (512, 512)
|
58 |
+
|
59 |
+
def preprocess_image(self, image):
|
60 |
+
"""Ensure image is in the right format and size"""
|
61 |
+
if isinstance(image, np.ndarray):
|
62 |
+
image = Image.fromarray(image)
|
63 |
+
return image.resize(self.target_size, Image.Resampling.LANCZOS)
|
64 |
+
|
65 |
+
def inpaint(self, prompt, image, mask_image, guidance_scale=10, num_inference_steps=60, generator=None):
|
66 |
+
"""
|
67 |
+
Args:
|
68 |
+
prompt (str): The prompt for inpainting
|
69 |
+
image (PIL.Image or np.ndarray): The original image
|
70 |
+
mask_image (PIL.Image or np.ndarray): The mask for inpainting
|
71 |
+
guidance_scale (float): Higher guidance scale encourages images that are closer to the prompt
|
72 |
+
num_inference_steps (int): Number of denoising steps
|
73 |
+
generator (torch.Generator): Generator for reproducibility
|
74 |
+
"""
|
75 |
+
# Preprocess images
|
76 |
+
if isinstance(image, np.ndarray):
|
77 |
+
image = Image.fromarray(image)
|
78 |
+
if isinstance(mask_image, np.ndarray):
|
79 |
+
mask_image = Image.fromarray(mask_image)
|
80 |
+
|
81 |
+
# Resize images
|
82 |
+
image = image.resize(self.target_size, Image.Resampling.LANCZOS)
|
83 |
+
mask_image = mask_image.resize(self.target_size, Image.Resampling.NEAREST)
|
84 |
+
|
85 |
+
# Run inpainting
|
86 |
+
output = self.pipe(
|
87 |
+
prompt=prompt,
|
88 |
+
image=image,
|
89 |
+
mask_image=mask_image,
|
90 |
+
guidance_scale=guidance_scale,
|
91 |
+
num_inference_steps=num_inference_steps,
|
92 |
+
generator=generator,
|
93 |
+
height=512,
|
94 |
+
width=512
|
95 |
+
)
|
96 |
+
|
97 |
+
return output.images[0]
|
requirements.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
regex
|
2 |
+
tqdm
|
3 |
+
diffusers
|
4 |
+
transformers
|
5 |
+
scipy
|
6 |
+
accelerate
|
7 |
+
opencv-python
|
8 |
+
Xformers
|
9 |
+
gradio
|
10 |
+
torch
|
11 |
+
torchvision
|
12 |
+
Pillow
|
13 |
+
matplotlib
|
14 |
+
numpy
|
15 |
+
git+https://github.com/facebookresearch/segment-anything.git
|
16 |
+
pycocotools
|
17 |
+
onnxruntime
|
18 |
+
onnx
|
utils.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import PIL
|
2 |
+
import numpy as np
|
3 |
+
import copy
|
4 |
+
import cv2
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
from torchvision.transforms.functional import to_pil_image
|
7 |
+
import torch
|
8 |
+
from PIL import Image
|
9 |
+
import matplotlib
|
10 |
+
matplotlib.use('Agg')
|
11 |
+
|
12 |
+
def show_anns(anns, ax=None):
|
13 |
+
if len(anns) == 0:
|
14 |
+
return
|
15 |
+
if ax is None:
|
16 |
+
ax = plt.gca()
|
17 |
+
|
18 |
+
sorted_anns = sorted(enumerate(anns), key=(lambda x: x[1]['area']), reverse=True)
|
19 |
+
|
20 |
+
for original_idx, ann in sorted_anns:
|
21 |
+
m = ann['segmentation']
|
22 |
+
if m.shape != (512, 512): # Ensure mask is right size
|
23 |
+
m = cv2.resize(m.astype(float), (512, 512))
|
24 |
+
|
25 |
+
# Create a random color for this mask
|
26 |
+
color_mask = np.random.random(3)
|
27 |
+
|
28 |
+
# Create the colored mask
|
29 |
+
colored_mask = np.zeros((512, 512, 3))
|
30 |
+
for i in range(3):
|
31 |
+
colored_mask[:,:,i] = color_mask[i]
|
32 |
+
|
33 |
+
# Add the mask with transparency
|
34 |
+
ax.imshow(np.dstack([colored_mask, m * 0.35]))
|
35 |
+
|
36 |
+
# Find contours of the mask
|
37 |
+
contours, _ = cv2.findContours((m * 255).astype(np.uint8),
|
38 |
+
cv2.RETR_EXTERNAL,
|
39 |
+
cv2.CHAIN_APPROX_SIMPLE)
|
40 |
+
|
41 |
+
# Add mask number if contours exist
|
42 |
+
if contours:
|
43 |
+
# Get the largest contour
|
44 |
+
cnt = max(contours, key=cv2.contourArea)
|
45 |
+
M = cv2.moments(cnt)
|
46 |
+
|
47 |
+
if M["m00"] != 0:
|
48 |
+
cx = int(M["m10"] / M["m00"])
|
49 |
+
cy = int(M["m01"] / M["m00"])
|
50 |
+
|
51 |
+
# Add text with white color and black outline for visibility
|
52 |
+
ax.text(cx, cy, str(original_idx),
|
53 |
+
color='white',
|
54 |
+
fontsize=16,
|
55 |
+
ha='center',
|
56 |
+
va='center',
|
57 |
+
fontweight='bold',
|
58 |
+
bbox=dict(facecolor='black',
|
59 |
+
alpha=0.5,
|
60 |
+
edgecolor='none',
|
61 |
+
pad=1))
|
62 |
+
|
63 |
+
|
64 |
+
def create_image_grid(original_image, images, names, rows, columns):
|
65 |
+
names = copy.copy(names)
|
66 |
+
images = copy.copy(images)
|
67 |
+
|
68 |
+
# Filter out empty prompts and their corresponding images
|
69 |
+
filtered_images = []
|
70 |
+
filtered_names = []
|
71 |
+
for img, name in zip(images, names):
|
72 |
+
if name.strip():
|
73 |
+
filtered_images.append(img)
|
74 |
+
filtered_names.append(name)
|
75 |
+
|
76 |
+
images = filtered_images
|
77 |
+
names = filtered_names
|
78 |
+
|
79 |
+
# Add original image
|
80 |
+
images.insert(0, original_image)
|
81 |
+
names.insert(0, 'Original')
|
82 |
+
|
83 |
+
fig = plt.figure(figsize=(20, 20))
|
84 |
+
|
85 |
+
for idx, (img, name) in enumerate(zip(images, names)):
|
86 |
+
ax = fig.add_subplot(rows, columns, idx + 1)
|
87 |
+
|
88 |
+
if isinstance(img, PIL.Image.Image):
|
89 |
+
ax.imshow(img)
|
90 |
+
else:
|
91 |
+
ax.imshow(img)
|
92 |
+
|
93 |
+
ax.set_title(name, fontsize=12, pad=10)
|
94 |
+
ax.axis('off')
|
95 |
+
|
96 |
+
plt.tight_layout()
|
97 |
+
return fig
|