Spaces:
Runtime error
Runtime error
To support zero zero-gpu
Browse files
app.py
CHANGED
@@ -5,9 +5,11 @@ from PIL import Image
|
|
5 |
from timm.data import create_transform
|
6 |
|
7 |
|
|
|
|
|
8 |
# Prepare the model.
|
9 |
import models
|
10 |
-
model = models.mambaout_femto(pretrained=True) # can change different model name
|
11 |
model.eval()
|
12 |
|
13 |
# Prepare the transform.
|
@@ -17,9 +19,9 @@ transform = create_transform(input_size=224, crop_pct=model.default_cfg['crop_pc
|
|
17 |
response = requests.get("https://git.io/JJkYN")
|
18 |
labels = response.text.split("\n")
|
19 |
|
|
|
20 |
def predict(inp):
|
21 |
-
inp = transform(inp).unsqueeze(0)
|
22 |
-
|
23 |
with torch.no_grad():
|
24 |
prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
|
25 |
confidences = {labels[i]: float(prediction[i]) for i in range(1000)}
|
|
|
5 |
from timm.data import create_transform
|
6 |
|
7 |
|
8 |
+
device = "cuda"
|
9 |
+
|
10 |
# Prepare the model.
|
11 |
import models
|
12 |
+
model = models.mambaout_femto(pretrained=True).to(device=device) # can change different model name
|
13 |
model.eval()
|
14 |
|
15 |
# Prepare the transform.
|
|
|
19 |
response = requests.get("https://git.io/JJkYN")
|
20 |
labels = response.text.split("\n")
|
21 |
|
22 | |
23 |
def predict(inp):
|
24 |
+
inp = transform(inp).unsqueeze(0).to(device=device)
|
|
|
25 |
with torch.no_grad():
|
26 |
prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
|
27 |
confidences = {labels[i]: float(prediction[i]) for i in range(1000)}
|