Spaces:
Running
on
T4
Running
on
T4
Update app.py
Browse files
app.py
CHANGED
@@ -8,6 +8,14 @@ from PIL import Image
|
|
8 |
from sam2.build_sam import build_sam2
|
9 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
def show_mask(mask, ax, random_color=False, borders = True):
|
12 |
if random_color:
|
13 |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
|
|
8 |
from sam2.build_sam import build_sam2
|
9 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
10 |
|
11 |
+
# use bfloat16 for the entire notebook
|
12 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
13 |
+
|
14 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
15 |
+
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
16 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
17 |
+
torch.backends.cudnn.allow_tf32 = True
|
18 |
+
|
19 |
def show_mask(mask, ax, random_color=False, borders = True):
|
20 |
if random_color:
|
21 |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|