Sanshruth commited on
Commit
c3acf88
·
verified ·
1 Parent(s): e5b565c

Upload 3 files

Browse files
Files changed (3) hide show
  1. engine.py +97 -0
  2. requirements.txt +18 -0
  3. utils.py +97 -0
engine.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from torchvision.transforms.functional import to_pil_image, to_tensor
4
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
5
+ from diffusers import StableDiffusionInpaintPipeline, EulerDiscreteScheduler
6
+ from PIL import Image
7
+ import numpy as np
8
+ import cv2
9
+
10
+ class SegmentAnythingModel:
11
+ def __init__(self, sam_checkpoint, model_type, device):
12
+ self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
13
+ self.sam.to(device=device)
14
+ self.mask_generator = SamAutomaticMaskGenerator(
15
+ model=self.sam,
16
+ points_per_side=32,
17
+ pred_iou_thresh=0.99,
18
+ stability_score_thresh=0.92,
19
+ crop_n_layers=1,
20
+ crop_n_points_downscale_factor=2,
21
+ min_mask_region_area=100
22
+ )
23
+ self.target_size = (512, 512)
24
+
25
+ def preprocess_image(self, image):
26
+ """Resize image to 512x512"""
27
+ if isinstance(image, str):
28
+ image = Image.open(image)
29
+ elif isinstance(image, np.ndarray):
30
+ image = Image.fromarray(image)
31
+
32
+ # Get current dimensions
33
+ width, height = image.size
34
+
35
+ # Resize to 512x512 directly
36
+ image = image.resize(self.target_size, Image.Resampling.LANCZOS)
37
+ return np.array(image)
38
+
39
+ def generate_masks(self, image):
40
+ processed_image = self.preprocess_image(image)
41
+ return self.mask_generator.generate(processed_image)
42
+
43
+ class StableDiffusionInpaintingPipeline:
44
+ def __init__(self, model_dir):
45
+ # Initialize the scheduler first
46
+ self.scheduler = EulerDiscreteScheduler.from_pretrained(model_dir, subfolder="scheduler")
47
+
48
+ # Initialize the pipeline with the scheduler
49
+ self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
50
+ model_dir,
51
+ scheduler=self.scheduler,
52
+ revision="fp16",
53
+ torch_dtype=torch.float16
54
+ )
55
+ self.pipe = self.pipe.to("cuda")
56
+ self.pipe.enable_xformers_memory_efficient_attention()
57
+ self.target_size = (512, 512)
58
+
59
+ def preprocess_image(self, image):
60
+ """Ensure image is in the right format and size"""
61
+ if isinstance(image, np.ndarray):
62
+ image = Image.fromarray(image)
63
+ return image.resize(self.target_size, Image.Resampling.LANCZOS)
64
+
65
+ def inpaint(self, prompt, image, mask_image, guidance_scale=10, num_inference_steps=60, generator=None):
66
+ """
67
+ Args:
68
+ prompt (str): The prompt for inpainting
69
+ image (PIL.Image or np.ndarray): The original image
70
+ mask_image (PIL.Image or np.ndarray): The mask for inpainting
71
+ guidance_scale (float): Higher guidance scale encourages images that are closer to the prompt
72
+ num_inference_steps (int): Number of denoising steps
73
+ generator (torch.Generator): Generator for reproducibility
74
+ """
75
+ # Preprocess images
76
+ if isinstance(image, np.ndarray):
77
+ image = Image.fromarray(image)
78
+ if isinstance(mask_image, np.ndarray):
79
+ mask_image = Image.fromarray(mask_image)
80
+
81
+ # Resize images
82
+ image = image.resize(self.target_size, Image.Resampling.LANCZOS)
83
+ mask_image = mask_image.resize(self.target_size, Image.Resampling.NEAREST)
84
+
85
+ # Run inpainting
86
+ output = self.pipe(
87
+ prompt=prompt,
88
+ image=image,
89
+ mask_image=mask_image,
90
+ guidance_scale=guidance_scale,
91
+ num_inference_steps=num_inference_steps,
92
+ generator=generator,
93
+ height=512,
94
+ width=512
95
+ )
96
+
97
+ return output.images[0]
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ regex
2
+ tqdm
3
+ diffusers
4
+ transformers
5
+ scipy
6
+ accelerate
7
+ opencv-python
8
+ Xformers
9
+ gradio
10
+ torch
11
+ torchvision
12
+ Pillow
13
+ matplotlib
14
+ numpy
15
+ git+https://github.com/facebookresearch/segment-anything.git
16
+ pycocotools
17
+ onnxruntime
18
+ onnx
utils.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ import numpy as np
3
+ import copy
4
+ import cv2
5
+ import matplotlib.pyplot as plt
6
+ from torchvision.transforms.functional import to_pil_image
7
+ import torch
8
+ from PIL import Image
9
+ import matplotlib
10
+ matplotlib.use('Agg')
11
+
12
+ def show_anns(anns, ax=None):
13
+ if len(anns) == 0:
14
+ return
15
+ if ax is None:
16
+ ax = plt.gca()
17
+
18
+ sorted_anns = sorted(enumerate(anns), key=(lambda x: x[1]['area']), reverse=True)
19
+
20
+ for original_idx, ann in sorted_anns:
21
+ m = ann['segmentation']
22
+ if m.shape != (512, 512): # Ensure mask is right size
23
+ m = cv2.resize(m.astype(float), (512, 512))
24
+
25
+ # Create a random color for this mask
26
+ color_mask = np.random.random(3)
27
+
28
+ # Create the colored mask
29
+ colored_mask = np.zeros((512, 512, 3))
30
+ for i in range(3):
31
+ colored_mask[:,:,i] = color_mask[i]
32
+
33
+ # Add the mask with transparency
34
+ ax.imshow(np.dstack([colored_mask, m * 0.35]))
35
+
36
+ # Find contours of the mask
37
+ contours, _ = cv2.findContours((m * 255).astype(np.uint8),
38
+ cv2.RETR_EXTERNAL,
39
+ cv2.CHAIN_APPROX_SIMPLE)
40
+
41
+ # Add mask number if contours exist
42
+ if contours:
43
+ # Get the largest contour
44
+ cnt = max(contours, key=cv2.contourArea)
45
+ M = cv2.moments(cnt)
46
+
47
+ if M["m00"] != 0:
48
+ cx = int(M["m10"] / M["m00"])
49
+ cy = int(M["m01"] / M["m00"])
50
+
51
+ # Add text with white color and black outline for visibility
52
+ ax.text(cx, cy, str(original_idx),
53
+ color='white',
54
+ fontsize=16,
55
+ ha='center',
56
+ va='center',
57
+ fontweight='bold',
58
+ bbox=dict(facecolor='black',
59
+ alpha=0.5,
60
+ edgecolor='none',
61
+ pad=1))
62
+
63
+
64
+ def create_image_grid(original_image, images, names, rows, columns):
65
+ names = copy.copy(names)
66
+ images = copy.copy(images)
67
+
68
+ # Filter out empty prompts and their corresponding images
69
+ filtered_images = []
70
+ filtered_names = []
71
+ for img, name in zip(images, names):
72
+ if name.strip():
73
+ filtered_images.append(img)
74
+ filtered_names.append(name)
75
+
76
+ images = filtered_images
77
+ names = filtered_names
78
+
79
+ # Add original image
80
+ images.insert(0, original_image)
81
+ names.insert(0, 'Original')
82
+
83
+ fig = plt.figure(figsize=(20, 20))
84
+
85
+ for idx, (img, name) in enumerate(zip(images, names)):
86
+ ax = fig.add_subplot(rows, columns, idx + 1)
87
+
88
+ if isinstance(img, PIL.Image.Image):
89
+ ax.imshow(img)
90
+ else:
91
+ ax.imshow(img)
92
+
93
+ ax.set_title(name, fontsize=12, pad=10)
94
+ ax.axis('off')
95
+
96
+ plt.tight_layout()
97
+ return fig