Felix Marty commited on
Commit
7804c1f
1 Parent(s): e27b4eb
Files changed (2) hide show
  1. app.py +14 -7
  2. onnx_export.py +10 -4
app.py CHANGED
@@ -8,6 +8,8 @@ from onnx_export import convert
8
  from huggingface_hub import HfApi, Repository
9
 
10
 
 
 
11
  DATASET_REPO_URL = "https://huggingface.co/datasets/safetensors/conversions"
12
  DATA_FILENAME = "data.csv"
13
  DATA_FILE = os.path.join("data", DATA_FILENAME)
@@ -17,7 +19,7 @@ HF_TOKEN = os.environ.get("HF_TOKEN")
17
  repo: Optional[Repository] = None
18
  if HF_TOKEN:
19
  repo = Repository(local_dir="data", clone_from=DATASET_REPO_URL, token=HF_TOKEN)
20
-
21
 
22
  def onnx_export(token: str, model_id: str, task: str) -> str:
23
  if token == "" or model_id == "":
@@ -28,9 +30,16 @@ def onnx_export(token: str, model_id: str, task: str) -> str:
28
  """
29
  try:
30
  api = HfApi(token=token)
31
- commit_info = convert(api=api, model_id=model_id, task=task)
 
 
 
 
32
  print("[commit_info]", commit_info)
33
 
 
 
 
34
  """
35
  # save in a private dataset:
36
  if repo is not None:
@@ -51,8 +60,6 @@ def onnx_export(token: str, model_id: str, task: str) -> str:
51
 
52
  return f"### Success 🔥 Yay! This model was successfully converted and a PR was open using your token, here: {commit_info.pr_url}]({commit_info.pr_url})"
53
  """
54
- except Exception as e:
55
- return f"### Error: {e}"
56
 
57
 
58
  DESCRIPTION = """
@@ -65,14 +72,14 @@ The steps are the following:
65
  """
66
 
67
  demo = gr.Interface(
68
- title="Convert any model to Safetensors and open a PR",
69
  description=DESCRIPTION,
70
  allow_flagging="never",
71
  article="Check out the [Optimum repo on GitHub](https://github.com/huggingface/optimum)",
72
  inputs=[
73
- gr.Text(max_lines=1, label="your_hf_token"),
74
  gr.Text(max_lines=1, label="model_id"),
75
- gr.Text(max_lines=1, label="task")
76
  ],
77
  outputs=[gr.Markdown(label="output")],
78
  fn=onnx_export,
 
8
  from huggingface_hub import HfApi, Repository
9
 
10
 
11
+ # TODO: save stats about the Space?
12
+ """
13
  DATASET_REPO_URL = "https://huggingface.co/datasets/safetensors/conversions"
14
  DATA_FILENAME = "data.csv"
15
  DATA_FILE = os.path.join("data", DATA_FILENAME)
 
19
  repo: Optional[Repository] = None
20
  if HF_TOKEN:
21
  repo = Repository(local_dir="data", clone_from=DATASET_REPO_URL, token=HF_TOKEN)
22
+ """
23
 
24
  def onnx_export(token: str, model_id: str, task: str) -> str:
25
  if token == "" or model_id == "":
 
30
  """
31
  try:
32
  api = HfApi(token=token)
33
+
34
+ error, commit_info = convert(api=api, model_id=model_id, task=task)
35
+ if error != "":
36
+ return error
37
+
38
  print("[commit_info]", commit_info)
39
 
40
+ return f"### Success 🔥 Yay! This model was successfully converted and a PR was open using your token, here: {commit_info.pr_url}]({commit_info.pr_url})"
41
+ except Exception as e:
42
+ return f"### Error: {e}"
43
  """
44
  # save in a private dataset:
45
  if repo is not None:
 
60
 
61
  return f"### Success 🔥 Yay! This model was successfully converted and a PR was open using your token, here: {commit_info.pr_url}]({commit_info.pr_url})"
62
  """
 
 
63
 
64
 
65
  DESCRIPTION = """
 
72
  """
73
 
74
  demo = gr.Interface(
75
+ title="POC to convert any PyTorch model to ONNX",
76
  description=DESCRIPTION,
77
  allow_flagging="never",
78
  article="Check out the [Optimum repo on GitHub](https://github.com/huggingface/optimum)",
79
  inputs=[
80
+ gr.Text(max_lines=1, label="Hugging Face token"),
81
  gr.Text(max_lines=1, label="model_id"),
82
+ gr.Text(value="auto", max_lines=1, label="task (can be left blank, will auto-infer)")
83
  ],
84
  outputs=[gr.Markdown(label="output")],
85
  fn=onnx_export,
onnx_export.py CHANGED
@@ -12,7 +12,7 @@ import os
12
  import shutil
13
  import argparse
14
 
15
- from typing import Optional
16
 
17
  from huggingface_hub import CommitOperationAdd, HfApi, hf_hub_download, get_repo_discussions
18
  from huggingface_hub.file_download import repo_folder_name
@@ -26,7 +26,7 @@ def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discuss
26
  if discussion.status == "open" and discussion.is_pull_request and discussion.title == pr_title:
27
  return discussion
28
 
29
- def convert_onnx(model_id: str, task: str, folder: str):
30
 
31
  # Allocate the model
32
  model = TasksManager.get_model_from_task(task, model_id, framework="pt")
@@ -77,11 +77,17 @@ def convert_onnx(model_id: str, task: str, folder: str):
77
  return operations
78
 
79
 
80
- def convert(api: "HfApi", model_id: str, task:str, force: bool=False) -> Optional["CommitInfo"]:
81
  pr_title = "Adding ONNX file of this model"
82
  info = api.model_info(model_id)
83
  filenames = set(s.rfilename for s in info.siblings)
84
 
 
 
 
 
 
 
85
  with TemporaryDirectory() as d:
86
  folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
87
  os.makedirs(folder)
@@ -105,7 +111,7 @@ def convert(api: "HfApi", model_id: str, task:str, force: bool=False) -> Optiona
105
  )
106
  finally:
107
  shutil.rmtree(folder)
108
- return new_pr
109
 
110
 
111
  if __name__ == "__main__":
 
12
  import shutil
13
  import argparse
14
 
15
+ from typing import Optional, Tuple, List
16
 
17
  from huggingface_hub import CommitOperationAdd, HfApi, hf_hub_download, get_repo_discussions
18
  from huggingface_hub.file_download import repo_folder_name
 
26
  if discussion.status == "open" and discussion.is_pull_request and discussion.title == pr_title:
27
  return discussion
28
 
29
+ def convert_onnx(model_id: str, task: str, folder: str) -> List:
30
 
31
  # Allocate the model
32
  model = TasksManager.get_model_from_task(task, model_id, framework="pt")
 
77
  return operations
78
 
79
 
80
+ def convert(api: "HfApi", model_id: str, task: str, force: bool = False) -> Tuple[int, "CommitInfo"]:
81
  pr_title = "Adding ONNX file of this model"
82
  info = api.model_info(model_id)
83
  filenames = set(s.rfilename for s in info.siblings)
84
 
85
+ if task == "auto":
86
+ try:
87
+ task = TasksManager.infer_task_from_model(model_id)
88
+ except Exception as e:
89
+ return f"### Error: {e}. Please pass explicitely the task as it could not be infered.", None
90
+
91
  with TemporaryDirectory() as d:
92
  folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
93
  os.makedirs(folder)
 
111
  )
112
  finally:
113
  shutil.rmtree(folder)
114
+ return "0", new_pr
115
 
116
 
117
  if __name__ == "__main__":