John6666 commited on
Commit
ac5827a
Β·
verified Β·
1 Parent(s): fb8d74d

Upload multit2i.py

Browse files
Files changed (1) hide show
  1. multit2i.py +5 -5
multit2i.py CHANGED
@@ -104,7 +104,7 @@ def save_gallery_images(images, progress=gr.Progress(track_tqdm=True)):
104
  def load_model(model_name: str):
105
  global loaded_models
106
  global model_info_dict
107
- if model_name in loaded_models.keys(): return model_name
108
  try:
109
  loaded_models[model_name] = gr.load(f'models/{model_name}')
110
  print(f"Loaded: {model_name}")
@@ -112,13 +112,13 @@ def load_model(model_name: str):
112
  if model_name in loaded_models.keys(): del loaded_models[model_name]
113
  print(f"Failed to load: {model_name}")
114
  print(e)
115
- return ""
116
  try:
117
  model_info_dict[model_name] = get_t2i_model_info_dict(model_name)
118
  except Exception as e:
119
  if model_name in model_info_dict.keys(): del model_info_dict[model_name]
120
  print(e)
121
- return model_name
122
 
123
 
124
  async def async_load_models(models: list, limit: int=5):
@@ -163,12 +163,12 @@ def infer(prompt: str, model_name: str, recom_prompt: bool, progress=gr.Progress
163
  caption = model_name.split("/")[-1]
164
  try:
165
  model = load_model(model_name)
166
- if not model: return ("", None)
167
  image_path = model(prompt + rprompt + seed)
168
  image = Image.open(image_path).convert('RGB')
169
  except Exception as e:
170
  print(e)
171
- return ("", None)
172
  return (image, caption)
173
 
174
 
 
104
  def load_model(model_name: str):
105
  global loaded_models
106
  global model_info_dict
107
+ if model_name in loaded_models.keys(): return loaded_models[model_name]
108
  try:
109
  loaded_models[model_name] = gr.load(f'models/{model_name}')
110
  print(f"Loaded: {model_name}")
 
112
  if model_name in loaded_models.keys(): del loaded_models[model_name]
113
  print(f"Failed to load: {model_name}")
114
  print(e)
115
+ return None
116
  try:
117
  model_info_dict[model_name] = get_t2i_model_info_dict(model_name)
118
  except Exception as e:
119
  if model_name in model_info_dict.keys(): del model_info_dict[model_name]
120
  print(e)
121
+ return loaded_models[model_name]
122
 
123
 
124
  async def async_load_models(models: list, limit: int=5):
 
163
  caption = model_name.split("/")[-1]
164
  try:
165
  model = load_model(model_name)
166
+ if not model: return (Image.Image(), None)
167
  image_path = model(prompt + rprompt + seed)
168
  image = Image.open(image_path).convert('RGB')
169
  except Exception as e:
170
  print(e)
171
+ return (Image.Image(), None)
172
  return (image, caption)
173
 
174