John6666 commited on
Commit
c87a258
Β·
verified Β·
1 Parent(s): 2956d04

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +13 -6
  2. multit2i.py +86 -18
app.py CHANGED
@@ -50,13 +50,18 @@ with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", fill_width=True, css=css) as demo:
50
  clear_prompt = gr.Button(value="Clear Prompt πŸ—‘οΈ", size="sm", scale=1)
51
  prompt = gr.Text(label="Prompt", lines=2, max_lines=8, placeholder="1girl, solo, ...", show_copy_button=True)
52
  neg_prompt = gr.Text(label="Negative Prompt", lines=1, max_lines=8, placeholder="", visible=False)
 
 
 
 
 
53
  with gr.Accordion("Recommended Prompt", open=False):
54
  recom_prompt_preset = gr.Radio(label="Set Presets", choices=get_recom_prompt_type(), value="Common")
55
  with gr.Row():
56
  positive_prefix = gr.CheckboxGroup(label="Use Positive Prefix", choices=get_positive_prefix(), value=[])
57
  positive_suffix = gr.CheckboxGroup(label="Use Positive Suffix", choices=get_positive_suffix(), value=["Common"])
58
- negative_prefix = gr.CheckboxGroup(label="Use Negative Prefix", choices=get_negative_prefix(), value=[], visible=False)
59
- negative_suffix = gr.CheckboxGroup(label="Use Negative Suffix", choices=get_negative_suffix(), value=["Common"], visible=False)
60
  with gr.Accordion("Prompt Transformer", open=False):
61
  v2_rating = gr.Radio(label="Rating", choices=list(V2_RATING_OPTIONS), value="sfw")
62
  v2_aspect_ratio = gr.Radio(label="Aspect ratio", info="The aspect ratio of the image.", choices=list(V2_ASPECT_RATIO_OPTIONS), value="square", visible=False)
@@ -109,12 +114,14 @@ with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", fill_width=True, css=css) as demo:
109
  img_i = gr.Number(i, visible=False)
110
  image_num.change(lambda i, n: gr.update(visible = (i < n)), [img_i, image_num], o, show_api=False)
111
  gen_event = gr.on(triggers=[run_button.click, prompt.submit],
112
- fn=lambda i, n, m, t1, t2, l1, l2, l3, l4: infer_fn(m, t1, t2, l1, l2, l3, l4) if (i < n) else None,
113
- inputs=[img_i, image_num, model_name, prompt, neg_prompt, positive_prefix, positive_suffix, negative_prefix, negative_suffix],
 
114
  outputs=[o], queue=True, show_api=False)
115
  gen_event2 = gr.on(triggers=[random_button.click],
116
- fn=lambda i, n, m, t1, t2, l1, l2, l3, l4: infer_rand_fn(m, t1, t2, l1, l2, l3, l4) if (i < n) else None,
117
- inputs=[img_i, image_num, model_name, prompt, neg_prompt, positive_prefix, positive_suffix, negative_prefix, negative_suffix],
 
118
  outputs=[o], queue=True, show_api=False)
119
  o.change(save_gallery, [o, results], [results, image_files], show_api=False)
120
  stop_button.click(lambda: gr.update(interactive=False), None, stop_button, cancels=[gen_event, gen_event2], show_api=False)
 
50
  clear_prompt = gr.Button(value="Clear Prompt πŸ—‘οΈ", size="sm", scale=1)
51
  prompt = gr.Text(label="Prompt", lines=2, max_lines=8, placeholder="1girl, solo, ...", show_copy_button=True)
52
  neg_prompt = gr.Text(label="Negative Prompt", lines=1, max_lines=8, placeholder="", visible=False)
53
+ with gr.Accordion("Advanced options", open=False):
54
+ width = gr.Number(label="Width", info="If 0, the default value is used.", maximum=1216, step=32, value=None)
55
+ height = gr.Number(label="Height", info="If 0, the default value is used.", maximum=1216, step=32, value=None)
56
+ steps = gr.Number(label="Number of inference steps", info="If 0, the default value is used.", maximum=100, step=1, value=None)
57
+ cfg = gr.Number(label="Guidance scale", info="If 0, the default value is used.", maximum=30.0, step=0.1, value=None)
58
  with gr.Accordion("Recommended Prompt", open=False):
59
  recom_prompt_preset = gr.Radio(label="Set Presets", choices=get_recom_prompt_type(), value="Common")
60
  with gr.Row():
61
  positive_prefix = gr.CheckboxGroup(label="Use Positive Prefix", choices=get_positive_prefix(), value=[])
62
  positive_suffix = gr.CheckboxGroup(label="Use Positive Suffix", choices=get_positive_suffix(), value=["Common"])
63
+ negative_prefix = gr.CheckboxGroup(label="Use Negative Prefix", choices=get_negative_prefix(), value=[])
64
+ negative_suffix = gr.CheckboxGroup(label="Use Negative Suffix", choices=get_negative_suffix(), value=["Common"])
65
  with gr.Accordion("Prompt Transformer", open=False):
66
  v2_rating = gr.Radio(label="Rating", choices=list(V2_RATING_OPTIONS), value="sfw")
67
  v2_aspect_ratio = gr.Radio(label="Aspect ratio", info="The aspect ratio of the image.", choices=list(V2_ASPECT_RATIO_OPTIONS), value="square", visible=False)
 
114
  img_i = gr.Number(i, visible=False)
115
  image_num.change(lambda i, n: gr.update(visible = (i < n)), [img_i, image_num], o, show_api=False)
116
  gen_event = gr.on(triggers=[run_button.click, prompt.submit],
117
+ fn=lambda i, n, m, t1, t2, n1, n2, n3, n4, l1, l2, l3, l4: infer_fn(m, t1, t2, n1, n2, n3, n4, l1, l2, l3, l4) if (i < n) else None,
118
+ inputs=[img_i, image_num, model_name, prompt, neg_prompt, height, width, steps, cfg,
119
+ positive_prefix, positive_suffix, negative_prefix, negative_suffix],
120
  outputs=[o], queue=True, show_api=False)
121
  gen_event2 = gr.on(triggers=[random_button.click],
122
+ fn=lambda i, n, m, t1, t2, n1, n2, n3, n4, l1, l2, l3, l4: infer_rand_fn(m, t1, t2, n1, n2, n3, n4, l1, l2, l3, l4) if (i < n) else None,
123
+ inputs=[img_i, image_num, model_name, prompt, neg_prompt, height, width, steps, cfg,
124
+ positive_prefix, positive_suffix, negative_prefix, negative_suffix],
125
  outputs=[o], queue=True, show_api=False)
126
  o.change(save_gallery, [o, results], [results, image_files], show_api=False)
127
  stop_button.click(lambda: gr.update(interactive=False), None, stop_button, cancels=[gen_event, gen_event2], show_api=False)
multit2i.py CHANGED
@@ -2,6 +2,11 @@ import gradio as gr
2
  import asyncio
3
  from threading import RLock
4
  from pathlib import Path
 
 
 
 
 
5
 
6
 
7
  lock = RLock()
@@ -80,7 +85,7 @@ def get_t2i_model_info_dict(repo_id: str):
80
  return info
81
 
82
 
83
- def rename_image(image_path: str | None, model_name: str):
84
  from PIL import Image
85
  from datetime import datetime, timezone, timedelta
86
  if image_path is None: return None
@@ -90,7 +95,10 @@ def rename_image(image_path: str | None, model_name: str):
90
  if Path(image_path).exists():
91
  png_path = "image.png"
92
  Image.open(image_path).convert('RGBA').save(png_path, "PNG")
93
- new_path = str(Path(png_path).resolve().rename(Path(filename).resolve()))
 
 
 
94
  return new_path
95
  else:
96
  return None
@@ -125,13 +133,14 @@ def load_from_model(model_name: str, hf_token: str = None):
125
  f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `hf_token` parameter."
126
  )
127
  headers["X-Wait-For-Model"] = "true"
128
- client = huggingface_hub.InferenceClient(model=model_name, headers=headers, token=hf_token, timeout=600)
 
129
  inputs = gr.components.Textbox(label="Input")
130
  outputs = gr.components.Image(label="Output")
131
  fn = client.text_to_image
132
 
133
- def query_huggingface_inference_endpoints(*data):
134
- return fn(*data)
135
 
136
  interface_info = {
137
  "fn": query_huggingface_inference_endpoints,
@@ -164,6 +173,34 @@ def load_model(model_name: str):
164
  return loaded_models[model_name]
165
 
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  def load_models(models: list):
168
  for model in models:
169
  load_model(model)
@@ -276,21 +313,48 @@ def get_model_info_md(model_name: str):
276
 
277
 
278
  def change_model(model_name: str):
279
- load_model(model_name)
280
  return get_model_info_md(model_name)
281
 
282
 
283
  def warm_model(model_name: str):
284
- model = load_model(model_name)
285
  if model:
286
  try:
287
  print(f"Warming model: {model_name}")
288
- model(" ")
289
  except Exception as e:
290
  print(e)
291
 
292
 
293
- async def infer(model_name: str, prompt: str, neg_prompt: str, timeout: float):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  import random
295
  noise = ""
296
  rand = random.randint(1, 500)
@@ -298,7 +362,8 @@ async def infer(model_name: str, prompt: str, neg_prompt: str, timeout: float):
298
  noise += " "
299
  model = load_model(model_name)
300
  if not model: return None
301
- task = asyncio.create_task(asyncio.to_thread(model, f'{prompt} {noise}'))
 
302
  await asyncio.sleep(0)
303
  try:
304
  result = await asyncio.wait_for(task, timeout=timeout)
@@ -309,20 +374,21 @@ async def infer(model_name: str, prompt: str, neg_prompt: str, timeout: float):
309
  result = None
310
  if task.done() and result is not None:
311
  with lock:
312
- image = rename_image(result, model_name)
313
  return image
314
  return None
315
 
316
 
317
- infer_timeout = 300
318
- def infer_fn(model_name: str, prompt: str, neg_prompt: str,
319
- pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = []):
320
  if model_name == 'NA':
321
  return None
322
  try:
323
  prompt, neg_prompt = recom_prompt(prompt, neg_prompt, pos_pre, pos_suf, neg_pre, neg_suf)
324
  loop = asyncio.new_event_loop()
325
- result = loop.run_until_complete(infer(model_name, prompt, neg_prompt, infer_timeout))
 
326
  except (Exception, asyncio.CancelledError) as e:
327
  print(e)
328
  print(f"Task aborted: {model_name}")
@@ -332,8 +398,9 @@ def infer_fn(model_name: str, prompt: str, neg_prompt: str,
332
  return result
333
 
334
 
335
- def infer_rand_fn(model_name_dummy: str, prompt: str, neg_prompt: str,
336
- pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = []):
 
337
  import random
338
  if model_name_dummy == 'NA':
339
  return None
@@ -342,7 +409,8 @@ def infer_rand_fn(model_name_dummy: str, prompt: str, neg_prompt: str,
342
  try:
343
  prompt, neg_prompt = recom_prompt(prompt, neg_prompt, pos_pre, pos_suf, neg_pre, neg_suf)
344
  loop = asyncio.new_event_loop()
345
- result = loop.run_until_complete(infer(model_name, prompt, neg_prompt, infer_timeout))
 
346
  except (Exception, asyncio.CancelledError) as e:
347
  print(e)
348
  print(f"Task aborted: {model_name}")
 
2
  import asyncio
3
  from threading import RLock
4
  from pathlib import Path
5
+ from huggingface_hub import InferenceClient
6
+
7
+
8
+ server_timeout = 600
9
+ inference_timeout = 300
10
 
11
 
12
  lock = RLock()
 
85
  return info
86
 
87
 
88
+ def rename_image(image_path: str | None, model_name: str, save_path: str | None = None):
89
  from PIL import Image
90
  from datetime import datetime, timezone, timedelta
91
  if image_path is None: return None
 
95
  if Path(image_path).exists():
96
  png_path = "image.png"
97
  Image.open(image_path).convert('RGBA').save(png_path, "PNG")
98
+ if save_path is not None:
99
+ new_path = str(Path(png_path).resolve().rename(Path(save_path).resolve()))
100
+ else:
101
+ new_path = str(Path(png_path).resolve().rename(Path(filename).resolve()))
102
  return new_path
103
  else:
104
  return None
 
133
  f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `hf_token` parameter."
134
  )
135
  headers["X-Wait-For-Model"] = "true"
136
+ client = huggingface_hub.InferenceClient(model=model_name, headers=headers,
137
+ token=hf_token, timeout=server_timeout)
138
  inputs = gr.components.Textbox(label="Input")
139
  outputs = gr.components.Image(label="Output")
140
  fn = client.text_to_image
141
 
142
+ def query_huggingface_inference_endpoints(*data, **kwargs):
143
+ return fn(*data, **kwargs)
144
 
145
  interface_info = {
146
  "fn": query_huggingface_inference_endpoints,
 
173
  return loaded_models[model_name]
174
 
175
 
176
+ def load_model_api(model_name: str):
177
+ global loaded_models
178
+ global model_info_dict
179
+ if model_name in loaded_models.keys(): return loaded_models[model_name]
180
+ try:
181
+ client = InferenceClient(timeout=5)
182
+ status = client.get_model_status(model_name)
183
+ if status is None or status.framework != "diffusers" and not status.state in ["Loadable", "Loaded"]:
184
+ print(f"Failed to load by API: {model_name}")
185
+ return None
186
+ else:
187
+ loaded_models[model_name] = InferenceClient(model_name, timeout=server_timeout)
188
+ print(f"Loaded by API: {model_name}")
189
+ except Exception as e:
190
+ if model_name in loaded_models.keys(): del loaded_models[model_name]
191
+ print(f"Failed to load by API: {model_name}")
192
+ print(e)
193
+ return None
194
+ try:
195
+ model_info_dict[model_name] = get_t2i_model_info_dict(model_name)
196
+ print(f"Assigned by API: {model_name}")
197
+ except Exception as e:
198
+ if model_name in model_info_dict.keys(): del model_info_dict[model_name]
199
+ print(f"Failed to assigned by API: {model_name}")
200
+ print(e)
201
+ return loaded_models[model_name]
202
+
203
+
204
  def load_models(models: list):
205
  for model in models:
206
  load_model(model)
 
313
 
314
 
315
  def change_model(model_name: str):
316
+ load_model_api(model_name)
317
  return get_model_info_md(model_name)
318
 
319
 
320
  def warm_model(model_name: str):
321
+ model = load_model_api(model_name)
322
  if model:
323
  try:
324
  print(f"Warming model: {model_name}")
325
+ infer_body(model, " ")
326
  except Exception as e:
327
  print(e)
328
 
329
 
330
+ # https://huggingface.co/docs/api-inference/detailed_parameters
331
+ # https://huggingface.co/docs/huggingface_hub/package_reference/inference_client
332
+ def infer_body(client: InferenceClient | gr.Interface, prompt: str, neg_prompt: str | None = None,
333
+ height: int | None = None, width: int | None = None,
334
+ steps: int | None = None, cfg: int | None = None):
335
+ png_path = "image.png"
336
+ kwargs = {}
337
+ if height is not None and height >= 256: kwargs["height"] = height
338
+ if width is not None and width >= 256: kwargs["width"] = width
339
+ if steps is not None and steps >= 1: kwargs["num_inference_steps"] = steps
340
+ if cfg is not None and cfg > 0: cfg = kwargs["guidance_scale"] = cfg
341
+ try:
342
+ if isinstance(client, InferenceClient):
343
+ image = client.text_to_image(prompt=prompt, negative_prompt=neg_prompt, **kwargs)
344
+ elif isinstance(client, gr.Interface):
345
+ image = client.fn(prompt=prompt, negative_prompt=neg_prompt, **kwargs)
346
+ else: return None
347
+ image.save(png_path)
348
+ return str(Path(png_path).resolve())
349
+ except Exception as e:
350
+ print(e)
351
+ return None
352
+
353
+
354
+ async def infer(model_name: str, prompt: str, neg_prompt: str | None = None,
355
+ height: int | None = None, width: int | None = None,
356
+ steps: int | None = None, cfg: int | None = None,
357
+ save_path: str | None = None, timeout: float = inference_timeout):
358
  import random
359
  noise = ""
360
  rand = random.randint(1, 500)
 
362
  noise += " "
363
  model = load_model(model_name)
364
  if not model: return None
365
+ task = asyncio.create_task(asyncio.to_thread(infer_body, model, f"{prompt} {noise}", neg_prompt,
366
+ height, width, steps, cfg))
367
  await asyncio.sleep(0)
368
  try:
369
  result = await asyncio.wait_for(task, timeout=timeout)
 
374
  result = None
375
  if task.done() and result is not None:
376
  with lock:
377
+ image = rename_image(result, model_name, save_path)
378
  return image
379
  return None
380
 
381
 
382
+ def infer_fn(model_name: str, prompt: str, neg_prompt: str | None = None, height: int | None = None,
383
+ width: int | None = None, steps: int | None = None, cfg: int | None = None,
384
+ pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = [], save_path: str | None = None):
385
  if model_name == 'NA':
386
  return None
387
  try:
388
  prompt, neg_prompt = recom_prompt(prompt, neg_prompt, pos_pre, pos_suf, neg_pre, neg_suf)
389
  loop = asyncio.new_event_loop()
390
+ result = loop.run_until_complete(infer(model_name, prompt, neg_prompt, height, width,
391
+ steps, cfg, save_path, inference_timeout))
392
  except (Exception, asyncio.CancelledError) as e:
393
  print(e)
394
  print(f"Task aborted: {model_name}")
 
398
  return result
399
 
400
 
401
+ def infer_rand_fn(model_name_dummy: str, prompt: str, neg_prompt: str | None = None, height: int | None = None,
402
+ width: int | None = None, steps: int | None = None, cfg: int | None = None,
403
+ pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = [], save_path: str | None = None):
404
  import random
405
  if model_name_dummy == 'NA':
406
  return None
 
409
  try:
410
  prompt, neg_prompt = recom_prompt(prompt, neg_prompt, pos_pre, pos_suf, neg_pre, neg_suf)
411
  loop = asyncio.new_event_loop()
412
+ result = loop.run_until_complete(infer(model_name, prompt, neg_prompt, height, width,
413
+ steps, cfg, save_path, inference_timeout))
414
  except (Exception, asyncio.CancelledError) as e:
415
  print(e)
416
  print(f"Task aborted: {model_name}")