Image2Paragraph / models /segment_models /semgent_anything_model.py
Awiny's picture
update lightweight code
b25eb4e
raw
history blame
1.16 kB
import cv2
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
from utils.util import resize_long_edge_cv2
class SegmentAnything:
def __init__(self, device, arch="vit_b"):
self.device = device
if arch=='vit_b':
pretrained_weights="pretrained_models/sam_vit_b_01ec64.pth"
elif arch=='vit_l':
pretrained_weights="pretrained_models/sam_vit_l_0e2f7b.pth"
elif arch=='vit_h':
pretrained_weights="pretrained_models/sam_vit_h_0e2f7b.pth"
else:
raise ValueError(f"arch {arch} not supported")
self.model = self.initialize_model(arch, pretrained_weights)
def initialize_model(self, arch, pretrained_weights):
sam = sam_model_registry[arch](checkpoint=pretrained_weights)
sam.to(device=self.device)
mask_generator = SamAutomaticMaskGenerator(sam)
return mask_generator
def generate_mask(self, img_src):
image = cv2.imread(img_src)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = resize_long_edge_cv2(image, 384)
anns = self.model.generate(image)
return anns