Spaces:
Running
Running
Update convert.py
Browse files- convert.py +5 -2
convert.py
CHANGED
@@ -182,9 +182,13 @@ def check_final_model(model_id: str, folder: str, token: Optional[str]):
|
|
182 |
input_ids = torch.arange(10).unsqueeze(0)
|
183 |
pixel_values = torch.randn(1, 3, 224, 224)
|
184 |
input_values = torch.arange(1000).float().unsqueeze(0)
|
|
|
|
|
185 |
kwargs = {}
|
186 |
if "input_ids" in sig.parameters:
|
187 |
kwargs["input_ids"] = input_ids
|
|
|
|
|
188 |
if "decoder_input_ids" in sig.parameters:
|
189 |
kwargs["decoder_input_ids"] = input_ids
|
190 |
if "pixel_values" in sig.parameters:
|
@@ -213,8 +217,7 @@ def check_final_model(model_id: str, folder: str, token: Optional[str]):
|
|
213 |
kwargs["decoder_input_ids"] = decoder_input_ids
|
214 |
pt_logits = pt_model(**kwargs)[0]
|
215 |
except Exception:
|
216 |
-
|
217 |
-
return
|
218 |
sf_logits = sf_model(**kwargs)[0]
|
219 |
|
220 |
torch.testing.assert_close(sf_logits, pt_logits)
|
|
|
182 |
input_ids = torch.arange(10).unsqueeze(0)
|
183 |
pixel_values = torch.randn(1, 3, 224, 224)
|
184 |
input_values = torch.arange(1000).float().unsqueeze(0)
|
185 |
+
# Hardcoded for whisper basically
|
186 |
+
input_features = torch.zeros((1, 80, 3000))
|
187 |
kwargs = {}
|
188 |
if "input_ids" in sig.parameters:
|
189 |
kwargs["input_ids"] = input_ids
|
190 |
+
if "input_features" in sig.parameters:
|
191 |
+
kwargs["input_features"] = input_features
|
192 |
if "decoder_input_ids" in sig.parameters:
|
193 |
kwargs["decoder_input_ids"] = input_ids
|
194 |
if "pixel_values" in sig.parameters:
|
|
|
217 |
kwargs["decoder_input_ids"] = decoder_input_ids
|
218 |
pt_logits = pt_model(**kwargs)[0]
|
219 |
except Exception:
|
220 |
+
raise e
|
|
|
221 |
sf_logits = sf_model(**kwargs)[0]
|
222 |
|
223 |
torch.testing.assert_close(sf_logits, pt_logits)
|