anudit commited on
Commit
2b93a2c
·
verified ·
1 Parent(s): 095fcae

Delete onnx_export.py

Browse files
Files changed (1) hide show
  1. onnx_export.py +0 -62
onnx_export.py DELETED
@@ -1,62 +0,0 @@
1
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
- import torch
3
- import onnx
4
- import onnxruntime as ort
5
- import numpy as np
6
-
7
-
8
- model_dir = './finetune/gte-base-custom-matryoshka'
9
- model_out = f"{model_dir}/model.onnx"
10
-
11
- print("## Loading Model")
12
- tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
13
- model = AutoModelForSequenceClassification.from_pretrained(model_dir, trust_remote_code=True)
14
-
15
- # Set the model in evaluation mode
16
- model.eval()
17
-
18
- # Example input for export
19
- inputs = tokenizer("Example input text", return_tensors="pt")
20
-
21
-
22
- # Define the export function
23
- torch.onnx.export(
24
- model, # The model to export
25
- (inputs["input_ids"], inputs["attention_mask"]), # Model input (you can adjust based on your model's input)
26
- model_out, # The path to save the ONNX file
27
- export_params=True, # Store the trained parameter weights
28
- opset_version=14, # The ONNX version to use
29
- input_names=['input_ids', 'attention_mask'], # Model's input names
30
- output_names=['output'], # Model's output names
31
- dynamic_axes={'input_ids': {0: 'batch_size'}, # Dynamic axis for input (batch size)
32
- 'attention_mask': {0: 'batch_size'},
33
- 'output': {0: 'batch_size'}} # Dynamic axis for output (batch size)
34
- )
35
-
36
- print("## ONNX Model Exported")
37
-
38
- # Verify the ONNX model
39
- print("## Verifying Onnx")
40
-
41
- ort_session = ort.InferenceSession(model_out)
42
-
43
- if "token_type_ids" in inputs:
44
- del inputs["token_type_ids"]
45
-
46
- # Prepare inputs for ONNX inference
47
- ort_inputs = {k: v.cpu().detach().numpy() for k, v in inputs.items()}
48
- ort_outs = ort_session.run(None, ort_inputs)
49
-
50
- print("ONNX output:", ort_outs[0])
51
-
52
- with torch.no_grad():
53
- pytorch_outputs = model(**inputs)
54
- pytorch_output_array = pytorch_outputs.logits.cpu().numpy()
55
-
56
- print("PyTorch output:", pytorch_output_array)
57
-
58
- # Compare the outputs
59
- if np.allclose(pytorch_output_array, ort_outs[0], atol=1e-5):
60
- print("The ONNX model output matches the PyTorch model output!")
61
- else:
62
- print("The ONNX model output does NOT match the PyTorch model output.")