Spaces:
Runtime error
Runtime error
File size: 1,239 Bytes
57a1960 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import numpy as np
import torch
from segment_anything import SamPredictor, sam_model_registry
from PIL import Image
models = {
'vit_b': './checkpoints/sam_vit_b_01ec64.pth',
'vit_l': './checkpoints/sam_vit_l_0b3195.pth',
'vit_h': './checkpoints/sam_vit_h_4b8939.pth'
}
def get_sam_predictor(model_type='vit_h', device=None, image=None):
if device is None and torch.cuda.is_available():
device = 'cuda'
elif device is None:
device = 'cpu'
# sam model
sam = sam_model_registry[model_type](checkpoint=models[model_type])
sam = sam.to(device)
predictor = SamPredictor(sam)
if image is not None:
predictor.set_image(image)
return predictor
def sam_seg(predictor, input_img, input_points, input_labels):
masks, scores, logits = predictor.predict(
point_coords=input_points,
point_labels=input_labels,
multimask_output=True,
)
opt_idx = np.argmax(scores)
mask = masks[opt_idx]
out_image = np.zeros((input_img.shape[0], input_img.shape[1], 4), dtype=np.uint8)
out_image[:, :, :3] = input_img
out_image[:, :, 3] = mask.astype(np.uint8) * 255
torch.cuda.empty_cache()
return Image.fromarray(out_image, mode='RGBA')
|