Freak-ppa commited on
Commit
da05695
·
verified ·
1 Parent(s): 125ccb2

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +65 -0
  2. requirements.txt +7 -0
  3. sam_hq_vit_h.pth +3 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import random
4
+ import time
5
+ import json
6
+ import os
7
+ from loguru import logger
8
+ from decouple import config
9
+ import io
10
+ import torch
11
+ import numpy as np
12
+ import torch
13
+ import cv2
14
+ from PIL import Image
15
+
16
+ from segment_anything import sam_model_registry, SamPredictor
17
+
18
+ import spaces
19
+
20
+ print(f"Is CUDA available: {torch.cuda.is_available()}")
21
+ print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
22
+ print(torch.version.cuda)
23
+ device = torch.cuda.get_device_name(torch.cuda.current_device())
24
+ print(device)
25
+
26
+ sam_checkpoint = "sam-hq/models/sam_hq_vit_h.pth"
27
+ model_type = "vit_h"
28
+ device = "cuda"
29
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
30
+ sam.to(device=device)
31
+ predictor = SamPredictor(sam)
32
+
33
+ @spaces.GPU(duration=5)
34
+ def generate_image(prompt, image):
35
+ predictor.set_image(image)
36
+
37
+ prompt = json.loads(prompt)
38
+ input_points = np.array(prompt['input_points'])
39
+ input_labels = np.array(prompt['input_labels'])
40
+
41
+ mask, _, _ = predictor.predict(
42
+ point_coords=input_points,
43
+ point_labels=input_labels,
44
+ box=None,
45
+ multimask_output=False,
46
+ hq_token_only=True,
47
+ )
48
+
49
+ rgb_array = np.zeros((mask.shape[1], mask.shape[2], 3), dtype=np.uint8)
50
+ rgb_array[mask[0]] = 255
51
+ result = Image.fromarray(rgb_array)
52
+
53
+ return result
54
+
55
+
56
+ if __name__ == "__main__":
57
+ demo = gr.Interface(fn=generate_image, inputs=[
58
+ "text",
59
+ gr.Image(image_mode='RGB', type="numpy")
60
+ ],
61
+ outputs=[
62
+ gr.Image(type="numpy", image_mode='RGB')
63
+ ])
64
+ demo.launch(debug=True)
65
+ logger.debug('demo.launch()')
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ segment-anything-hq
2
+ python-decouple==3.8
3
+ torch
4
+ torchaudio
5
+ torchsde
6
+ torchvision
7
+ loguru==0.7.2
sam_hq_vit_h.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7ac14a085326d9fa6199c8c698c4f0e7280afdbb974d2c4660ec60877b45e35
3
+ size 2570940653