File size: 2,968 Bytes
b32b0a3
baea9b2
 
488d99e
baea9b2
 
 
 
2fbf361
b32b0a3
 
488d99e
d1212b2
488d99e
 
 
 
 
 
 
 
2fbf361
488d99e
2fbf361
baea9b2
488d99e
 
 
 
b32b0a3
 
576e22a
488d99e
b32b0a3
488d99e
 
 
b32b0a3
488d99e
 
 
 
 
 
 
 
b32b0a3
488d99e
 
 
 
 
 
b32b0a3
488d99e
b32b0a3
488d99e
 
 
b32b0a3
 
 
 
 
 
488d99e
 
baea9b2
b32b0a3
 
 
 
 
 
 
 
 
 
 
 
 
488d99e
576e22a
b32b0a3
 
576e22a
 
b32b0a3
576e22a
 
b32b0a3
488d99e
 
b32b0a3
 
488d99e
 
b32b0a3
488d99e
 
576e22a
5ae5bca
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from typing import List

import gradio as gr
import spaces
import supervision as sv
import torch
from PIL import Image

from utils.florence import load_florence_model, run_florence_inference, \
    FLORENCE_OPEN_VOCABULARY_DETECTION_TASK
from utils.sam import load_sam_image_model, run_sam_inference

DEVICE = torch.device("cuda")
# DEVICE = torch.device("cpu")

torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True


FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)


@spaces.GPU
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def process_image(
    image_input, text_input
) -> List[Image]:
    if not image_input:
        gr.Info("Please upload an image.")
        return []

    if not text_input:
        gr.Info("Please enter a text prompt.")
        return []

    texts = [prompt.strip() for prompt in text_input.split(",")]
    detections_list = []
    for text in texts:
        _, result = run_florence_inference(
            model=FLORENCE_MODEL,
            processor=FLORENCE_PROCESSOR,
            device=DEVICE,
            image=image_input,
            task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
            text=text
        )
        detections = sv.Detections.from_lmm(
            lmm=sv.LMM.FLORENCE_2,
            result=result,
            resolution_wh=image_input.size
        )
        detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
        detections_list.append(detections)

    detections = sv.Detections.merge(detections_list)
    detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
    return [
        Image.fromarray(mask.astype("uint8") * 255)
        for mask
        in detections.mask
    ]


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            image_input_component = gr.Image(
                type='pil', label='Upload image')
            text_input_component = gr.Textbox(
                label='Text prompt',
                placeholder='Enter comma separated text prompts')
            submit_button_component = gr.Button(
                value='Submit', variant='primary')
        with gr.Column():
            gallery_output_component = gr.Gallery(label='Output masks')

    submit_button_component.click(
        fn=process_image,
        inputs=[
            image_input_component,
            text_input_component
        ],
        outputs=[
            gallery_output_component,
        ]
    )
    text_input_component.submit(
        fn=process_image,
        inputs=[
            image_input_component,
            text_input_component
        ],
        outputs=[
            gallery_output_component,
        ]
    )

demo.launch(debug=False, show_error=True)