Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
from segment_anything import SamPredictor, sam_model_registry | |
from PIL import Image | |
models = { | |
'vit_b': './checkpoints/sam_vit_b_01ec64.pth', | |
'vit_l': './checkpoints/sam_vit_l_0b3195.pth', | |
'vit_h': './checkpoints/sam_vit_h_4b8939.pth' | |
} | |
def get_sam_predictor(model_type='vit_h', device=None, image=None): | |
if device is None and torch.cuda.is_available(): | |
device = 'cuda' | |
elif device is None: | |
device = 'cpu' | |
# sam model | |
sam = sam_model_registry[model_type](checkpoint=models[model_type]) | |
sam = sam.to(device) | |
predictor = SamPredictor(sam) | |
if image is not None: | |
predictor.set_image(image) | |
return predictor | |
def sam_seg(predictor, input_img, input_points, input_labels): | |
masks, scores, logits = predictor.predict( | |
point_coords=input_points, | |
point_labels=input_labels, | |
multimask_output=True, | |
) | |
opt_idx = np.argmax(scores) | |
mask = masks[opt_idx] | |
out_image = np.zeros((input_img.shape[0], input_img.shape[1], 4), dtype=np.uint8) | |
out_image[:, :, :3] = input_img | |
out_image[:, :, 3] = mask.astype(np.uint8) * 255 | |
torch.cuda.empty_cache() | |
return Image.fromarray(out_image, mode='RGBA') | |