File size: 6,926 Bytes
04c8db1
f75daf5
 
04c8db1
 
 
f75daf5
04c8db1
 
f75daf5
04c8db1
 
 
 
f75daf5
7567dc4
 
04c8db1
f75daf5
 
 
 
 
 
04c8db1
 
 
 
 
f75daf5
 
04c8db1
7804c1f
f75daf5
be527a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04c8db1
be527a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508744a
 
 
 
36e2d86
04c8db1
 
 
 
 
 
 
 
 
3166d00
04c8db1
 
 
 
 
 
3166d00
04c8db1
 
 
 
 
 
 
 
be527a9
f75daf5
 
04c8db1
 
 
f75daf5
 
 
 
7567dc4
 
7804c1f
 
 
 
04c8db1
 
 
 
7804c1f
f75daf5
 
 
 
 
 
 
 
 
 
 
04c8db1
46363ea
04c8db1
f75daf5
be527a9
 
7567dc4
171b6b3
7567dc4
171b6b3
7567dc4
171b6b3
7567dc4
be527a9
 
 
 
7567dc4
be527a9
 
f75daf5
 
7804c1f
f75daf5
 
 
 
 
 
 
 
 
 
 
be527a9
f75daf5
 
 
 
be527a9
f75daf5
 
 
 
 
 
 
 
 
 
04c8db1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import argparse
import os
import shutil
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import List, Optional, Tuple

from huggingface_hub import (CommitOperationAdd, HfApi, get_repo_discussions,
                             hf_hub_download)
from huggingface_hub.file_download import repo_folder_name
from optimum.exporters.onnx import (OnnxConfigWithPast, export,
                                    validate_model_outputs)
from optimum.exporters.tasks import TasksManager
from transformers import AutoConfig, AutoTokenizer, is_torch_available

SPACES_URL = "https://huggingface.co/spaces/optimum/exporters"


def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
    try:
        discussions = api.get_repo_discussions(repo_id=model_id)
    except Exception:
        return None
    for discussion in discussions:
        if (
            discussion.status == "open"
            and discussion.is_pull_request
            and discussion.title == pr_title
        ):
            return discussion


def convert_onnx(model_id: str, task: str, folder: str) -> List:

    # Allocate the model
    model = TasksManager.get_model_from_task(task, model_id, framework="pt")
    model_type = model.config.model_type.replace("_", "-")
    model_name = getattr(model, "name", None)

    onnx_config_constructor = TasksManager.get_exporter_config_constructor(
        model_type, "onnx", task=task, model_name=model_name
    )
    onnx_config = onnx_config_constructor(model.config)

    needs_pad_token_id = (
        isinstance(onnx_config, OnnxConfigWithPast)
        and getattr(model.config, "pad_token_id", None) is None
        and task in ["sequence_classification"]
    )
    if needs_pad_token_id:
        # if args.pad_token_id is not None:
        #    model.config.pad_token_id = args.pad_token_id
        try:
            tok = AutoTokenizer.from_pretrained(model_id)
            model.config.pad_token_id = tok.pad_token_id
        except Exception:
            raise ValueError(
                "Could not infer the pad token id, which is needed in this case, please provide it with the --pad_token_id argument"
            )

    # Ensure the requested opset is sufficient
    opset = onnx_config.DEFAULT_ONNX_OPSET

    output = Path(folder).joinpath("model.onnx")
    onnx_inputs, onnx_outputs = export(
        model,
        onnx_config,
        opset,
        output,
    )

    atol = onnx_config.ATOL_FOR_VALIDATION
    if isinstance(atol, dict):
        atol = atol[task.replace("-with-past", "")]

    try:
        validate_model_outputs(onnx_config, model, output, onnx_outputs, atol)
        print(f"All good, model saved at: {output}")
    except ValueError:
        print(f"An error occured, but the model was saved at: {output.as_posix()}")

    n_files = len(
        [
            name
            for name in os.listdir(folder)
            if os.path.isfile(os.path.join(folder, name)) and not name.startswith(".")
        ]
    )

    if n_files == 1:
        operations = [
            CommitOperationAdd(
                path_in_repo=file_name, path_or_fileobj=os.path.join(folder, file_name)
            )
            for file_name in os.listdir(folder)
        ]
    else:
        operations = [
            CommitOperationAdd(
                path_in_repo=os.path.join("onnx", file_name),
                path_or_fileobj=os.path.join(folder, file_name),
            )
            for file_name in os.listdir(folder)
        ]

    return operations


def convert(
    api: "HfApi", model_id: str, task: str, force: bool = False
) -> Tuple[int, "CommitInfo"]:
    pr_title = "Adding ONNX file of this model"
    info = api.model_info(model_id)
    filenames = set(s.rfilename for s in info.siblings)

    requesting_user = api.whoami()["name"]

    if task == "auto":
        try:
            task = TasksManager.infer_task_from_model(model_id)
        except Exception as e:
            return (
                f"### Error: {e}. Please pass explicitely the task as it could not be infered.",
                None,
            )

    with TemporaryDirectory() as d:
        folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
        os.makedirs(folder)
        new_pr = None
        try:
            pr = previous_pr(api, model_id, pr_title)
            if "model.onnx" in filenames and not force:
                raise Exception(f"Model {model_id} is already converted, skipping..")
            elif pr is not None and not force:
                url = f"https://huggingface.co/{model_id}/discussions/{pr.num}"
                new_pr = pr
                raise Exception(
                    f"Model {model_id} already has an open PR check out [{url}]({url})"
                )
            else:
                operations = convert_onnx(model_id, task, folder)

                commit_description = f"""
Beep boop I am the [ONNX export bot ๐Ÿค–๐ŸŽ๏ธ]({SPACES_URL}). On behalf of [{requesting_user}](https://huggingface.co/{requesting_user}), I would like to add to this repository the model converted to ONNX.

What is ONNX? It stands for "Open Neural Network Exchange", and is the most commonly used open standard for machine learning interoperability. You can find out more at [onnx.ai](https://onnx.ai/)!

The exported ONNX model can be then be consumed by various backends as TensorRT or TVM, or simply be used in a few lines with ๐Ÿค— Optimum through ONNX Runtime, check out how [here](https://huggingface.co/docs/optimum/main/en/onnxruntime/usage_guides/models)!
                """
                new_pr = api.create_commit(
                    repo_id=model_id,
                    operations=operations,
                    commit_message=pr_title,
                    commit_description=commit_description,  # TODO
                    create_pr=True,
                )
        finally:
            shutil.rmtree(folder)
        return "0", new_pr


if __name__ == "__main__":
    DESCRIPTION = """
    Simple utility tool to convert automatically a model on the hub to onnx format.
    It is PyTorch exclusive for now.
    It works by downloading the weights (PT), converting them locally, and uploading them back
    as a PR on the hub.
    """
    parser = argparse.ArgumentParser(description=DESCRIPTION)
    parser.add_argument(
        "--model_id",
        type=str,
        help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`",
    )
    parser.add_argument(
        "--task",
        type=str,
        help="The task the model is performing",
    )
    parser.add_argument(
        "--force",
        action="store_true",
        help="Create the PR even if it already exists of if the model was already converted.",
    )
    args = parser.parse_args()
    api = HfApi()
    convert(api, args.model_id, task=args.task, force=args.force)