John6666 commited on
Commit
8f9d631
Β·
verified Β·
1 Parent(s): d1ad803

Upload multit2i.py

Browse files
Files changed (1) hide show
  1. multit2i.py +15 -4
multit2i.py CHANGED
@@ -143,20 +143,23 @@ def save_gallery(image_path: str | None, images: list[tuple] | None):
143
 
144
  # https://github.com/gradio-app/gradio/blob/main/gradio/external.py
145
  # https://huggingface.co/docs/huggingface_hub/package_reference/inference_client
146
- def load_from_model(model_name: str, hf_token: str = None):
 
147
  import httpx
148
  import huggingface_hub
149
- from gradio.exceptions import ModelNotFoundError
150
  model_url = f"https://huggingface.co/{model_name}"
151
  api_url = f"https://api-inference.huggingface.co/models/{model_name}"
152
  print(f"Fetching model from: {model_url}")
153
 
154
- headers = {"Authorization": f"Bearer {hf_token}"} if hf_token is not None else {}
155
  response = httpx.request("GET", api_url, headers=headers)
156
  if response.status_code != 200:
157
  raise ModelNotFoundError(
158
  f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `hf_token` parameter."
159
  )
 
 
160
  headers["X-Wait-For-Model"] = "true"
161
  client = huggingface_hub.InferenceClient(model=model_name, headers=headers,
162
  token=hf_token, timeout=server_timeout)
@@ -165,7 +168,14 @@ def load_from_model(model_name: str, hf_token: str = None):
165
  fn = client.text_to_image
166
 
167
  def query_huggingface_inference_endpoints(*data, **kwargs):
168
- return fn(*data, **kwargs)
 
 
 
 
 
 
 
169
 
170
  interface_info = {
171
  "fn": query_huggingface_inference_endpoints,
@@ -370,6 +380,7 @@ def infer_body(client: InferenceClient | gr.Interface | object, prompt: str, neg
370
  elif isinstance(client, gr.Interface):
371
  image = client.fn(prompt=prompt, negative_prompt=neg_prompt, **kwargs, token=HF_TOKEN)
372
  else: return None
 
373
  image.save(png_path)
374
  return str(Path(png_path).resolve())
375
  except Exception as e:
 
143
 
144
  # https://github.com/gradio-app/gradio/blob/main/gradio/external.py
145
  # https://huggingface.co/docs/huggingface_hub/package_reference/inference_client
146
+ from typing import Literal
147
+ def load_from_model(model_name: str, hf_token: str | Literal[False] | None = None):
148
  import httpx
149
  import huggingface_hub
150
+ from gradio.exceptions import ModelNotFoundError, TooManyRequestsError
151
  model_url = f"https://huggingface.co/{model_name}"
152
  api_url = f"https://api-inference.huggingface.co/models/{model_name}"
153
  print(f"Fetching model from: {model_url}")
154
 
155
+ headers = ({} if hf_token in [False, None] else {"Authorization": f"Bearer {hf_token}"})
156
  response = httpx.request("GET", api_url, headers=headers)
157
  if response.status_code != 200:
158
  raise ModelNotFoundError(
159
  f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `hf_token` parameter."
160
  )
161
+ p = response.json().get("pipeline_tag")
162
+ if p != "text-to-image": raise ModelNotFoundError(f"This model isn't for text-to-image or unsupported: {model_name}.")
163
  headers["X-Wait-For-Model"] = "true"
164
  client = huggingface_hub.InferenceClient(model=model_name, headers=headers,
165
  token=hf_token, timeout=server_timeout)
 
168
  fn = client.text_to_image
169
 
170
  def query_huggingface_inference_endpoints(*data, **kwargs):
171
+ try:
172
+ data = fn(*data, **kwargs) # type: ignore
173
+ except huggingface_hub.utils.HfHubHTTPError as e:
174
+ if "429" in str(e):
175
+ raise TooManyRequestsError() from e
176
+ except Exception as e:
177
+ raise Exception(e)
178
+ return data
179
 
180
  interface_info = {
181
  "fn": query_huggingface_inference_endpoints,
 
380
  elif isinstance(client, gr.Interface):
381
  image = client.fn(prompt=prompt, negative_prompt=neg_prompt, **kwargs, token=HF_TOKEN)
382
  else: return None
383
+ if isinstance(image, tuple): return None
384
  image.save(png_path)
385
  return str(Path(png_path).resolve())
386
  except Exception as e: