Spaces:
Sleeping
Sleeping
Update inference_from_video.py
Browse files- inference_from_video.py +9 -4
inference_from_video.py
CHANGED
@@ -123,11 +123,16 @@ def main():
|
|
123 |
model.eval()
|
124 |
|
125 |
# Load Trained Weight #
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
model.load_state_dict(torch.load(args.model, map_location=device), strict=False)
|
128 |
-
else:
|
129 |
-
from safetensors.torch import load_model
|
130 |
-
load_model(model, args.model, strict=False)
|
131 |
|
132 |
model.to(device)
|
133 |
|
|
|
123 |
model.eval()
|
124 |
|
125 |
# Load Trained Weight #
|
126 |
+
try:
|
127 |
+
if args.model.endswith(".pt") or args.model.endswith(".bin"):
|
128 |
+
model.load_state_dict(torch.load(args.model, map_location=device), strict=False)
|
129 |
+
else:
|
130 |
+
from safetensors.torch import load_model
|
131 |
+
load_model(model, args.model, strict=False)
|
132 |
+
except OSError as e:
|
133 |
+
print(f"Error loading model with safetensors: {e}")
|
134 |
+
print("Falling back to torch.load")
|
135 |
model.load_state_dict(torch.load(args.model, map_location=device), strict=False)
|
|
|
|
|
|
|
136 |
|
137 |
model.to(device)
|
138 |
|