John6666 commited on
Commit
c5a3258
Β·
verified Β·
1 Parent(s): 278e75f

Upload multit2i.py

Browse files
Files changed (1) hide show
  1. multit2i.py +10 -8
multit2i.py CHANGED
@@ -3,8 +3,10 @@ import asyncio
3
  from threading import RLock
4
  from pathlib import Path
5
  from huggingface_hub import InferenceClient
 
6
 
7
 
 
8
  server_timeout = 600
9
  inference_timeout = 300
10
 
@@ -38,14 +40,14 @@ def find_model_list(author: str="", tags: list[str]=[], not_tag="", sort: str="l
38
  if not sort: sort = "last_modified"
39
  models = []
40
  try:
41
- model_infos = api.list_models(author=author, pipeline_tag="text-to-image",
42
  tags=list_uniq(default_tags + tags), cardData=True, sort=sort, limit=limit * 5)
43
  except Exception as e:
44
  print(f"Error: Failed to list models.")
45
  print(e)
46
  return models
47
  for model in model_infos:
48
- if not model.private and not model.gated:
49
  if not_tag and not_tag in model.tags: continue
50
  models.append(model.id)
51
  if len(models) == limit: break
@@ -58,7 +60,7 @@ def get_t2i_model_info_dict(repo_id: str):
58
  info = {"md": "None"}
59
  try:
60
  if not is_repo_name(repo_id) or not api.repo_exists(repo_id=repo_id): return info
61
- model = api.model_info(repo_id=repo_id)
62
  except Exception as e:
63
  print(f"Error: Failed to get {repo_id}'s info.")
64
  print(e)
@@ -156,7 +158,7 @@ def load_model(model_name: str):
156
  global model_info_dict
157
  if model_name in loaded_models.keys(): return loaded_models[model_name]
158
  try:
159
- loaded_models[model_name] = load_from_model(model_name)
160
  print(f"Loaded: {model_name}")
161
  except Exception as e:
162
  if model_name in loaded_models.keys(): del loaded_models[model_name]
@@ -179,12 +181,12 @@ def load_model_api(model_name: str):
179
  if model_name in loaded_models.keys(): return loaded_models[model_name]
180
  try:
181
  client = InferenceClient(timeout=5)
182
- status = client.get_model_status(model_name)
183
  if status is None or status.framework != "diffusers" or status.state not in ["Loadable", "Loaded"]:
184
  print(f"Failed to load by API: {model_name}")
185
  return None
186
  else:
187
- loaded_models[model_name] = InferenceClient(model_name, timeout=server_timeout)
188
  print(f"Loaded by API: {model_name}")
189
  except Exception as e:
190
  if model_name in loaded_models.keys(): del loaded_models[model_name]
@@ -340,9 +342,9 @@ def infer_body(client: InferenceClient | gr.Interface, prompt: str, neg_prompt:
340
  if cfg is not None and cfg > 0: cfg = kwargs["guidance_scale"] = cfg
341
  try:
342
  if isinstance(client, InferenceClient):
343
- image = client.text_to_image(prompt=prompt, negative_prompt=neg_prompt, **kwargs)
344
  elif isinstance(client, gr.Interface):
345
- image = client.fn(prompt=prompt, negative_prompt=neg_prompt, **kwargs)
346
  else: return None
347
  image.save(png_path)
348
  return str(Path(png_path).resolve())
 
3
  from threading import RLock
4
  from pathlib import Path
5
  from huggingface_hub import InferenceClient
6
+ import os
7
 
8
 
9
+ HF_TOKEN = os.environ.get("HF_TOKEN") if os.environ.get("HF_TOKEN") else None
10
  server_timeout = 600
11
  inference_timeout = 300
12
 
 
40
  if not sort: sort = "last_modified"
41
  models = []
42
  try:
43
+ model_infos = api.list_models(author=author, pipeline_tag="text-to-image", token=HF_TOKEN,
44
  tags=list_uniq(default_tags + tags), cardData=True, sort=sort, limit=limit * 5)
45
  except Exception as e:
46
  print(f"Error: Failed to list models.")
47
  print(e)
48
  return models
49
  for model in model_infos:
50
+ if not model.private and not model.gated and HF_TOKEN is None:
51
  if not_tag and not_tag in model.tags: continue
52
  models.append(model.id)
53
  if len(models) == limit: break
 
60
  info = {"md": "None"}
61
  try:
62
  if not is_repo_name(repo_id) or not api.repo_exists(repo_id=repo_id): return info
63
+ model = api.model_info(repo_id=repo_id, token=HF_TOKEN)
64
  except Exception as e:
65
  print(f"Error: Failed to get {repo_id}'s info.")
66
  print(e)
 
158
  global model_info_dict
159
  if model_name in loaded_models.keys(): return loaded_models[model_name]
160
  try:
161
+ loaded_models[model_name] = load_from_model(model_name, hf_token=HF_TOKEN)
162
  print(f"Loaded: {model_name}")
163
  except Exception as e:
164
  if model_name in loaded_models.keys(): del loaded_models[model_name]
 
181
  if model_name in loaded_models.keys(): return loaded_models[model_name]
182
  try:
183
  client = InferenceClient(timeout=5)
184
+ status = client.get_model_status(model_name, token=HF_TOKEN)
185
  if status is None or status.framework != "diffusers" or status.state not in ["Loadable", "Loaded"]:
186
  print(f"Failed to load by API: {model_name}")
187
  return None
188
  else:
189
+ loaded_models[model_name] = InferenceClient(model_name, token=HF_TOKEN, timeout=server_timeout)
190
  print(f"Loaded by API: {model_name}")
191
  except Exception as e:
192
  if model_name in loaded_models.keys(): del loaded_models[model_name]
 
342
  if cfg is not None and cfg > 0: cfg = kwargs["guidance_scale"] = cfg
343
  try:
344
  if isinstance(client, InferenceClient):
345
+ image = client.text_to_image(prompt=prompt, negative_prompt=neg_prompt, **kwargs, token=HF_TOKEN)
346
  elif isinstance(client, gr.Interface):
347
+ image = client.fn(prompt=prompt, negative_prompt=neg_prompt, **kwargs, token=HF_TOKEN)
348
  else: return None
349
  image.save(png_path)
350
  return str(Path(png_path).resolve())