File size: 13,940 Bytes
0ad74ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
"""This module contains the EndpointV3Compatibility class, which is used to connect to Gradio apps running 3.x.x versions of Gradio."""

from __future__ import annotations

import json
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal

import httpx
import huggingface_hub
import websockets
from packaging import version

from gradio_client import serializing, utils
from gradio_client.exceptions import SerializationSetupError
from gradio_client.utils import (
    Communicator,
)

if TYPE_CHECKING:
    from gradio_client import Client


class EndpointV3Compatibility:
    """Endpoint class for connecting to v3 endpoints. Backwards compatibility."""

    def __init__(self, client: Client, fn_index: int, dependency: dict, *_args):
        self.client: Client = client
        self.fn_index = fn_index
        self.dependency = dependency
        api_name = dependency.get("api_name")
        self.api_name: str | Literal[False] | None = (
            "/" + api_name if isinstance(api_name, str) else api_name
        )
        self.use_ws = self._use_websocket(self.dependency)
        self.protocol = "ws" if self.use_ws else "http"
        self.input_component_types = []
        self.output_component_types = []
        self.root_url = client.src + "/" if not client.src.endswith("/") else client.src
        try:
            # Only a real API endpoint if backend_fn is True (so not just a frontend function), serializers are valid,
            # and api_name is not False (meaning that the developer has explicitly disabled the API endpoint)
            self.serializers, self.deserializers = self._setup_serializers()
            self.is_valid = self.dependency["backend_fn"] and self.api_name is not False
        except SerializationSetupError:
            self.is_valid = False
        self.backend_fn = dependency.get("backend_fn")
        self.show_api = True

    def __repr__(self):
        return f"Endpoint src: {self.client.src}, api_name: {self.api_name}, fn_index: {self.fn_index}"

    def __str__(self):
        return self.__repr__()

    def make_end_to_end_fn(self, helper: Communicator | None = None):
        _predict = self.make_predict(helper)

        def _inner(*data):
            if not self.is_valid:
                raise utils.InvalidAPIEndpointError()
            data = self.insert_state(*data)
            data = self.serialize(*data)
            predictions = _predict(*data)
            predictions = self.process_predictions(*predictions)
            # Append final output only if not already present
            # for consistency between generators and not generators
            if helper:
                with helper.lock:
                    if not helper.job.outputs:
                        helper.job.outputs.append(predictions)
            return predictions

        return _inner

    def make_cancel(self, helper: Communicator | None = None):  # noqa: ARG002 (needed so that both endpoints classes have the same api)
        return None

    def make_predict(self, helper: Communicator | None = None):
        def _predict(*data) -> tuple:
            data = json.dumps(
                {
                    "data": data,
                    "fn_index": self.fn_index,
                    "session_hash": self.client.session_hash,
                }
            )
            hash_data = json.dumps(
                {
                    "fn_index": self.fn_index,
                    "session_hash": self.client.session_hash,
                }
            )
            if self.use_ws:
                result = utils.synchronize_async(self._ws_fn, data, hash_data, helper)
                if "error" in result:
                    raise ValueError(result["error"])
            else:
                response = httpx.post(
                    self.client.api_url,
                    headers=self.client.headers,
                    json=data,
                    verify=self.client.ssl_verify,
                    **self.client.httpx_kwargs,
                )
                result = json.loads(response.content.decode("utf-8"))
            try:
                output = result["data"]
            except KeyError as ke:
                is_public_space = (
                    self.client.space_id
                    and not huggingface_hub.space_info(self.client.space_id).private
                )
                if "error" in result and "429" in result["error"] and is_public_space:
                    raise utils.TooManyRequestsError(
                        f"Too many requests to the API, please try again later. To avoid being rate-limited, "
                        f"please duplicate the Space using Client.duplicate({self.client.space_id}) "
                        f"and pass in your Hugging Face token."
                    ) from None
                elif "error" in result:
                    raise ValueError(result["error"]) from None
                raise KeyError(
                    f"Could not find 'data' key in response. Response received: {result}"
                ) from ke
            return tuple(output)

        return _predict

    def _predict_resolve(self, *data) -> Any:
        """Needed for gradio.load(), which has a slightly different signature for serializing/deserializing"""
        outputs = self.make_predict()(*data)
        if len(self.dependency["outputs"]) == 1:
            return outputs[0]
        return outputs

    def _upload(
        self, file_paths: list[str | list[str]]
    ) -> list[str | list[str]] | list[dict[str, Any] | list[dict[str, Any]]]:
        if not file_paths:
            return []
        # Put all the filepaths in one file
        # but then keep track of which index in the
        # original list they came from so we can recreate
        # the original structure
        files = []
        indices = []
        for i, fs in enumerate(file_paths):
            if not isinstance(fs, list):
                fs = [fs]
            for f in fs:
                files.append(("files", (Path(f).name, open(f, "rb"))))  # noqa: SIM115
                indices.append(i)
        r = httpx.post(
            self.client.upload_url,
            headers=self.client.headers,
            files=files,
            verify=self.client.ssl_verify,
            **self.client.httpx_kwargs,
        )
        if r.status_code != 200:
            uploaded = file_paths
        else:
            uploaded = []
            result = r.json()
            for i, fs in enumerate(file_paths):
                if isinstance(fs, list):
                    output = [o for ix, o in enumerate(result) if indices[ix] == i]
                    res = [
                        {
                            "is_file": True,
                            "name": o,
                            "orig_name": Path(f).name,
                            "data": None,
                        }
                        for f, o in zip(fs, output, strict=False)
                    ]
                else:
                    o = next(o for ix, o in enumerate(result) if indices[ix] == i)
                    res = {
                        "is_file": True,
                        "name": o,
                        "orig_name": Path(fs).name,
                        "data": None,
                    }
                uploaded.append(res)
        return uploaded

    def _add_uploaded_files_to_data(
        self,
        files: list[str | list[str]] | list[dict[str, Any] | list[dict[str, Any]]],
        data: list[Any],
    ) -> None:
        """Helper function to modify the input data with the uploaded files."""
        file_counter = 0
        for i, t in enumerate(self.input_component_types):
            if t in ["file", "uploadbutton"]:
                data[i] = files[file_counter]
                file_counter += 1

    def insert_state(self, *data) -> tuple:
        data = list(data)
        for i, input_component_type in enumerate(self.input_component_types):
            if input_component_type == utils.STATE_COMPONENT:
                data.insert(i, None)
        return tuple(data)

    def remove_skipped_components(self, *data) -> tuple:
        data = [
            d
            for d, oct in zip(data, self.output_component_types, strict=False)
            if oct not in utils.SKIP_COMPONENTS
        ]
        return tuple(data)

    def reduce_singleton_output(self, *data) -> Any:
        if (
            len(
                [
                    oct
                    for oct in self.output_component_types
                    if oct not in utils.SKIP_COMPONENTS
                ]
            )
            == 1
        ):
            return data[0]
        else:
            return data

    def serialize(self, *data) -> tuple:
        if len(data) != len(self.serializers):
            raise ValueError(
                f"Expected {len(self.serializers)} arguments, got {len(data)}"
            )

        files = [
            f
            for f, t in zip(data, self.input_component_types, strict=False)
            if t in ["file", "uploadbutton"]
        ]
        uploaded_files = self._upload(files)
        data = list(data)
        self._add_uploaded_files_to_data(uploaded_files, data)
        o = tuple(
            [s.serialize(d) for s, d in zip(self.serializers, data, strict=False)]
        )
        return o

    def deserialize(self, *data) -> tuple:
        if len(data) != len(self.deserializers):
            raise ValueError(
                f"Expected {len(self.deserializers)} outputs, got {len(data)}"
            )
        outputs = tuple(
            [
                s.deserialize(
                    d,
                    save_dir=self.client.output_dir,
                    hf_token=self.client.hf_token,
                    root_url=self.root_url,
                )
                for s, d in zip(self.deserializers, data, strict=False)
            ]
        )
        return outputs

    def process_predictions(self, *predictions):
        if self.client.download_files:
            predictions = self.deserialize(*predictions)
        predictions = self.remove_skipped_components(*predictions)
        predictions = self.reduce_singleton_output(*predictions)
        return predictions

    def _setup_serializers(
        self,
    ) -> tuple[list[serializing.Serializable], list[serializing.Serializable]]:
        inputs = self.dependency["inputs"]
        serializers = []

        for i in inputs:
            for component in self.client.config["components"]:
                if component["id"] == i:
                    component_name = component["type"]
                    self.input_component_types.append(component_name)
                    if component.get("serializer"):
                        serializer_name = component["serializer"]
                        if serializer_name not in serializing.SERIALIZER_MAPPING:
                            raise SerializationSetupError(
                                f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version."
                            )
                        serializer = serializing.SERIALIZER_MAPPING[serializer_name]
                    elif component_name in serializing.COMPONENT_MAPPING:
                        serializer = serializing.COMPONENT_MAPPING[component_name]
                    else:
                        raise SerializationSetupError(
                            f"Unknown component: {component_name}, you may need to update your gradio_client version."
                        )
                    serializers.append(serializer())  # type: ignore

        outputs = self.dependency["outputs"]
        deserializers = []
        for i in outputs:
            for component in self.client.config["components"]:
                if component["id"] == i:
                    component_name = component["type"]
                    self.output_component_types.append(component_name)
                    if component.get("serializer"):
                        serializer_name = component["serializer"]
                        if serializer_name not in serializing.SERIALIZER_MAPPING:
                            raise SerializationSetupError(
                                f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version."
                            )
                        deserializer = serializing.SERIALIZER_MAPPING[serializer_name]
                    elif component_name in utils.SKIP_COMPONENTS:
                        deserializer = serializing.SimpleSerializable
                    elif component_name in serializing.COMPONENT_MAPPING:
                        deserializer = serializing.COMPONENT_MAPPING[component_name]
                    else:
                        raise SerializationSetupError(
                            f"Unknown component: {component_name}, you may need to update your gradio_client version."
                        )
                    deserializers.append(deserializer())  # type: ignore

        return serializers, deserializers

    def _use_websocket(self, dependency: dict) -> bool:
        queue_enabled = self.client.config.get("enable_queue", False)
        queue_uses_websocket = version.parse(
            self.client.config.get("version", "2.0")
        ) >= version.Version("3.2")
        dependency_uses_queue = dependency.get("queue", False) is not False
        return queue_enabled and queue_uses_websocket and dependency_uses_queue

    async def _ws_fn(self, data, hash_data, helper: Communicator):
        async with websockets.connect(  # type: ignore
            self.client.ws_url,
            open_timeout=10,
            extra_headers=self.client.headers,
            max_size=1024 * 1024 * 1024,
        ) as websocket:
            return await utils.get_pred_from_ws(websocket, data, hash_data, helper)