Narsil HF staff commited on
Commit
eb54d89
·
1 Parent(s): 05acb17

Update convert.py

Browse files
Files changed (1) hide show
  1. convert.py +5 -2
convert.py CHANGED
@@ -182,9 +182,13 @@ def check_final_model(model_id: str, folder: str, token: Optional[str]):
182
  input_ids = torch.arange(10).unsqueeze(0)
183
  pixel_values = torch.randn(1, 3, 224, 224)
184
  input_values = torch.arange(1000).float().unsqueeze(0)
 
 
185
  kwargs = {}
186
  if "input_ids" in sig.parameters:
187
  kwargs["input_ids"] = input_ids
 
 
188
  if "decoder_input_ids" in sig.parameters:
189
  kwargs["decoder_input_ids"] = input_ids
190
  if "pixel_values" in sig.parameters:
@@ -213,8 +217,7 @@ def check_final_model(model_id: str, folder: str, token: Optional[str]):
213
  kwargs["decoder_input_ids"] = decoder_input_ids
214
  pt_logits = pt_model(**kwargs)[0]
215
  except Exception:
216
- print(f"Model {model_id} could not be checked, ignoring {e}")
217
- return
218
  sf_logits = sf_model(**kwargs)[0]
219
 
220
  torch.testing.assert_close(sf_logits, pt_logits)
 
182
  input_ids = torch.arange(10).unsqueeze(0)
183
  pixel_values = torch.randn(1, 3, 224, 224)
184
  input_values = torch.arange(1000).float().unsqueeze(0)
185
+ # Hardcoded for whisper basically
186
+ input_features = torch.zeros((1, 80, 3000))
187
  kwargs = {}
188
  if "input_ids" in sig.parameters:
189
  kwargs["input_ids"] = input_ids
190
+ if "input_features" in sig.parameters:
191
+ kwargs["input_features"] = input_features
192
  if "decoder_input_ids" in sig.parameters:
193
  kwargs["decoder_input_ids"] = input_ids
194
  if "pixel_values" in sig.parameters:
 
217
  kwargs["decoder_input_ids"] = decoder_input_ids
218
  pt_logits = pt_model(**kwargs)[0]
219
  except Exception:
220
+ raise e
 
221
  sf_logits = sf_model(**kwargs)[0]
222
 
223
  torch.testing.assert_close(sf_logits, pt_logits)