Delete onnx_export.py
Browse files- onnx_export.py +0 -62
onnx_export.py
DELETED
@@ -1,62 +0,0 @@
|
|
1 |
-
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
2 |
-
import torch
|
3 |
-
import onnx
|
4 |
-
import onnxruntime as ort
|
5 |
-
import numpy as np
|
6 |
-
|
7 |
-
|
8 |
-
model_dir = './finetune/gte-base-custom-matryoshka'
|
9 |
-
model_out = f"{model_dir}/model.onnx"
|
10 |
-
|
11 |
-
print("## Loading Model")
|
12 |
-
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
13 |
-
model = AutoModelForSequenceClassification.from_pretrained(model_dir, trust_remote_code=True)
|
14 |
-
|
15 |
-
# Set the model in evaluation mode
|
16 |
-
model.eval()
|
17 |
-
|
18 |
-
# Example input for export
|
19 |
-
inputs = tokenizer("Example input text", return_tensors="pt")
|
20 |
-
|
21 |
-
|
22 |
-
# Define the export function
|
23 |
-
torch.onnx.export(
|
24 |
-
model, # The model to export
|
25 |
-
(inputs["input_ids"], inputs["attention_mask"]), # Model input (you can adjust based on your model's input)
|
26 |
-
model_out, # The path to save the ONNX file
|
27 |
-
export_params=True, # Store the trained parameter weights
|
28 |
-
opset_version=14, # The ONNX version to use
|
29 |
-
input_names=['input_ids', 'attention_mask'], # Model's input names
|
30 |
-
output_names=['output'], # Model's output names
|
31 |
-
dynamic_axes={'input_ids': {0: 'batch_size'}, # Dynamic axis for input (batch size)
|
32 |
-
'attention_mask': {0: 'batch_size'},
|
33 |
-
'output': {0: 'batch_size'}} # Dynamic axis for output (batch size)
|
34 |
-
)
|
35 |
-
|
36 |
-
print("## ONNX Model Exported")
|
37 |
-
|
38 |
-
# Verify the ONNX model
|
39 |
-
print("## Verifying Onnx")
|
40 |
-
|
41 |
-
ort_session = ort.InferenceSession(model_out)
|
42 |
-
|
43 |
-
if "token_type_ids" in inputs:
|
44 |
-
del inputs["token_type_ids"]
|
45 |
-
|
46 |
-
# Prepare inputs for ONNX inference
|
47 |
-
ort_inputs = {k: v.cpu().detach().numpy() for k, v in inputs.items()}
|
48 |
-
ort_outs = ort_session.run(None, ort_inputs)
|
49 |
-
|
50 |
-
print("ONNX output:", ort_outs[0])
|
51 |
-
|
52 |
-
with torch.no_grad():
|
53 |
-
pytorch_outputs = model(**inputs)
|
54 |
-
pytorch_output_array = pytorch_outputs.logits.cpu().numpy()
|
55 |
-
|
56 |
-
print("PyTorch output:", pytorch_output_array)
|
57 |
-
|
58 |
-
# Compare the outputs
|
59 |
-
if np.allclose(pytorch_output_array, ort_outs[0], atol=1e-5):
|
60 |
-
print("The ONNX model output matches the PyTorch model output!")
|
61 |
-
else:
|
62 |
-
print("The ONNX model output does NOT match the PyTorch model output.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|