Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Felix Marty
commited on
Commit
·
04c8db1
1
Parent(s):
ff649d1
style
Browse files- app.py +21 -10
- onnx_export.py +52 -26
app.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
import csv
|
2 |
-
import datetime
|
3 |
import os
|
|
|
4 |
from typing import Optional
|
5 |
-
import gradio as gr
|
6 |
|
7 |
-
|
8 |
from huggingface_hub import HfApi, Repository
|
9 |
|
|
|
10 |
|
11 |
DATASET_REPO_URL = "https://huggingface.co/datasets/optimum/exporters"
|
12 |
DATA_FILENAME = "data.csv"
|
@@ -20,6 +20,7 @@ repo: Optional[Repository] = None
|
|
20 |
if HF_TOKEN:
|
21 |
repo = Repository(local_dir=DATADIR, clone_from=DATASET_REPO_URL, token=HF_TOKEN)
|
22 |
|
|
|
23 |
def onnx_export(token: str, model_id: str, task: str) -> str:
|
24 |
if token == "" or model_id == "":
|
25 |
return """
|
@@ -33,7 +34,7 @@ def onnx_export(token: str, model_id: str, task: str) -> str:
|
|
33 |
error, commit_info = convert(api=api, model_id=model_id, task=task)
|
34 |
if error != "0":
|
35 |
return error
|
36 |
-
|
37 |
print("[commit_info]", commit_info)
|
38 |
|
39 |
# save in a private dataset
|
@@ -57,6 +58,7 @@ def onnx_export(token: str, model_id: str, task: str) -> str:
|
|
57 |
except Exception as e:
|
58 |
return f"#### Error: {e}"
|
59 |
|
|
|
60 |
TTILE_IMAGE = """
|
61 |
<div
|
62 |
style="
|
@@ -111,14 +113,23 @@ with gr.Blocks() as demo:
|
|
111 |
|
112 |
with gr.Column():
|
113 |
input_token = gr.Textbox(max_lines=1, label="Hugging Face token")
|
114 |
-
input_model = gr.Textbox(
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
btn = gr.Button("Convert to ONNX")
|
118 |
output = gr.Markdown(label="Output")
|
119 |
-
|
120 |
-
|
121 |
-
|
|
|
122 |
|
123 |
"""
|
124 |
demo = gr.Interface(
|
@@ -136,4 +147,4 @@ demo = gr.Interface(
|
|
136 |
)
|
137 |
"""
|
138 |
|
139 |
-
demo.launch()
|
|
|
1 |
import csv
|
|
|
2 |
import os
|
3 |
+
from datetime import datetime
|
4 |
from typing import Optional
|
|
|
5 |
|
6 |
+
import gradio as gr
|
7 |
from huggingface_hub import HfApi, Repository
|
8 |
|
9 |
+
from onnx_export import convert
|
10 |
|
11 |
DATASET_REPO_URL = "https://huggingface.co/datasets/optimum/exporters"
|
12 |
DATA_FILENAME = "data.csv"
|
|
|
20 |
if HF_TOKEN:
|
21 |
repo = Repository(local_dir=DATADIR, clone_from=DATASET_REPO_URL, token=HF_TOKEN)
|
22 |
|
23 |
+
|
24 |
def onnx_export(token: str, model_id: str, task: str) -> str:
|
25 |
if token == "" or model_id == "":
|
26 |
return """
|
|
|
34 |
error, commit_info = convert(api=api, model_id=model_id, task=task)
|
35 |
if error != "0":
|
36 |
return error
|
37 |
+
|
38 |
print("[commit_info]", commit_info)
|
39 |
|
40 |
# save in a private dataset
|
|
|
58 |
except Exception as e:
|
59 |
return f"#### Error: {e}"
|
60 |
|
61 |
+
|
62 |
TTILE_IMAGE = """
|
63 |
<div
|
64 |
style="
|
|
|
113 |
|
114 |
with gr.Column():
|
115 |
input_token = gr.Textbox(max_lines=1, label="Hugging Face token")
|
116 |
+
input_model = gr.Textbox(
|
117 |
+
max_lines=1,
|
118 |
+
label="Model name",
|
119 |
+
placeholder="textattack/distilbert-base-cased-CoLA",
|
120 |
+
)
|
121 |
+
input_task = gr.Textbox(
|
122 |
+
value="auto",
|
123 |
+
max_lines=1,
|
124 |
+
label='Task (can be left to "auto", will be automatically inferred)',
|
125 |
+
)
|
126 |
|
127 |
btn = gr.Button("Convert to ONNX")
|
128 |
output = gr.Markdown(label="Output")
|
129 |
+
|
130 |
+
btn.click(
|
131 |
+
fn=onnx_export, inputs=[input_token, input_model, input_task], outputs=output
|
132 |
+
)
|
133 |
|
134 |
"""
|
135 |
demo = gr.Interface(
|
|
|
147 |
)
|
148 |
"""
|
149 |
|
150 |
+
demo.launch()
|
onnx_export.py
CHANGED
@@ -1,33 +1,35 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
from optimum.exporters.onnx import OnnxConfigWithPast, export, validate_model_outputs
|
4 |
-
|
5 |
-
from tempfile import TemporaryDirectory
|
6 |
-
|
7 |
-
from transformers import AutoConfig, AutoTokenizer, is_torch_available
|
8 |
-
|
9 |
-
from pathlib import Path
|
10 |
-
|
11 |
import os
|
12 |
import shutil
|
13 |
-
import
|
14 |
-
|
15 |
-
from typing import Optional, Tuple
|
16 |
|
17 |
-
from huggingface_hub import CommitOperationAdd, HfApi,
|
|
|
18 |
from huggingface_hub.file_download import repo_folder_name
|
|
|
|
|
|
|
|
|
19 |
|
20 |
SPACES_URL = "https://huggingface.co/spaces/optimum/exporters"
|
21 |
|
|
|
22 |
def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
|
23 |
try:
|
24 |
discussions = api.get_repo_discussions(repo_id=model_id)
|
25 |
except Exception:
|
26 |
return None
|
27 |
for discussion in discussions:
|
28 |
-
if
|
|
|
|
|
|
|
|
|
29 |
return discussion
|
30 |
|
|
|
31 |
def convert_onnx(model_id: str, task: str, folder: str) -> List:
|
32 |
|
33 |
# Allocate the model
|
@@ -46,7 +48,7 @@ def convert_onnx(model_id: str, task: str, folder: str) -> List:
|
|
46 |
and task in ["sequence_classification"]
|
47 |
)
|
48 |
if needs_pad_token_id:
|
49 |
-
#if args.pad_token_id is not None:
|
50 |
# model.config.pad_token_id = args.pad_token_id
|
51 |
try:
|
52 |
tok = AutoTokenizer.from_pretrained(model_id)
|
@@ -76,18 +78,37 @@ def convert_onnx(model_id: str, task: str, folder: str) -> List:
|
|
76 |
print(f"All good, model saved at: {output}")
|
77 |
except ValueError:
|
78 |
print(f"An error occured, but the model was saved at: {output.as_posix()}")
|
79 |
-
|
80 |
-
n_files = len(
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
if n_files == 1:
|
83 |
-
operations = [
|
|
|
|
|
|
|
|
|
|
|
84 |
else:
|
85 |
-
operations = [
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
return operations
|
88 |
|
89 |
|
90 |
-
def convert(
|
|
|
|
|
91 |
pr_title = "Adding ONNX file of this model"
|
92 |
info = api.model_info(model_id)
|
93 |
filenames = set(s.rfilename for s in info.siblings)
|
@@ -98,7 +119,10 @@ def convert(api: "HfApi", model_id: str, task: str, force: bool = False) -> Tupl
|
|
98 |
try:
|
99 |
task = TasksManager.infer_task_from_model(model_id)
|
100 |
except Exception as e:
|
101 |
-
return
|
|
|
|
|
|
|
102 |
|
103 |
with TemporaryDirectory() as d:
|
104 |
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
|
@@ -111,7 +135,9 @@ def convert(api: "HfApi", model_id: str, task: str, force: bool = False) -> Tupl
|
|
111 |
elif pr is not None and not force:
|
112 |
url = f"https://huggingface.co/{model_id}/discussions/{pr.num}"
|
113 |
new_pr = pr
|
114 |
-
raise Exception(
|
|
|
|
|
115 |
else:
|
116 |
operations = convert_onnx(model_id, task, folder)
|
117 |
|
@@ -159,4 +185,4 @@ if __name__ == "__main__":
|
|
159 |
)
|
160 |
args = parser.parse_args()
|
161 |
api = HfApi()
|
162 |
-
convert(api, args.model_id, task=args.task, force=args.force)
|
|
|
1 |
+
import argparse
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import os
|
3 |
import shutil
|
4 |
+
from pathlib import Path
|
5 |
+
from tempfile import TemporaryDirectory
|
6 |
+
from typing import List, Optional, Tuple
|
7 |
|
8 |
+
from huggingface_hub import (CommitOperationAdd, HfApi, get_repo_discussions,
|
9 |
+
hf_hub_download)
|
10 |
from huggingface_hub.file_download import repo_folder_name
|
11 |
+
from optimum.exporters.onnx import (OnnxConfigWithPast, export,
|
12 |
+
validate_model_outputs)
|
13 |
+
from optimum.exporters.tasks import TasksManager
|
14 |
+
from transformers import AutoConfig, AutoTokenizer, is_torch_available
|
15 |
|
16 |
SPACES_URL = "https://huggingface.co/spaces/optimum/exporters"
|
17 |
|
18 |
+
|
19 |
def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
|
20 |
try:
|
21 |
discussions = api.get_repo_discussions(repo_id=model_id)
|
22 |
except Exception:
|
23 |
return None
|
24 |
for discussion in discussions:
|
25 |
+
if (
|
26 |
+
discussion.status == "open"
|
27 |
+
and discussion.is_pull_request
|
28 |
+
and discussion.title == pr_title
|
29 |
+
):
|
30 |
return discussion
|
31 |
|
32 |
+
|
33 |
def convert_onnx(model_id: str, task: str, folder: str) -> List:
|
34 |
|
35 |
# Allocate the model
|
|
|
48 |
and task in ["sequence_classification"]
|
49 |
)
|
50 |
if needs_pad_token_id:
|
51 |
+
# if args.pad_token_id is not None:
|
52 |
# model.config.pad_token_id = args.pad_token_id
|
53 |
try:
|
54 |
tok = AutoTokenizer.from_pretrained(model_id)
|
|
|
78 |
print(f"All good, model saved at: {output}")
|
79 |
except ValueError:
|
80 |
print(f"An error occured, but the model was saved at: {output.as_posix()}")
|
81 |
+
|
82 |
+
n_files = len(
|
83 |
+
[
|
84 |
+
name
|
85 |
+
for name in os.listdir(folder)
|
86 |
+
if os.path.isfile(os.path.join(folder, name)) and not name.startswith(".")
|
87 |
+
]
|
88 |
+
)
|
89 |
+
|
90 |
if n_files == 1:
|
91 |
+
operations = [
|
92 |
+
CommitOperationAdd(
|
93 |
+
path_in_repo=file_name, path_or_fileobj=os.path.join(folder, file_name)
|
94 |
+
)
|
95 |
+
for file_name in os.listdir(folder)
|
96 |
+
]
|
97 |
else:
|
98 |
+
operations = [
|
99 |
+
CommitOperationAdd(
|
100 |
+
path_in_repo=os.path.join("onnx", file_name),
|
101 |
+
path_or_fileobj=os.path.join(folder, file_name),
|
102 |
+
)
|
103 |
+
for file_name in os.listdir(folder)
|
104 |
+
]
|
105 |
+
|
106 |
return operations
|
107 |
|
108 |
|
109 |
+
def convert(
|
110 |
+
api: "HfApi", model_id: str, task: str, force: bool = False
|
111 |
+
) -> Tuple[int, "CommitInfo"]:
|
112 |
pr_title = "Adding ONNX file of this model"
|
113 |
info = api.model_info(model_id)
|
114 |
filenames = set(s.rfilename for s in info.siblings)
|
|
|
119 |
try:
|
120 |
task = TasksManager.infer_task_from_model(model_id)
|
121 |
except Exception as e:
|
122 |
+
return (
|
123 |
+
f"### Error: {e}. Please pass explicitely the task as it could not be infered.",
|
124 |
+
None,
|
125 |
+
)
|
126 |
|
127 |
with TemporaryDirectory() as d:
|
128 |
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
|
|
|
135 |
elif pr is not None and not force:
|
136 |
url = f"https://huggingface.co/{model_id}/discussions/{pr.num}"
|
137 |
new_pr = pr
|
138 |
+
raise Exception(
|
139 |
+
f"Model {model_id} already has an open PR check out {url}"
|
140 |
+
)
|
141 |
else:
|
142 |
operations = convert_onnx(model_id, task, folder)
|
143 |
|
|
|
185 |
)
|
186 |
args = parser.parse_args()
|
187 |
api = HfApi()
|
188 |
+
convert(api, args.model_id, task=args.task, force=args.force)
|