Felix Marty
commited on
Commit
·
7804c1f
1
Parent(s):
e27b4eb
fix
Browse files- app.py +14 -7
- onnx_export.py +10 -4
app.py
CHANGED
@@ -8,6 +8,8 @@ from onnx_export import convert
|
|
8 |
from huggingface_hub import HfApi, Repository
|
9 |
|
10 |
|
|
|
|
|
11 |
DATASET_REPO_URL = "https://huggingface.co/datasets/safetensors/conversions"
|
12 |
DATA_FILENAME = "data.csv"
|
13 |
DATA_FILE = os.path.join("data", DATA_FILENAME)
|
@@ -17,7 +19,7 @@ HF_TOKEN = os.environ.get("HF_TOKEN")
|
|
17 |
repo: Optional[Repository] = None
|
18 |
if HF_TOKEN:
|
19 |
repo = Repository(local_dir="data", clone_from=DATASET_REPO_URL, token=HF_TOKEN)
|
20 |
-
|
21 |
|
22 |
def onnx_export(token: str, model_id: str, task: str) -> str:
|
23 |
if token == "" or model_id == "":
|
@@ -28,9 +30,16 @@ def onnx_export(token: str, model_id: str, task: str) -> str:
|
|
28 |
"""
|
29 |
try:
|
30 |
api = HfApi(token=token)
|
31 |
-
|
|
|
|
|
|
|
|
|
32 |
print("[commit_info]", commit_info)
|
33 |
|
|
|
|
|
|
|
34 |
"""
|
35 |
# save in a private dataset:
|
36 |
if repo is not None:
|
@@ -51,8 +60,6 @@ def onnx_export(token: str, model_id: str, task: str) -> str:
|
|
51 |
|
52 |
return f"### Success 🔥 Yay! This model was successfully converted and a PR was open using your token, here: {commit_info.pr_url}]({commit_info.pr_url})"
|
53 |
"""
|
54 |
-
except Exception as e:
|
55 |
-
return f"### Error: {e}"
|
56 |
|
57 |
|
58 |
DESCRIPTION = """
|
@@ -65,14 +72,14 @@ The steps are the following:
|
|
65 |
"""
|
66 |
|
67 |
demo = gr.Interface(
|
68 |
-
title="
|
69 |
description=DESCRIPTION,
|
70 |
allow_flagging="never",
|
71 |
article="Check out the [Optimum repo on GitHub](https://github.com/huggingface/optimum)",
|
72 |
inputs=[
|
73 |
-
gr.Text(max_lines=1, label="
|
74 |
gr.Text(max_lines=1, label="model_id"),
|
75 |
-
gr.Text(max_lines=1, label="task")
|
76 |
],
|
77 |
outputs=[gr.Markdown(label="output")],
|
78 |
fn=onnx_export,
|
|
|
8 |
from huggingface_hub import HfApi, Repository
|
9 |
|
10 |
|
11 |
+
# TODO: save stats about the Space?
|
12 |
+
"""
|
13 |
DATASET_REPO_URL = "https://huggingface.co/datasets/safetensors/conversions"
|
14 |
DATA_FILENAME = "data.csv"
|
15 |
DATA_FILE = os.path.join("data", DATA_FILENAME)
|
|
|
19 |
repo: Optional[Repository] = None
|
20 |
if HF_TOKEN:
|
21 |
repo = Repository(local_dir="data", clone_from=DATASET_REPO_URL, token=HF_TOKEN)
|
22 |
+
"""
|
23 |
|
24 |
def onnx_export(token: str, model_id: str, task: str) -> str:
|
25 |
if token == "" or model_id == "":
|
|
|
30 |
"""
|
31 |
try:
|
32 |
api = HfApi(token=token)
|
33 |
+
|
34 |
+
error, commit_info = convert(api=api, model_id=model_id, task=task)
|
35 |
+
if error != "":
|
36 |
+
return error
|
37 |
+
|
38 |
print("[commit_info]", commit_info)
|
39 |
|
40 |
+
return f"### Success 🔥 Yay! This model was successfully converted and a PR was open using your token, here: {commit_info.pr_url}]({commit_info.pr_url})"
|
41 |
+
except Exception as e:
|
42 |
+
return f"### Error: {e}"
|
43 |
"""
|
44 |
# save in a private dataset:
|
45 |
if repo is not None:
|
|
|
60 |
|
61 |
return f"### Success 🔥 Yay! This model was successfully converted and a PR was open using your token, here: {commit_info.pr_url}]({commit_info.pr_url})"
|
62 |
"""
|
|
|
|
|
63 |
|
64 |
|
65 |
DESCRIPTION = """
|
|
|
72 |
"""
|
73 |
|
74 |
demo = gr.Interface(
|
75 |
+
title="POC to convert any PyTorch model to ONNX",
|
76 |
description=DESCRIPTION,
|
77 |
allow_flagging="never",
|
78 |
article="Check out the [Optimum repo on GitHub](https://github.com/huggingface/optimum)",
|
79 |
inputs=[
|
80 |
+
gr.Text(max_lines=1, label="Hugging Face token"),
|
81 |
gr.Text(max_lines=1, label="model_id"),
|
82 |
+
gr.Text(value="auto", max_lines=1, label="task (can be left blank, will auto-infer)")
|
83 |
],
|
84 |
outputs=[gr.Markdown(label="output")],
|
85 |
fn=onnx_export,
|
onnx_export.py
CHANGED
@@ -12,7 +12,7 @@ import os
|
|
12 |
import shutil
|
13 |
import argparse
|
14 |
|
15 |
-
from typing import Optional
|
16 |
|
17 |
from huggingface_hub import CommitOperationAdd, HfApi, hf_hub_download, get_repo_discussions
|
18 |
from huggingface_hub.file_download import repo_folder_name
|
@@ -26,7 +26,7 @@ def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discuss
|
|
26 |
if discussion.status == "open" and discussion.is_pull_request and discussion.title == pr_title:
|
27 |
return discussion
|
28 |
|
29 |
-
def convert_onnx(model_id: str, task: str, folder: str):
|
30 |
|
31 |
# Allocate the model
|
32 |
model = TasksManager.get_model_from_task(task, model_id, framework="pt")
|
@@ -77,11 +77,17 @@ def convert_onnx(model_id: str, task: str, folder: str):
|
|
77 |
return operations
|
78 |
|
79 |
|
80 |
-
def convert(api: "HfApi", model_id: str, task:str, force: bool=False) ->
|
81 |
pr_title = "Adding ONNX file of this model"
|
82 |
info = api.model_info(model_id)
|
83 |
filenames = set(s.rfilename for s in info.siblings)
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
with TemporaryDirectory() as d:
|
86 |
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
|
87 |
os.makedirs(folder)
|
@@ -105,7 +111,7 @@ def convert(api: "HfApi", model_id: str, task:str, force: bool=False) -> Optiona
|
|
105 |
)
|
106 |
finally:
|
107 |
shutil.rmtree(folder)
|
108 |
-
return new_pr
|
109 |
|
110 |
|
111 |
if __name__ == "__main__":
|
|
|
12 |
import shutil
|
13 |
import argparse
|
14 |
|
15 |
+
from typing import Optional, Tuple, List
|
16 |
|
17 |
from huggingface_hub import CommitOperationAdd, HfApi, hf_hub_download, get_repo_discussions
|
18 |
from huggingface_hub.file_download import repo_folder_name
|
|
|
26 |
if discussion.status == "open" and discussion.is_pull_request and discussion.title == pr_title:
|
27 |
return discussion
|
28 |
|
29 |
+
def convert_onnx(model_id: str, task: str, folder: str) -> List:
|
30 |
|
31 |
# Allocate the model
|
32 |
model = TasksManager.get_model_from_task(task, model_id, framework="pt")
|
|
|
77 |
return operations
|
78 |
|
79 |
|
80 |
+
def convert(api: "HfApi", model_id: str, task: str, force: bool = False) -> Tuple[int, "CommitInfo"]:
|
81 |
pr_title = "Adding ONNX file of this model"
|
82 |
info = api.model_info(model_id)
|
83 |
filenames = set(s.rfilename for s in info.siblings)
|
84 |
|
85 |
+
if task == "auto":
|
86 |
+
try:
|
87 |
+
task = TasksManager.infer_task_from_model(model_id)
|
88 |
+
except Exception as e:
|
89 |
+
return f"### Error: {e}. Please pass explicitely the task as it could not be infered.", None
|
90 |
+
|
91 |
with TemporaryDirectory() as d:
|
92 |
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
|
93 |
os.makedirs(folder)
|
|
|
111 |
)
|
112 |
finally:
|
113 |
shutil.rmtree(folder)
|
114 |
+
return "0", new_pr
|
115 |
|
116 |
|
117 |
if __name__ == "__main__":
|