Felix Marty commited on
Commit
f75daf5
·
1 Parent(s): 89d7a1e

add sketch

Browse files
Files changed (3) hide show
  1. app.py +84 -4
  2. onnx_export.py +132 -0
  3. requirements.txt +3 -0
app.py CHANGED
@@ -1,7 +1,87 @@
 
 
 
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import datetime
3
+ import os
4
+ from typing import Optional
5
  import gradio as gr
6
 
7
+ from onnx_export import convert
8
+ from huggingface_hub import HfApi, Repository
9
 
10
+
11
+ DATASET_REPO_URL = "https://huggingface.co/datasets/safetensors/conversions"
12
+ DATA_FILENAME = "data.csv"
13
+ DATA_FILE = os.path.join("data", DATA_FILENAME)
14
+
15
+ HF_TOKEN = os.environ.get("HF_TOKEN")
16
+
17
+ repo: Optional[Repository] = None
18
+ if HF_TOKEN:
19
+ repo = Repository(local_dir="data", clone_from=DATASET_REPO_URL, token=HF_TOKEN)
20
+
21
+
22
+ def onnx_export(token: str, model_id: str, task: str) -> str:
23
+ if token == "" or model_id == "":
24
+ return """
25
+ ### Invalid input 🐞
26
+
27
+ Please fill a token and model_id.
28
+ """
29
+ try:
30
+ api = HfApi(token=token)
31
+ commit_info = convert(api=api, model_id=model_id, task=task)
32
+ print("[commit_info]", commit_info)
33
+
34
+ # save in a private dataset:
35
+ if repo is not None:
36
+ repo.git_pull(rebase=True)
37
+ with open(DATA_FILE, "a") as csvfile:
38
+ writer = csv.DictWriter(
39
+ csvfile, fieldnames=["model_id", "pr_url", "time"]
40
+ )
41
+ writer.writerow(
42
+ {
43
+ "model_id": model_id,
44
+ "pr_url": commit_info.pr_url,
45
+ "time": str(datetime.now()),
46
+ }
47
+ )
48
+ commit_url = repo.push_to_hub()
49
+ print("[dataset]", commit_url)
50
+
51
+ return f"""
52
+ ### Success 🔥
53
+ Yay! This model was successfully converted and a PR was open using your token, here:
54
+ [{commit_info.pr_url}]({commit_info.pr_url})
55
+ """
56
+ except Exception as e:
57
+ return f"""
58
+ ### Error 😢😢
59
+
60
+ {e}
61
+ """
62
+
63
+
64
+ DESCRIPTION = """
65
+ The steps are the following:
66
+ - Paste a read-access token from hf.co/settings/tokens. Read access is enough given that we will open a PR against the source repo.
67
+ - Input a model id from the Hub
68
+ - If necessary, input the task for this model.
69
+ - Click "Convert to ONNX"
70
+ - That's it! You'll get feedback if it works or not, and if it worked, you'll get the URL of the opened PR!
71
+ """
72
+
73
+ demo = gr.Interface(
74
+ title="Convert any model to Safetensors and open a PR",
75
+ description=DESCRIPTION,
76
+ allow_flagging="never",
77
+ article="Check out the [Optimum repo on GitHub](https://github.com/huggingface/optimum)",
78
+ inputs=[
79
+ gr.Text(max_lines=1, label="your_hf_token"),
80
+ gr.Text(max_lines=1, label="model_id"),
81
+ gr.Text(max_lines=1, label="task")
82
+ ],
83
+ outputs=[gr.Markdown(label="output")],
84
+ fn=onnx_export,
85
+ )
86
+
87
+ demo.launch()
onnx_export.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from optimum.exporters.tasks import TasksManager
2
+
3
+ from optimum.exporters.onnx import OnnxConfigWithPast, export, validate_model_outputs
4
+
5
+ from tempfile import TemporaryDirectory
6
+
7
+ from transformers import AutoConfig, is_torch_available
8
+
9
+ from transformers import AutoConfig
10
+
11
+ from pathlib import Path
12
+
13
+ import os
14
+ import shutil
15
+ import argparse
16
+
17
+ from typing import Optional
18
+
19
+ from huggingface_hub import CommitOperationAdd, HfApi, hf_hub_download, get_repo_discussions
20
+ from huggingface_hub.file_download import repo_folder_name
21
+
22
+ def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
23
+ try:
24
+ discussions = api.get_repo_discussions(repo_id=model_id)
25
+ except Exception:
26
+ return None
27
+ for discussion in discussions:
28
+ if discussion.status == "open" and discussion.is_pull_request and discussion.title == pr_title:
29
+ return discussion
30
+
31
+ def convert_onnx(model_id: str, task: str, folder: str):
32
+ model_class = TasksManager.get_model_class_for_task(task)
33
+ config = AutoConfig.from_pretrained(model_id)
34
+ model = model_class.from_config(config)
35
+
36
+ device = "cpu" # ?
37
+
38
+ # Dynamic axes aren't supported for YOLO-like models. This means they cannot be exported to ONNX on CUDA devices.
39
+ # See: https://github.com/ultralytics/yolov5/pull/8378
40
+ if model.__class__.__name__.startswith("Yolos") and device != "cpu":
41
+ return
42
+
43
+ onnx_config_class_constructor = TasksManager.get_exporter_config_constructor(model_type=config.model_type, exporter="onnx", task=task, model_name=model_id)
44
+ onnx_config = onnx_config_class_constructor(model.config)
45
+
46
+ # We need to set this to some value to be able to test the outputs values for batch size > 1.
47
+ if (
48
+ isinstance(onnx_config, OnnxConfigWithPast)
49
+ and getattr(model.config, "pad_token_id", None) is None
50
+ and task == "sequence-classification"
51
+ ):
52
+ model.config.pad_token_id = 0
53
+
54
+ if is_torch_available():
55
+ from optimum.exporters.onnx.utils import TORCH_VERSION
56
+
57
+ if not onnx_config.is_torch_support_available:
58
+ print(
59
+ "Skipping due to incompatible PyTorch version. Minimum required is"
60
+ f" {onnx_config.MIN_TORCH_VERSION}, got: {TORCH_VERSION}"
61
+ )
62
+
63
+ onnx_inputs, onnx_outputs = export(
64
+ model, onnx_config, onnx_config.DEFAULT_ONNX_OPSET, Path(folder), device=device
65
+ )
66
+ atol = onnx_config.ATOL_FOR_VALIDATION
67
+ if isinstance(atol, dict):
68
+ atol = atol[task.replace("-with-past", "")]
69
+ validate_model_outputs(
70
+ onnx_config,
71
+ model,
72
+ Path(folder),
73
+ onnx_outputs,
74
+ atol,
75
+ )
76
+
77
+ # TODO: iterate in folder and add all
78
+ operations = [CommitOperationAdd(path_in_repo=local.split("/")[-1], path_or_fileobj=local) for local in local_filenames]
79
+
80
+ return operations
81
+
82
+
83
+ def convert(api: "HfApi", model_id: str, task:str, force: bool=False) -> Optional["CommitInfo"]:
84
+ pr_title = "Adding ONNX file of this model"
85
+ info = api.model_info(model_id)
86
+ filenames = set(s.rfilename for s in info.siblings)
87
+
88
+ with TemporaryDirectory() as d:
89
+ folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
90
+ os.makedirs(folder)
91
+ new_pr = None
92
+ try:
93
+ pr = previous_pr(api, model_id, pr_title)
94
+ if "model.onnx" in filenames and not force:
95
+ raise Exception(f"Model {model_id} is already converted, skipping..")
96
+ elif pr is not None and not force:
97
+ url = f"https://huggingface.co/{model_id}/discussions/{pr.num}"
98
+ new_pr = pr
99
+ raise Exception(f"Model {model_id} already has an open PR check out {url}")
100
+ else:
101
+ convert_onnx(model_id, task, folder)
102
+ finally:
103
+ shutil.rmtree(folder)
104
+ return new_pr
105
+
106
+
107
+ if __name__ == "__main__":
108
+ DESCRIPTION = """
109
+ Simple utility tool to convert automatically a model on the hub to onnx format.
110
+ It is PyTorch exclusive for now.
111
+ It works by downloading the weights (PT), converting them locally, and uploading them back
112
+ as a PR on the hub.
113
+ """
114
+ parser = argparse.ArgumentParser(description=DESCRIPTION)
115
+ parser.add_argument(
116
+ "model_id",
117
+ type=str,
118
+ help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`",
119
+ )
120
+ parser.add_argument(
121
+ "task",
122
+ type=str,
123
+ help="The task the model is performing",
124
+ )
125
+ parser.add_argument(
126
+ "--force",
127
+ action="store_true",
128
+ help="Create the PR even if it already exists of if the model was already converted.",
129
+ )
130
+ args = parser.parse_args()
131
+ api = HfApi()
132
+ convert(api, args.model_id, task=args.task, force=args.force)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ transformers
3
+ git+https://github.com/huggingface/optimum.git#egg=optimum[onnxruntime]