Spaces:
Running
Running
Bbmyy
commited on
Commit
·
c92c0ec
1
Parent(s):
ec0378d
first commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +5 -5
- app.py +90 -0
- ksort-logs/vote_log/gr_web_image_editing.log +0 -0
- ksort-logs/vote_log/gr_web_image_editing_multi.log +0 -0
- ksort-logs/vote_log/gr_web_image_generation.log +811 -0
- ksort-logs/vote_log/gr_web_image_generation_multi.log +6 -0
- ksort-logs/vote_log/gr_web_video_generation.log +0 -0
- ksort-logs/vote_log/gr_web_video_generation_multi.log +0 -0
- model/__init__.py +0 -0
- model/__pycache__/__init__.cpython-310.pyc +0 -0
- model/__pycache__/__init__.cpython-312.pyc +0 -0
- model/__pycache__/__init__.cpython-39.pyc +0 -0
- model/__pycache__/matchmaker.cpython-310.pyc +0 -0
- model/__pycache__/model_manager.cpython-310.pyc +0 -0
- model/__pycache__/model_registry.cpython-310.pyc +0 -0
- model/__pycache__/model_registry.cpython-312.pyc +0 -0
- model/__pycache__/model_registry.cpython-39.pyc +0 -0
- model/matchmaker.py +126 -0
- model/matchmaker_video.py +136 -0
- model/model_manager.py +239 -0
- model/model_registry.py +70 -0
- model/models/__init__.py +83 -0
- model/models/__pycache__/__init__.cpython-310.pyc +0 -0
- model/models/__pycache__/huggingface_models.cpython-310.pyc +0 -0
- model/models/__pycache__/local_models.cpython-310.pyc +0 -0
- model/models/__pycache__/openai_api_models.cpython-310.pyc +0 -0
- model/models/__pycache__/other_api_models.cpython-310.pyc +0 -0
- model/models/__pycache__/replicate_api_models.cpython-310.pyc +0 -0
- model/models/huggingface_models.py +65 -0
- model/models/local_models.py +16 -0
- model/models/openai_api_models.py +57 -0
- model/models/other_api_models.py +91 -0
- model/models/replicate_api_models.py +195 -0
- model_bbox/.gradio/certificate.pem +31 -0
- model_bbox/MIGC/__init__.py +0 -0
- model_bbox/MIGC/__pycache__/__init__.cpython-310.pyc +0 -0
- model_bbox/MIGC/__pycache__/inference_single_image.cpython-310.pyc +0 -0
- model_bbox/MIGC/inference_single_image.py +193 -0
- model_bbox/MIGC/migc/__init__.py +0 -0
- model_bbox/MIGC/migc/__pycache__/__init__.cpython-310.pyc +0 -0
- model_bbox/MIGC/migc/__pycache__/migc_arch.cpython-310.pyc +0 -0
- model_bbox/MIGC/migc/__pycache__/migc_layers.cpython-310.pyc +0 -0
- model_bbox/MIGC/migc/__pycache__/migc_pipeline.cpython-310.pyc +0 -0
- model_bbox/MIGC/migc/__pycache__/migc_utils.cpython-310.pyc +0 -0
- model_bbox/MIGC/migc/migc_arch.py +220 -0
- model_bbox/MIGC/migc/migc_layers.py +241 -0
- model_bbox/MIGC/migc/migc_pipeline.py +928 -0
- model_bbox/MIGC/migc/migc_utils.py +143 -0
- model_bbox/MIGC/pretrained_weights/MIGC_SD14.ckpt +3 -0
- model_bbox/MIGC/pretrained_weights/PUT_MIGC_CKPT_HERE +0 -0
README.md
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
---
|
2 |
title: Control Ability Arena
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 5.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: Control Ability Arena
|
3 |
+
emoji: 🖼
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: red
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 5.0.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
from serve.gradio_web import *
|
4 |
+
from serve.gradio_web_bbox import build_side_by_side_bbox_ui_anony
|
5 |
+
from serve.leaderboard import build_leaderboard_tab, build_leaderboard_video_tab, build_leaderboard_contributor
|
6 |
+
from model.model_manager import ModelManager
|
7 |
+
from pathlib import Path
|
8 |
+
from serve.constants import SERVER_PORT, ROOT_PATH, ELO_RESULTS_DIR
|
9 |
+
|
10 |
+
|
11 |
+
def make_default_md():
|
12 |
+
link_color = "#1976D2" # This color should be clear in both light and dark mode
|
13 |
+
leaderboard_md = f"""
|
14 |
+
# 🏅 Control-Ability-Arena: ...
|
15 |
+
### [Paper]... | [Twitter]...
|
16 |
+
- ⚡ For vision tasks, K-wise comparisons can provide much richer info but only take similar time as pairwise comparisons.
|
17 |
+
- 🎯 Well designed matchmaking algorithm can further save human efforts than random match pairing in normal Arena.
|
18 |
+
- 📈 Probabilistic modeling can obtain a faster and more stable convergence than Elo scoring system.
|
19 |
+
"""
|
20 |
+
|
21 |
+
return leaderboard_md
|
22 |
+
|
23 |
+
|
24 |
+
def build_combine_demo(models):
|
25 |
+
with gr.Blocks(
|
26 |
+
title="Play with Open Vision Models",
|
27 |
+
theme=gr.themes.Default(),
|
28 |
+
css=block_css,
|
29 |
+
) as demo:
|
30 |
+
|
31 |
+
with gr.Blocks():
|
32 |
+
md = make_default_md()
|
33 |
+
md_default = gr.Markdown(md, elem_id="default_leaderboard_markdown")
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
with gr.Tabs() as tabs_combine:
|
38 |
+
# with gr.Tab("Image Generation", id=0):
|
39 |
+
# with gr.Tabs() as tabs_ig:
|
40 |
+
# # with gr.Tab("Generation Leaderboard", id=0):
|
41 |
+
# # build_leaderboard_tab()
|
42 |
+
# with gr.Tab("Generation Arena (battle)", id=1):
|
43 |
+
# build_side_by_side_ui_anony(models)
|
44 |
+
|
45 |
+
with gr.Tab("BBox-to-Image Generation", id=0):
|
46 |
+
with gr.Tabs() as tabs_ig:
|
47 |
+
# with gr.Tab("Generation Leaderboard", id=0):
|
48 |
+
# build_leaderboard_tab()
|
49 |
+
with gr.Tab("Generation Arena (battle)", id=1):
|
50 |
+
build_side_by_side_bbox_ui_anony(models)
|
51 |
+
|
52 |
+
# with gr.Tab("Contributor", id=2):
|
53 |
+
# build_leaderboard_contributor()
|
54 |
+
|
55 |
+
return demo
|
56 |
+
|
57 |
+
|
58 |
+
def load_elo_results(elo_results_dir):
|
59 |
+
from collections import defaultdict
|
60 |
+
elo_results_file = defaultdict(lambda: None)
|
61 |
+
leaderboard_table_file = defaultdict(lambda: None)
|
62 |
+
|
63 |
+
if elo_results_dir is not None:
|
64 |
+
elo_results_dir = Path(elo_results_dir)
|
65 |
+
elo_results_file = {}
|
66 |
+
leaderboard_table_file = {}
|
67 |
+
for file in elo_results_dir.glob('elo_results_*.pkl'):
|
68 |
+
if 't2i_generation' in file.name:
|
69 |
+
elo_results_file['t2i_generation'] = file
|
70 |
+
# else:
|
71 |
+
# raise ValueError(f"Unknown file name: {file.name}")
|
72 |
+
for file in elo_results_dir.glob('*_leaderboard.csv'):
|
73 |
+
if 't2i_generation' in file.name:
|
74 |
+
leaderboard_table_file['t2i_generation'] = file
|
75 |
+
# else:
|
76 |
+
# raise ValueError(f"Unknown file name: {file.name}")
|
77 |
+
|
78 |
+
return elo_results_file, leaderboard_table_file
|
79 |
+
|
80 |
+
if __name__ == "__main__":
|
81 |
+
server_port = int(SERVER_PORT)
|
82 |
+
root_path = ROOT_PATH
|
83 |
+
elo_results_dir = ELO_RESULTS_DIR
|
84 |
+
models = ModelManager()
|
85 |
+
|
86 |
+
# elo_results_file, leaderboard_table_file = load_elo_results(elo_results_dir)
|
87 |
+
demo = build_combine_demo(models)
|
88 |
+
demo.queue(max_size=20).launch(server_port=server_port, root_path=ROOT_PATH, share=True)
|
89 |
+
|
90 |
+
# demo.launch(server_name="0.0.0.0", server_port=7860, root_path=ROOT_PATH)
|
ksort-logs/vote_log/gr_web_image_editing.log
ADDED
File without changes
|
ksort-logs/vote_log/gr_web_image_editing_multi.log
ADDED
File without changes
|
ksort-logs/vote_log/gr_web_image_generation.log
ADDED
@@ -0,0 +1,811 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2024-12-24 12:54:21 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py:34: FutureWarning: `Transformer2DModelOutput` is deprecated and will be removed in version 1.0.0. Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead.
|
2 |
+
2024-12-24 12:54:21 | ERROR | stderr | deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)
|
3 |
+
2024-12-24 12:54:24 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py:1003: UserWarning: Expected 12 arguments for function functools.partial(<function generate_igm_annoy at 0x7f5da264b1c0>, <bound method ModelManager.generate_image_b2i_parallel_anony of <model.model_manager.ModelManager object at 0x7f5febc53d60>>), received 11.
|
4 |
+
2024-12-24 12:54:24 | ERROR | stderr | warnings.warn(
|
5 |
+
2024-12-24 12:54:24 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py:1007: UserWarning: Expected at least 12 arguments for function functools.partial(<function generate_igm_annoy at 0x7f5da264b1c0>, <bound method ModelManager.generate_image_b2i_parallel_anony of <model.model_manager.ModelManager object at 0x7f5febc53d60>>), received 11.
|
6 |
+
2024-12-24 12:54:24 | ERROR | stderr | warnings.warn(
|
7 |
+
2024-12-24 12:54:24 | INFO | stdout | * Running on local URL: http://127.0.0.1:7860
|
8 |
+
2024-12-24 12:54:52 | INFO | stdout | background.shape (600, 600, 4)
|
9 |
+
2024-12-24 12:54:52 | INFO | stdout | len(layers) 1
|
10 |
+
2024-12-24 12:54:52 | INFO | stdout | composite.shape (600, 600, 4)
|
11 |
+
2024-12-24 12:54:55 | INFO | stdout | background.shape (600, 600, 4)
|
12 |
+
2024-12-24 12:54:55 | INFO | stdout | len(layers) 1
|
13 |
+
2024-12-24 12:54:55 | INFO | stdout | composite.shape (600, 600, 4)
|
14 |
+
2024-12-24 12:54:56 | INFO | stdout | background.shape (600, 600, 4)
|
15 |
+
2024-12-24 12:54:56 | INFO | stdout | len(layers) 1
|
16 |
+
2024-12-24 12:54:56 | INFO | stdout | composite.shape (600, 600, 4)
|
17 |
+
2024-12-24 12:54:58 | INFO | stdout | background.shape (600, 600, 4)
|
18 |
+
2024-12-24 12:54:58 | INFO | stdout | len(layers) 1
|
19 |
+
2024-12-24 12:54:58 | INFO | stdout | composite.shape (600, 600, 4)
|
20 |
+
2024-12-24 12:55:00 | INFO | stdout |
|
21 |
+
2024-12-24 12:55:00 | INFO | stdout | Could not create share link. Missing file: /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/frpc_linux_amd64_v0.3.
|
22 |
+
2024-12-24 12:55:00 | INFO | stdout |
|
23 |
+
2024-12-24 12:55:00 | INFO | stdout | Please check your internet connection. This can happen if your antivirus software blocks the download of this file. You can install manually by following these steps:
|
24 |
+
2024-12-24 12:55:00 | INFO | stdout |
|
25 |
+
2024-12-24 12:55:00 | INFO | stdout | 1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64
|
26 |
+
2024-12-24 12:55:00 | INFO | stdout | 2. Rename the downloaded file to: frpc_linux_amd64_v0.3
|
27 |
+
2024-12-24 12:55:00 | INFO | stdout | 3. Move the file to this location: /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio
|
28 |
+
2024-12-24 12:55:04 | INFO | stdout | background.shape (600, 600, 4)
|
29 |
+
2024-12-24 12:55:04 | INFO | stdout | len(layers) 1
|
30 |
+
2024-12-24 12:55:04 | INFO | stdout | composite.shape (600, 600, 4)
|
31 |
+
2024-12-24 12:55:06 | INFO | stdout | background.shape (600, 600, 4)
|
32 |
+
2024-12-24 12:55:06 | INFO | stdout | len(layers) 1
|
33 |
+
2024-12-24 12:55:06 | INFO | stdout | composite.shape (600, 600, 4)
|
34 |
+
2024-12-24 12:55:09 | INFO | stdout | background.shape (600, 600, 4)
|
35 |
+
2024-12-24 12:55:09 | INFO | stdout | len(layers) 1
|
36 |
+
2024-12-24 12:55:09 | INFO | stdout | composite.shape (600, 600, 4)
|
37 |
+
2024-12-24 12:55:14 | INFO | stdout | background.shape (600, 600, 4)
|
38 |
+
2024-12-24 12:55:14 | INFO | stdout | len(layers) 1
|
39 |
+
2024-12-24 12:55:14 | INFO | stdout | composite.shape (600, 600, 4)
|
40 |
+
2024-12-24 12:55:21 | INFO | stdout | background.shape (600, 600, 4)
|
41 |
+
2024-12-24 12:55:21 | INFO | stdout | len(layers) 1
|
42 |
+
2024-12-24 12:55:21 | INFO | stdout | composite.shape (600, 600, 4)
|
43 |
+
2024-12-24 12:55:21 | INFO | stdout | background.shape (600, 600, 4)
|
44 |
+
2024-12-24 12:55:21 | INFO | stdout | len(layers) 1
|
45 |
+
2024-12-24 12:55:21 | INFO | stdout | composite.shape (600, 600, 4)
|
46 |
+
2024-12-24 12:55:25 | INFO | stdout | background.shape (600, 600, 4)
|
47 |
+
2024-12-24 12:55:25 | INFO | stdout | len(layers) 1
|
48 |
+
2024-12-24 12:55:25 | INFO | stdout | composite.shape (600, 600, 4)
|
49 |
+
2024-12-24 12:55:26 | INFO | stdout | background.shape (600, 600, 4)
|
50 |
+
2024-12-24 12:55:26 | INFO | stdout | len(layers) 1
|
51 |
+
2024-12-24 12:55:26 | INFO | stdout | composite.shape (600, 600, 4)
|
52 |
+
2024-12-24 12:55:27 | INFO | stdout | background.shape (600, 600, 4)
|
53 |
+
2024-12-24 12:55:27 | INFO | stdout | len(layers) 1
|
54 |
+
2024-12-24 12:55:27 | INFO | stdout | composite.shape (600, 600, 4)
|
55 |
+
2024-12-24 12:55:29 | INFO | stdout | background.shape (600, 600, 4)
|
56 |
+
2024-12-24 12:55:29 | INFO | stdout | len(layers) 1
|
57 |
+
2024-12-24 12:55:29 | INFO | stdout | composite.shape (600, 600, 4)
|
58 |
+
2024-12-24 12:55:31 | INFO | stdout | background.shape (600, 600, 4)
|
59 |
+
2024-12-24 12:55:31 | INFO | stdout | len(layers) 1
|
60 |
+
2024-12-24 12:55:31 | INFO | stdout | composite.shape (600, 600, 4)
|
61 |
+
2024-12-24 12:55:35 | INFO | stdout | background.shape (600, 600, 4)
|
62 |
+
2024-12-24 12:55:35 | INFO | stdout | len(layers) 1
|
63 |
+
2024-12-24 12:55:35 | INFO | stdout | composite.shape (600, 600, 4)
|
64 |
+
2024-12-24 12:55:41 | INFO | stdout | background.shape (600, 600, 4)
|
65 |
+
2024-12-24 12:55:41 | INFO | stdout | len(layers) 1
|
66 |
+
2024-12-24 12:55:41 | INFO | stdout | composite.shape (600, 600, 4)
|
67 |
+
2024-12-24 12:55:47 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py:1780: UserWarning: A function (disable_order_buttons) returned too many output values (needed: 3, returned: 5). Ignoring extra values.
|
68 |
+
2024-12-24 12:55:47 | ERROR | stderr | Output components:
|
69 |
+
2024-12-24 12:55:47 | ERROR | stderr | [textbox, button, button]
|
70 |
+
2024-12-24 12:55:47 | ERROR | stderr | Output values returned:
|
71 |
+
2024-12-24 12:55:47 | ERROR | stderr | [{'interactive': False, '__type__': 'update'}, {'interactive': False, '__type__': 'update'}, {'interactive': False, '__type__': 'update'}, {'interactive': False, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}]
|
72 |
+
2024-12-24 12:55:47 | ERROR | stderr | warnings.warn(
|
73 |
+
2024-12-24 12:55:47 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/helpers.py:968: UserWarning: Unexpected argument. Filling with None.
|
74 |
+
2024-12-24 12:55:47 | ERROR | stderr | warnings.warn("Unexpected argument. Filling with None.")
|
75 |
+
2024-12-24 12:55:47 | ERROR | stderr | Traceback (most recent call last):
|
76 |
+
2024-12-24 12:55:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/queueing.py", line 625, in process_events
|
77 |
+
2024-12-24 12:55:47 | ERROR | stderr | response = await route_utils.call_process_api(
|
78 |
+
2024-12-24 12:55:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/route_utils.py", line 322, in call_process_api
|
79 |
+
2024-12-24 12:55:47 | ERROR | stderr | output = await app.get_blocks().process_api(
|
80 |
+
2024-12-24 12:55:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2047, in process_api
|
81 |
+
2024-12-24 12:55:47 | ERROR | stderr | result = await self.call_function(
|
82 |
+
2024-12-24 12:55:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 1606, in call_function
|
83 |
+
2024-12-24 12:55:47 | ERROR | stderr | prediction = await utils.async_iteration(iterator)
|
84 |
+
2024-12-24 12:55:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 714, in async_iteration
|
85 |
+
2024-12-24 12:55:47 | ERROR | stderr | return await anext(iterator)
|
86 |
+
2024-12-24 12:55:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 708, in __anext__
|
87 |
+
2024-12-24 12:55:47 | ERROR | stderr | return await anyio.to_thread.run_sync(
|
88 |
+
2024-12-24 12:55:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync
|
89 |
+
2024-12-24 12:55:47 | ERROR | stderr | return await get_async_backend().run_sync_in_worker_thread(
|
90 |
+
2024-12-24 12:55:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2505, in run_sync_in_worker_thread
|
91 |
+
2024-12-24 12:55:47 | ERROR | stderr | return await future
|
92 |
+
2024-12-24 12:55:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 1005, in run
|
93 |
+
2024-12-24 12:55:47 | ERROR | stderr | result = context.run(func, *args)
|
94 |
+
2024-12-24 12:55:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 691, in run_sync_iterator_async
|
95 |
+
2024-12-24 12:55:47 | ERROR | stderr | return next(iterator)
|
96 |
+
2024-12-24 12:55:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 852, in gen_wrapper
|
97 |
+
2024-12-24 12:55:47 | ERROR | stderr | response = next(iterator)
|
98 |
+
2024-12-24 12:55:47 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/serve/vote_utils.py", line 793, in generate_igm_annoy
|
99 |
+
2024-12-24 12:55:47 | ERROR | stderr | = gen_func(text, grounding_instruction, out_imagebox, model_name0, model_name1, model_name2, model_name3)
|
100 |
+
2024-12-24 12:55:47 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/model_manager.py", line 94, in generate_image_b2i_parallel_anony
|
101 |
+
2024-12-24 12:55:47 | ERROR | stderr | model_ids = matchmaker(num_players=len(self.model_ig_list), not_run=not_run)
|
102 |
+
2024-12-24 12:55:47 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/matchmaker.py", line 95, in matchmaker
|
103 |
+
2024-12-24 12:55:47 | ERROR | stderr | ratings, comparison_counts, total_comparisons = load_json_via_sftp()
|
104 |
+
2024-12-24 12:55:47 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/matchmaker.py", line 79, in load_json_via_sftp
|
105 |
+
2024-12-24 12:55:47 | ERROR | stderr | create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
|
106 |
+
2024-12-24 12:55:47 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/matchmaker.py", line 21, in create_ssh_matchmaker_client
|
107 |
+
2024-12-24 12:55:47 | ERROR | stderr | ssh_matchmaker_client.connect(server, port, user, password)
|
108 |
+
2024-12-24 12:55:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/paramiko/client.py", line 377, in connect
|
109 |
+
2024-12-24 12:55:47 | ERROR | stderr | to_try = list(self._families_and_addresses(hostname, port))
|
110 |
+
2024-12-24 12:55:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/paramiko/client.py", line 202, in _families_and_addresses
|
111 |
+
2024-12-24 12:55:47 | ERROR | stderr | addrinfos = socket.getaddrinfo(
|
112 |
+
2024-12-24 12:55:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/socket.py", line 955, in getaddrinfo
|
113 |
+
2024-12-24 12:55:47 | ERROR | stderr | for res in _socket.getaddrinfo(host, port, family, type, proto, flags):
|
114 |
+
2024-12-24 12:55:47 | ERROR | stderr | socket.gaierror: [Errno -8] Servname not supported for ai_socktype
|
115 |
+
2024-12-24 12:55:47 | INFO | stdout | Rank
|
116 |
+
2024-12-24 13:17:15 | INFO | stdout | Keyboard interruption in main thread... closing server.
|
117 |
+
2024-12-24 13:17:15 | ERROR | stderr | Traceback (most recent call last):
|
118 |
+
2024-12-24 13:17:15 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2869, in block_thread
|
119 |
+
2024-12-24 13:17:15 | ERROR | stderr | time.sleep(0.1)
|
120 |
+
2024-12-24 13:17:15 | ERROR | stderr | KeyboardInterrupt
|
121 |
+
2024-12-24 13:17:15 | ERROR | stderr |
|
122 |
+
2024-12-24 13:17:15 | ERROR | stderr | During handling of the above exception, another exception occurred:
|
123 |
+
2024-12-24 13:17:15 | ERROR | stderr |
|
124 |
+
2024-12-24 13:17:15 | ERROR | stderr | Traceback (most recent call last):
|
125 |
+
2024-12-24 13:17:15 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/app.py", line 88, in <module>
|
126 |
+
2024-12-24 13:17:15 | ERROR | stderr | demo.queue(max_size=20).launch(server_port=server_port, root_path=ROOT_PATH, share=True)
|
127 |
+
2024-12-24 13:17:15 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2774, in launch
|
128 |
+
2024-12-24 13:17:15 | ERROR | stderr | self.block_thread()
|
129 |
+
2024-12-24 13:17:15 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2873, in block_thread
|
130 |
+
2024-12-24 13:17:15 | ERROR | stderr | self.server.close()
|
131 |
+
2024-12-24 13:17:15 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/http_server.py", line 69, in close
|
132 |
+
2024-12-24 13:17:15 | ERROR | stderr | self.thread.join(timeout=5)
|
133 |
+
2024-12-24 13:17:15 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/threading.py", line 1100, in join
|
134 |
+
2024-12-24 13:17:15 | ERROR | stderr | self._wait_for_tstate_lock(timeout=max(timeout, 0))
|
135 |
+
2024-12-24 13:17:15 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/threading.py", line 1116, in _wait_for_tstate_lock
|
136 |
+
2024-12-24 13:17:15 | ERROR | stderr | if lock.acquire(block, timeout):
|
137 |
+
2024-12-24 13:17:15 | ERROR | stderr | KeyboardInterrupt
|
138 |
+
2024-12-24 13:17:23 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py:34: FutureWarning: `Transformer2DModelOutput` is deprecated and will be removed in version 1.0.0. Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead.
|
139 |
+
2024-12-24 13:17:23 | ERROR | stderr | deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)
|
140 |
+
2024-12-24 13:17:25 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py:1003: UserWarning: Expected 10 arguments for function functools.partial(<function generate_igm_annoy at 0x7f3b81f730a0>, <bound method ModelManager.generate_image_b2i_parallel_anony of <model.model_manager.ModelManager object at 0x7f3dcb5d7d60>>), received 11.
|
141 |
+
2024-12-24 13:17:25 | ERROR | stderr | warnings.warn(
|
142 |
+
2024-12-24 13:17:25 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py:1011: UserWarning: Expected maximum 10 arguments for function functools.partial(<function generate_igm_annoy at 0x7f3b81f730a0>, <bound method ModelManager.generate_image_b2i_parallel_anony of <model.model_manager.ModelManager object at 0x7f3dcb5d7d60>>), received 11.
|
143 |
+
2024-12-24 13:17:25 | ERROR | stderr | warnings.warn(
|
144 |
+
2024-12-24 13:17:25 | INFO | stdout | * Running on local URL: http://127.0.0.1:7860
|
145 |
+
2024-12-24 13:17:40 | INFO | stdout | background.shape (600, 600, 4)
|
146 |
+
2024-12-24 13:17:40 | INFO | stdout | len(layers) 1
|
147 |
+
2024-12-24 13:17:40 | INFO | stdout | composite.shape (600, 600, 4)
|
148 |
+
2024-12-24 13:17:43 | INFO | stdout | background.shape (600, 600, 4)
|
149 |
+
2024-12-24 13:17:43 | INFO | stdout | len(layers) 1
|
150 |
+
2024-12-24 13:17:43 | INFO | stdout | composite.shape (600, 600, 4)
|
151 |
+
2024-12-24 13:17:47 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py:1780: UserWarning: A function (disable_order_buttons) returned too many output values (needed: 3, returned: 5). Ignoring extra values.
|
152 |
+
2024-12-24 13:17:47 | ERROR | stderr | Output components:
|
153 |
+
2024-12-24 13:17:47 | ERROR | stderr | [textbox, button, button]
|
154 |
+
2024-12-24 13:17:47 | ERROR | stderr | Output values returned:
|
155 |
+
2024-12-24 13:17:47 | ERROR | stderr | [{'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}]
|
156 |
+
2024-12-24 13:17:47 | ERROR | stderr | warnings.warn(
|
157 |
+
2024-12-24 13:17:47 | ERROR | stderr | Traceback (most recent call last):
|
158 |
+
2024-12-24 13:17:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/queueing.py", line 625, in process_events
|
159 |
+
2024-12-24 13:17:47 | ERROR | stderr | response = await route_utils.call_process_api(
|
160 |
+
2024-12-24 13:17:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/route_utils.py", line 322, in call_process_api
|
161 |
+
2024-12-24 13:17:47 | ERROR | stderr | output = await app.get_blocks().process_api(
|
162 |
+
2024-12-24 13:17:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2047, in process_api
|
163 |
+
2024-12-24 13:17:47 | ERROR | stderr | result = await self.call_function(
|
164 |
+
2024-12-24 13:17:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 1606, in call_function
|
165 |
+
2024-12-24 13:17:47 | ERROR | stderr | prediction = await utils.async_iteration(iterator)
|
166 |
+
2024-12-24 13:17:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 714, in async_iteration
|
167 |
+
2024-12-24 13:17:47 | ERROR | stderr | return await anext(iterator)
|
168 |
+
2024-12-24 13:17:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 708, in __anext__
|
169 |
+
2024-12-24 13:17:47 | ERROR | stderr | return await anyio.to_thread.run_sync(
|
170 |
+
2024-12-24 13:17:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync
|
171 |
+
2024-12-24 13:17:47 | ERROR | stderr | return await get_async_backend().run_sync_in_worker_thread(
|
172 |
+
2024-12-24 13:17:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2505, in run_sync_in_worker_thread
|
173 |
+
2024-12-24 13:17:47 | ERROR | stderr | return await future
|
174 |
+
2024-12-24 13:17:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 1005, in run
|
175 |
+
2024-12-24 13:17:47 | ERROR | stderr | result = context.run(func, *args)
|
176 |
+
2024-12-24 13:17:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 691, in run_sync_iterator_async
|
177 |
+
2024-12-24 13:17:47 | ERROR | stderr | return next(iterator)
|
178 |
+
2024-12-24 13:17:47 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 847, in gen_wrapper
|
179 |
+
2024-12-24 13:17:47 | ERROR | stderr | iterator = f(*args, **kwargs)
|
180 |
+
2024-12-24 13:17:47 | ERROR | stderr | TypeError: generate_igm_annoy() takes 11 positional arguments but 12 were given
|
181 |
+
2024-12-24 13:17:48 | INFO | stdout | Rank
|
182 |
+
2024-12-24 13:17:56 | INFO | stdout |
|
183 |
+
2024-12-24 13:17:56 | INFO | stdout | Could not create share link. Missing file: /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/frpc_linux_amd64_v0.3.
|
184 |
+
2024-12-24 13:17:56 | INFO | stdout |
|
185 |
+
2024-12-24 13:17:56 | INFO | stdout | Please check your internet connection. This can happen if your antivirus software blocks the download of this file. You can install manually by following these steps:
|
186 |
+
2024-12-24 13:17:56 | INFO | stdout |
|
187 |
+
2024-12-24 13:17:56 | INFO | stdout | 1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64
|
188 |
+
2024-12-24 13:17:56 | INFO | stdout | 2. Rename the downloaded file to: frpc_linux_amd64_v0.3
|
189 |
+
2024-12-24 13:17:56 | INFO | stdout | 3. Move the file to this location: /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio
|
190 |
+
2024-12-24 13:18:00 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py:1780: UserWarning: A function (disable_order_buttons) returned too many output values (needed: 3, returned: 5). Ignoring extra values.
|
191 |
+
2024-12-24 13:18:00 | ERROR | stderr | Output components:
|
192 |
+
2024-12-24 13:18:00 | ERROR | stderr | [textbox, button, button]
|
193 |
+
2024-12-24 13:18:00 | ERROR | stderr | Output values returned:
|
194 |
+
2024-12-24 13:18:00 | ERROR | stderr | [{'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}]
|
195 |
+
2024-12-24 13:18:00 | ERROR | stderr | warnings.warn(
|
196 |
+
2024-12-24 13:18:00 | ERROR | stderr | Traceback (most recent call last):
|
197 |
+
2024-12-24 13:18:00 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/queueing.py", line 625, in process_events
|
198 |
+
2024-12-24 13:18:00 | ERROR | stderr | response = await route_utils.call_process_api(
|
199 |
+
2024-12-24 13:18:00 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/route_utils.py", line 322, in call_process_api
|
200 |
+
2024-12-24 13:18:00 | ERROR | stderr | output = await app.get_blocks().process_api(
|
201 |
+
2024-12-24 13:18:00 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2047, in process_api
|
202 |
+
2024-12-24 13:18:00 | ERROR | stderr | result = await self.call_function(
|
203 |
+
2024-12-24 13:18:00 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 1606, in call_function
|
204 |
+
2024-12-24 13:18:00 | ERROR | stderr | prediction = await utils.async_iteration(iterator)
|
205 |
+
2024-12-24 13:18:00 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 714, in async_iteration
|
206 |
+
2024-12-24 13:18:00 | ERROR | stderr | return await anext(iterator)
|
207 |
+
2024-12-24 13:18:00 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 708, in __anext__
|
208 |
+
2024-12-24 13:18:00 | ERROR | stderr | return await anyio.to_thread.run_sync(
|
209 |
+
2024-12-24 13:18:00 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync
|
210 |
+
2024-12-24 13:18:00 | ERROR | stderr | return await get_async_backend().run_sync_in_worker_thread(
|
211 |
+
2024-12-24 13:18:00 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2505, in run_sync_in_worker_thread
|
212 |
+
2024-12-24 13:18:00 | ERROR | stderr | return await future
|
213 |
+
2024-12-24 13:18:00 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 1005, in run
|
214 |
+
2024-12-24 13:18:00 | ERROR | stderr | result = context.run(func, *args)
|
215 |
+
2024-12-24 13:18:00 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 691, in run_sync_iterator_async
|
216 |
+
2024-12-24 13:18:00 | ERROR | stderr | return next(iterator)
|
217 |
+
2024-12-24 13:18:00 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 847, in gen_wrapper
|
218 |
+
2024-12-24 13:18:00 | ERROR | stderr | iterator = f(*args, **kwargs)
|
219 |
+
2024-12-24 13:18:00 | ERROR | stderr | TypeError: generate_igm_annoy() takes 11 positional arguments but 12 were given
|
220 |
+
2024-12-24 13:18:00 | INFO | stdout | Rank
|
221 |
+
2024-12-24 13:18:01 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py:1780: UserWarning: A function (disable_order_buttons) returned too many output values (needed: 3, returned: 5). Ignoring extra values.
|
222 |
+
2024-12-24 13:18:01 | ERROR | stderr | Output components:
|
223 |
+
2024-12-24 13:18:01 | ERROR | stderr | [textbox, button, button]
|
224 |
+
2024-12-24 13:18:01 | ERROR | stderr | Output values returned:
|
225 |
+
2024-12-24 13:18:01 | ERROR | stderr | [{'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}]
|
226 |
+
2024-12-24 13:18:01 | ERROR | stderr | warnings.warn(
|
227 |
+
2024-12-24 13:18:01 | ERROR | stderr | Traceback (most recent call last):
|
228 |
+
2024-12-24 13:18:01 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/queueing.py", line 625, in process_events
|
229 |
+
2024-12-24 13:18:01 | ERROR | stderr | response = await route_utils.call_process_api(
|
230 |
+
2024-12-24 13:18:01 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/route_utils.py", line 322, in call_process_api
|
231 |
+
2024-12-24 13:18:01 | ERROR | stderr | output = await app.get_blocks().process_api(
|
232 |
+
2024-12-24 13:18:01 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2047, in process_api
|
233 |
+
2024-12-24 13:18:01 | ERROR | stderr | result = await self.call_function(
|
234 |
+
2024-12-24 13:18:01 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 1606, in call_function
|
235 |
+
2024-12-24 13:18:01 | ERROR | stderr | prediction = await utils.async_iteration(iterator)
|
236 |
+
2024-12-24 13:18:01 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 714, in async_iteration
|
237 |
+
2024-12-24 13:18:01 | ERROR | stderr | return await anext(iterator)
|
238 |
+
2024-12-24 13:18:01 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 708, in __anext__
|
239 |
+
2024-12-24 13:18:01 | ERROR | stderr | return await anyio.to_thread.run_sync(
|
240 |
+
2024-12-24 13:18:01 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync
|
241 |
+
2024-12-24 13:18:01 | ERROR | stderr | return await get_async_backend().run_sync_in_worker_thread(
|
242 |
+
2024-12-24 13:18:01 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2505, in run_sync_in_worker_thread
|
243 |
+
2024-12-24 13:18:01 | ERROR | stderr | return await future
|
244 |
+
2024-12-24 13:18:01 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 1005, in run
|
245 |
+
2024-12-24 13:18:01 | ERROR | stderr | result = context.run(func, *args)
|
246 |
+
2024-12-24 13:18:01 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 691, in run_sync_iterator_async
|
247 |
+
2024-12-24 13:18:01 | ERROR | stderr | return next(iterator)
|
248 |
+
2024-12-24 13:18:01 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 847, in gen_wrapper
|
249 |
+
2024-12-24 13:18:01 | ERROR | stderr | iterator = f(*args, **kwargs)
|
250 |
+
2024-12-24 13:18:01 | ERROR | stderr | TypeError: generate_igm_annoy() takes 11 positional arguments but 12 were given
|
251 |
+
2024-12-24 13:18:01 | INFO | stdout | Rank
|
252 |
+
2024-12-24 13:32:32 | INFO | stdout | Keyboard interruption in main thread... closing server.
|
253 |
+
2024-12-24 13:32:32 | ERROR | stderr | Traceback (most recent call last):
|
254 |
+
2024-12-24 13:32:32 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2869, in block_thread
|
255 |
+
2024-12-24 13:32:32 | ERROR | stderr | time.sleep(0.1)
|
256 |
+
2024-12-24 13:32:32 | ERROR | stderr | KeyboardInterrupt
|
257 |
+
2024-12-24 13:32:32 | ERROR | stderr |
|
258 |
+
2024-12-24 13:32:32 | ERROR | stderr | During handling of the above exception, another exception occurred:
|
259 |
+
2024-12-24 13:32:32 | ERROR | stderr |
|
260 |
+
2024-12-24 13:32:32 | ERROR | stderr | Traceback (most recent call last):
|
261 |
+
2024-12-24 13:32:32 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/app.py", line 88, in <module>
|
262 |
+
2024-12-24 13:32:32 | ERROR | stderr | demo.queue(max_size=20).launch(server_port=server_port, root_path=ROOT_PATH, share=True)
|
263 |
+
2024-12-24 13:32:32 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2774, in launch
|
264 |
+
2024-12-24 13:32:32 | ERROR | stderr | self.block_thread()
|
265 |
+
2024-12-24 13:32:32 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2873, in block_thread
|
266 |
+
2024-12-24 13:32:32 | ERROR | stderr | self.server.close()
|
267 |
+
2024-12-24 13:32:32 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/http_server.py", line 69, in close
|
268 |
+
2024-12-24 13:32:32 | ERROR | stderr | self.thread.join(timeout=5)
|
269 |
+
2024-12-24 13:32:32 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/threading.py", line 1100, in join
|
270 |
+
2024-12-24 13:32:32 | ERROR | stderr | self._wait_for_tstate_lock(timeout=max(timeout, 0))
|
271 |
+
2024-12-24 13:32:32 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/threading.py", line 1116, in _wait_for_tstate_lock
|
272 |
+
2024-12-24 13:32:32 | ERROR | stderr | if lock.acquire(block, timeout):
|
273 |
+
2024-12-24 13:32:32 | ERROR | stderr | KeyboardInterrupt
|
274 |
+
2024-12-24 13:32:52 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py:34: FutureWarning: `Transformer2DModelOutput` is deprecated and will be removed in version 1.0.0. Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead.
|
275 |
+
2024-12-24 13:32:52 | ERROR | stderr | deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)
|
276 |
+
2024-12-24 13:32:54 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py:1003: UserWarning: Expected 12 arguments for function functools.partial(<function generate_b2i_annoy at 0x7f3f2c4731c0>, <bound method ModelManager.generate_image_b2i_parallel_anony of <model.model_manager.ModelManager object at 0x7f3da2b87970>>), received 11.
|
277 |
+
2024-12-24 13:32:54 | ERROR | stderr | warnings.warn(
|
278 |
+
2024-12-24 13:32:54 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py:1007: UserWarning: Expected at least 12 arguments for function functools.partial(<function generate_b2i_annoy at 0x7f3f2c4731c0>, <bound method ModelManager.generate_image_b2i_parallel_anony of <model.model_manager.ModelManager object at 0x7f3da2b87970>>), received 11.
|
279 |
+
2024-12-24 13:32:54 | ERROR | stderr | warnings.warn(
|
280 |
+
2024-12-24 13:32:54 | INFO | stdout | * Running on local URL: http://127.0.0.1:7860
|
281 |
+
2024-12-24 13:33:05 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py:1780: UserWarning: A function (disable_order_buttons) returned too many output values (needed: 3, returned: 5). Ignoring extra values.
|
282 |
+
2024-12-24 13:33:05 | ERROR | stderr | Output components:
|
283 |
+
2024-12-24 13:33:05 | ERROR | stderr | [textbox, button, button]
|
284 |
+
2024-12-24 13:33:05 | ERROR | stderr | Output values returned:
|
285 |
+
2024-12-24 13:33:05 | ERROR | stderr | [{'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}]
|
286 |
+
2024-12-24 13:33:05 | ERROR | stderr | warnings.warn(
|
287 |
+
2024-12-24 13:33:05 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/helpers.py:968: UserWarning: Unexpected argument. Filling with None.
|
288 |
+
2024-12-24 13:33:05 | ERROR | stderr | warnings.warn("Unexpected argument. Filling with None.")
|
289 |
+
2024-12-24 13:33:06 | INFO | stdout | Rank
|
290 |
+
2024-12-24 13:33:17 | INFO | stdout | background.shape (600, 600, 4)
|
291 |
+
2024-12-24 13:33:17 | INFO | stdout | len(layers) 1
|
292 |
+
2024-12-24 13:33:17 | INFO | stdout | composite.shape (600, 600, 4)
|
293 |
+
2024-12-24 13:33:18 | INFO | stdout | background.shape (600, 600, 4)
|
294 |
+
2024-12-24 13:33:18 | INFO | stdout | len(layers) 1
|
295 |
+
2024-12-24 13:33:18 | INFO | stdout | composite.shape (600, 600, 4)
|
296 |
+
2024-12-24 13:33:19 | INFO | stdout | background.shape (600, 600, 4)
|
297 |
+
2024-12-24 13:33:19 | INFO | stdout | len(layers) 1
|
298 |
+
2024-12-24 13:33:19 | INFO | stdout | composite.shape (600, 600, 4)
|
299 |
+
2024-12-24 13:33:20 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py:1780: UserWarning: A function (disable_order_buttons) returned too many output values (needed: 3, returned: 5). Ignoring extra values.
|
300 |
+
2024-12-24 13:33:20 | ERROR | stderr | Output components:
|
301 |
+
2024-12-24 13:33:20 | ERROR | stderr | [textbox, button, button]
|
302 |
+
2024-12-24 13:33:20 | ERROR | stderr | Output values returned:
|
303 |
+
2024-12-24 13:33:20 | ERROR | stderr | [{'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}]
|
304 |
+
2024-12-24 13:33:20 | ERROR | stderr | warnings.warn(
|
305 |
+
2024-12-24 13:33:20 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/helpers.py:968: UserWarning: Unexpected argument. Filling with None.
|
306 |
+
2024-12-24 13:33:20 | ERROR | stderr | warnings.warn("Unexpected argument. Filling with None.")
|
307 |
+
2024-12-24 13:33:20 | INFO | stdout | Rank
|
308 |
+
2024-12-24 13:33:25 | INFO | stdout |
|
309 |
+
2024-12-24 13:33:25 | INFO | stdout | Could not create share link. Missing file: /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/frpc_linux_amd64_v0.3.
|
310 |
+
2024-12-24 13:33:25 | INFO | stdout |
|
311 |
+
2024-12-24 13:33:25 | INFO | stdout | Please check your internet connection. This can happen if your antivirus software blocks the download of this file. You can install manually by following these steps:
|
312 |
+
2024-12-24 13:33:25 | INFO | stdout |
|
313 |
+
2024-12-24 13:33:25 | INFO | stdout | 1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64
|
314 |
+
2024-12-24 13:33:25 | INFO | stdout | 2. Rename the downloaded file to: frpc_linux_amd64_v0.3
|
315 |
+
2024-12-24 13:33:25 | INFO | stdout | 3. Move the file to this location: /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio
|
316 |
+
2024-12-24 13:33:29 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py:1780: UserWarning: A function (disable_order_buttons) returned too many output values (needed: 3, returned: 5). Ignoring extra values.
|
317 |
+
2024-12-24 13:33:29 | ERROR | stderr | Output components:
|
318 |
+
2024-12-24 13:33:29 | ERROR | stderr | [textbox, button, button]
|
319 |
+
2024-12-24 13:33:29 | ERROR | stderr | Output values returned:
|
320 |
+
2024-12-24 13:33:29 | ERROR | stderr | [{'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}]
|
321 |
+
2024-12-24 13:33:29 | ERROR | stderr | warnings.warn(
|
322 |
+
2024-12-24 13:33:30 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/helpers.py:968: UserWarning: Unexpected argument. Filling with None.
|
323 |
+
2024-12-24 13:33:30 | ERROR | stderr | warnings.warn("Unexpected argument. Filling with None.")
|
324 |
+
2024-12-24 13:33:30 | INFO | stdout | Rank
|
325 |
+
2024-12-24 13:33:31 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py:1780: UserWarning: A function (disable_order_buttons) returned too many output values (needed: 3, returned: 5). Ignoring extra values.
|
326 |
+
2024-12-24 13:33:31 | ERROR | stderr | Output components:
|
327 |
+
2024-12-24 13:33:31 | ERROR | stderr | [textbox, button, button]
|
328 |
+
2024-12-24 13:33:31 | ERROR | stderr | Output values returned:
|
329 |
+
2024-12-24 13:33:31 | ERROR | stderr | [{'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}]
|
330 |
+
2024-12-24 13:33:31 | ERROR | stderr | warnings.warn(
|
331 |
+
2024-12-24 13:33:32 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/helpers.py:968: UserWarning: Unexpected argument. Filling with None.
|
332 |
+
2024-12-24 13:33:32 | ERROR | stderr | warnings.warn("Unexpected argument. Filling with None.")
|
333 |
+
2024-12-24 13:33:32 | INFO | stdout | Rank
|
334 |
+
2024-12-24 13:33:33 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py:1780: UserWarning: A function (disable_order_buttons) returned too many output values (needed: 3, returned: 5). Ignoring extra values.
|
335 |
+
2024-12-24 13:33:33 | ERROR | stderr | Output components:
|
336 |
+
2024-12-24 13:33:33 | ERROR | stderr | [textbox, button, button]
|
337 |
+
2024-12-24 13:33:33 | ERROR | stderr | Output values returned:
|
338 |
+
2024-12-24 13:33:33 | ERROR | stderr | [{'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}]
|
339 |
+
2024-12-24 13:33:33 | ERROR | stderr | warnings.warn(
|
340 |
+
2024-12-24 13:33:33 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/helpers.py:968: UserWarning: Unexpected argument. Filling with None.
|
341 |
+
2024-12-24 13:33:33 | ERROR | stderr | warnings.warn("Unexpected argument. Filling with None.")
|
342 |
+
2024-12-24 13:33:34 | INFO | stdout | Rank
|
343 |
+
2024-12-24 13:33:34 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py:1780: UserWarning: A function (disable_order_buttons) returned too many output values (needed: 3, returned: 5). Ignoring extra values.
|
344 |
+
2024-12-24 13:33:34 | ERROR | stderr | Output components:
|
345 |
+
2024-12-24 13:33:34 | ERROR | stderr | [textbox, button, button]
|
346 |
+
2024-12-24 13:33:34 | ERROR | stderr | Output values returned:
|
347 |
+
2024-12-24 13:33:34 | ERROR | stderr | [{'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}]
|
348 |
+
2024-12-24 13:33:34 | ERROR | stderr | warnings.warn(
|
349 |
+
2024-12-24 13:33:34 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/helpers.py:968: UserWarning: Unexpected argument. Filling with None.
|
350 |
+
2024-12-24 13:33:34 | ERROR | stderr | warnings.warn("Unexpected argument. Filling with None.")
|
351 |
+
2024-12-24 13:33:34 | INFO | stdout | Rank
|
352 |
+
2024-12-24 13:33:35 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py:1780: UserWarning: A function (disable_order_buttons) returned too many output values (needed: 3, returned: 5). Ignoring extra values.
|
353 |
+
2024-12-24 13:33:35 | ERROR | stderr | Output components:
|
354 |
+
2024-12-24 13:33:35 | ERROR | stderr | [textbox, button, button]
|
355 |
+
2024-12-24 13:33:35 | ERROR | stderr | Output values returned:
|
356 |
+
2024-12-24 13:33:35 | ERROR | stderr | [{'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}]
|
357 |
+
2024-12-24 13:33:35 | ERROR | stderr | warnings.warn(
|
358 |
+
2024-12-24 13:33:35 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/helpers.py:968: UserWarning: Unexpected argument. Filling with None.
|
359 |
+
2024-12-24 13:33:35 | ERROR | stderr | warnings.warn("Unexpected argument. Filling with None.")
|
360 |
+
2024-12-24 13:33:35 | INFO | stdout | Rank
|
361 |
+
2024-12-24 13:33:52 | INFO | stdout | Keyboard interruption in main thread... closing server.
|
362 |
+
2024-12-24 13:33:52 | ERROR | stderr | Traceback (most recent call last):
|
363 |
+
2024-12-24 13:33:52 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2869, in block_thread
|
364 |
+
2024-12-24 13:33:52 | ERROR | stderr | time.sleep(0.1)
|
365 |
+
2024-12-24 13:33:52 | ERROR | stderr | KeyboardInterrupt
|
366 |
+
2024-12-24 13:33:52 | ERROR | stderr |
|
367 |
+
2024-12-24 13:33:52 | ERROR | stderr | During handling of the above exception, another exception occurred:
|
368 |
+
2024-12-24 13:33:52 | ERROR | stderr |
|
369 |
+
2024-12-24 13:33:52 | ERROR | stderr | Traceback (most recent call last):
|
370 |
+
2024-12-24 13:33:52 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/app.py", line 88, in <module>
|
371 |
+
2024-12-24 13:33:52 | ERROR | stderr | demo.queue(max_size=20).launch(server_port=server_port, root_path=ROOT_PATH, share=True)
|
372 |
+
2024-12-24 13:33:52 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2774, in launch
|
373 |
+
2024-12-24 13:33:52 | ERROR | stderr | self.block_thread()
|
374 |
+
2024-12-24 13:33:52 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2873, in block_thread
|
375 |
+
2024-12-24 13:33:52 | ERROR | stderr | self.server.close()
|
376 |
+
2024-12-24 13:33:52 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/http_server.py", line 69, in close
|
377 |
+
2024-12-24 13:33:52 | ERROR | stderr | self.thread.join(timeout=5)
|
378 |
+
2024-12-24 13:33:52 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/threading.py", line 1100, in join
|
379 |
+
2024-12-24 13:33:52 | ERROR | stderr | self._wait_for_tstate_lock(timeout=max(timeout, 0))
|
380 |
+
2024-12-24 13:33:52 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/threading.py", line 1116, in _wait_for_tstate_lock
|
381 |
+
2024-12-24 13:33:52 | ERROR | stderr | if lock.acquire(block, timeout):
|
382 |
+
2024-12-24 13:33:52 | ERROR | stderr | KeyboardInterrupt
|
383 |
+
2024-12-24 13:33:53 | ERROR | stderr | Exception ignored in: <module 'threading' from '/share/bcy/miniconda3/envs/Arena/lib/python3.10/threading.py'>
|
384 |
+
2024-12-24 13:33:53 | ERROR | stderr | Traceback (most recent call last):
|
385 |
+
2024-12-24 13:33:53 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/threading.py", line 1567, in _shutdown
|
386 |
+
2024-12-24 13:33:53 | ERROR | stderr | lock.acquire()
|
387 |
+
2024-12-24 13:33:53 | ERROR | stderr | KeyboardInterrupt:
|
388 |
+
2024-12-24 13:34:05 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py:34: FutureWarning: `Transformer2DModelOutput` is deprecated and will be removed in version 1.0.0. Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead.
|
389 |
+
2024-12-24 13:34:05 | ERROR | stderr | deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)
|
390 |
+
2024-12-24 13:34:07 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py:1003: UserWarning: Expected 12 arguments for function functools.partial(<function generate_b2i_annoy at 0x7f4353e73250>, <bound method ModelManager.generate_image_b2i_parallel_anony of <model.model_manager.ModelManager object at 0x7f459d45bd60>>), received 11.
|
391 |
+
2024-12-24 13:34:07 | ERROR | stderr | warnings.warn(
|
392 |
+
2024-12-24 13:34:07 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py:1007: UserWarning: Expected at least 12 arguments for function functools.partial(<function generate_b2i_annoy at 0x7f4353e73250>, <bound method ModelManager.generate_image_b2i_parallel_anony of <model.model_manager.ModelManager object at 0x7f459d45bd60>>), received 11.
|
393 |
+
2024-12-24 13:34:07 | ERROR | stderr | warnings.warn(
|
394 |
+
2024-12-24 13:34:07 | INFO | stdout | * Running on local URL: http://127.0.0.1:7860
|
395 |
+
2024-12-24 13:34:18 | INFO | stdout |
|
396 |
+
2024-12-24 13:34:18 | INFO | stdout | Could not create share link. Missing file: /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/frpc_linux_amd64_v0.3.
|
397 |
+
2024-12-24 13:34:18 | INFO | stdout |
|
398 |
+
2024-12-24 13:34:18 | INFO | stdout | Please check your internet connection. This can happen if your antivirus software blocks the download of this file. You can install manually by following these steps:
|
399 |
+
2024-12-24 13:34:18 | INFO | stdout |
|
400 |
+
2024-12-24 13:34:18 | INFO | stdout | 1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64
|
401 |
+
2024-12-24 13:34:18 | INFO | stdout | 2. Rename the downloaded file to: frpc_linux_amd64_v0.3
|
402 |
+
2024-12-24 13:34:18 | INFO | stdout | 3. Move the file to this location: /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio
|
403 |
+
2024-12-24 13:34:25 | INFO | stdout | background.shape (600, 600, 4)
|
404 |
+
2024-12-24 13:34:25 | INFO | stdout | len(layers) 1
|
405 |
+
2024-12-24 13:34:25 | INFO | stdout | composite.shape (600, 600, 4)
|
406 |
+
2024-12-24 13:34:26 | INFO | stdout | background.shape (600, 600, 4)
|
407 |
+
2024-12-24 13:34:26 | INFO | stdout | len(layers) 1
|
408 |
+
2024-12-24 13:34:26 | INFO | stdout | composite.shape (600, 600, 4)
|
409 |
+
2024-12-24 13:34:28 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py:1780: UserWarning: A function (disable_order_buttons) returned too many output values (needed: 3, returned: 5). Ignoring extra values.
|
410 |
+
2024-12-24 13:34:28 | ERROR | stderr | Output components:
|
411 |
+
2024-12-24 13:34:28 | ERROR | stderr | [textbox, button, button]
|
412 |
+
2024-12-24 13:34:28 | ERROR | stderr | Output values returned:
|
413 |
+
2024-12-24 13:34:28 | ERROR | stderr | [{'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}]
|
414 |
+
2024-12-24 13:34:28 | ERROR | stderr | warnings.warn(
|
415 |
+
2024-12-24 13:34:29 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/helpers.py:968: UserWarning: Unexpected argument. Filling with None.
|
416 |
+
2024-12-24 13:34:29 | ERROR | stderr | warnings.warn("Unexpected argument. Filling with None.")
|
417 |
+
2024-12-24 13:34:29 | INFO | stdout | Rank
|
418 |
+
2024-12-24 13:43:01 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py:1780: UserWarning: A function (disable_order_buttons) returned too many output values (needed: 3, returned: 5). Ignoring extra values.
|
419 |
+
2024-12-24 13:43:01 | ERROR | stderr | Output components:
|
420 |
+
2024-12-24 13:43:01 | ERROR | stderr | [textbox, button, button]
|
421 |
+
2024-12-24 13:43:01 | ERROR | stderr | Output values returned:
|
422 |
+
2024-12-24 13:43:01 | ERROR | stderr | [{'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}]
|
423 |
+
2024-12-24 13:43:01 | ERROR | stderr | warnings.warn(
|
424 |
+
2024-12-24 13:43:02 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/helpers.py:968: UserWarning: Unexpected argument. Filling with None.
|
425 |
+
2024-12-24 13:43:02 | ERROR | stderr | warnings.warn("Unexpected argument. Filling with None.")
|
426 |
+
2024-12-24 13:43:02 | INFO | stdout | Rank
|
427 |
+
2024-12-24 13:43:10 | INFO | stdout | background.shape (600, 600, 4)
|
428 |
+
2024-12-24 13:43:10 | INFO | stdout | len(layers) 1
|
429 |
+
2024-12-24 13:43:10 | INFO | stdout | composite.shape (600, 600, 4)
|
430 |
+
2024-12-24 13:43:12 | INFO | stdout | background.shape (600, 600, 4)
|
431 |
+
2024-12-24 13:43:12 | INFO | stdout | len(layers) 1
|
432 |
+
2024-12-24 13:43:12 | INFO | stdout | composite.shape (600, 600, 4)
|
433 |
+
2024-12-24 13:43:14 | INFO | stdout | background.shape (600, 600, 4)
|
434 |
+
2024-12-24 13:43:14 | INFO | stdout | len(layers) 1
|
435 |
+
2024-12-24 13:43:14 | INFO | stdout | composite.shape (600, 600, 4)
|
436 |
+
2024-12-24 13:43:17 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py:1780: UserWarning: A function (disable_order_buttons) returned too many output values (needed: 3, returned: 5). Ignoring extra values.
|
437 |
+
2024-12-24 13:43:17 | ERROR | stderr | Output components:
|
438 |
+
2024-12-24 13:43:17 | ERROR | stderr | [textbox, button, button]
|
439 |
+
2024-12-24 13:43:17 | ERROR | stderr | Output values returned:
|
440 |
+
2024-12-24 13:43:17 | ERROR | stderr | [{'interactive': False, '__type__': 'update'}, {'interactive': False, '__type__': 'update'}, {'interactive': False, '__type__': 'update'}, {'interactive': False, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}]
|
441 |
+
2024-12-24 13:43:17 | ERROR | stderr | warnings.warn(
|
442 |
+
2024-12-24 13:43:17 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/helpers.py:968: UserWarning: Unexpected argument. Filling with None.
|
443 |
+
2024-12-24 13:43:17 | ERROR | stderr | warnings.warn("Unexpected argument. Filling with None.")
|
444 |
+
2024-12-24 13:43:17 | ERROR | stderr | Traceback (most recent call last):
|
445 |
+
2024-12-24 13:43:17 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/queueing.py", line 625, in process_events
|
446 |
+
2024-12-24 13:43:17 | ERROR | stderr | response = await route_utils.call_process_api(
|
447 |
+
2024-12-24 13:43:17 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/route_utils.py", line 322, in call_process_api
|
448 |
+
2024-12-24 13:43:17 | ERROR | stderr | output = await app.get_blocks().process_api(
|
449 |
+
2024-12-24 13:43:17 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2047, in process_api
|
450 |
+
2024-12-24 13:43:17 | ERROR | stderr | result = await self.call_function(
|
451 |
+
2024-12-24 13:43:17 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 1606, in call_function
|
452 |
+
2024-12-24 13:43:17 | ERROR | stderr | prediction = await utils.async_iteration(iterator)
|
453 |
+
2024-12-24 13:43:17 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 714, in async_iteration
|
454 |
+
2024-12-24 13:43:17 | ERROR | stderr | return await anext(iterator)
|
455 |
+
2024-12-24 13:43:17 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 708, in __anext__
|
456 |
+
2024-12-24 13:43:17 | ERROR | stderr | return await anyio.to_thread.run_sync(
|
457 |
+
2024-12-24 13:43:17 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync
|
458 |
+
2024-12-24 13:43:17 | ERROR | stderr | return await get_async_backend().run_sync_in_worker_thread(
|
459 |
+
2024-12-24 13:43:17 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2505, in run_sync_in_worker_thread
|
460 |
+
2024-12-24 13:43:17 | ERROR | stderr | return await future
|
461 |
+
2024-12-24 13:43:17 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 1005, in run
|
462 |
+
2024-12-24 13:43:17 | ERROR | stderr | result = context.run(func, *args)
|
463 |
+
2024-12-24 13:43:17 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 691, in run_sync_iterator_async
|
464 |
+
2024-12-24 13:43:17 | ERROR | stderr | return next(iterator)
|
465 |
+
2024-12-24 13:43:17 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 852, in gen_wrapper
|
466 |
+
2024-12-24 13:43:17 | ERROR | stderr | response = next(iterator)
|
467 |
+
2024-12-24 13:43:17 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/serve/vote_utils.py", line 896, in generate_b2i_annoy
|
468 |
+
2024-12-24 13:43:17 | ERROR | stderr | = gen_func(text, grounding_instruction, out_imagebox, model_name0, model_name1, model_name2, model_name3)
|
469 |
+
2024-12-24 13:43:17 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/model_manager.py", line 94, in generate_image_b2i_parallel_anony
|
470 |
+
2024-12-24 13:43:17 | ERROR | stderr | model_ids = matchmaker(num_players=len(self.model_ig_list), not_run=not_run)
|
471 |
+
2024-12-24 13:43:17 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/matchmaker.py", line 95, in matchmaker
|
472 |
+
2024-12-24 13:43:17 | ERROR | stderr | ratings, comparison_counts, total_comparisons = load_json_via_sftp()
|
473 |
+
2024-12-24 13:43:17 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/matchmaker.py", line 79, in load_json_via_sftp
|
474 |
+
2024-12-24 13:43:17 | ERROR | stderr | create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
|
475 |
+
2024-12-24 13:43:17 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/matchmaker.py", line 21, in create_ssh_matchmaker_client
|
476 |
+
2024-12-24 13:43:17 | ERROR | stderr | ssh_matchmaker_client.connect(server, port, user, password)
|
477 |
+
2024-12-24 13:43:17 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/paramiko/client.py", line 377, in connect
|
478 |
+
2024-12-24 13:43:17 | ERROR | stderr | to_try = list(self._families_and_addresses(hostname, port))
|
479 |
+
2024-12-24 13:43:17 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/paramiko/client.py", line 202, in _families_and_addresses
|
480 |
+
2024-12-24 13:43:17 | ERROR | stderr | addrinfos = socket.getaddrinfo(
|
481 |
+
2024-12-24 13:43:17 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/socket.py", line 955, in getaddrinfo
|
482 |
+
2024-12-24 13:43:17 | ERROR | stderr | for res in _socket.getaddrinfo(host, port, family, type, proto, flags):
|
483 |
+
2024-12-24 13:43:17 | ERROR | stderr | socket.gaierror: [Errno -8] Servname not supported for ai_socktype
|
484 |
+
2024-12-24 13:43:17 | INFO | stdout | Rank
|
485 |
+
2024-12-24 13:44:00 | INFO | stdout | background.shape (600, 600, 4)
|
486 |
+
2024-12-24 13:44:00 | INFO | stdout | len(layers) 1
|
487 |
+
2024-12-24 13:44:00 | INFO | stdout | composite.shape (600, 600, 4)
|
488 |
+
2024-12-24 13:44:01 | INFO | stdout | background.shape (600, 600, 4)
|
489 |
+
2024-12-24 13:44:01 | INFO | stdout | len(layers) 1
|
490 |
+
2024-12-24 13:44:01 | INFO | stdout | composite.shape (600, 600, 4)
|
491 |
+
2024-12-24 13:44:08 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py:1780: UserWarning: A function (disable_order_buttons) returned too many output values (needed: 3, returned: 5). Ignoring extra values.
|
492 |
+
2024-12-24 13:44:08 | ERROR | stderr | Output components:
|
493 |
+
2024-12-24 13:44:08 | ERROR | stderr | [textbox, button, button]
|
494 |
+
2024-12-24 13:44:08 | ERROR | stderr | Output values returned:
|
495 |
+
2024-12-24 13:44:08 | ERROR | stderr | [{'interactive': False, '__type__': 'update'}, {'interactive': False, '__type__': 'update'}, {'interactive': False, '__type__': 'update'}, {'interactive': False, '__type__': 'update'}, {'interactive': True, '__type__': 'update'}]
|
496 |
+
2024-12-24 13:44:08 | ERROR | stderr | warnings.warn(
|
497 |
+
2024-12-24 13:44:08 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/helpers.py:968: UserWarning: Unexpected argument. Filling with None.
|
498 |
+
2024-12-24 13:44:08 | ERROR | stderr | warnings.warn("Unexpected argument. Filling with None.")
|
499 |
+
2024-12-24 13:44:08 | ERROR | stderr | Traceback (most recent call last):
|
500 |
+
2024-12-24 13:44:08 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/queueing.py", line 625, in process_events
|
501 |
+
2024-12-24 13:44:08 | ERROR | stderr | response = await route_utils.call_process_api(
|
502 |
+
2024-12-24 13:44:08 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/route_utils.py", line 322, in call_process_api
|
503 |
+
2024-12-24 13:44:08 | ERROR | stderr | output = await app.get_blocks().process_api(
|
504 |
+
2024-12-24 13:44:08 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2047, in process_api
|
505 |
+
2024-12-24 13:44:08 | ERROR | stderr | result = await self.call_function(
|
506 |
+
2024-12-24 13:44:08 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 1606, in call_function
|
507 |
+
2024-12-24 13:44:08 | ERROR | stderr | prediction = await utils.async_iteration(iterator)
|
508 |
+
2024-12-24 13:44:08 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 714, in async_iteration
|
509 |
+
2024-12-24 13:44:08 | ERROR | stderr | return await anext(iterator)
|
510 |
+
2024-12-24 13:44:08 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 708, in __anext__
|
511 |
+
2024-12-24 13:44:08 | ERROR | stderr | return await anyio.to_thread.run_sync(
|
512 |
+
2024-12-24 13:44:08 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync
|
513 |
+
2024-12-24 13:44:08 | ERROR | stderr | return await get_async_backend().run_sync_in_worker_thread(
|
514 |
+
2024-12-24 13:44:08 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2505, in run_sync_in_worker_thread
|
515 |
+
2024-12-24 13:44:08 | ERROR | stderr | return await future
|
516 |
+
2024-12-24 13:44:08 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 1005, in run
|
517 |
+
2024-12-24 13:44:08 | ERROR | stderr | result = context.run(func, *args)
|
518 |
+
2024-12-24 13:44:08 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 691, in run_sync_iterator_async
|
519 |
+
2024-12-24 13:44:08 | ERROR | stderr | return next(iterator)
|
520 |
+
2024-12-24 13:44:08 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 852, in gen_wrapper
|
521 |
+
2024-12-24 13:44:08 | ERROR | stderr | response = next(iterator)
|
522 |
+
2024-12-24 13:44:08 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/serve/vote_utils.py", line 896, in generate_b2i_annoy
|
523 |
+
2024-12-24 13:44:08 | ERROR | stderr | = gen_func(text, grounding_instruction, out_imagebox, model_name0, model_name1, model_name2, model_name3)
|
524 |
+
2024-12-24 13:44:08 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/model_manager.py", line 94, in generate_image_b2i_parallel_anony
|
525 |
+
2024-12-24 13:44:08 | ERROR | stderr | model_ids = matchmaker(num_players=len(self.model_ig_list), not_run=not_run)
|
526 |
+
2024-12-24 13:44:08 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/matchmaker.py", line 95, in matchmaker
|
527 |
+
2024-12-24 13:44:08 | ERROR | stderr | ratings, comparison_counts, total_comparisons = load_json_via_sftp()
|
528 |
+
2024-12-24 13:44:08 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/matchmaker.py", line 79, in load_json_via_sftp
|
529 |
+
2024-12-24 13:44:08 | ERROR | stderr | create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
|
530 |
+
2024-12-24 13:44:08 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/matchmaker.py", line 21, in create_ssh_matchmaker_client
|
531 |
+
2024-12-24 13:44:08 | ERROR | stderr | ssh_matchmaker_client.connect(server, port, user, password)
|
532 |
+
2024-12-24 13:44:08 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/paramiko/client.py", line 377, in connect
|
533 |
+
2024-12-24 13:44:08 | ERROR | stderr | to_try = list(self._families_and_addresses(hostname, port))
|
534 |
+
2024-12-24 13:44:08 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/paramiko/client.py", line 202, in _families_and_addresses
|
535 |
+
2024-12-24 13:44:08 | ERROR | stderr | addrinfos = socket.getaddrinfo(
|
536 |
+
2024-12-24 13:44:08 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/socket.py", line 955, in getaddrinfo
|
537 |
+
2024-12-24 13:44:08 | ERROR | stderr | for res in _socket.getaddrinfo(host, port, family, type, proto, flags):
|
538 |
+
2024-12-24 13:44:08 | ERROR | stderr | socket.gaierror: [Errno -8] Servname not supported for ai_socktype
|
539 |
+
2024-12-24 13:44:09 | INFO | stdout | Rank
|
540 |
+
2024-12-24 13:45:55 | INFO | stdout | background.shape (600, 600, 4)
|
541 |
+
2024-12-24 13:45:55 | INFO | stdout | len(layers) 1
|
542 |
+
2024-12-24 13:45:55 | INFO | stdout | composite.shape (600, 600, 4)
|
543 |
+
2024-12-24 13:45:57 | INFO | stdout | background.shape (600, 600, 4)
|
544 |
+
2024-12-24 13:45:57 | INFO | stdout | len(layers) 1
|
545 |
+
2024-12-24 13:45:57 | INFO | stdout | composite.shape (600, 600, 4)
|
546 |
+
2024-12-24 13:45:59 | INFO | stdout | background.shape (600, 600, 4)
|
547 |
+
2024-12-24 13:45:59 | INFO | stdout | len(layers) 1
|
548 |
+
2024-12-24 13:45:59 | INFO | stdout | composite.shape (600, 600, 4)
|
549 |
+
2024-12-24 13:52:19 | INFO | stdout | Keyboard interruption in main thread... closing server.
|
550 |
+
2024-12-24 13:52:20 | ERROR | stderr | Traceback (most recent call last):
|
551 |
+
2024-12-24 13:52:20 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2869, in block_thread
|
552 |
+
2024-12-24 13:52:20 | ERROR | stderr | time.sleep(0.1)
|
553 |
+
2024-12-24 13:52:20 | ERROR | stderr | KeyboardInterrupt
|
554 |
+
2024-12-24 13:52:20 | ERROR | stderr |
|
555 |
+
2024-12-24 13:52:20 | ERROR | stderr | During handling of the above exception, another exception occurred:
|
556 |
+
2024-12-24 13:52:20 | ERROR | stderr |
|
557 |
+
2024-12-24 13:52:20 | ERROR | stderr | Traceback (most recent call last):
|
558 |
+
2024-12-24 13:52:20 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/app.py", line 88, in <module>
|
559 |
+
2024-12-24 13:52:20 | ERROR | stderr | demo.queue(max_size=20).launch(server_port=server_port, root_path=ROOT_PATH, share=True)
|
560 |
+
2024-12-24 13:52:20 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2774, in launch
|
561 |
+
2024-12-24 13:52:20 | ERROR | stderr | self.block_thread()
|
562 |
+
2024-12-24 13:52:20 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2873, in block_thread
|
563 |
+
2024-12-24 13:52:20 | ERROR | stderr | self.server.close()
|
564 |
+
2024-12-24 13:52:20 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/http_server.py", line 69, in close
|
565 |
+
2024-12-24 13:52:20 | ERROR | stderr | self.thread.join(timeout=5)
|
566 |
+
2024-12-24 13:52:20 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/threading.py", line 1100, in join
|
567 |
+
2024-12-24 13:52:20 | ERROR | stderr | self._wait_for_tstate_lock(timeout=max(timeout, 0))
|
568 |
+
2024-12-24 13:52:20 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/threading.py", line 1116, in _wait_for_tstate_lock
|
569 |
+
2024-12-24 13:52:20 | ERROR | stderr | if lock.acquire(block, timeout):
|
570 |
+
2024-12-24 13:52:20 | ERROR | stderr | KeyboardInterrupt
|
571 |
+
2024-12-24 13:52:32 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py:34: FutureWarning: `Transformer2DModelOutput` is deprecated and will be removed in version 1.0.0. Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead.
|
572 |
+
2024-12-24 13:52:32 | ERROR | stderr | deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)
|
573 |
+
2024-12-24 13:52:34 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py:1003: UserWarning: Expected 12 arguments for function functools.partial(<function generate_b2i_annoy at 0x7fb8f926f250>, <bound method ModelManager.generate_image_b2i_parallel_anony of <model.model_manager.ModelManager object at 0x7fb76fa6b6a0>>), received 11.
|
574 |
+
2024-12-24 13:52:34 | ERROR | stderr | warnings.warn(
|
575 |
+
2024-12-24 13:52:34 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py:1007: UserWarning: Expected at least 12 arguments for function functools.partial(<function generate_b2i_annoy at 0x7fb8f926f250>, <bound method ModelManager.generate_image_b2i_parallel_anony of <model.model_manager.ModelManager object at 0x7fb76fa6b6a0>>), received 11.
|
576 |
+
2024-12-24 13:52:34 | ERROR | stderr | warnings.warn(
|
577 |
+
2024-12-24 13:52:35 | INFO | stdout | * Running on local URL: http://127.0.0.1:7860
|
578 |
+
2024-12-24 13:52:42 | INFO | stdout | background.shape (600, 600, 4)
|
579 |
+
2024-12-24 13:52:42 | INFO | stdout | len(layers) 1
|
580 |
+
2024-12-24 13:52:42 | INFO | stdout | composite.shape (600, 600, 4)
|
581 |
+
2024-12-24 13:52:42 | INFO | stdout | background.shape (600, 600, 4)
|
582 |
+
2024-12-24 13:52:42 | INFO | stdout | len(layers) 1
|
583 |
+
2024-12-24 13:52:42 | INFO | stdout | composite.shape (600, 600, 4)
|
584 |
+
2024-12-24 13:52:44 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/helpers.py:968: UserWarning: Unexpected argument. Filling with None.
|
585 |
+
2024-12-24 13:52:44 | ERROR | stderr | warnings.warn("Unexpected argument. Filling with None.")
|
586 |
+
2024-12-24 13:52:44 | ERROR | stderr | Traceback (most recent call last):
|
587 |
+
2024-12-24 13:52:44 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/queueing.py", line 625, in process_events
|
588 |
+
2024-12-24 13:52:44 | ERROR | stderr | response = await route_utils.call_process_api(
|
589 |
+
2024-12-24 13:52:44 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/route_utils.py", line 322, in call_process_api
|
590 |
+
2024-12-24 13:52:44 | ERROR | stderr | output = await app.get_blocks().process_api(
|
591 |
+
2024-12-24 13:52:44 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2047, in process_api
|
592 |
+
2024-12-24 13:52:44 | ERROR | stderr | result = await self.call_function(
|
593 |
+
2024-12-24 13:52:44 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 1606, in call_function
|
594 |
+
2024-12-24 13:52:44 | ERROR | stderr | prediction = await utils.async_iteration(iterator)
|
595 |
+
2024-12-24 13:52:44 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 714, in async_iteration
|
596 |
+
2024-12-24 13:52:44 | ERROR | stderr | return await anext(iterator)
|
597 |
+
2024-12-24 13:52:44 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 708, in __anext__
|
598 |
+
2024-12-24 13:52:44 | ERROR | stderr | return await anyio.to_thread.run_sync(
|
599 |
+
2024-12-24 13:52:44 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync
|
600 |
+
2024-12-24 13:52:44 | ERROR | stderr | return await get_async_backend().run_sync_in_worker_thread(
|
601 |
+
2024-12-24 13:52:44 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2505, in run_sync_in_worker_thread
|
602 |
+
2024-12-24 13:52:44 | ERROR | stderr | return await future
|
603 |
+
2024-12-24 13:52:44 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 1005, in run
|
604 |
+
2024-12-24 13:52:44 | ERROR | stderr | result = context.run(func, *args)
|
605 |
+
2024-12-24 13:52:44 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 691, in run_sync_iterator_async
|
606 |
+
2024-12-24 13:52:44 | ERROR | stderr | return next(iterator)
|
607 |
+
2024-12-24 13:52:44 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 852, in gen_wrapper
|
608 |
+
2024-12-24 13:52:44 | ERROR | stderr | response = next(iterator)
|
609 |
+
2024-12-24 13:52:44 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/serve/vote_utils.py", line 896, in generate_b2i_annoy
|
610 |
+
2024-12-24 13:52:44 | ERROR | stderr | = gen_func(text, grounding_instruction, out_imagebox, model_name0, model_name1, model_name2, model_name3)
|
611 |
+
2024-12-24 13:52:44 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/model_manager.py", line 94, in generate_image_b2i_parallel_anony
|
612 |
+
2024-12-24 13:52:44 | ERROR | stderr | model_ids = matchmaker(num_players=len(self.model_ig_list), not_run=not_run)
|
613 |
+
2024-12-24 13:52:44 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/matchmaker.py", line 95, in matchmaker
|
614 |
+
2024-12-24 13:52:44 | ERROR | stderr | ratings, comparison_counts, total_comparisons = load_json_via_sftp()
|
615 |
+
2024-12-24 13:52:44 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/matchmaker.py", line 79, in load_json_via_sftp
|
616 |
+
2024-12-24 13:52:44 | ERROR | stderr | create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
|
617 |
+
2024-12-24 13:52:44 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/matchmaker.py", line 21, in create_ssh_matchmaker_client
|
618 |
+
2024-12-24 13:52:44 | ERROR | stderr | ssh_matchmaker_client.connect(server, port, user, password)
|
619 |
+
2024-12-24 13:52:44 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/paramiko/client.py", line 377, in connect
|
620 |
+
2024-12-24 13:52:44 | ERROR | stderr | to_try = list(self._families_and_addresses(hostname, port))
|
621 |
+
2024-12-24 13:52:44 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/paramiko/client.py", line 202, in _families_and_addresses
|
622 |
+
2024-12-24 13:52:44 | ERROR | stderr | addrinfos = socket.getaddrinfo(
|
623 |
+
2024-12-24 13:52:44 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/socket.py", line 955, in getaddrinfo
|
624 |
+
2024-12-24 13:52:44 | ERROR | stderr | for res in _socket.getaddrinfo(host, port, family, type, proto, flags):
|
625 |
+
2024-12-24 13:52:44 | ERROR | stderr | socket.gaierror: [Errno -8] Servname not supported for ai_socktype
|
626 |
+
2024-12-24 13:52:44 | INFO | stdout | Rank
|
627 |
+
2024-12-24 13:53:06 | INFO | stdout |
|
628 |
+
2024-12-24 13:53:06 | INFO | stdout | Could not create share link. Missing file: /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/frpc_linux_amd64_v0.3.
|
629 |
+
2024-12-24 13:53:06 | INFO | stdout |
|
630 |
+
2024-12-24 13:53:06 | INFO | stdout | Please check your internet connection. This can happen if your antivirus software blocks the download of this file. You can install manually by following these steps:
|
631 |
+
2024-12-24 13:53:06 | INFO | stdout |
|
632 |
+
2024-12-24 13:53:06 | INFO | stdout | 1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64
|
633 |
+
2024-12-24 13:53:06 | INFO | stdout | 2. Rename the downloaded file to: frpc_linux_amd64_v0.3
|
634 |
+
2024-12-24 13:53:06 | INFO | stdout | 3. Move the file to this location: /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio
|
635 |
+
2024-12-24 13:56:11 | INFO | stdout | Keyboard interruption in main thread... closing server.
|
636 |
+
2024-12-24 13:56:12 | ERROR | stderr | Traceback (most recent call last):
|
637 |
+
2024-12-24 13:56:12 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2869, in block_thread
|
638 |
+
2024-12-24 13:56:12 | ERROR | stderr | time.sleep(0.1)
|
639 |
+
2024-12-24 13:56:12 | ERROR | stderr | KeyboardInterrupt
|
640 |
+
2024-12-24 13:56:12 | ERROR | stderr |
|
641 |
+
2024-12-24 13:56:12 | ERROR | stderr | During handling of the above exception, another exception occurred:
|
642 |
+
2024-12-24 13:56:12 | ERROR | stderr |
|
643 |
+
2024-12-24 13:56:12 | ERROR | stderr | Traceback (most recent call last):
|
644 |
+
2024-12-24 13:56:12 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/app.py", line 88, in <module>
|
645 |
+
2024-12-24 13:56:12 | ERROR | stderr | demo.queue(max_size=20).launch(server_port=server_port, root_path=ROOT_PATH, share=True)
|
646 |
+
2024-12-24 13:56:12 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2774, in launch
|
647 |
+
2024-12-24 13:56:12 | ERROR | stderr | self.block_thread()
|
648 |
+
2024-12-24 13:56:12 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2873, in block_thread
|
649 |
+
2024-12-24 13:56:12 | ERROR | stderr | self.server.close()
|
650 |
+
2024-12-24 13:56:12 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/http_server.py", line 69, in close
|
651 |
+
2024-12-24 13:56:12 | ERROR | stderr | self.thread.join(timeout=5)
|
652 |
+
2024-12-24 13:56:12 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/threading.py", line 1100, in join
|
653 |
+
2024-12-24 13:56:12 | ERROR | stderr | self._wait_for_tstate_lock(timeout=max(timeout, 0))
|
654 |
+
2024-12-24 13:56:12 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/threading.py", line 1116, in _wait_for_tstate_lock
|
655 |
+
2024-12-24 13:56:12 | ERROR | stderr | if lock.acquire(block, timeout):
|
656 |
+
2024-12-24 13:56:12 | ERROR | stderr | KeyboardInterrupt
|
657 |
+
2024-12-24 13:56:24 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py:34: FutureWarning: `Transformer2DModelOutput` is deprecated and will be removed in version 1.0.0. Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead.
|
658 |
+
2024-12-24 13:56:24 | ERROR | stderr | deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)
|
659 |
+
2024-12-24 13:56:26 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py:1003: UserWarning: Expected 12 arguments for function functools.partial(<function generate_b2i_annoy at 0x7ff04a643250>, <bound method ModelManager.generate_image_b2i_parallel_anony of <model.model_manager.ModelManager object at 0x7ff293c5fd60>>), received 11.
|
660 |
+
2024-12-24 13:56:26 | ERROR | stderr | warnings.warn(
|
661 |
+
2024-12-24 13:56:26 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py:1007: UserWarning: Expected at least 12 arguments for function functools.partial(<function generate_b2i_annoy at 0x7ff04a643250>, <bound method ModelManager.generate_image_b2i_parallel_anony of <model.model_manager.ModelManager object at 0x7ff293c5fd60>>), received 11.
|
662 |
+
2024-12-24 13:56:26 | ERROR | stderr | warnings.warn(
|
663 |
+
2024-12-24 13:56:26 | INFO | stdout | * Running on local URL: http://127.0.0.1:7860
|
664 |
+
2024-12-24 13:56:36 | INFO | stdout | background.shape (600, 600, 4)
|
665 |
+
2024-12-24 13:56:36 | INFO | stdout | len(layers) 1
|
666 |
+
2024-12-24 13:56:36 | INFO | stdout | composite.shape (600, 600, 4)
|
667 |
+
2024-12-24 13:56:36 | INFO | stdout | background.shape (600, 600, 4)
|
668 |
+
2024-12-24 13:56:36 | INFO | stdout | len(layers) 1
|
669 |
+
2024-12-24 13:56:36 | INFO | stdout | composite.shape (600, 600, 4)
|
670 |
+
2024-12-24 13:56:41 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/helpers.py:968: UserWarning: Unexpected argument. Filling with None.
|
671 |
+
2024-12-24 13:56:41 | ERROR | stderr | warnings.warn("Unexpected argument. Filling with None.")
|
672 |
+
2024-12-24 13:56:41 | ERROR | stderr | Traceback (most recent call last):
|
673 |
+
2024-12-24 13:56:41 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/queueing.py", line 625, in process_events
|
674 |
+
2024-12-24 13:56:41 | ERROR | stderr | response = await route_utils.call_process_api(
|
675 |
+
2024-12-24 13:56:41 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/route_utils.py", line 322, in call_process_api
|
676 |
+
2024-12-24 13:56:41 | ERROR | stderr | output = await app.get_blocks().process_api(
|
677 |
+
2024-12-24 13:56:41 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2047, in process_api
|
678 |
+
2024-12-24 13:56:41 | ERROR | stderr | result = await self.call_function(
|
679 |
+
2024-12-24 13:56:41 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 1606, in call_function
|
680 |
+
2024-12-24 13:56:41 | ERROR | stderr | prediction = await utils.async_iteration(iterator)
|
681 |
+
2024-12-24 13:56:41 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 714, in async_iteration
|
682 |
+
2024-12-24 13:56:41 | ERROR | stderr | return await anext(iterator)
|
683 |
+
2024-12-24 13:56:41 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 708, in __anext__
|
684 |
+
2024-12-24 13:56:41 | ERROR | stderr | return await anyio.to_thread.run_sync(
|
685 |
+
2024-12-24 13:56:41 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync
|
686 |
+
2024-12-24 13:56:41 | ERROR | stderr | return await get_async_backend().run_sync_in_worker_thread(
|
687 |
+
2024-12-24 13:56:41 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2505, in run_sync_in_worker_thread
|
688 |
+
2024-12-24 13:56:41 | ERROR | stderr | return await future
|
689 |
+
2024-12-24 13:56:41 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 1005, in run
|
690 |
+
2024-12-24 13:56:41 | ERROR | stderr | result = context.run(func, *args)
|
691 |
+
2024-12-24 13:56:41 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 691, in run_sync_iterator_async
|
692 |
+
2024-12-24 13:56:41 | ERROR | stderr | return next(iterator)
|
693 |
+
2024-12-24 13:56:41 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 852, in gen_wrapper
|
694 |
+
2024-12-24 13:56:41 | ERROR | stderr | response = next(iterator)
|
695 |
+
2024-12-24 13:56:41 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/serve/vote_utils.py", line 896, in generate_b2i_annoy
|
696 |
+
2024-12-24 13:56:41 | ERROR | stderr | = gen_func(text, grounding_instruction, out_imagebox, model_name0, model_name1, model_name2, model_name3)
|
697 |
+
2024-12-24 13:56:41 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/model_manager.py", line 94, in generate_image_b2i_parallel_anony
|
698 |
+
2024-12-24 13:56:41 | ERROR | stderr | model_ids = matchmaker(num_players=len(self.model_ig_list), not_run=not_run)
|
699 |
+
2024-12-24 13:56:41 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/matchmaker.py", line 95, in matchmaker
|
700 |
+
2024-12-24 13:56:41 | ERROR | stderr | ratings, comparison_counts, total_comparisons = load_json_via_sftp()
|
701 |
+
2024-12-24 13:56:41 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/matchmaker.py", line 79, in load_json_via_sftp
|
702 |
+
2024-12-24 13:56:41 | ERROR | stderr | create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
|
703 |
+
2024-12-24 13:56:41 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/matchmaker.py", line 21, in create_ssh_matchmaker_client
|
704 |
+
2024-12-24 13:56:41 | ERROR | stderr | ssh_matchmaker_client.connect(server, port, user, password)
|
705 |
+
2024-12-24 13:56:41 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/paramiko/client.py", line 377, in connect
|
706 |
+
2024-12-24 13:56:41 | ERROR | stderr | to_try = list(self._families_and_addresses(hostname, port))
|
707 |
+
2024-12-24 13:56:41 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/paramiko/client.py", line 202, in _families_and_addresses
|
708 |
+
2024-12-24 13:56:41 | ERROR | stderr | addrinfos = socket.getaddrinfo(
|
709 |
+
2024-12-24 13:56:41 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/socket.py", line 955, in getaddrinfo
|
710 |
+
2024-12-24 13:56:41 | ERROR | stderr | for res in _socket.getaddrinfo(host, port, family, type, proto, flags):
|
711 |
+
2024-12-24 13:56:41 | ERROR | stderr | socket.gaierror: [Errno -8] Servname not supported for ai_socktype
|
712 |
+
2024-12-24 13:56:41 | INFO | stdout | Rank
|
713 |
+
2024-12-24 13:57:05 | INFO | stdout |
|
714 |
+
2024-12-24 13:57:05 | INFO | stdout | Could not create share link. Missing file: /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/frpc_linux_amd64_v0.3.
|
715 |
+
2024-12-24 13:57:05 | INFO | stdout |
|
716 |
+
2024-12-24 13:57:05 | INFO | stdout | Please check your internet connection. This can happen if your antivirus software blocks the download of this file. You can install manually by following these steps:
|
717 |
+
2024-12-24 13:57:05 | INFO | stdout |
|
718 |
+
2024-12-24 13:57:05 | INFO | stdout | 1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64
|
719 |
+
2024-12-24 13:57:05 | INFO | stdout | 2. Rename the downloaded file to: frpc_linux_amd64_v0.3
|
720 |
+
2024-12-24 13:57:05 | INFO | stdout | 3. Move the file to this location: /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio
|
721 |
+
2024-12-24 13:58:10 | INFO | stdout | Keyboard interruption in main thread... closing server.
|
722 |
+
2024-12-24 13:58:10 | ERROR | stderr | Traceback (most recent call last):
|
723 |
+
2024-12-24 13:58:10 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2869, in block_thread
|
724 |
+
2024-12-24 13:58:10 | ERROR | stderr | time.sleep(0.1)
|
725 |
+
2024-12-24 13:58:10 | ERROR | stderr | KeyboardInterrupt
|
726 |
+
2024-12-24 13:58:10 | ERROR | stderr |
|
727 |
+
2024-12-24 13:58:10 | ERROR | stderr | During handling of the above exception, another exception occurred:
|
728 |
+
2024-12-24 13:58:10 | ERROR | stderr |
|
729 |
+
2024-12-24 13:58:10 | ERROR | stderr | Traceback (most recent call last):
|
730 |
+
2024-12-24 13:58:10 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/app.py", line 88, in <module>
|
731 |
+
2024-12-24 13:58:10 | ERROR | stderr | demo.queue(max_size=20).launch(server_port=server_port, root_path=ROOT_PATH, share=True)
|
732 |
+
2024-12-24 13:58:10 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2774, in launch
|
733 |
+
2024-12-24 13:58:10 | ERROR | stderr | self.block_thread()
|
734 |
+
2024-12-24 13:58:10 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2873, in block_thread
|
735 |
+
2024-12-24 13:58:10 | ERROR | stderr | self.server.close()
|
736 |
+
2024-12-24 13:58:10 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/http_server.py", line 69, in close
|
737 |
+
2024-12-24 13:58:10 | ERROR | stderr | self.thread.join(timeout=5)
|
738 |
+
2024-12-24 13:58:10 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/threading.py", line 1100, in join
|
739 |
+
2024-12-24 13:58:10 | ERROR | stderr | self._wait_for_tstate_lock(timeout=max(timeout, 0))
|
740 |
+
2024-12-24 13:58:10 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/threading.py", line 1116, in _wait_for_tstate_lock
|
741 |
+
2024-12-24 13:58:10 | ERROR | stderr | if lock.acquire(block, timeout):
|
742 |
+
2024-12-24 13:58:10 | ERROR | stderr | KeyboardInterrupt
|
743 |
+
2024-12-24 13:58:11 | ERROR | stderr | Exception ignored in: <module 'threading' from '/share/bcy/miniconda3/envs/Arena/lib/python3.10/threading.py'>
|
744 |
+
2024-12-24 13:58:11 | ERROR | stderr | Traceback (most recent call last):
|
745 |
+
2024-12-24 13:58:11 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/threading.py", line 1567, in _shutdown
|
746 |
+
2024-12-24 13:58:11 | ERROR | stderr | lock.acquire()
|
747 |
+
2024-12-24 13:58:11 | ERROR | stderr | KeyboardInterrupt:
|
748 |
+
2024-12-24 13:58:20 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py:34: FutureWarning: `Transformer2DModelOutput` is deprecated and will be removed in version 1.0.0. Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead.
|
749 |
+
2024-12-24 13:58:20 | ERROR | stderr | deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)
|
750 |
+
2024-12-24 13:58:21 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py:1003: UserWarning: Expected 12 arguments for function functools.partial(<function generate_b2i_annoy at 0x7f8ee8393250>, <bound method ModelManager.generate_image_b2i_parallel_anony of <model.model_manager.ModelManager object at 0x7f913195bd60>>), received 11.
|
751 |
+
2024-12-24 13:58:21 | ERROR | stderr | warnings.warn(
|
752 |
+
2024-12-24 13:58:21 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py:1007: UserWarning: Expected at least 12 arguments for function functools.partial(<function generate_b2i_annoy at 0x7f8ee8393250>, <bound method ModelManager.generate_image_b2i_parallel_anony of <model.model_manager.ModelManager object at 0x7f913195bd60>>), received 11.
|
753 |
+
2024-12-24 13:58:21 | ERROR | stderr | warnings.warn(
|
754 |
+
2024-12-24 13:58:22 | INFO | stdout | * Running on local URL: http://127.0.0.1:7860
|
755 |
+
2024-12-24 13:58:32 | INFO | stdout | background.shape (600, 600, 4)
|
756 |
+
2024-12-24 13:58:32 | INFO | stdout | len(layers) 1
|
757 |
+
2024-12-24 13:58:32 | INFO | stdout | composite.shape (600, 600, 4)
|
758 |
+
2024-12-24 13:58:33 | INFO | stdout | background.shape (600, 600, 4)
|
759 |
+
2024-12-24 13:58:33 | INFO | stdout | len(layers) 1
|
760 |
+
2024-12-24 13:58:33 | INFO | stdout | composite.shape (600, 600, 4)
|
761 |
+
2024-12-24 13:58:37 | ERROR | stderr | /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/helpers.py:968: UserWarning: Unexpected argument. Filling with None.
|
762 |
+
2024-12-24 13:58:37 | ERROR | stderr | warnings.warn("Unexpected argument. Filling with None.")
|
763 |
+
2024-12-24 13:58:37 | INFO | stdout | [0]
|
764 |
+
2024-12-24 13:58:37 | INFO | stdout | ['replicate_SDXL_text2image']
|
765 |
+
2024-12-24 13:58:37 | ERROR | stderr | Traceback (most recent call last):
|
766 |
+
2024-12-24 13:58:37 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/queueing.py", line 625, in process_events
|
767 |
+
2024-12-24 13:58:37 | ERROR | stderr | response = await route_utils.call_process_api(
|
768 |
+
2024-12-24 13:58:37 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/route_utils.py", line 322, in call_process_api
|
769 |
+
2024-12-24 13:58:37 | ERROR | stderr | output = await app.get_blocks().process_api(
|
770 |
+
2024-12-24 13:58:37 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 2047, in process_api
|
771 |
+
2024-12-24 13:58:37 | ERROR | stderr | result = await self.call_function(
|
772 |
+
2024-12-24 13:58:37 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/blocks.py", line 1606, in call_function
|
773 |
+
2024-12-24 13:58:37 | ERROR | stderr | prediction = await utils.async_iteration(iterator)
|
774 |
+
2024-12-24 13:58:37 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 714, in async_iteration
|
775 |
+
2024-12-24 13:58:37 | ERROR | stderr | return await anext(iterator)
|
776 |
+
2024-12-24 13:58:37 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 708, in __anext__
|
777 |
+
2024-12-24 13:58:37 | ERROR | stderr | return await anyio.to_thread.run_sync(
|
778 |
+
2024-12-24 13:58:37 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync
|
779 |
+
2024-12-24 13:58:37 | ERROR | stderr | return await get_async_backend().run_sync_in_worker_thread(
|
780 |
+
2024-12-24 13:58:37 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2505, in run_sync_in_worker_thread
|
781 |
+
2024-12-24 13:58:37 | ERROR | stderr | return await future
|
782 |
+
2024-12-24 13:58:37 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 1005, in run
|
783 |
+
2024-12-24 13:58:37 | ERROR | stderr | result = context.run(func, *args)
|
784 |
+
2024-12-24 13:58:37 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 691, in run_sync_iterator_async
|
785 |
+
2024-12-24 13:58:37 | ERROR | stderr | return next(iterator)
|
786 |
+
2024-12-24 13:58:37 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/utils.py", line 852, in gen_wrapper
|
787 |
+
2024-12-24 13:58:37 | ERROR | stderr | response = next(iterator)
|
788 |
+
2024-12-24 13:58:37 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/serve/vote_utils.py", line 896, in generate_b2i_annoy
|
789 |
+
2024-12-24 13:58:37 | ERROR | stderr | = gen_func(text, grounding_instruction, out_imagebox, model_name0, model_name1, model_name2, model_name3)
|
790 |
+
2024-12-24 13:58:37 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/model_manager.py", line 104, in generate_image_b2i_parallel_anony
|
791 |
+
2024-12-24 13:58:37 | ERROR | stderr | results = [future.result() for future in futures]
|
792 |
+
2024-12-24 13:58:37 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/model_manager.py", line 104, in <listcomp>
|
793 |
+
2024-12-24 13:58:37 | ERROR | stderr | results = [future.result() for future in futures]
|
794 |
+
2024-12-24 13:58:37 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/concurrent/futures/_base.py", line 451, in result
|
795 |
+
2024-12-24 13:58:37 | ERROR | stderr | return self.__get_result()
|
796 |
+
2024-12-24 13:58:37 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
|
797 |
+
2024-12-24 13:58:37 | ERROR | stderr | raise self._exception
|
798 |
+
2024-12-24 13:58:37 | ERROR | stderr | File "/share/bcy/miniconda3/envs/Arena/lib/python3.10/concurrent/futures/thread.py", line 58, in run
|
799 |
+
2024-12-24 13:58:37 | ERROR | stderr | result = self.fn(*self.args, **self.kwargs)
|
800 |
+
2024-12-24 13:58:37 | ERROR | stderr | File "/home/bcy/projects/Arena/Control_Ability_Arena/model/model_manager.py", line 87, in generate_image_b2i
|
801 |
+
2024-12-24 13:58:37 | ERROR | stderr | return result
|
802 |
+
2024-12-24 13:58:37 | ERROR | stderr | UnboundLocalError: local variable 'result' referenced before assignment
|
803 |
+
2024-12-24 13:58:37 | INFO | stdout | Rank
|
804 |
+
2024-12-24 13:58:53 | INFO | stdout |
|
805 |
+
2024-12-24 13:58:53 | INFO | stdout | Could not create share link. Missing file: /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio/frpc_linux_amd64_v0.3.
|
806 |
+
2024-12-24 13:58:53 | INFO | stdout |
|
807 |
+
2024-12-24 13:58:53 | INFO | stdout | Please check your internet connection. This can happen if your antivirus software blocks the download of this file. You can install manually by following these steps:
|
808 |
+
2024-12-24 13:58:53 | INFO | stdout |
|
809 |
+
2024-12-24 13:58:53 | INFO | stdout | 1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64
|
810 |
+
2024-12-24 13:58:53 | INFO | stdout | 2. Rename the downloaded file to: frpc_linux_amd64_v0.3
|
811 |
+
2024-12-24 13:58:53 | INFO | stdout | 3. Move the file to this location: /share/bcy/miniconda3/envs/Arena/lib/python3.10/site-packages/gradio
|
ksort-logs/vote_log/gr_web_image_generation_multi.log
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2024-12-24 12:55:47 | INFO | gradio_web_server_image_generation_multi | generate. ip: None
|
2 |
+
2024-12-24 13:43:17 | INFO | gradio_web_server_image_generation_multi | generate. ip: None
|
3 |
+
2024-12-24 13:44:08 | INFO | gradio_web_server_image_generation_multi | generate. ip: None
|
4 |
+
2024-12-24 13:52:44 | INFO | gradio_web_server_image_generation_multi | generate. ip: None
|
5 |
+
2024-12-24 13:56:41 | INFO | gradio_web_server_image_generation_multi | generate. ip: None
|
6 |
+
2024-12-24 13:58:37 | INFO | gradio_web_server_image_generation_multi | generate. ip: None
|
ksort-logs/vote_log/gr_web_video_generation.log
ADDED
File without changes
|
ksort-logs/vote_log/gr_web_video_generation_multi.log
ADDED
File without changes
|
model/__init__.py
ADDED
File without changes
|
model/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (157 Bytes). View file
|
|
model/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (161 Bytes). View file
|
|
model/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (146 Bytes). View file
|
|
model/__pycache__/matchmaker.cpython-310.pyc
ADDED
Binary file (4.04 kB). View file
|
|
model/__pycache__/model_manager.cpython-310.pyc
ADDED
Binary file (9.9 kB). View file
|
|
model/__pycache__/model_registry.cpython-310.pyc
ADDED
Binary file (1.64 kB). View file
|
|
model/__pycache__/model_registry.cpython-312.pyc
ADDED
Binary file (2.72 kB). View file
|
|
model/__pycache__/model_registry.cpython-39.pyc
ADDED
Binary file (1.81 kB). View file
|
|
model/matchmaker.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import json
|
3 |
+
from trueskill import TrueSkill
|
4 |
+
import paramiko
|
5 |
+
import io, os
|
6 |
+
import sys
|
7 |
+
import random
|
8 |
+
|
9 |
+
sys.path.append('../')
|
10 |
+
from serve.constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_SKILL
|
11 |
+
trueskill_env = TrueSkill()
|
12 |
+
|
13 |
+
ssh_matchmaker_client = None
|
14 |
+
sftp_matchmaker_client = None
|
15 |
+
|
16 |
+
def create_ssh_matchmaker_client(server, port, user, password):
|
17 |
+
global ssh_matchmaker_client, sftp_matchmaker_client
|
18 |
+
ssh_matchmaker_client = paramiko.SSHClient()
|
19 |
+
ssh_matchmaker_client.load_system_host_keys()
|
20 |
+
ssh_matchmaker_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
21 |
+
ssh_matchmaker_client.connect(server, port, user, password)
|
22 |
+
|
23 |
+
transport = ssh_matchmaker_client.get_transport()
|
24 |
+
transport.set_keepalive(60)
|
25 |
+
|
26 |
+
sftp_matchmaker_client = ssh_matchmaker_client.open_sftp()
|
27 |
+
|
28 |
+
|
29 |
+
def is_connected():
|
30 |
+
global ssh_matchmaker_client, sftp_matchmaker_client
|
31 |
+
if ssh_matchmaker_client is None or sftp_matchmaker_client is None:
|
32 |
+
return False
|
33 |
+
if not ssh_matchmaker_client.get_transport().is_active():
|
34 |
+
return False
|
35 |
+
try:
|
36 |
+
sftp_matchmaker_client.listdir('.')
|
37 |
+
except Exception as e:
|
38 |
+
print(f"Error checking SFTP connection: {e}")
|
39 |
+
return False
|
40 |
+
return True
|
41 |
+
|
42 |
+
|
43 |
+
def ucb_score(trueskill_diff, t, n):
|
44 |
+
exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
|
45 |
+
ucb = -trueskill_diff + 1.0 * exploration_term
|
46 |
+
return ucb
|
47 |
+
|
48 |
+
|
49 |
+
def update_trueskill(ratings, ranks):
|
50 |
+
new_ratings = trueskill_env.rate(ratings, ranks)
|
51 |
+
return new_ratings
|
52 |
+
|
53 |
+
|
54 |
+
def serialize_rating(rating):
|
55 |
+
return {'mu': rating.mu, 'sigma': rating.sigma}
|
56 |
+
|
57 |
+
|
58 |
+
def deserialize_rating(rating_dict):
|
59 |
+
return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
|
60 |
+
|
61 |
+
|
62 |
+
def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
|
63 |
+
global sftp_matchmaker_client
|
64 |
+
if not is_connected():
|
65 |
+
create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
|
66 |
+
data = {
|
67 |
+
'ratings': [serialize_rating(r) for r in ratings],
|
68 |
+
'comparison_counts': comparison_counts.tolist(),
|
69 |
+
'total_comparisons': total_comparisons
|
70 |
+
}
|
71 |
+
json_data = json.dumps(data)
|
72 |
+
with sftp_matchmaker_client.open(SSH_SKILL, 'w') as f:
|
73 |
+
f.write(json_data)
|
74 |
+
|
75 |
+
|
76 |
+
def load_json_via_sftp():
|
77 |
+
global sftp_matchmaker_client
|
78 |
+
if not is_connected():
|
79 |
+
create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
|
80 |
+
with sftp_matchmaker_client.open(SSH_SKILL, 'r') as f:
|
81 |
+
data = json.load(f)
|
82 |
+
ratings = [deserialize_rating(r) for r in data['ratings']]
|
83 |
+
comparison_counts = np.array(data['comparison_counts'])
|
84 |
+
total_comparisons = data['total_comparisons']
|
85 |
+
return ratings, comparison_counts, total_comparisons
|
86 |
+
|
87 |
+
|
88 |
+
class RunningPivot(object):
|
89 |
+
running_pivot = []
|
90 |
+
|
91 |
+
|
92 |
+
def matchmaker(num_players, k_group=4, not_run=[]):
|
93 |
+
trueskill_env = TrueSkill()
|
94 |
+
|
95 |
+
ratings, comparison_counts, total_comparisons = load_json_via_sftp()
|
96 |
+
|
97 |
+
ratings = ratings[:num_players]
|
98 |
+
comparison_counts = comparison_counts[:num_players, :num_players]
|
99 |
+
|
100 |
+
# Randomly select a player
|
101 |
+
# selected_player = np.random.randint(0, num_players)
|
102 |
+
comparison_counts[RunningPivot.running_pivot, :] = float('inf')
|
103 |
+
comparison_counts[not_run, :] = float('inf')
|
104 |
+
selected_player = np.argmin(comparison_counts.sum(axis=1))
|
105 |
+
|
106 |
+
RunningPivot.running_pivot.append(selected_player)
|
107 |
+
RunningPivot.running_pivot = RunningPivot.running_pivot[-5:]
|
108 |
+
print(RunningPivot.running_pivot)
|
109 |
+
|
110 |
+
selected_trueskill_score = trueskill_env.expose(ratings[selected_player])
|
111 |
+
trueskill_scores = np.array([trueskill_env.expose(p) for p in ratings])
|
112 |
+
trueskill_diff = np.abs(trueskill_scores - selected_trueskill_score)
|
113 |
+
n = comparison_counts[selected_player]
|
114 |
+
ucb_scores = ucb_score(trueskill_diff, total_comparisons, n)
|
115 |
+
|
116 |
+
# Exclude self, select opponent with highest UCB score
|
117 |
+
ucb_scores[selected_player] = -float('inf')
|
118 |
+
ucb_scores[not_run] = -float('inf')
|
119 |
+
opponents = np.argsort(ucb_scores)[-k_group + 1:].tolist()
|
120 |
+
|
121 |
+
# Group players
|
122 |
+
model_ids = [selected_player] + opponents
|
123 |
+
|
124 |
+
random.shuffle(model_ids)
|
125 |
+
|
126 |
+
return model_ids
|
model/matchmaker_video.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import json
|
3 |
+
from trueskill import TrueSkill
|
4 |
+
import paramiko
|
5 |
+
import io, os
|
6 |
+
import sys
|
7 |
+
import random
|
8 |
+
|
9 |
+
sys.path.append('../')
|
10 |
+
from serve.constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_VIDEO_SKILL
|
11 |
+
trueskill_env = TrueSkill()
|
12 |
+
|
13 |
+
ssh_matchmaker_client = None
|
14 |
+
sftp_matchmaker_client = None
|
15 |
+
|
16 |
+
|
17 |
+
def create_ssh_matchmaker_client(server, port, user, password):
|
18 |
+
global ssh_matchmaker_client, sftp_matchmaker_client
|
19 |
+
ssh_matchmaker_client = paramiko.SSHClient()
|
20 |
+
ssh_matchmaker_client.load_system_host_keys()
|
21 |
+
ssh_matchmaker_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
22 |
+
ssh_matchmaker_client.connect(server, port, user, password)
|
23 |
+
|
24 |
+
transport = ssh_matchmaker_client.get_transport()
|
25 |
+
transport.set_keepalive(60)
|
26 |
+
|
27 |
+
sftp_matchmaker_client = ssh_matchmaker_client.open_sftp()
|
28 |
+
|
29 |
+
|
30 |
+
def is_connected():
|
31 |
+
global ssh_matchmaker_client, sftp_matchmaker_client
|
32 |
+
if ssh_matchmaker_client is None or sftp_matchmaker_client is None:
|
33 |
+
return False
|
34 |
+
if not ssh_matchmaker_client.get_transport().is_active():
|
35 |
+
return False
|
36 |
+
try:
|
37 |
+
sftp_matchmaker_client.listdir('.')
|
38 |
+
except Exception as e:
|
39 |
+
print(f"Error checking SFTP connection: {e}")
|
40 |
+
return False
|
41 |
+
return True
|
42 |
+
|
43 |
+
|
44 |
+
def ucb_score(trueskill_diff, t, n):
|
45 |
+
exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
|
46 |
+
ucb = -trueskill_diff + 1.0 * exploration_term
|
47 |
+
return ucb
|
48 |
+
|
49 |
+
|
50 |
+
def update_trueskill(ratings, ranks):
|
51 |
+
new_ratings = trueskill_env.rate(ratings, ranks)
|
52 |
+
return new_ratings
|
53 |
+
|
54 |
+
|
55 |
+
def serialize_rating(rating):
|
56 |
+
return {'mu': rating.mu, 'sigma': rating.sigma}
|
57 |
+
|
58 |
+
|
59 |
+
def deserialize_rating(rating_dict):
|
60 |
+
return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
|
61 |
+
|
62 |
+
|
63 |
+
def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
|
64 |
+
global sftp_matchmaker_client
|
65 |
+
if not is_connected():
|
66 |
+
create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
|
67 |
+
data = {
|
68 |
+
'ratings': [serialize_rating(r) for r in ratings],
|
69 |
+
'comparison_counts': comparison_counts.tolist(),
|
70 |
+
'total_comparisons': total_comparisons
|
71 |
+
}
|
72 |
+
json_data = json.dumps(data)
|
73 |
+
with sftp_matchmaker_client.open(SSH_VIDEO_SKILL, 'w') as f:
|
74 |
+
f.write(json_data)
|
75 |
+
|
76 |
+
|
77 |
+
def load_json_via_sftp():
|
78 |
+
global sftp_matchmaker_client
|
79 |
+
if not is_connected():
|
80 |
+
create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
|
81 |
+
with sftp_matchmaker_client.open(SSH_VIDEO_SKILL, 'r') as f:
|
82 |
+
data = json.load(f)
|
83 |
+
ratings = [deserialize_rating(r) for r in data['ratings']]
|
84 |
+
comparison_counts = np.array(data['comparison_counts'])
|
85 |
+
total_comparisons = data['total_comparisons']
|
86 |
+
return ratings, comparison_counts, total_comparisons
|
87 |
+
|
88 |
+
|
89 |
+
def matchmaker_video(num_players, k_group=4):
|
90 |
+
trueskill_env = TrueSkill()
|
91 |
+
|
92 |
+
ratings, comparison_counts, total_comparisons = load_json_via_sftp()
|
93 |
+
|
94 |
+
ratings = ratings[:num_players]
|
95 |
+
comparison_counts = comparison_counts[:num_players, :num_players]
|
96 |
+
|
97 |
+
selected_player = np.argmin(comparison_counts.sum(axis=1))
|
98 |
+
|
99 |
+
selected_trueskill_score = trueskill_env.expose(ratings[selected_player])
|
100 |
+
trueskill_scores = np.array([trueskill_env.expose(p) for p in ratings])
|
101 |
+
trueskill_diff = np.abs(trueskill_scores - selected_trueskill_score)
|
102 |
+
n = comparison_counts[selected_player]
|
103 |
+
ucb_scores = ucb_score(trueskill_diff, total_comparisons, n)
|
104 |
+
|
105 |
+
# Exclude self, select opponent with highest UCB score
|
106 |
+
ucb_scores[selected_player] = -float('inf')
|
107 |
+
|
108 |
+
excluded_players_1 = [7, 10]
|
109 |
+
excluded_players_2 = [6, 8, 9]
|
110 |
+
excluded_players = excluded_players_1 + excluded_players_2
|
111 |
+
if selected_player in excluded_players_1:
|
112 |
+
for player in excluded_players:
|
113 |
+
ucb_scores[player] = -float('inf')
|
114 |
+
if selected_player in excluded_players_2:
|
115 |
+
for player in excluded_players_1:
|
116 |
+
ucb_scores[player] = -float('inf')
|
117 |
+
else:
|
118 |
+
excluded_ucb_scores = {player: ucb_scores[player] for player in excluded_players}
|
119 |
+
max_player = max(excluded_ucb_scores, key=excluded_ucb_scores.get)
|
120 |
+
if max_player in excluded_players_1:
|
121 |
+
for player in excluded_players:
|
122 |
+
if player != max_player:
|
123 |
+
ucb_scores[player] = -float('inf')
|
124 |
+
else:
|
125 |
+
for player in excluded_players_1:
|
126 |
+
ucb_scores[player] = -float('inf')
|
127 |
+
|
128 |
+
|
129 |
+
opponents = np.argsort(ucb_scores)[-k_group + 1:].tolist()
|
130 |
+
|
131 |
+
# Group players
|
132 |
+
model_ids = [selected_player] + opponents
|
133 |
+
|
134 |
+
random.shuffle(model_ids)
|
135 |
+
|
136 |
+
return model_ids
|
model/model_manager.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import concurrent.futures
|
2 |
+
import random
|
3 |
+
import gradio as gr
|
4 |
+
import requests, os
|
5 |
+
import io, base64, json
|
6 |
+
import spaces
|
7 |
+
import torch
|
8 |
+
from PIL import Image
|
9 |
+
from openai import OpenAI
|
10 |
+
from .models import IMAGE_GENERATION_MODELS, VIDEO_GENERATION_MODELS, B2I_MODELS, load_pipeline
|
11 |
+
from serve.upload import get_random_mscoco_prompt, get_random_video_prompt, get_ssh_random_video_prompt, get_ssh_random_image_prompt
|
12 |
+
from serve.constants import SSH_CACHE_OPENSOURCE, SSH_CACHE_ADVANCE, SSH_CACHE_PIKA, SSH_CACHE_SORA, SSH_CACHE_IMAGE
|
13 |
+
|
14 |
+
|
15 |
+
class ModelManager:
|
16 |
+
def __init__(self):
|
17 |
+
self.model_ig_list = IMAGE_GENERATION_MODELS
|
18 |
+
self.model_ie_list = [] #IMAGE_EDITION_MODELS
|
19 |
+
self.model_vg_list = VIDEO_GENERATION_MODELS
|
20 |
+
self.model_b2i_list = B2I_MODELS
|
21 |
+
self.loaded_models = {}
|
22 |
+
|
23 |
+
def load_model_pipe(self, model_name):
|
24 |
+
if not model_name in self.loaded_models:
|
25 |
+
pipe = load_pipeline(model_name)
|
26 |
+
self.loaded_models[model_name] = pipe
|
27 |
+
else:
|
28 |
+
pipe = self.loaded_models[model_name]
|
29 |
+
return pipe
|
30 |
+
|
31 |
+
@spaces.GPU(duration=120)
|
32 |
+
def generate_image_ig(self, prompt, model_name):
|
33 |
+
pipe = self.load_model_pipe(model_name)
|
34 |
+
if 'Stable-cascade' not in model_name:
|
35 |
+
result = pipe(prompt=prompt).images[0]
|
36 |
+
else:
|
37 |
+
prior, decoder = pipe
|
38 |
+
prior.enable_model_cpu_offload()
|
39 |
+
prior_output = prior(
|
40 |
+
prompt=prompt,
|
41 |
+
height=512,
|
42 |
+
width=512,
|
43 |
+
negative_prompt='',
|
44 |
+
guidance_scale=4.0,
|
45 |
+
num_images_per_prompt=1,
|
46 |
+
num_inference_steps=20
|
47 |
+
)
|
48 |
+
decoder.enable_model_cpu_offload()
|
49 |
+
result = decoder(
|
50 |
+
image_embeddings=prior_output.image_embeddings.to(torch.float16),
|
51 |
+
prompt=prompt,
|
52 |
+
negative_prompt='',
|
53 |
+
guidance_scale=0.0,
|
54 |
+
output_type="pil",
|
55 |
+
num_inference_steps=10
|
56 |
+
).images[0]
|
57 |
+
return result
|
58 |
+
|
59 |
+
def generate_image_ig_api(self, prompt, model_name):
|
60 |
+
pipe = self.load_model_pipe(model_name)
|
61 |
+
result = pipe(prompt=prompt)
|
62 |
+
return result
|
63 |
+
|
64 |
+
def generate_image_ig_parallel_anony(self, prompt, model_A, model_B, model_C, model_D):
|
65 |
+
if model_A == "" and model_B == "" and model_C == "" and model_D == "":
|
66 |
+
from .matchmaker import matchmaker
|
67 |
+
not_run = [20,21,22, 25,26, 30] #12,13,14,15,16,17,18,19,20,21,22, #23,24,
|
68 |
+
model_ids = matchmaker(num_players=len(self.model_ig_list), not_run=not_run)
|
69 |
+
print(model_ids)
|
70 |
+
model_names = [self.model_ig_list[i] for i in model_ids]
|
71 |
+
print(model_names)
|
72 |
+
else:
|
73 |
+
model_names = [model_A, model_B, model_C, model_D]
|
74 |
+
|
75 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
76 |
+
futures = [executor.submit(self.generate_image_ig, prompt, model) if model.startswith("huggingface")
|
77 |
+
else executor.submit(self.generate_image_ig_api, prompt, model) for model in model_names]
|
78 |
+
results = [future.result() for future in futures]
|
79 |
+
|
80 |
+
return results[0], results[1], results[2], results[3], \
|
81 |
+
model_names[0], model_names[1], model_names[2], model_names[3]
|
82 |
+
|
83 |
+
def generate_image_b2i(self, prompt, grounding_instruction, bbox, model_name):
|
84 |
+
pipe = self.load_model_pipe(model_name)
|
85 |
+
if model_name == "local_MIGC_b2i":
|
86 |
+
from model_bbox.MIGC.inference_single_image import inference_image
|
87 |
+
result = inference_image(pipe, prompt, grounding_instruction, bbox)
|
88 |
+
elif model_name == "huggingface_ReCo_b2i":
|
89 |
+
from model_bbox.ReCo.inference import inference_image
|
90 |
+
result = inference_image(pipe, prompt, grounding_instruction, bbox)
|
91 |
+
return result
|
92 |
+
|
93 |
+
|
94 |
+
def generate_image_b2i_parallel_anony(self, prompt, grounding_instruction, bbox, model_A, model_B, model_C, model_D):
|
95 |
+
if model_A == "" and model_B == "" and model_C == "" and model_D == "":
|
96 |
+
from .matchmaker import matchmaker
|
97 |
+
not_run = [] #12,13,14,15,16,17,18,19,20,21,22, #23,24,
|
98 |
+
# model_ids = matchmaker(num_players=len(self.model_ig_list), not_run=not_run)
|
99 |
+
model_ids = [0, 1]
|
100 |
+
print(model_ids)
|
101 |
+
model_names = [self.model_b2i_list[i] for i in model_ids]
|
102 |
+
print(model_names)
|
103 |
+
else:
|
104 |
+
model_names = [model_A, model_B, model_C, model_D]
|
105 |
+
|
106 |
+
from concurrent.futures import ProcessPoolExecutor
|
107 |
+
with ProcessPoolExecutor() as executor:
|
108 |
+
futures = [executor.submit(self.generate_image_b2i, prompt, grounding_instruction, bbox, model)
|
109 |
+
for model in model_names]
|
110 |
+
results = [future.result() for future in futures]
|
111 |
+
|
112 |
+
# with concurrent.futures.ThreadPoolExecutor() as executor:
|
113 |
+
# futures = [executor.submit(self.generate_image_b2i, prompt, grounding_instruction, bbox, model) for model in model_names]
|
114 |
+
# results = [future.result() for future in futures]
|
115 |
+
|
116 |
+
blank_image = None
|
117 |
+
final_results = []
|
118 |
+
for i in range(4):
|
119 |
+
if i < len(model_ids):
|
120 |
+
# 如果是有效模型,返回相应的生成结果
|
121 |
+
final_results.append(results[i])
|
122 |
+
else:
|
123 |
+
# 如果没有生成结果,则返回空白图像
|
124 |
+
final_results.append(blank_image)
|
125 |
+
final_model_names = []
|
126 |
+
for i in range(4):
|
127 |
+
if i < len(model_ids):
|
128 |
+
final_model_names.append(model_names[i])
|
129 |
+
else:
|
130 |
+
final_model_names.append("")
|
131 |
+
|
132 |
+
return final_results[0], final_results[1], final_results[2], final_results[3], \
|
133 |
+
final_model_names[0], final_model_names[1], final_model_names[2], final_model_names[3]
|
134 |
+
|
135 |
+
def generate_image_ig_cache_anony(self, model_A, model_B, model_C, model_D):
|
136 |
+
if model_A == "" and model_B == "" and model_C == "" and model_D == "":
|
137 |
+
from .matchmaker import matchmaker
|
138 |
+
not_run = [20,21,22]
|
139 |
+
model_ids = matchmaker(num_players=len(self.model_ig_list), not_run=not_run)
|
140 |
+
print(model_ids)
|
141 |
+
model_names = [self.model_ig_list[i] for i in model_ids]
|
142 |
+
print(model_names)
|
143 |
+
else:
|
144 |
+
model_names = [model_A, model_B, model_C, model_D]
|
145 |
+
|
146 |
+
root_dir = SSH_CACHE_IMAGE
|
147 |
+
local_dir = "./cache_image"
|
148 |
+
if not os.path.exists(local_dir):
|
149 |
+
os.makedirs(local_dir)
|
150 |
+
prompt, results = get_ssh_random_image_prompt(root_dir, local_dir, model_names)
|
151 |
+
|
152 |
+
return results[0], results[1], results[2], results[3], \
|
153 |
+
model_names[0], model_names[1], model_names[2], model_names[3], prompt
|
154 |
+
|
155 |
+
def generate_video_vg_parallel_anony(self, model_A, model_B, model_C, model_D):
|
156 |
+
if model_A == "" and model_B == "" and model_C == "" and model_D == "":
|
157 |
+
# model_names = random.sample([model for model in self.model_vg_list], 4)
|
158 |
+
|
159 |
+
from .matchmaker_video import matchmaker_video
|
160 |
+
model_ids = matchmaker_video(num_players=len(self.model_vg_list))
|
161 |
+
print(model_ids)
|
162 |
+
model_names = [self.model_vg_list[i] for i in model_ids]
|
163 |
+
print(model_names)
|
164 |
+
else:
|
165 |
+
model_names = [model_A, model_B, model_C, model_D]
|
166 |
+
|
167 |
+
root_dir = SSH_CACHE_OPENSOURCE
|
168 |
+
for name in model_names:
|
169 |
+
if "Runway-Gen3" in name or "Runway-Gen2" in name or "Pika-v1.0" in name:
|
170 |
+
root_dir = SSH_CACHE_ADVANCE
|
171 |
+
elif "Pika-beta" in name:
|
172 |
+
root_dir = SSH_CACHE_PIKA
|
173 |
+
elif "Sora" in name and "OpenSora" not in name:
|
174 |
+
root_dir = SSH_CACHE_SORA
|
175 |
+
|
176 |
+
local_dir = "./cache_video"
|
177 |
+
if not os.path.exists(local_dir):
|
178 |
+
os.makedirs(local_dir)
|
179 |
+
prompt, results = get_ssh_random_video_prompt(root_dir, local_dir, model_names)
|
180 |
+
cache_dir = local_dir
|
181 |
+
|
182 |
+
return results[0], results[1], results[2], results[3], \
|
183 |
+
model_names[0], model_names[1], model_names[2], model_names[3], prompt, cache_dir
|
184 |
+
|
185 |
+
def generate_image_ig_museum_parallel_anony(self, model_A, model_B, model_C, model_D):
|
186 |
+
if model_A == "" and model_B == "" and model_C == "" and model_D == "":
|
187 |
+
# model_names = random.sample([model for model in self.model_ig_list], 4)
|
188 |
+
|
189 |
+
from .matchmaker import matchmaker
|
190 |
+
model_ids = matchmaker(num_players=len(self.model_ig_list))
|
191 |
+
print(model_ids)
|
192 |
+
model_names = [self.model_ig_list[i] for i in model_ids]
|
193 |
+
print(model_names)
|
194 |
+
else:
|
195 |
+
model_names = [model_A, model_B, model_C, model_D]
|
196 |
+
|
197 |
+
prompt = get_random_mscoco_prompt()
|
198 |
+
print(prompt)
|
199 |
+
|
200 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
201 |
+
futures = [executor.submit(self.generate_image_ig, prompt, model) if model.startswith("huggingface")
|
202 |
+
else executor.submit(self.generate_image_ig_api, prompt, model) for model in model_names]
|
203 |
+
results = [future.result() for future in futures]
|
204 |
+
|
205 |
+
return results[0], results[1], results[2], results[3], \
|
206 |
+
model_names[0], model_names[1], model_names[2], model_names[3], prompt
|
207 |
+
|
208 |
+
def generate_image_ig_parallel(self, prompt, model_A, model_B):
|
209 |
+
model_names = [model_A, model_B]
|
210 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
211 |
+
futures = [executor.submit(self.generate_image_ig, prompt, model) if model.startswith("imagenhub")
|
212 |
+
else executor.submit(self.generate_image_ig_api, prompt, model) for model in model_names]
|
213 |
+
results = [future.result() for future in futures]
|
214 |
+
return results[0], results[1]
|
215 |
+
|
216 |
+
@spaces.GPU(duration=200)
|
217 |
+
def generate_image_ie(self, textbox_source, textbox_target, textbox_instruct, source_image, model_name):
|
218 |
+
pipe = self.load_model_pipe(model_name)
|
219 |
+
result = pipe(src_image = source_image, src_prompt = textbox_source, target_prompt = textbox_target, instruct_prompt = textbox_instruct)
|
220 |
+
return result
|
221 |
+
|
222 |
+
def generate_image_ie_parallel(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B):
|
223 |
+
model_names = [model_A, model_B]
|
224 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
225 |
+
futures = [
|
226 |
+
executor.submit(self.generate_image_ie, textbox_source, textbox_target, textbox_instruct, source_image,
|
227 |
+
model) for model in model_names]
|
228 |
+
results = [future.result() for future in futures]
|
229 |
+
return results[0], results[1]
|
230 |
+
|
231 |
+
def generate_image_ie_parallel_anony(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B):
|
232 |
+
if model_A == "" and model_B == "":
|
233 |
+
model_names = random.sample([model for model in self.model_ie_list], 2)
|
234 |
+
else:
|
235 |
+
model_names = [model_A, model_B]
|
236 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
237 |
+
futures = [executor.submit(self.generate_image_ie, textbox_source, textbox_target, textbox_instruct, source_image, model) for model in model_names]
|
238 |
+
results = [future.result() for future in futures]
|
239 |
+
return results[0], results[1], model_names[0], model_names[1]
|
model/model_registry.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
ModelInfo = namedtuple("ModelInfo", ["simple_name", "link", "description"])
|
5 |
+
model_info = {}
|
6 |
+
|
7 |
+
def register_model_info(
|
8 |
+
full_names: List[str], simple_name: str, link: str, description: str
|
9 |
+
):
|
10 |
+
info = ModelInfo(simple_name, link, description)
|
11 |
+
|
12 |
+
for full_name in full_names:
|
13 |
+
model_info[full_name] = info
|
14 |
+
|
15 |
+
def get_model_info(name: str) -> ModelInfo:
|
16 |
+
if name in model_info:
|
17 |
+
return model_info[name]
|
18 |
+
else:
|
19 |
+
# To fix this, please use `register_model_info` to register your model
|
20 |
+
return ModelInfo(
|
21 |
+
name, "", "Register the description at fastchat/model/model_registry.py"
|
22 |
+
)
|
23 |
+
|
24 |
+
def get_model_description_md(model_list):
|
25 |
+
model_description_md = """
|
26 |
+
| | | | | | | | | | | |
|
27 |
+
| ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- |
|
28 |
+
"""
|
29 |
+
ct = 0
|
30 |
+
visited = set()
|
31 |
+
for i, name in enumerate(model_list):
|
32 |
+
model_source, model_name, model_type = name.split("_")
|
33 |
+
minfo = get_model_info(model_name)
|
34 |
+
if minfo.simple_name in visited:
|
35 |
+
continue
|
36 |
+
visited.add(minfo.simple_name)
|
37 |
+
# one_model_md = f"[{minfo.simple_name}]({minfo.link}): {minfo.description}"
|
38 |
+
one_model_md = f"{minfo.simple_name}"
|
39 |
+
|
40 |
+
if ct % 11 == 0:
|
41 |
+
model_description_md += "|"
|
42 |
+
model_description_md += f" {one_model_md} |"
|
43 |
+
if ct % 11 == 10:
|
44 |
+
model_description_md += "\n"
|
45 |
+
ct += 1
|
46 |
+
return model_description_md
|
47 |
+
|
48 |
+
def get_video_model_description_md(model_list):
|
49 |
+
model_description_md = """
|
50 |
+
| | | | | | |
|
51 |
+
| ---- | ---- | ---- | ---- | ---- | ---- |
|
52 |
+
"""
|
53 |
+
ct = 0
|
54 |
+
visited = set()
|
55 |
+
for i, name in enumerate(model_list):
|
56 |
+
model_source, model_name, model_type = name.split("_")
|
57 |
+
minfo = get_model_info(model_name)
|
58 |
+
if minfo.simple_name in visited:
|
59 |
+
continue
|
60 |
+
visited.add(minfo.simple_name)
|
61 |
+
# one_model_md = f"[{minfo.simple_name}]({minfo.link}): {minfo.description}"
|
62 |
+
one_model_md = f"{minfo.simple_name}"
|
63 |
+
|
64 |
+
if ct % 7 == 0:
|
65 |
+
model_description_md += "|"
|
66 |
+
model_description_md += f" {one_model_md} |"
|
67 |
+
if ct % 7 == 6:
|
68 |
+
model_description_md += "\n"
|
69 |
+
ct += 1
|
70 |
+
return model_description_md
|
model/models/__init__.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .huggingface_models import load_huggingface_model
|
2 |
+
from .replicate_api_models import load_replicate_model
|
3 |
+
from .openai_api_models import load_openai_model
|
4 |
+
from .other_api_models import load_other_model
|
5 |
+
from .local_models import load_local_model
|
6 |
+
|
7 |
+
|
8 |
+
IMAGE_GENERATION_MODELS = [
|
9 |
+
'replicate_SDXL_text2image',
|
10 |
+
'replicate_SD-v3.0_text2image',
|
11 |
+
'replicate_SD-v2.1_text2image',
|
12 |
+
'replicate_SD-v1.5_text2image',
|
13 |
+
'replicate_SDXL-Lightning_text2image',
|
14 |
+
'replicate_Kandinsky-v2.0_text2image',
|
15 |
+
'replicate_Kandinsky-v2.2_text2image',
|
16 |
+
'replicate_Proteus-v0.2_text2image',
|
17 |
+
'replicate_Playground-v2.0_text2image',
|
18 |
+
'replicate_Playground-v2.5_text2image',
|
19 |
+
'replicate_Dreamshaper-xl-turbo_text2image',
|
20 |
+
'replicate_SDXL-Deepcache_text2image',
|
21 |
+
'replicate_Openjourney-v4_text2image',
|
22 |
+
'replicate_LCM-v1.5_text2image',
|
23 |
+
'replicate_Realvisxl-v3.0_text2image',
|
24 |
+
'replicate_Realvisxl-v2.0_text2image',
|
25 |
+
'replicate_Pixart-Sigma_text2image',
|
26 |
+
'replicate_SSD-1b_text2image',
|
27 |
+
'replicate_Open-Dalle-v1.1_text2image',
|
28 |
+
'replicate_Deepfloyd-IF_text2image',
|
29 |
+
'huggingface_SD-turbo_text2image',
|
30 |
+
'huggingface_SDXL-turbo_text2image',
|
31 |
+
'huggingface_Stable-cascade_text2image',
|
32 |
+
'openai_Dalle-2_text2image',
|
33 |
+
'openai_Dalle-3_text2image',
|
34 |
+
'other_Midjourney-v6.0_text2image',
|
35 |
+
'other_Midjourney-v5.0_text2image',
|
36 |
+
"replicate_FLUX.1-schnell_text2image",
|
37 |
+
"replicate_FLUX.1-pro_text2image",
|
38 |
+
"replicate_FLUX.1-dev_text2image",
|
39 |
+
'other_Meissonic_text2image',
|
40 |
+
"replicate_FLUX-1.1-pro_text2image",
|
41 |
+
'replicate_SD-v3.5-large_text2image',
|
42 |
+
'replicate_SD-v3.5-large-turbo_text2image',
|
43 |
+
]
|
44 |
+
|
45 |
+
VIDEO_GENERATION_MODELS = ['replicate_Zeroscope-v2-xl_text2video',
|
46 |
+
'replicate_Animate-Diff_text2video',
|
47 |
+
'replicate_OpenSora_text2video',
|
48 |
+
'replicate_LaVie_text2video',
|
49 |
+
'replicate_VideoCrafter2_text2video',
|
50 |
+
'replicate_Stable-Video-Diffusion_text2video',
|
51 |
+
'other_Runway-Gen3_text2video',
|
52 |
+
'other_Pika-beta_text2video',
|
53 |
+
'other_Pika-v1.0_text2video',
|
54 |
+
'other_Runway-Gen2_text2video',
|
55 |
+
'other_Sora_text2video',
|
56 |
+
'replicate_Cogvideox-5b_text2video',
|
57 |
+
'other_KLing-v1.0_text2video',
|
58 |
+
]
|
59 |
+
|
60 |
+
B2I_MODELS = ['local_MIGC_b2i', 'huggingface_ReCo_b2i']
|
61 |
+
|
62 |
+
|
63 |
+
def load_pipeline(model_name):
|
64 |
+
"""
|
65 |
+
Load a model pipeline based on the model name
|
66 |
+
Args:
|
67 |
+
model_name (str): The name of the model to load, should be of the form {source}_{name}_{type}
|
68 |
+
"""
|
69 |
+
model_source, model_name, model_type = model_name.split("_")
|
70 |
+
|
71 |
+
if model_source == "replicate":
|
72 |
+
pipe = load_replicate_model(model_name, model_type)
|
73 |
+
elif model_source == "huggingface":
|
74 |
+
pipe = load_huggingface_model(model_name, model_type)
|
75 |
+
elif model_source == "openai":
|
76 |
+
pipe = load_openai_model(model_name, model_type)
|
77 |
+
elif model_source == "other":
|
78 |
+
pipe = load_other_model(model_name, model_type)
|
79 |
+
elif model_source == "local":
|
80 |
+
pipe = load_local_model(model_name, model_type)
|
81 |
+
else:
|
82 |
+
raise ValueError(f"Model source {model_source} not supported")
|
83 |
+
return pipe
|
model/models/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (2.79 kB). View file
|
|
model/models/__pycache__/huggingface_models.cpython-310.pyc
ADDED
Binary file (1.8 kB). View file
|
|
model/models/__pycache__/local_models.cpython-310.pyc
ADDED
Binary file (578 Bytes). View file
|
|
model/models/__pycache__/openai_api_models.cpython-310.pyc
ADDED
Binary file (1.6 kB). View file
|
|
model/models/__pycache__/other_api_models.cpython-310.pyc
ADDED
Binary file (2.56 kB). View file
|
|
model/models/__pycache__/replicate_api_models.cpython-310.pyc
ADDED
Binary file (6.31 kB). View file
|
|
model/models/huggingface_models.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import DiffusionPipeline
|
2 |
+
from diffusers import AutoPipelineForText2Image
|
3 |
+
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
|
4 |
+
from diffusers import StableDiffusionPipeline
|
5 |
+
import torch
|
6 |
+
import os
|
7 |
+
|
8 |
+
|
9 |
+
def load_huggingface_model(model_name, model_type):
|
10 |
+
if model_name == "SD-turbo":
|
11 |
+
pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo", torch_dtype=torch.float16, variant="fp16")
|
12 |
+
pipe = pipe.to("cuda")
|
13 |
+
elif model_name == "SDXL-turbo":
|
14 |
+
pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
|
15 |
+
pipe = pipe.to("cuda")
|
16 |
+
elif model_name == "Stable-cascade":
|
17 |
+
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=torch.bfloat16)
|
18 |
+
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.float16)
|
19 |
+
pipe = [prior, decoder]
|
20 |
+
elif model_name == "ReCo":
|
21 |
+
path = '/home/bcy/cache/.cache/huggingface/hub/models--j-min--reco_sd14_coco/snapshots/11a062da5a0a84501047cb19e113f520eb610415' if os.path.isdir('/home/bcy/cache/.cache/huggingface/hub/models--j-min--reco_sd14_coco/snapshots/11a062da5a0a84501047cb19e113f520eb610415') else "CompVis/stable-diffusion-v1-4"
|
22 |
+
pipe = StableDiffusionPipeline.from_pretrained(path ,torch_dtype=torch.float16)
|
23 |
+
pipe = pipe.to("cuda")
|
24 |
+
else:
|
25 |
+
raise NotImplementedError
|
26 |
+
# if model_name == "SD-turbo":
|
27 |
+
# pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo")
|
28 |
+
# elif model_name == "SDXL-turbo":
|
29 |
+
# pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
|
30 |
+
# else:
|
31 |
+
# raise NotImplementedError
|
32 |
+
# pipe = pipe.to("cpu")
|
33 |
+
return pipe
|
34 |
+
|
35 |
+
|
36 |
+
if __name__ == "__main__":
|
37 |
+
# for name in ["SD-turbo", "SDXL-turbo"]: #"SD-turbo", "SDXL-turbo"
|
38 |
+
# pipe = load_huggingface_model(name, "text2image")
|
39 |
+
|
40 |
+
# for name in ["IF-I-XL-v1.0"]:
|
41 |
+
# pipe = load_huggingface_model(name, 'text2image')
|
42 |
+
# pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
|
43 |
+
|
44 |
+
prompt = 'draw a tiger'
|
45 |
+
pipe = load_huggingface_model('Stable-cascade', "text2image")
|
46 |
+
prior, decoder = pipe
|
47 |
+
prior.enable_model_cpu_offload()
|
48 |
+
prior_output = prior(
|
49 |
+
prompt=prompt,
|
50 |
+
height=512,
|
51 |
+
width=512,
|
52 |
+
negative_prompt='',
|
53 |
+
guidance_scale=4.0,
|
54 |
+
num_images_per_prompt=1,
|
55 |
+
num_inference_steps=20
|
56 |
+
)
|
57 |
+
decoder.enable_model_cpu_offload()
|
58 |
+
result = decoder(
|
59 |
+
image_embeddings=prior_output.image_embeddings.to(torch.float16),
|
60 |
+
prompt=prompt,
|
61 |
+
negative_prompt='',
|
62 |
+
guidance_scale=0.0,
|
63 |
+
output_type="pil",
|
64 |
+
num_inference_steps=10
|
65 |
+
).images[0]
|
model/models/local_models.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
migc_path = os.path.dirname(os.path.abspath(__file__))
|
5 |
+
print(migc_path)
|
6 |
+
if migc_path not in sys.path:
|
7 |
+
sys.path.append(migc_path)
|
8 |
+
|
9 |
+
from model_bbox.MIGC.inference_single_image import MIGC_Pipe
|
10 |
+
|
11 |
+
def load_local_model(model_name, model_type):
|
12 |
+
if model_name == "MIGC":
|
13 |
+
pipe = MIGC_Pipe()
|
14 |
+
else:
|
15 |
+
raise NotImplementedError
|
16 |
+
return pipe
|
model/models/openai_api_models.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from openai import OpenAI
|
2 |
+
from PIL import Image
|
3 |
+
import requests
|
4 |
+
import io
|
5 |
+
import os
|
6 |
+
import base64
|
7 |
+
|
8 |
+
|
9 |
+
class OpenaiModel():
|
10 |
+
def __init__(self, model_name, model_type):
|
11 |
+
self.model_name = model_name
|
12 |
+
self.model_type = model_type
|
13 |
+
|
14 |
+
def __call__(self, *args, **kwargs):
|
15 |
+
if self.model_type == "text2image":
|
16 |
+
assert "prompt" in kwargs, "prompt is required for text2image model"
|
17 |
+
|
18 |
+
client = OpenAI()
|
19 |
+
|
20 |
+
if 'Dalle-3' in self.model_name:
|
21 |
+
client = OpenAI()
|
22 |
+
response = client.images.generate(
|
23 |
+
model="dall-e-3",
|
24 |
+
prompt=kwargs["prompt"],
|
25 |
+
size="1024x1024",
|
26 |
+
quality="standard",
|
27 |
+
n=1,
|
28 |
+
)
|
29 |
+
elif 'Dalle-2' in self.model_name:
|
30 |
+
client = OpenAI()
|
31 |
+
response = client.images.generate(
|
32 |
+
model="dall-e-2",
|
33 |
+
prompt=kwargs["prompt"],
|
34 |
+
size="512x512",
|
35 |
+
quality="standard",
|
36 |
+
n=1,
|
37 |
+
)
|
38 |
+
else:
|
39 |
+
raise NotImplementedError
|
40 |
+
|
41 |
+
result_url = response.data[0].url
|
42 |
+
response = requests.get(result_url)
|
43 |
+
result = Image.open(io.BytesIO(response.content))
|
44 |
+
return result
|
45 |
+
else:
|
46 |
+
raise ValueError("model_type must be text2image or image2image")
|
47 |
+
|
48 |
+
|
49 |
+
def load_openai_model(model_name, model_type):
|
50 |
+
return OpenaiModel(model_name, model_type)
|
51 |
+
|
52 |
+
|
53 |
+
if __name__ == "__main__":
|
54 |
+
pipe = load_openai_model('Dalle-3', 'text2image')
|
55 |
+
result = pipe(prompt='draw a tiger')
|
56 |
+
print(result)
|
57 |
+
|
model/models/other_api_models.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
from PIL import Image
|
5 |
+
import io, time
|
6 |
+
|
7 |
+
|
8 |
+
class OtherModel():
|
9 |
+
def __init__(self, model_name, model_type):
|
10 |
+
self.model_name = model_name
|
11 |
+
self.model_type = model_type
|
12 |
+
self.image_url = "https://www.xdai.online/mj/submit/imagine"
|
13 |
+
self.key = os.environ.get('MIDJOURNEY_KEY')
|
14 |
+
self.get_image_url = "https://www.xdai.online/mj/image/"
|
15 |
+
self.repeat_num = 5
|
16 |
+
|
17 |
+
def __call__(self, *args, **kwargs):
|
18 |
+
if self.model_type == "text2image":
|
19 |
+
assert "prompt" in kwargs, "prompt is required for text2image model"
|
20 |
+
if self.model_name == "Midjourney-v6.0":
|
21 |
+
data = {
|
22 |
+
"base64Array": [],
|
23 |
+
"notifyHook": "",
|
24 |
+
"prompt": "{} --v 6.0".format(kwargs["prompt"]),
|
25 |
+
"state": "",
|
26 |
+
"botType": "MID_JOURNEY",
|
27 |
+
}
|
28 |
+
elif self.model_name == "Midjourney-v5.0":
|
29 |
+
data = {
|
30 |
+
"base64Array": [],
|
31 |
+
"notifyHook": "",
|
32 |
+
"prompt": "{} --v 5.0".format(kwargs["prompt"]),
|
33 |
+
"state": "",
|
34 |
+
"botType": "MID_JOURNEY",
|
35 |
+
}
|
36 |
+
else:
|
37 |
+
raise NotImplementedError
|
38 |
+
|
39 |
+
headers = {
|
40 |
+
"Authorization": "Bearer {}".format(self.key),
|
41 |
+
"Content-Type": "application/json"
|
42 |
+
}
|
43 |
+
while 1:
|
44 |
+
response = requests.post(self.image_url, data=json.dumps(data), headers=headers)
|
45 |
+
if response.status_code == 200:
|
46 |
+
print("Submit success!")
|
47 |
+
response_json = json.loads(response.content.decode('utf-8'))
|
48 |
+
img_id = response_json["result"]
|
49 |
+
result_url = self.get_image_url + img_id
|
50 |
+
print(result_url)
|
51 |
+
self.repeat_num = 800
|
52 |
+
while 1:
|
53 |
+
time.sleep(1)
|
54 |
+
img_response = requests.get(result_url)
|
55 |
+
if img_response.status_code == 200:
|
56 |
+
result = Image.open(io.BytesIO(img_response.content))
|
57 |
+
width, height = result.size
|
58 |
+
new_width = width // 2
|
59 |
+
new_height = height // 2
|
60 |
+
result = result.crop((0, 0, new_width, new_height))
|
61 |
+
self.repeat_num = 5
|
62 |
+
return result
|
63 |
+
else:
|
64 |
+
self.repeat_num = self.repeat_num - 1
|
65 |
+
if self.repeat_num == 0:
|
66 |
+
raise ValueError("Image request failed.")
|
67 |
+
continue
|
68 |
+
|
69 |
+
else:
|
70 |
+
self.repeat_num = self.repeat_num - 1
|
71 |
+
if self.repeat_num == 0:
|
72 |
+
raise ValueError("API request failed.")
|
73 |
+
continue
|
74 |
+
if self.model_type == "text2video":
|
75 |
+
assert "prompt" in kwargs, "prompt is required for text2video model"
|
76 |
+
|
77 |
+
else:
|
78 |
+
raise ValueError("model_type must be text2image")
|
79 |
+
|
80 |
+
|
81 |
+
def load_other_model(model_name, model_type):
|
82 |
+
return OtherModel(model_name, model_type)
|
83 |
+
|
84 |
+
if __name__ == "__main__":
|
85 |
+
import http.client
|
86 |
+
import json
|
87 |
+
|
88 |
+
pipe = load_other_model("Midjourney-v5.0", "text2image")
|
89 |
+
result = pipe(prompt="An Impressionist illustration depicts a river winding through a meadow")
|
90 |
+
print(result)
|
91 |
+
exit()
|
model/models/replicate_api_models.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import replicate
|
2 |
+
from PIL import Image
|
3 |
+
import requests
|
4 |
+
import io
|
5 |
+
import os
|
6 |
+
import base64
|
7 |
+
|
8 |
+
Replicate_MODEl_NAME_MAP = {
|
9 |
+
"SDXL": "stability-ai/sdxl:7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc",
|
10 |
+
"SD-v3.0": "stability-ai/stable-diffusion-3",
|
11 |
+
"SD-v2.1": "stability-ai/stable-diffusion:ac732df83cea7fff18b8472768c88ad041fa750ff7682a21affe81863cbe77e4",
|
12 |
+
"SD-v1.5": "stability-ai/stable-diffusion:b3d14e1cd1f9470bbb0bb68cac48e5f483e5be309551992cc33dc30654a82bb7",
|
13 |
+
"SDXL-Lightning": "bytedance/sdxl-lightning-4step:5f24084160c9089501c1b3545d9be3c27883ae2239b6f412990e82d4a6210f8f",
|
14 |
+
"Kandinsky-v2.0": "ai-forever/kandinsky-2:3c6374e7a9a17e01afe306a5218cc67de55b19ea536466d6ea2602cfecea40a9",
|
15 |
+
"Kandinsky-v2.2": "ai-forever/kandinsky-2.2:ad9d7879fbffa2874e1d909d1d37d9bc682889cc65b31f7bb00d2362619f194a",
|
16 |
+
"Proteus-v0.2": "lucataco/proteus-v0.2:06775cd262843edbde5abab958abdbb65a0a6b58ca301c9fd78fa55c775fc019",
|
17 |
+
"Playground-v2.0": "playgroundai/playground-v2-1024px-aesthetic:42fe626e41cc811eaf02c94b892774839268ce1994ea778eba97103fe1ef51b8",
|
18 |
+
"Playground-v2.5": "playgroundai/playground-v2.5-1024px-aesthetic:a45f82a1382bed5c7aeb861dac7c7d191b0fdf74d8d57c4a0e6ed7d4d0bf7d24",
|
19 |
+
"Dreamshaper-xl-turbo": "lucataco/dreamshaper-xl-turbo:0a1710e0187b01a255302738ca0158ff02a22f4638679533e111082f9dd1b615",
|
20 |
+
"SDXL-Deepcache": "lucataco/sdxl-deepcache:eaf678fb34006669e9a3c6dd5971e2279bf20ee0adeced464d7b6d95de16dc93",
|
21 |
+
"Openjourney-v4": "prompthero/openjourney:ad59ca21177f9e217b9075e7300cf6e14f7e5b4505b87b9689dbd866e9768969",
|
22 |
+
"LCM-v1.5": "fofr/latent-consistency-model:683d19dc312f7a9f0428b04429a9ccefd28dbf7785fef083ad5cf991b65f406f",
|
23 |
+
"Realvisxl-v3.0": "fofr/realvisxl-v3:33279060bbbb8858700eb2146350a98d96ef334fcf817f37eb05915e1534aa1c",
|
24 |
+
|
25 |
+
"Realvisxl-v2.0": "lucataco/realvisxl-v2.0:7d6a2f9c4754477b12c14ed2a58f89bb85128edcdd581d24ce58b6926029de08",
|
26 |
+
"Pixart-Sigma": "cjwbw/pixart-sigma:5a54352c99d9fef467986bc8f3a20205e8712cbd3df1cbae4975d6254c902de1",
|
27 |
+
"SSD-1b": "lucataco/ssd-1b:b19e3639452c59ce8295b82aba70a231404cb062f2eb580ea894b31e8ce5bbb6",
|
28 |
+
"Open-Dalle-v1.1": "lucataco/open-dalle-v1.1:1c7d4c8dec39c7306df7794b28419078cb9d18b9213ab1c21fdc46a1deca0144",
|
29 |
+
"Deepfloyd-IF": "andreasjansson/deepfloyd-if:fb84d659df149f4515c351e394d22222a94144aa1403870c36025c8b28846c8d",
|
30 |
+
|
31 |
+
"Zeroscope-v2-xl": "anotherjesse/zeroscope-v2-xl:9f747673945c62801b13b84701c783929c0ee784e4748ec062204894dda1a351",
|
32 |
+
# "Damo-Text-to-Video": "cjwbw/damo-text-to-video:1e205ea73084bd17a0a3b43396e49ba0d6bc2e754e9283b2df49fad2dcf95755",
|
33 |
+
"Animate-Diff": "lucataco/animate-diff:beecf59c4aee8d81bf04f0381033dfa10dc16e845b4ae00d281e2fa377e48a9f",
|
34 |
+
"OpenSora": "camenduru/open-sora:8099e5722ba3d5f408cd3e696e6df058137056268939337a3fbe3912e86e72ad",
|
35 |
+
"LaVie": "cjwbw/lavie:0bca850c4928b6c30052541fa002f24cbb4b677259c461dd041d271ba9d3c517",
|
36 |
+
"VideoCrafter2": "lucataco/video-crafter:7757c5775e962c618053e7df4343052a21075676d6234e8ede5fa67c9e43bce0",
|
37 |
+
"Stable-Video-Diffusion": "sunfjun/stable-video-diffusion:d68b6e09eedbac7a49e3d8644999d93579c386a083768235cabca88796d70d82",
|
38 |
+
"FLUX.1-schnell": "black-forest-labs/flux-schnell",
|
39 |
+
"FLUX.1-pro": "black-forest-labs/flux-pro",
|
40 |
+
"FLUX.1-dev": "black-forest-labs/flux-dev",
|
41 |
+
"FLUX-1.1-pro": "black-forest-labs/flux-1.1-pro",
|
42 |
+
"SD-v3.5-large": "stability-ai/stable-diffusion-3.5-large",
|
43 |
+
"SD-v3.5-large-turbo": "stability-ai/stable-diffusion-3.5-large-turbo",
|
44 |
+
}
|
45 |
+
|
46 |
+
|
47 |
+
class ReplicateModel():
|
48 |
+
def __init__(self, model_name, model_type):
|
49 |
+
self.model_name = model_name
|
50 |
+
self.model_type = model_type
|
51 |
+
|
52 |
+
def __call__(self, *args, **kwargs):
|
53 |
+
if self.model_type == "text2image":
|
54 |
+
assert "prompt" in kwargs, "prompt is required for text2image model"
|
55 |
+
output = replicate.run(
|
56 |
+
f"{Replicate_MODEl_NAME_MAP[self.model_name]}",
|
57 |
+
input={
|
58 |
+
"width": 512,
|
59 |
+
"height": 512,
|
60 |
+
"prompt": kwargs["prompt"]
|
61 |
+
},
|
62 |
+
)
|
63 |
+
if 'Openjourney' in self.model_name:
|
64 |
+
for item in output:
|
65 |
+
result_url = item
|
66 |
+
break
|
67 |
+
elif isinstance(output, list):
|
68 |
+
result_url = output[0]
|
69 |
+
else:
|
70 |
+
result_url = output
|
71 |
+
print(self.model_name, result_url)
|
72 |
+
response = requests.get(result_url)
|
73 |
+
result = Image.open(io.BytesIO(response.content))
|
74 |
+
return result
|
75 |
+
|
76 |
+
elif self.model_type == "text2video":
|
77 |
+
assert "prompt" in kwargs, "prompt is required for text2image model"
|
78 |
+
if self.model_name == "Zeroscope-v2-xl":
|
79 |
+
input = {
|
80 |
+
"fps": 24,
|
81 |
+
"width": 512,
|
82 |
+
"height": 512,
|
83 |
+
"prompt": kwargs["prompt"],
|
84 |
+
"guidance_scale": 17.5,
|
85 |
+
# "negative_prompt": "very blue, dust, noisy, washed out, ugly, distorted, broken",
|
86 |
+
"num_frames": 48,
|
87 |
+
}
|
88 |
+
elif self.model_name == "Damo-Text-to-Video":
|
89 |
+
input={
|
90 |
+
"fps": 8,
|
91 |
+
"prompt": kwargs["prompt"],
|
92 |
+
"num_frames": 16,
|
93 |
+
"num_inference_steps": 50
|
94 |
+
}
|
95 |
+
elif self.model_name == "Animate-Diff":
|
96 |
+
input={
|
97 |
+
"path": "toonyou_beta3.safetensors",
|
98 |
+
"seed": 255224557,
|
99 |
+
"steps": 25,
|
100 |
+
"prompt": kwargs["prompt"],
|
101 |
+
"n_prompt": "badhandv4, easynegative, ng_deepnegative_v1_75t, verybadimagenegative_v1.3, bad-artist, bad_prompt_version2-neg, teeth",
|
102 |
+
"motion_module": "mm_sd_v14",
|
103 |
+
"guidance_scale": 7.5
|
104 |
+
}
|
105 |
+
elif self.model_name == "OpenSora":
|
106 |
+
input={
|
107 |
+
"seed": 1234,
|
108 |
+
"prompt": kwargs["prompt"],
|
109 |
+
}
|
110 |
+
elif self.model_name == "LaVie":
|
111 |
+
input={
|
112 |
+
"width": 512,
|
113 |
+
"height": 512,
|
114 |
+
"prompt": kwargs["prompt"],
|
115 |
+
"quality": 9,
|
116 |
+
"video_fps": 8,
|
117 |
+
"interpolation": False,
|
118 |
+
"sample_method": "ddpm",
|
119 |
+
"guidance_scale": 7,
|
120 |
+
"super_resolution": False,
|
121 |
+
"num_inference_steps": 50
|
122 |
+
}
|
123 |
+
elif self.model_name == "VideoCrafter2":
|
124 |
+
input={
|
125 |
+
"fps": 24,
|
126 |
+
"seed": 64045,
|
127 |
+
"steps": 40,
|
128 |
+
"width": 512,
|
129 |
+
"height": 512,
|
130 |
+
"prompt": kwargs["prompt"],
|
131 |
+
}
|
132 |
+
elif self.model_name == "Stable-Video-Diffusion":
|
133 |
+
text2image_name = "SD-v2.1"
|
134 |
+
output = replicate.run(
|
135 |
+
f"{Replicate_MODEl_NAME_MAP[text2image_name]}",
|
136 |
+
input={
|
137 |
+
"width": 512,
|
138 |
+
"height": 512,
|
139 |
+
"prompt": kwargs["prompt"]
|
140 |
+
},
|
141 |
+
)
|
142 |
+
if isinstance(output, list):
|
143 |
+
image_url = output[0]
|
144 |
+
else:
|
145 |
+
image_url = output
|
146 |
+
print(image_url)
|
147 |
+
|
148 |
+
input={
|
149 |
+
"cond_aug": 0.02,
|
150 |
+
"decoding_t": 14,
|
151 |
+
"input_image": "{}".format(image_url),
|
152 |
+
"video_length": "14_frames_with_svd",
|
153 |
+
"sizing_strategy": "maintain_aspect_ratio",
|
154 |
+
"motion_bucket_id": 127,
|
155 |
+
"frames_per_second": 6
|
156 |
+
}
|
157 |
+
|
158 |
+
output = replicate.run(
|
159 |
+
f"{Replicate_MODEl_NAME_MAP[self.model_name]}",
|
160 |
+
input=input,
|
161 |
+
)
|
162 |
+
if isinstance(output, list):
|
163 |
+
result_url = output[0]
|
164 |
+
else:
|
165 |
+
result_url = output
|
166 |
+
print(self.model_name)
|
167 |
+
print(result_url)
|
168 |
+
# response = requests.get(result_url)
|
169 |
+
# result = Image.open(io.BytesIO(response.content))
|
170 |
+
|
171 |
+
# for event in handler.iter_events(with_logs=True):
|
172 |
+
# if isinstance(event, fal_client.InProgress):
|
173 |
+
# print('Request in progress')
|
174 |
+
# print(event.logs)
|
175 |
+
|
176 |
+
# result = handler.get()
|
177 |
+
# print("result video: ====")
|
178 |
+
# print(result)
|
179 |
+
# result_url = result['video']['url']
|
180 |
+
# return result_url
|
181 |
+
return result_url
|
182 |
+
else:
|
183 |
+
raise ValueError("model_type must be text2image or image2image")
|
184 |
+
|
185 |
+
|
186 |
+
def load_replicate_model(model_name, model_type):
|
187 |
+
return ReplicateModel(model_name, model_type)
|
188 |
+
|
189 |
+
|
190 |
+
if __name__ == "__main__":
|
191 |
+
model_name = 'replicate_zeroscope-v2-xl_text2video'
|
192 |
+
model_source, model_name, model_type = model_name.split("_")
|
193 |
+
pipe = load_replicate_model(model_name, model_type)
|
194 |
+
prompt = "Clown fish swimming in a coral reef, beautiful, 8k, perfect, award winning, national geographic"
|
195 |
+
result = pipe(prompt=prompt)
|
model_bbox/.gradio/certificate.pem
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-----BEGIN CERTIFICATE-----
|
2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
31 |
+
-----END CERTIFICATE-----
|
model_bbox/MIGC/__init__.py
ADDED
File without changes
|
model_bbox/MIGC/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (167 Bytes). View file
|
|
model_bbox/MIGC/__pycache__/inference_single_image.cpython-310.pyc
ADDED
Binary file (4.98 kB). View file
|
|
model_bbox/MIGC/inference_single_image.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
|
5 |
+
migc_path = os.path.dirname(os.path.abspath(__file__))
|
6 |
+
print(migc_path)
|
7 |
+
if migc_path not in sys.path:
|
8 |
+
sys.path.append(migc_path)
|
9 |
+
import yaml
|
10 |
+
from diffusers import EulerDiscreteScheduler
|
11 |
+
from migc.migc_utils import seed_everything
|
12 |
+
from migc.migc_pipeline import StableDiffusionMIGCPipeline, MIGCProcessor, AttentionStore
|
13 |
+
|
14 |
+
def normalize_bbox(bboxes, img_width, img_height):
|
15 |
+
normalized_bboxes = []
|
16 |
+
for box in bboxes:
|
17 |
+
x_min, y_min, x_max, y_max = box
|
18 |
+
|
19 |
+
x_min = x_min / img_width
|
20 |
+
y_min = y_min / img_height
|
21 |
+
x_max = x_max / img_width
|
22 |
+
y_max = y_max / img_height
|
23 |
+
|
24 |
+
normalized_bboxes.append([x_min, y_min, x_max, y_max])
|
25 |
+
|
26 |
+
return [normalized_bboxes]
|
27 |
+
|
28 |
+
def create_simple_prompt(input_str):
|
29 |
+
# 先将输入字符串按分号分割,并去掉空字符串
|
30 |
+
objects = [obj for obj in input_str.split(';') if obj.strip()]
|
31 |
+
|
32 |
+
# 创建详细描述字符串
|
33 |
+
prompt_description = "masterpiece, best quality, " + ", ".join(objects)
|
34 |
+
|
35 |
+
# 创建最终结构
|
36 |
+
prompt_final = [[prompt_description] + objects]
|
37 |
+
|
38 |
+
return prompt_final
|
39 |
+
|
40 |
+
|
41 |
+
def inference_single_image(prompt, grounding_instruction, state):
|
42 |
+
print(prompt)
|
43 |
+
print(grounding_instruction)
|
44 |
+
bbox = state['boxes']
|
45 |
+
print(bbox)
|
46 |
+
bbox = normalize_bbox(bbox, 600, 600)
|
47 |
+
print(bbox)
|
48 |
+
simple_prompt = create_simple_prompt(grounding_instruction)
|
49 |
+
print(simple_prompt)
|
50 |
+
migc_ckpt_path = 'pretrained_weights/MIGC_SD14.ckpt'
|
51 |
+
migc_ckpt_path_all = os.path.join(migc_path, migc_ckpt_path)
|
52 |
+
print(migc_ckpt_path_all)
|
53 |
+
assert os.path.isfile(migc_ckpt_path_all), "Please download the ckpt of migc and put it in the pretrained_weighrs/ folder!"
|
54 |
+
|
55 |
+
|
56 |
+
sd1x_path = '/share/bcy/cache/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b' if os.path.isdir('/share/bcy/cache/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b') else "CompVis/stable-diffusion-v1-4"
|
57 |
+
# MIGC is a plug-and-play controller.
|
58 |
+
# You can go to https://civitai.com/search/models?baseModel=SD%201.4&baseModel=SD%201.5&sortBy=models_v5 find a base model with better generation ability to achieve better creations.
|
59 |
+
|
60 |
+
# Construct MIGC pipeline
|
61 |
+
pipe = StableDiffusionMIGCPipeline.from_pretrained(
|
62 |
+
sd1x_path)
|
63 |
+
pipe.attention_store = AttentionStore()
|
64 |
+
from migc.migc_utils import load_migc
|
65 |
+
load_migc(pipe.unet , pipe.attention_store,
|
66 |
+
migc_ckpt_path_all, attn_processor=MIGCProcessor)
|
67 |
+
pipe = pipe.to("cuda")
|
68 |
+
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
69 |
+
|
70 |
+
|
71 |
+
# prompt_final = [['masterpiece, best quality,black colored ball,gray colored cat,white colored bed,\
|
72 |
+
# green colored plant,red colored teddy bear,blue colored wall,brown colored vase,orange colored book,\
|
73 |
+
# yellow colored hat', 'black colored ball', 'gray colored cat', 'white colored bed', 'green colored plant', \
|
74 |
+
# 'red colored teddy bear', 'blue colored wall', 'brown colored vase', 'orange colored book', 'yellow colored hat']]
|
75 |
+
|
76 |
+
# bboxes = [[[0.3125, 0.609375, 0.625, 0.875], [0.5625, 0.171875, 0.984375, 0.6875], \
|
77 |
+
# [0.0, 0.265625, 0.984375, 0.984375], [0.0, 0.015625, 0.21875, 0.328125], \
|
78 |
+
# [0.171875, 0.109375, 0.546875, 0.515625], [0.234375, 0.0, 1.0, 0.3125], \
|
79 |
+
# [0.71875, 0.625, 0.953125, 0.921875], [0.0625, 0.484375, 0.359375, 0.8125], \
|
80 |
+
# [0.609375, 0.09375, 0.90625, 0.28125]]]
|
81 |
+
negative_prompt = 'worst quality, low quality, bad anatomy, watermark, text, blurry'
|
82 |
+
seed = 7351007268695528845
|
83 |
+
seed_everything(seed)
|
84 |
+
print("Start inference: ")
|
85 |
+
image = pipe(simple_prompt, bbox, num_inference_steps=50, guidance_scale=7.5,
|
86 |
+
MIGCsteps=25, aug_phase_with_and=False, negative_prompt=negative_prompt).images[0]
|
87 |
+
return image
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
# def MIGC_Pipe():
|
93 |
+
# migc_ckpt_path = 'pretrained_weights/MIGC_SD14.ckpt'
|
94 |
+
# migc_ckpt_path_all = os.path.join(migc_path, migc_ckpt_path)
|
95 |
+
# print(migc_ckpt_path_all)
|
96 |
+
# assert os.path.isfile(migc_ckpt_path_all), "Please download the ckpt of migc and put it in the pretrained_weighrs/ folder!"
|
97 |
+
# sd1x_path = '/share/bcy/cache/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b' if os.path.isdir('/share/bcy/cache/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b') else "CompVis/stable-diffusion-v1-4"
|
98 |
+
# pipe = StableDiffusionMIGCPipeline.from_pretrained(
|
99 |
+
# sd1x_path)
|
100 |
+
# pipe.attention_store = AttentionStore()
|
101 |
+
# from migc.migc_utils import load_migc
|
102 |
+
# load_migc(pipe.unet , pipe.attention_store,
|
103 |
+
# migc_ckpt_path_all, attn_processor=MIGCProcessor)
|
104 |
+
# pipe = pipe.to("cuda")
|
105 |
+
# pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
106 |
+
# return pipe
|
107 |
+
|
108 |
+
def MIGC_Pipe():
|
109 |
+
migc_ckpt_path = 'pretrained_weights/MIGC_SD14.ckpt'
|
110 |
+
migc_ckpt_path_all = os.path.join(migc_path, migc_ckpt_path)
|
111 |
+
print(f"加载 MIGC 权重文件路径: {migc_ckpt_path_all}")
|
112 |
+
|
113 |
+
assert os.path.isfile(migc_ckpt_path_all), f"请下载 MIGC 的 ckpt 文件并将其放在 'pretrained_weights/' 文件夹中: {migc_ckpt_path_all}"
|
114 |
+
|
115 |
+
sd1x_path = '/share/bcy/cache/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b' if os.path.isdir('/share/bcy/cache/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b') else "CompVis/stable-diffusion-v1-4"
|
116 |
+
print(f"加载 StableDiffusion 模型: {sd1x_path}")
|
117 |
+
|
118 |
+
# 加载 StableDiffusionMIGCPipeline
|
119 |
+
print("load sd:")
|
120 |
+
pipe = StableDiffusionMIGCPipeline.from_pretrained(sd1x_path)
|
121 |
+
pipe.attention_store = AttentionStore()
|
122 |
+
|
123 |
+
# 导入并加载 MIGC 权重
|
124 |
+
print("load migc")
|
125 |
+
from migc.migc_utils import load_migc
|
126 |
+
load_migc(pipe.unet, pipe.attention_store, migc_ckpt_path_all, attn_processor=MIGCProcessor)
|
127 |
+
|
128 |
+
# 确保模型和 attention_store 被正确加载
|
129 |
+
assert pipe.unet is not None, "unet 模型未正确加载!"
|
130 |
+
assert pipe.attention_store is not None, "attention_store 未正确加载!"
|
131 |
+
|
132 |
+
# 转移到 CUDA
|
133 |
+
if torch.cuda.is_available():
|
134 |
+
device = torch.device("cuda")
|
135 |
+
print("使用 CUDA 设备")
|
136 |
+
else:
|
137 |
+
device = torch.device("cpu")
|
138 |
+
print("使用 CPU")
|
139 |
+
|
140 |
+
pipe = pipe.to(device)
|
141 |
+
|
142 |
+
# 设置调度器
|
143 |
+
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
144 |
+
|
145 |
+
return pipe
|
146 |
+
|
147 |
+
|
148 |
+
def create_simple_prompt(input_str):
|
149 |
+
# 先将输入字符串按分号分割,并去掉空字符串
|
150 |
+
objects = [obj for obj in input_str.split(';') if obj.strip()]
|
151 |
+
|
152 |
+
# 创建详细描述字符串
|
153 |
+
prompt_description = "masterpiece, best quality, " + ", ".join(objects)
|
154 |
+
|
155 |
+
# 创建最终结构
|
156 |
+
prompt_final = [[prompt_description] + objects]
|
157 |
+
|
158 |
+
return prompt_final
|
159 |
+
|
160 |
+
|
161 |
+
def inference_image(pipe, prompt, grounding_instruction, state):
|
162 |
+
print(prompt)
|
163 |
+
print(grounding_instruction)
|
164 |
+
bbox = state['boxes']
|
165 |
+
print(bbox)
|
166 |
+
bbox = normalize_bbox(bbox, 600, 600)
|
167 |
+
print(bbox)
|
168 |
+
simple_prompt = create_simple_prompt(grounding_instruction)
|
169 |
+
print(simple_prompt)
|
170 |
+
negative_prompt = 'worst quality, low quality, bad anatomy, watermark, text, blurry'
|
171 |
+
seed = 7351007268695528845
|
172 |
+
seed_everything(seed)
|
173 |
+
print("Start inference: ")
|
174 |
+
image = pipe(simple_prompt, bbox, num_inference_steps=50, guidance_scale=7.5,
|
175 |
+
MIGCsteps=25, aug_phase_with_and=False, negative_prompt=negative_prompt).images[0]
|
176 |
+
return image
|
177 |
+
|
178 |
+
|
179 |
+
|
180 |
+
if __name__ == "__main__":
|
181 |
+
prompt_final = [['masterpiece, best quality,black colored ball,gray colored cat,white colored bed,\
|
182 |
+
green colored plant,red colored teddy bear,blue colored wall,brown colored vase,orange colored book,\
|
183 |
+
yellow colored hat', 'black colored ball', 'gray colored cat', 'white colored bed', 'green colored plant', \
|
184 |
+
'red colored teddy bear', 'blue colored wall', 'brown colored vase', 'orange colored book', 'yellow colored hat']]
|
185 |
+
|
186 |
+
bboxes = [[[0.3125, 0.609375, 0.625, 0.875], [0.5625, 0.171875, 0.984375, 0.6875], \
|
187 |
+
[0.0, 0.265625, 0.984375, 0.984375], [0.0, 0.015625, 0.21875, 0.328125], \
|
188 |
+
[0.171875, 0.109375, 0.546875, 0.515625], [0.234375, 0.0, 1.0, 0.3125], \
|
189 |
+
[0.71875, 0.625, 0.953125, 0.921875], [0.0625, 0.484375, 0.359375, 0.8125], \
|
190 |
+
[0.609375, 0.09375, 0.90625, 0.28125]]]
|
191 |
+
image = inference_single_image("a cat", prompt_final, bboxes)
|
192 |
+
image.save("output.png")
|
193 |
+
print("done")
|
model_bbox/MIGC/migc/__init__.py
ADDED
File without changes
|
model_bbox/MIGC/migc/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (172 Bytes). View file
|
|
model_bbox/MIGC/migc/__pycache__/migc_arch.cpython-310.pyc
ADDED
Binary file (6.76 kB). View file
|
|
model_bbox/MIGC/migc/__pycache__/migc_layers.cpython-310.pyc
ADDED
Binary file (8.28 kB). View file
|
|
model_bbox/MIGC/migc/__pycache__/migc_pipeline.cpython-310.pyc
ADDED
Binary file (25.2 kB). View file
|
|
model_bbox/MIGC/migc/__pycache__/migc_utils.cpython-310.pyc
ADDED
Binary file (5.62 kB). View file
|
|
model_bbox/MIGC/migc/migc_arch.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
from migc.migc_layers import CBAM, CrossAttention, LayoutAttention
|
6 |
+
|
7 |
+
|
8 |
+
class FourierEmbedder():
|
9 |
+
def __init__(self, num_freqs=64, temperature=100):
|
10 |
+
self.num_freqs = num_freqs
|
11 |
+
self.temperature = temperature
|
12 |
+
self.freq_bands = temperature ** ( torch.arange(num_freqs) / num_freqs )
|
13 |
+
|
14 |
+
@ torch.no_grad()
|
15 |
+
def __call__(self, x, cat_dim=-1):
|
16 |
+
out = []
|
17 |
+
for freq in self.freq_bands:
|
18 |
+
out.append( torch.sin( freq*x ) )
|
19 |
+
out.append( torch.cos( freq*x ) )
|
20 |
+
return torch.cat(out, cat_dim) # torch.Size([5, 30, 64])
|
21 |
+
|
22 |
+
|
23 |
+
class PositionNet(nn.Module):
|
24 |
+
def __init__(self, in_dim, out_dim, fourier_freqs=8):
|
25 |
+
super().__init__()
|
26 |
+
self.in_dim = in_dim
|
27 |
+
self.out_dim = out_dim
|
28 |
+
|
29 |
+
self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
|
30 |
+
self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy
|
31 |
+
|
32 |
+
# -------------------------------------------------------------- #
|
33 |
+
self.linears_position = nn.Sequential(
|
34 |
+
nn.Linear(self.position_dim, 512),
|
35 |
+
nn.SiLU(),
|
36 |
+
nn.Linear(512, 512),
|
37 |
+
nn.SiLU(),
|
38 |
+
nn.Linear(512, out_dim),
|
39 |
+
)
|
40 |
+
|
41 |
+
def forward(self, boxes):
|
42 |
+
|
43 |
+
# embedding position (it may includes padding as placeholder)
|
44 |
+
xyxy_embedding = self.fourier_embedder(boxes) # B*1*4 --> B*1*C torch.Size([5, 1, 64])
|
45 |
+
xyxy_embedding = self.linears_position(xyxy_embedding) # B*1*C --> B*1*768 torch.Size([5, 1, 768])
|
46 |
+
|
47 |
+
return xyxy_embedding
|
48 |
+
|
49 |
+
|
50 |
+
class SAC(nn.Module):
|
51 |
+
def __init__(self, C, number_pro=30):
|
52 |
+
super().__init__()
|
53 |
+
self.C = C
|
54 |
+
self.number_pro = number_pro
|
55 |
+
self.conv1 = nn.Conv2d(C + 1, C, 1, 1)
|
56 |
+
self.cbam1 = CBAM(C)
|
57 |
+
self.conv2 = nn.Conv2d(C, 1, 1, 1)
|
58 |
+
self.cbam2 = CBAM(number_pro, reduction_ratio=1)
|
59 |
+
|
60 |
+
def forward(self, x, guidance_mask, sac_scale=None):
|
61 |
+
'''
|
62 |
+
:param x: (B, phase_num, HW, C)
|
63 |
+
:param guidance_mask: (B, phase_num, H, W)
|
64 |
+
:return:
|
65 |
+
'''
|
66 |
+
B, phase_num, HW, C = x.shape
|
67 |
+
_, _, H, W = guidance_mask.shape
|
68 |
+
guidance_mask = guidance_mask.view(guidance_mask.shape[0], phase_num, -1)[
|
69 |
+
..., None] # (B, phase_num, HW, 1)
|
70 |
+
|
71 |
+
null_x = torch.zeros_like(x[:, [0], ...]).to(x.device)
|
72 |
+
null_mask = torch.zeros_like(guidance_mask[:, [0], ...]).to(guidance_mask.device)
|
73 |
+
|
74 |
+
x = torch.cat([x, null_x], dim=1)
|
75 |
+
guidance_mask = torch.cat([guidance_mask, null_mask], dim=1)
|
76 |
+
phase_num += 1
|
77 |
+
|
78 |
+
|
79 |
+
scale = torch.cat([x, guidance_mask], dim=-1) # (B, phase_num, HW, C+1)
|
80 |
+
scale = scale.view(-1, H, W, C + 1) # (B * phase_num, H, W, C+1)
|
81 |
+
scale = scale.permute(0, 3, 1, 2) # (B * phase_num, C+1, H, W)
|
82 |
+
scale = self.conv1(scale) # (B * phase_num, C, H, W)
|
83 |
+
scale = self.cbam1(scale) # (B * phase_num, C, H, W)
|
84 |
+
scale = self.conv2(scale) # (B * phase_num, 1, H, W)
|
85 |
+
scale = scale.view(B, phase_num, H, W) # (B, phase_num, H, W)
|
86 |
+
|
87 |
+
null_scale = scale[:, [-1], ...]
|
88 |
+
scale = scale[:, :-1, ...]
|
89 |
+
x = x[:, :-1, ...]
|
90 |
+
|
91 |
+
pad_num = self.number_pro - phase_num + 1
|
92 |
+
|
93 |
+
ori_phase_num = scale[:, 1:-1, ...].shape[1]
|
94 |
+
phase_scale = torch.cat([scale[:, 1:-1, ...], null_scale.repeat(1, pad_num, 1, 1)], dim=1)
|
95 |
+
shuffled_order = torch.randperm(phase_scale.shape[1])
|
96 |
+
inv_shuffled_order = torch.argsort(shuffled_order)
|
97 |
+
|
98 |
+
random_phase_scale = phase_scale[:, shuffled_order, ...]
|
99 |
+
|
100 |
+
scale = torch.cat([scale[:, [0], ...], random_phase_scale, scale[:, [-1], ...]], dim=1)
|
101 |
+
# (B, number_pro, H, W)
|
102 |
+
|
103 |
+
scale = self.cbam2(scale) # (B, number_pro, H, W)
|
104 |
+
scale = scale.view(B, self.number_pro, HW)[..., None] # (B, number_pro, HW)
|
105 |
+
|
106 |
+
random_phase_scale = scale[:, 1: -1, ...]
|
107 |
+
phase_scale = random_phase_scale[:, inv_shuffled_order[:ori_phase_num], :]
|
108 |
+
if sac_scale is not None:
|
109 |
+
instance_num = len(sac_scale)
|
110 |
+
for i in range(instance_num):
|
111 |
+
phase_scale[:, i, ...] = phase_scale[:, i, ...] * sac_scale[i]
|
112 |
+
|
113 |
+
|
114 |
+
scale = torch.cat([scale[:, [0], ...], phase_scale, scale[:, [-1], ...]], dim=1)
|
115 |
+
|
116 |
+
scale = scale.softmax(dim=1) # (B, phase_num, HW, 1)
|
117 |
+
out = (x * scale).sum(dim=1, keepdims=True) # (B, 1, HW, C)
|
118 |
+
return out, scale
|
119 |
+
|
120 |
+
|
121 |
+
class MIGC(nn.Module):
|
122 |
+
def __init__(self, C, attn_type='base', context_dim=768, heads=8):
|
123 |
+
super().__init__()
|
124 |
+
self.ea = CrossAttention(query_dim=C, context_dim=context_dim,
|
125 |
+
heads=heads, dim_head=C // heads,
|
126 |
+
dropout=0.0)
|
127 |
+
self.la = LayoutAttention(query_dim=C,
|
128 |
+
heads=heads, dim_head=C // heads,
|
129 |
+
dropout=0.0)
|
130 |
+
self.norm = nn.LayerNorm(C)
|
131 |
+
self.sac = SAC(C)
|
132 |
+
self.pos_net = PositionNet(in_dim=768, out_dim=768)
|
133 |
+
|
134 |
+
def forward(self, ca_x, guidance_mask, other_info, return_fuser_info=False):
|
135 |
+
# x: (B, instance_num+1, HW, C)
|
136 |
+
# guidance_mask: (B, instance_num, H, W)
|
137 |
+
# box: (instance_num, 4)
|
138 |
+
# image_token: (B, instance_num+1, HW, C)
|
139 |
+
full_H = other_info['height']
|
140 |
+
full_W = other_info['width']
|
141 |
+
B, _, HW, C = ca_x.shape
|
142 |
+
instance_num = guidance_mask.shape[1]
|
143 |
+
down_scale = int(math.sqrt(full_H * full_W // ca_x.shape[2]))
|
144 |
+
H = full_H // down_scale
|
145 |
+
W = full_W // down_scale
|
146 |
+
guidance_mask = F.interpolate(guidance_mask, size=(H, W), mode='bilinear') # (B, instance_num, H, W)
|
147 |
+
|
148 |
+
|
149 |
+
supplement_mask = other_info['supplement_mask'] # (B, 1, 64, 64)
|
150 |
+
supplement_mask = F.interpolate(supplement_mask, size=(H, W), mode='bilinear') # (B, 1, H, W)
|
151 |
+
image_token = other_info['image_token']
|
152 |
+
assert image_token.shape == ca_x.shape
|
153 |
+
context = other_info['context_pooler']
|
154 |
+
box = other_info['box']
|
155 |
+
box = box.view(B * instance_num, 1, -1)
|
156 |
+
box_token = self.pos_net(box)
|
157 |
+
context = torch.cat([context[1:, ...], box_token], dim=1)
|
158 |
+
ca_scale = other_info['ca_scale'] if 'ca_scale' in other_info else None
|
159 |
+
ea_scale = other_info['ea_scale'] if 'ea_scale' in other_info else None
|
160 |
+
sac_scale = other_info['sac_scale'] if 'sac_scale' in other_info else None
|
161 |
+
|
162 |
+
ea_x, ea_attn = self.ea(self.norm(image_token[:, 1:, ...].view(B * instance_num, HW, C)),
|
163 |
+
context=context, return_attn=True)
|
164 |
+
ea_x = ea_x.view(B, instance_num, HW, C)
|
165 |
+
ea_x = ea_x * guidance_mask.view(B, instance_num, HW, 1)
|
166 |
+
|
167 |
+
ca_x[:, 1:, ...] = ca_x[:, 1:, ...] * guidance_mask.view(B, instance_num, HW, 1) # (B, phase_num, HW, C)
|
168 |
+
if ca_scale is not None:
|
169 |
+
assert len(ca_scale) == instance_num
|
170 |
+
for i in range(instance_num):
|
171 |
+
ca_x[:, i+1, ...] = ca_x[:, i+1, ...] * ca_scale[i] + ea_x[:, i, ...] * ea_scale[i]
|
172 |
+
else:
|
173 |
+
ca_x[:, 1:, ...] = ca_x[:, 1:, ...] + ea_x
|
174 |
+
|
175 |
+
ori_image_token = image_token[:, 0, ...] # (B, HW, C)
|
176 |
+
fusion_template = self.la(x=ori_image_token, guidance_mask=torch.cat([guidance_mask[:, :, ...], supplement_mask], dim=1)) # (B, HW, C)
|
177 |
+
fusion_template = fusion_template.view(B, 1, HW, C) # (B, 1, HW, C)
|
178 |
+
|
179 |
+
ca_x = torch.cat([ca_x, fusion_template], dim = 1)
|
180 |
+
ca_x[:, 0, ...] = ca_x[:, 0, ...] * supplement_mask.view(B, HW, 1)
|
181 |
+
guidance_mask = torch.cat([
|
182 |
+
supplement_mask,
|
183 |
+
guidance_mask,
|
184 |
+
torch.ones(B, 1, H, W).to(guidance_mask.device)
|
185 |
+
], dim=1)
|
186 |
+
|
187 |
+
|
188 |
+
out_MIGC, sac_scale = self.sac(ca_x, guidance_mask, sac_scale=sac_scale)
|
189 |
+
if return_fuser_info:
|
190 |
+
fuser_info = {}
|
191 |
+
fuser_info['sac_scale'] = sac_scale.view(B, instance_num + 2, H, W)
|
192 |
+
fuser_info['ea_attn'] = ea_attn.mean(dim=1).view(B, instance_num, H, W, 2)
|
193 |
+
return out_MIGC, fuser_info
|
194 |
+
else:
|
195 |
+
return out_MIGC
|
196 |
+
|
197 |
+
|
198 |
+
class NaiveFuser(nn.Module):
|
199 |
+
def __init__(self):
|
200 |
+
super().__init__()
|
201 |
+
def forward(self, ca_x, guidance_mask, other_info, return_fuser_info=False):
|
202 |
+
# ca_x: (B, instance_num+1, HW, C)
|
203 |
+
# guidance_mask: (B, instance_num, H, W)
|
204 |
+
# box: (instance_num, 4)
|
205 |
+
# image_token: (B, instance_num+1, HW, C)
|
206 |
+
full_H = other_info['height']
|
207 |
+
full_W = other_info['width']
|
208 |
+
B, _, HW, C = ca_x.shape
|
209 |
+
instance_num = guidance_mask.shape[1]
|
210 |
+
down_scale = int(math.sqrt(full_H * full_W // ca_x.shape[2]))
|
211 |
+
H = full_H // down_scale
|
212 |
+
W = full_W // down_scale
|
213 |
+
guidance_mask = F.interpolate(guidance_mask, size=(H, W), mode='bilinear') # (B, instance_num, H, W)
|
214 |
+
guidance_mask = torch.cat([torch.ones(B, 1, H, W).to(guidance_mask.device), guidance_mask * 10], dim=1) # (B, instance_num+1, H, W)
|
215 |
+
guidance_mask = guidance_mask.view(B, instance_num + 1, HW, 1)
|
216 |
+
out_MIGC = (ca_x * guidance_mask).sum(dim=1) / (guidance_mask.sum(dim=1) + 1e-6)
|
217 |
+
if return_fuser_info:
|
218 |
+
return out_MIGC, None
|
219 |
+
else:
|
220 |
+
return out_MIGC
|
model_bbox/MIGC/migc/migc_layers.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import random
|
5 |
+
import math
|
6 |
+
from inspect import isfunction
|
7 |
+
from einops import rearrange, repeat
|
8 |
+
from torch import nn, einsum
|
9 |
+
|
10 |
+
|
11 |
+
def exists(val):
|
12 |
+
return val is not None
|
13 |
+
|
14 |
+
|
15 |
+
def default(val, d):
|
16 |
+
if exists(val):
|
17 |
+
return val
|
18 |
+
return d() if isfunction(d) else d
|
19 |
+
|
20 |
+
|
21 |
+
class CrossAttention(nn.Module):
|
22 |
+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
23 |
+
super().__init__()
|
24 |
+
inner_dim = dim_head * heads
|
25 |
+
context_dim = default(context_dim, query_dim)
|
26 |
+
|
27 |
+
self.scale = dim_head ** -0.5
|
28 |
+
self.heads = heads
|
29 |
+
|
30 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
31 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
32 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
33 |
+
|
34 |
+
self.to_out = nn.Sequential(
|
35 |
+
nn.Linear(inner_dim, query_dim),
|
36 |
+
nn.Dropout(dropout)
|
37 |
+
)
|
38 |
+
|
39 |
+
def forward(self, x, context=None, mask=None, return_attn=False, need_softmax=True, guidance_mask=None,
|
40 |
+
forward_layout_guidance=False):
|
41 |
+
h = self.heads
|
42 |
+
b = x.shape[0]
|
43 |
+
|
44 |
+
q = self.to_q(x)
|
45 |
+
context = default(context, x)
|
46 |
+
k = self.to_k(context)
|
47 |
+
v = self.to_v(context)
|
48 |
+
|
49 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
50 |
+
|
51 |
+
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
52 |
+
if forward_layout_guidance:
|
53 |
+
# sim: (B * phase_num * h, HW, 77), b = B * phase_num
|
54 |
+
# guidance_mask: (B, phase_num, 64, 64)
|
55 |
+
HW = sim.shape[1]
|
56 |
+
H = W = int(math.sqrt(HW))
|
57 |
+
guidance_mask = F.interpolate(guidance_mask, size=(H, W), mode='nearest') # (B, phase_num, H, W)
|
58 |
+
sim = sim.view(b, h, HW, 77)
|
59 |
+
guidance_mask = guidance_mask.view(b, 1, HW, 1)
|
60 |
+
guidance_mask[guidance_mask == 1] = 5.0
|
61 |
+
guidance_mask[guidance_mask == 0] = 0.1
|
62 |
+
sim[:, :, :, 1:] = sim[:, :, :, 1:] * guidance_mask
|
63 |
+
sim = sim.view(b * h, HW, 77)
|
64 |
+
|
65 |
+
if exists(mask):
|
66 |
+
mask = rearrange(mask, 'b ... -> b (...)')
|
67 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
68 |
+
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
69 |
+
sim.masked_fill_(~mask, max_neg_value)
|
70 |
+
|
71 |
+
if need_softmax:
|
72 |
+
attn = sim.softmax(dim=-1)
|
73 |
+
else:
|
74 |
+
attn = sim
|
75 |
+
|
76 |
+
out = einsum('b i j, b j d -> b i d', attn, v)
|
77 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
78 |
+
if return_attn:
|
79 |
+
attn = attn.view(b, h, attn.shape[-2], attn.shape[-1])
|
80 |
+
return self.to_out(out), attn
|
81 |
+
else:
|
82 |
+
return self.to_out(out)
|
83 |
+
|
84 |
+
|
85 |
+
class LayoutAttention(nn.Module):
|
86 |
+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., use_lora=False):
|
87 |
+
super().__init__()
|
88 |
+
inner_dim = dim_head * heads
|
89 |
+
context_dim = default(context_dim, query_dim)
|
90 |
+
|
91 |
+
self.use_lora = use_lora
|
92 |
+
self.scale = dim_head ** -0.5
|
93 |
+
self.heads = heads
|
94 |
+
|
95 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
96 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
97 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
98 |
+
|
99 |
+
self.to_out = nn.Sequential(
|
100 |
+
nn.Linear(inner_dim, query_dim),
|
101 |
+
nn.Dropout(dropout)
|
102 |
+
)
|
103 |
+
|
104 |
+
def forward(self, x, context=None, mask=None, return_attn=False, need_softmax=True, guidance_mask=None):
|
105 |
+
h = self.heads
|
106 |
+
b = x.shape[0]
|
107 |
+
|
108 |
+
q = self.to_q(x)
|
109 |
+
context = default(context, x)
|
110 |
+
k = self.to_k(context)
|
111 |
+
v = self.to_v(context)
|
112 |
+
|
113 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
114 |
+
|
115 |
+
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
116 |
+
|
117 |
+
_, phase_num, H, W = guidance_mask.shape
|
118 |
+
HW = H * W
|
119 |
+
guidance_mask_o = guidance_mask.view(b * phase_num, HW, 1)
|
120 |
+
guidance_mask_t = guidance_mask.view(b * phase_num, 1, HW)
|
121 |
+
guidance_mask_sim = torch.bmm(guidance_mask_o, guidance_mask_t) # (B * phase_num, HW, HW)
|
122 |
+
guidance_mask_sim = guidance_mask_sim.view(b, phase_num, HW, HW).sum(dim=1)
|
123 |
+
guidance_mask_sim[guidance_mask_sim > 1] = 1 # (B, HW, HW)
|
124 |
+
guidance_mask_sim = guidance_mask_sim.view(b, 1, HW, HW)
|
125 |
+
guidance_mask_sim = guidance_mask_sim.repeat(1, self.heads, 1, 1)
|
126 |
+
guidance_mask_sim = guidance_mask_sim.view(b * self.heads, HW, HW) # (B * head, HW, HW)
|
127 |
+
|
128 |
+
sim[:, :, :HW][guidance_mask_sim == 0] = -torch.finfo(sim.dtype).max
|
129 |
+
|
130 |
+
if exists(mask):
|
131 |
+
mask = rearrange(mask, 'b ... -> b (...)')
|
132 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
133 |
+
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
134 |
+
sim.masked_fill_(~mask, max_neg_value)
|
135 |
+
|
136 |
+
# attention, what we cannot get enough of
|
137 |
+
|
138 |
+
if need_softmax:
|
139 |
+
attn = sim.softmax(dim=-1)
|
140 |
+
else:
|
141 |
+
attn = sim
|
142 |
+
|
143 |
+
out = einsum('b i j, b j d -> b i d', attn, v)
|
144 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
145 |
+
if return_attn:
|
146 |
+
attn = attn.view(b, h, attn.shape[-2], attn.shape[-1])
|
147 |
+
return self.to_out(out), attn
|
148 |
+
else:
|
149 |
+
return self.to_out(out)
|
150 |
+
|
151 |
+
|
152 |
+
class BasicConv(nn.Module):
|
153 |
+
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=False, bias=False):
|
154 |
+
super(BasicConv, self).__init__()
|
155 |
+
self.out_channels = out_planes
|
156 |
+
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
|
157 |
+
self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
|
158 |
+
self.relu = nn.ReLU() if relu else None
|
159 |
+
|
160 |
+
def forward(self, x):
|
161 |
+
x = self.conv(x)
|
162 |
+
if self.bn is not None:
|
163 |
+
x = self.bn(x)
|
164 |
+
if self.relu is not None:
|
165 |
+
x = self.relu(x)
|
166 |
+
return x
|
167 |
+
|
168 |
+
class Flatten(nn.Module):
|
169 |
+
def forward(self, x):
|
170 |
+
return x.view(x.size(0), -1)
|
171 |
+
|
172 |
+
class ChannelGate(nn.Module):
|
173 |
+
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
|
174 |
+
super(ChannelGate, self).__init__()
|
175 |
+
self.gate_channels = gate_channels
|
176 |
+
self.mlp = nn.Sequential(
|
177 |
+
Flatten(),
|
178 |
+
nn.Linear(gate_channels, gate_channels // reduction_ratio),
|
179 |
+
nn.ReLU(),
|
180 |
+
nn.Linear(gate_channels // reduction_ratio, gate_channels)
|
181 |
+
)
|
182 |
+
self.pool_types = pool_types
|
183 |
+
def forward(self, x):
|
184 |
+
channel_att_sum = None
|
185 |
+
for pool_type in self.pool_types:
|
186 |
+
if pool_type=='avg':
|
187 |
+
avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
|
188 |
+
channel_att_raw = self.mlp( avg_pool )
|
189 |
+
elif pool_type=='max':
|
190 |
+
max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
|
191 |
+
channel_att_raw = self.mlp( max_pool )
|
192 |
+
elif pool_type=='lp':
|
193 |
+
lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
|
194 |
+
channel_att_raw = self.mlp( lp_pool )
|
195 |
+
elif pool_type=='lse':
|
196 |
+
# LSE pool only
|
197 |
+
lse_pool = logsumexp_2d(x)
|
198 |
+
channel_att_raw = self.mlp( lse_pool )
|
199 |
+
|
200 |
+
if channel_att_sum is None:
|
201 |
+
channel_att_sum = channel_att_raw
|
202 |
+
else:
|
203 |
+
channel_att_sum = channel_att_sum + channel_att_raw
|
204 |
+
|
205 |
+
scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
|
206 |
+
return x * scale
|
207 |
+
|
208 |
+
def logsumexp_2d(tensor):
|
209 |
+
tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
|
210 |
+
s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
|
211 |
+
outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
|
212 |
+
return outputs
|
213 |
+
|
214 |
+
class ChannelPool(nn.Module):
|
215 |
+
def forward(self, x):
|
216 |
+
return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
|
217 |
+
|
218 |
+
class SpatialGate(nn.Module):
|
219 |
+
def __init__(self):
|
220 |
+
super(SpatialGate, self).__init__()
|
221 |
+
kernel_size = 7
|
222 |
+
self.compress = ChannelPool()
|
223 |
+
self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
|
224 |
+
def forward(self, x):
|
225 |
+
x_compress = self.compress(x)
|
226 |
+
x_out = self.spatial(x_compress)
|
227 |
+
scale = F.sigmoid(x_out) # broadcasting
|
228 |
+
return x * scale
|
229 |
+
|
230 |
+
class CBAM(nn.Module):
|
231 |
+
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
|
232 |
+
super(CBAM, self).__init__()
|
233 |
+
self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
|
234 |
+
self.no_spatial=no_spatial
|
235 |
+
if not no_spatial:
|
236 |
+
self.SpatialGate = SpatialGate()
|
237 |
+
def forward(self, x):
|
238 |
+
x_out = self.ChannelGate(x)
|
239 |
+
if not self.no_spatial:
|
240 |
+
x_out = self.SpatialGate(x_out)
|
241 |
+
return x_out
|
model_bbox/MIGC/migc/migc_pipeline.py
ADDED
@@ -0,0 +1,928 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import random
|
3 |
+
import time
|
4 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
5 |
+
# import moxing as mox
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from diffusers.loaders import TextualInversionLoaderMixin
|
9 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
10 |
+
from diffusers.models.attention_processor import Attention
|
11 |
+
from diffusers.pipelines.stable_diffusion import (
|
12 |
+
StableDiffusionPipeline,
|
13 |
+
StableDiffusionPipelineOutput,
|
14 |
+
StableDiffusionSafetyChecker,
|
15 |
+
)
|
16 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
17 |
+
from diffusers.utils import logging
|
18 |
+
from PIL import Image, ImageDraw, ImageFont
|
19 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
20 |
+
import inspect
|
21 |
+
import os
|
22 |
+
import math
|
23 |
+
import torch.nn as nn
|
24 |
+
import torch.nn.functional as F
|
25 |
+
# from utils import load_utils
|
26 |
+
import argparse
|
27 |
+
import yaml
|
28 |
+
import cv2
|
29 |
+
import math
|
30 |
+
from migc.migc_arch import MIGC, NaiveFuser
|
31 |
+
from scipy.ndimage import uniform_filter, gaussian_filter
|
32 |
+
|
33 |
+
logger = logging.get_logger(__name__)
|
34 |
+
|
35 |
+
class AttentionStore:
|
36 |
+
@staticmethod
|
37 |
+
def get_empty_store():
|
38 |
+
return {"down": [], "mid": [], "up": []}
|
39 |
+
|
40 |
+
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
41 |
+
if is_cross:
|
42 |
+
if attn.shape[1] in self.attn_res:
|
43 |
+
self.step_store[place_in_unet].append(attn)
|
44 |
+
|
45 |
+
self.cur_att_layer += 1
|
46 |
+
if self.cur_att_layer == self.num_att_layers:
|
47 |
+
self.cur_att_layer = 0
|
48 |
+
self.between_steps()
|
49 |
+
|
50 |
+
def between_steps(self):
|
51 |
+
self.attention_store = self.step_store
|
52 |
+
self.step_store = self.get_empty_store()
|
53 |
+
|
54 |
+
def maps(self, block_type: str):
|
55 |
+
return self.attention_store[block_type]
|
56 |
+
|
57 |
+
def reset(self):
|
58 |
+
self.cur_att_layer = 0
|
59 |
+
self.step_store = self.get_empty_store()
|
60 |
+
self.attention_store = {}
|
61 |
+
|
62 |
+
def __init__(self, attn_res=[64*64, 32*32, 16*16, 8*8]):
|
63 |
+
"""
|
64 |
+
Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion
|
65 |
+
process
|
66 |
+
"""
|
67 |
+
self.num_att_layers = -1
|
68 |
+
self.cur_att_layer = 0
|
69 |
+
self.step_store = self.get_empty_store()
|
70 |
+
self.attention_store = {}
|
71 |
+
self.curr_step_index = 0
|
72 |
+
self.attn_res = attn_res
|
73 |
+
|
74 |
+
|
75 |
+
def get_sup_mask(mask_list):
|
76 |
+
or_mask = np.zeros_like(mask_list[0])
|
77 |
+
for mask in mask_list:
|
78 |
+
or_mask += mask
|
79 |
+
or_mask[or_mask >= 1] = 1
|
80 |
+
sup_mask = 1 - or_mask
|
81 |
+
return sup_mask
|
82 |
+
|
83 |
+
|
84 |
+
class MIGCProcessor(nn.Module):
|
85 |
+
def __init__(self, config, attnstore, place_in_unet):
|
86 |
+
super().__init__()
|
87 |
+
self.attnstore = attnstore
|
88 |
+
self.place_in_unet = place_in_unet
|
89 |
+
self.not_use_migc = config['not_use_migc']
|
90 |
+
self.naive_fuser = NaiveFuser()
|
91 |
+
self.embedding = {}
|
92 |
+
if not self.not_use_migc:
|
93 |
+
self.migc = MIGC(config['C'])
|
94 |
+
|
95 |
+
def __call__(
|
96 |
+
self,
|
97 |
+
attn: Attention,
|
98 |
+
hidden_states,
|
99 |
+
encoder_hidden_states=None,
|
100 |
+
attention_mask=None,
|
101 |
+
prompt_nums=[],
|
102 |
+
bboxes=[],
|
103 |
+
ith=None,
|
104 |
+
embeds_pooler=None,
|
105 |
+
timestep=None,
|
106 |
+
height=512,
|
107 |
+
width=512,
|
108 |
+
MIGCsteps=20,
|
109 |
+
NaiveFuserSteps=-1,
|
110 |
+
ca_scale=None,
|
111 |
+
ea_scale=None,
|
112 |
+
sac_scale=None,
|
113 |
+
use_sa_preserve=False,
|
114 |
+
sa_preserve=False,
|
115 |
+
):
|
116 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
117 |
+
assert(batch_size == 2, "We currently only implement sampling with batch_size=1, \
|
118 |
+
and we will implement sampling with batch_size=N as soon as possible.")
|
119 |
+
attention_mask = attn.prepare_attention_mask(
|
120 |
+
attention_mask, sequence_length, batch_size
|
121 |
+
)
|
122 |
+
|
123 |
+
instance_num = len(bboxes[0])
|
124 |
+
|
125 |
+
if ith > MIGCsteps:
|
126 |
+
not_use_migc = True
|
127 |
+
else:
|
128 |
+
not_use_migc = self.not_use_migc
|
129 |
+
is_vanilla_cross = (not_use_migc and ith > NaiveFuserSteps)
|
130 |
+
if instance_num == 0:
|
131 |
+
is_vanilla_cross = True
|
132 |
+
|
133 |
+
is_cross = encoder_hidden_states is not None
|
134 |
+
|
135 |
+
ori_hidden_states = hidden_states.clone()
|
136 |
+
|
137 |
+
# Only Need Negative Prompt and Global Prompt.
|
138 |
+
if is_cross and is_vanilla_cross:
|
139 |
+
encoder_hidden_states = encoder_hidden_states[:2, ...]
|
140 |
+
|
141 |
+
# In this case, we need to use MIGC or naive_fuser, so we copy the hidden_states_cond (instance_num+1) times for QKV
|
142 |
+
if is_cross and not is_vanilla_cross:
|
143 |
+
hidden_states_uncond = hidden_states[[0], ...]
|
144 |
+
hidden_states_cond = hidden_states[[1], ...].repeat(instance_num + 1, 1, 1)
|
145 |
+
hidden_states = torch.cat([hidden_states_uncond, hidden_states_cond])
|
146 |
+
|
147 |
+
# QKV Operation of Vanilla Self-Attention or Cross-Attention
|
148 |
+
query = attn.to_q(hidden_states)
|
149 |
+
|
150 |
+
if (
|
151 |
+
not is_cross
|
152 |
+
and use_sa_preserve
|
153 |
+
and timestep.item() in self.embedding
|
154 |
+
and self.place_in_unet == "up"
|
155 |
+
):
|
156 |
+
hidden_states = torch.cat((hidden_states, torch.from_numpy(self.embedding[timestep.item()]).to(hidden_states.device)), dim=1)
|
157 |
+
|
158 |
+
if not is_cross and sa_preserve and self.place_in_unet == "up":
|
159 |
+
self.embedding[timestep.item()] = ori_hidden_states.cpu().numpy()
|
160 |
+
|
161 |
+
encoder_hidden_states = (
|
162 |
+
encoder_hidden_states
|
163 |
+
if encoder_hidden_states is not None
|
164 |
+
else hidden_states
|
165 |
+
)
|
166 |
+
key = attn.to_k(encoder_hidden_states)
|
167 |
+
value = attn.to_v(encoder_hidden_states)
|
168 |
+
query = attn.head_to_batch_dim(query)
|
169 |
+
key = attn.head_to_batch_dim(key)
|
170 |
+
value = attn.head_to_batch_dim(value)
|
171 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask) # 48 4096 77
|
172 |
+
self.attnstore(attention_probs, is_cross, self.place_in_unet)
|
173 |
+
hidden_states = torch.bmm(attention_probs, value)
|
174 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
175 |
+
hidden_states = attn.to_out[0](hidden_states)
|
176 |
+
hidden_states = attn.to_out[1](hidden_states)
|
177 |
+
|
178 |
+
###### Self-Attention Results ######
|
179 |
+
if not is_cross:
|
180 |
+
return hidden_states
|
181 |
+
|
182 |
+
###### Vanilla Cross-Attention Results ######
|
183 |
+
if is_vanilla_cross:
|
184 |
+
return hidden_states
|
185 |
+
|
186 |
+
###### Cross-Attention with MIGC ######
|
187 |
+
assert (not is_vanilla_cross)
|
188 |
+
# hidden_states: torch.Size([1+1+instance_num, HW, C]), the first 1 is the uncond ca output, the second 1 is the global ca output.
|
189 |
+
hidden_states_uncond = hidden_states[[0], ...] # torch.Size([1, HW, C])
|
190 |
+
cond_ca_output = hidden_states[1: , ...].unsqueeze(0) # torch.Size([1, 1+instance_num, 5, 64, 1280])
|
191 |
+
guidance_masks = []
|
192 |
+
in_box = []
|
193 |
+
# Construct Instance Guidance Mask
|
194 |
+
for bbox in bboxes[0]:
|
195 |
+
guidance_mask = np.zeros((height, width))
|
196 |
+
w_min = int(width * bbox[0])
|
197 |
+
w_max = int(width * bbox[2])
|
198 |
+
h_min = int(height * bbox[1])
|
199 |
+
h_max = int(height * bbox[3])
|
200 |
+
guidance_mask[h_min: h_max, w_min: w_max] = 1.0
|
201 |
+
guidance_masks.append(guidance_mask[None, ...])
|
202 |
+
in_box.append([bbox[0], bbox[2], bbox[1], bbox[3]])
|
203 |
+
|
204 |
+
# Construct Background Guidance Mask
|
205 |
+
sup_mask = get_sup_mask(guidance_masks)
|
206 |
+
supplement_mask = torch.from_numpy(sup_mask[None, ...])
|
207 |
+
supplement_mask = F.interpolate(supplement_mask, (height//8, width//8), mode='bilinear').float()
|
208 |
+
supplement_mask = supplement_mask.to(hidden_states.device) # (1, 1, H, W)
|
209 |
+
|
210 |
+
guidance_masks = np.concatenate(guidance_masks, axis=0)
|
211 |
+
guidance_masks = guidance_masks[None, ...]
|
212 |
+
guidance_masks = torch.from_numpy(guidance_masks).float().to(cond_ca_output.device)
|
213 |
+
guidance_masks = F.interpolate(guidance_masks, (height//8, width//8), mode='bilinear') # (1, instance_num, H, W)
|
214 |
+
|
215 |
+
in_box = torch.from_numpy(np.array(in_box))[None, ...].float().to(cond_ca_output.device) # (1, instance_num, 4)
|
216 |
+
|
217 |
+
other_info = {}
|
218 |
+
other_info['image_token'] = hidden_states_cond[None, ...]
|
219 |
+
other_info['context'] = encoder_hidden_states[1:, ...]
|
220 |
+
other_info['box'] = in_box
|
221 |
+
other_info['context_pooler'] =embeds_pooler # (instance_num, 1, 768)
|
222 |
+
other_info['supplement_mask'] = supplement_mask
|
223 |
+
other_info['attn2'] = None
|
224 |
+
other_info['attn'] = attn
|
225 |
+
other_info['height'] = height
|
226 |
+
other_info['width'] = width
|
227 |
+
other_info['ca_scale'] = ca_scale
|
228 |
+
other_info['ea_scale'] = ea_scale
|
229 |
+
other_info['sac_scale'] = sac_scale
|
230 |
+
|
231 |
+
if not not_use_migc:
|
232 |
+
hidden_states_cond, fuser_info = self.migc(cond_ca_output,
|
233 |
+
guidance_masks,
|
234 |
+
other_info=other_info,
|
235 |
+
return_fuser_info=True)
|
236 |
+
else:
|
237 |
+
hidden_states_cond, fuser_info = self.naive_fuser(cond_ca_output,
|
238 |
+
guidance_masks,
|
239 |
+
other_info=other_info,
|
240 |
+
return_fuser_info=True)
|
241 |
+
hidden_states_cond = hidden_states_cond.squeeze(1)
|
242 |
+
|
243 |
+
hidden_states = torch.cat([hidden_states_uncond, hidden_states_cond])
|
244 |
+
return hidden_states
|
245 |
+
|
246 |
+
|
247 |
+
class StableDiffusionMIGCPipeline(StableDiffusionPipeline):
|
248 |
+
def __init__(
|
249 |
+
self,
|
250 |
+
vae: AutoencoderKL,
|
251 |
+
text_encoder: CLIPTextModel,
|
252 |
+
tokenizer: CLIPTokenizer,
|
253 |
+
unet: UNet2DConditionModel,
|
254 |
+
scheduler: KarrasDiffusionSchedulers,
|
255 |
+
safety_checker: StableDiffusionSafetyChecker,
|
256 |
+
feature_extractor: CLIPImageProcessor,
|
257 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
258 |
+
requires_safety_checker: bool = True,
|
259 |
+
):
|
260 |
+
# Get the parameter signature of the parent class constructor
|
261 |
+
parent_init_signature = inspect.signature(super().__init__)
|
262 |
+
parent_init_params = parent_init_signature.parameters
|
263 |
+
|
264 |
+
# Dynamically build a parameter dictionary based on the parameters of the parent class constructor
|
265 |
+
init_kwargs = {
|
266 |
+
"vae": vae,
|
267 |
+
"text_encoder": text_encoder,
|
268 |
+
"tokenizer": tokenizer,
|
269 |
+
"unet": unet,
|
270 |
+
"scheduler": scheduler,
|
271 |
+
"safety_checker": safety_checker,
|
272 |
+
"feature_extractor": feature_extractor,
|
273 |
+
"requires_safety_checker": requires_safety_checker
|
274 |
+
}
|
275 |
+
if 'image_encoder' in parent_init_params.items():
|
276 |
+
init_kwargs['image_encoder'] = image_encoder
|
277 |
+
super().__init__(**init_kwargs)
|
278 |
+
|
279 |
+
self.instance_set = set()
|
280 |
+
self.embedding = {}
|
281 |
+
|
282 |
+
def _encode_prompt(
|
283 |
+
self,
|
284 |
+
prompts,
|
285 |
+
device,
|
286 |
+
num_images_per_prompt,
|
287 |
+
do_classifier_free_guidance,
|
288 |
+
negative_prompt=None,
|
289 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
290 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
291 |
+
):
|
292 |
+
r"""
|
293 |
+
Encodes the prompt into text encoder hidden states.
|
294 |
+
|
295 |
+
Args:
|
296 |
+
prompt (`str` or `List[str]`, *optional*):
|
297 |
+
prompt to be encoded
|
298 |
+
device: (`torch.device`):
|
299 |
+
torch device
|
300 |
+
num_images_per_prompt (`int`):
|
301 |
+
number of images that should be generated per prompt
|
302 |
+
do_classifier_free_guidance (`bool`):
|
303 |
+
whether to use classifier free guidance or not
|
304 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
305 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
306 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
307 |
+
less than `1`).
|
308 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
309 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
310 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
311 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
312 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
313 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
314 |
+
argument.
|
315 |
+
"""
|
316 |
+
if prompts is not None and isinstance(prompts, str):
|
317 |
+
batch_size = 1
|
318 |
+
elif prompts is not None and isinstance(prompts, list):
|
319 |
+
batch_size = len(prompts)
|
320 |
+
else:
|
321 |
+
batch_size = prompt_embeds.shape[0]
|
322 |
+
|
323 |
+
prompt_embeds_none_flag = (prompt_embeds is None)
|
324 |
+
prompt_embeds_list = []
|
325 |
+
embeds_pooler_list = []
|
326 |
+
for prompt in prompts:
|
327 |
+
if prompt_embeds_none_flag:
|
328 |
+
# textual inversion: procecss multi-vector tokens if necessary
|
329 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
330 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
331 |
+
|
332 |
+
text_inputs = self.tokenizer(
|
333 |
+
prompt,
|
334 |
+
padding="max_length",
|
335 |
+
max_length=self.tokenizer.model_max_length,
|
336 |
+
truncation=True,
|
337 |
+
return_tensors="pt",
|
338 |
+
)
|
339 |
+
text_input_ids = text_inputs.input_ids
|
340 |
+
untruncated_ids = self.tokenizer(
|
341 |
+
prompt, padding="longest", return_tensors="pt"
|
342 |
+
).input_ids
|
343 |
+
|
344 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[
|
345 |
+
-1
|
346 |
+
] and not torch.equal(text_input_ids, untruncated_ids):
|
347 |
+
removed_text = self.tokenizer.batch_decode(
|
348 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
|
349 |
+
)
|
350 |
+
logger.warning(
|
351 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
352 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
353 |
+
)
|
354 |
+
|
355 |
+
if (
|
356 |
+
hasattr(self.text_encoder.config, "use_attention_mask")
|
357 |
+
and self.text_encoder.config.use_attention_mask
|
358 |
+
):
|
359 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
360 |
+
else:
|
361 |
+
attention_mask = None
|
362 |
+
|
363 |
+
prompt_embeds = self.text_encoder(
|
364 |
+
text_input_ids.to(device),
|
365 |
+
attention_mask=attention_mask,
|
366 |
+
)
|
367 |
+
embeds_pooler = prompt_embeds.pooler_output
|
368 |
+
prompt_embeds = prompt_embeds[0]
|
369 |
+
|
370 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
371 |
+
embeds_pooler = embeds_pooler.to(dtype=self.text_encoder.dtype, device=device)
|
372 |
+
|
373 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
374 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
375 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
376 |
+
embeds_pooler = embeds_pooler.repeat(1, num_images_per_prompt)
|
377 |
+
prompt_embeds = prompt_embeds.view(
|
378 |
+
bs_embed * num_images_per_prompt, seq_len, -1
|
379 |
+
)
|
380 |
+
embeds_pooler = embeds_pooler.view(
|
381 |
+
bs_embed * num_images_per_prompt, -1
|
382 |
+
)
|
383 |
+
prompt_embeds_list.append(prompt_embeds)
|
384 |
+
embeds_pooler_list.append(embeds_pooler)
|
385 |
+
prompt_embeds = torch.cat(prompt_embeds_list, dim=0)
|
386 |
+
embeds_pooler = torch.cat(embeds_pooler_list, dim=0)
|
387 |
+
# negative_prompt_embeds: (prompt_nums[0]+prompt_nums[1]+...prompt_nums[n], token_num, token_channel), <class 'torch.Tensor'>
|
388 |
+
|
389 |
+
# get unconditional embeddings for classifier free guidance
|
390 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
391 |
+
uncond_tokens: List[str]
|
392 |
+
if negative_prompt is None:
|
393 |
+
negative_prompt = "worst quality, low quality, bad anatomy"
|
394 |
+
uncond_tokens = [negative_prompt] * batch_size
|
395 |
+
|
396 |
+
# textual inversion: procecss multi-vector tokens if necessary
|
397 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
398 |
+
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
399 |
+
|
400 |
+
max_length = prompt_embeds.shape[1]
|
401 |
+
uncond_input = self.tokenizer(
|
402 |
+
uncond_tokens,
|
403 |
+
padding="max_length",
|
404 |
+
max_length=max_length,
|
405 |
+
truncation=True,
|
406 |
+
return_tensors="pt",
|
407 |
+
)
|
408 |
+
|
409 |
+
if (
|
410 |
+
hasattr(self.text_encoder.config, "use_attention_mask")
|
411 |
+
and self.text_encoder.config.use_attention_mask
|
412 |
+
):
|
413 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
414 |
+
else:
|
415 |
+
attention_mask = None
|
416 |
+
|
417 |
+
negative_prompt_embeds = self.text_encoder(
|
418 |
+
uncond_input.input_ids.to(device),
|
419 |
+
attention_mask=attention_mask,
|
420 |
+
)
|
421 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
422 |
+
|
423 |
+
if do_classifier_free_guidance:
|
424 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
425 |
+
seq_len = negative_prompt_embeds.shape[1]
|
426 |
+
|
427 |
+
negative_prompt_embeds = negative_prompt_embeds.to(
|
428 |
+
dtype=self.text_encoder.dtype, device=device
|
429 |
+
)
|
430 |
+
|
431 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
432 |
+
1, num_images_per_prompt, 1
|
433 |
+
)
|
434 |
+
negative_prompt_embeds = negative_prompt_embeds.view(
|
435 |
+
batch_size * num_images_per_prompt, seq_len, -1
|
436 |
+
)
|
437 |
+
# negative_prompt_embeds: (len(prompt_nums), token_num, token_channel), <class 'torch.Tensor'>
|
438 |
+
|
439 |
+
# For classifier free guidance, we need to do two forward passes.
|
440 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
441 |
+
# to avoid doing two forward passes
|
442 |
+
final_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
443 |
+
|
444 |
+
return final_prompt_embeds, prompt_embeds, embeds_pooler[:, None, :]
|
445 |
+
|
446 |
+
def check_inputs(
|
447 |
+
self,
|
448 |
+
prompt,
|
449 |
+
token_indices,
|
450 |
+
bboxes,
|
451 |
+
height,
|
452 |
+
width,
|
453 |
+
callback_steps,
|
454 |
+
negative_prompt=None,
|
455 |
+
prompt_embeds=None,
|
456 |
+
negative_prompt_embeds=None,
|
457 |
+
):
|
458 |
+
if height % 8 != 0 or width % 8 != 0:
|
459 |
+
raise ValueError(
|
460 |
+
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
|
461 |
+
)
|
462 |
+
|
463 |
+
if (callback_steps is None) or (
|
464 |
+
callback_steps is not None
|
465 |
+
and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
466 |
+
):
|
467 |
+
raise ValueError(
|
468 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
469 |
+
f" {type(callback_steps)}."
|
470 |
+
)
|
471 |
+
|
472 |
+
if prompt is not None and prompt_embeds is not None:
|
473 |
+
raise ValueError(
|
474 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
475 |
+
" only forward one of the two."
|
476 |
+
)
|
477 |
+
elif prompt is None and prompt_embeds is None:
|
478 |
+
raise ValueError(
|
479 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
480 |
+
)
|
481 |
+
elif prompt is not None and (
|
482 |
+
not isinstance(prompt, str) and not isinstance(prompt, list)
|
483 |
+
):
|
484 |
+
raise ValueError(
|
485 |
+
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
486 |
+
)
|
487 |
+
|
488 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
489 |
+
raise ValueError(
|
490 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
491 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
492 |
+
)
|
493 |
+
|
494 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
495 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
496 |
+
raise ValueError(
|
497 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
498 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
499 |
+
f" {negative_prompt_embeds.shape}."
|
500 |
+
)
|
501 |
+
|
502 |
+
if token_indices is not None:
|
503 |
+
if isinstance(token_indices, list):
|
504 |
+
if isinstance(token_indices[0], list):
|
505 |
+
if isinstance(token_indices[0][0], list):
|
506 |
+
token_indices_batch_size = len(token_indices)
|
507 |
+
elif isinstance(token_indices[0][0], int):
|
508 |
+
token_indices_batch_size = 1
|
509 |
+
else:
|
510 |
+
raise TypeError(
|
511 |
+
"`token_indices` must be a list of lists of integers or a list of integers."
|
512 |
+
)
|
513 |
+
else:
|
514 |
+
raise TypeError(
|
515 |
+
"`token_indices` must be a list of lists of integers or a list of integers."
|
516 |
+
)
|
517 |
+
else:
|
518 |
+
raise TypeError(
|
519 |
+
"`token_indices` must be a list of lists of integers or a list of integers."
|
520 |
+
)
|
521 |
+
|
522 |
+
if bboxes is not None:
|
523 |
+
if isinstance(bboxes, list):
|
524 |
+
if isinstance(bboxes[0], list):
|
525 |
+
if (
|
526 |
+
isinstance(bboxes[0][0], list)
|
527 |
+
and len(bboxes[0][0]) == 4
|
528 |
+
and all(isinstance(x, float) for x in bboxes[0][0])
|
529 |
+
):
|
530 |
+
bboxes_batch_size = len(bboxes)
|
531 |
+
elif (
|
532 |
+
isinstance(bboxes[0], list)
|
533 |
+
and len(bboxes[0]) == 4
|
534 |
+
and all(isinstance(x, float) for x in bboxes[0])
|
535 |
+
):
|
536 |
+
bboxes_batch_size = 1
|
537 |
+
else:
|
538 |
+
print(isinstance(bboxes[0], list), len(bboxes[0]))
|
539 |
+
raise TypeError(
|
540 |
+
"`bboxes` must be a list of lists of list with four floats or a list of tuples with four floats."
|
541 |
+
)
|
542 |
+
else:
|
543 |
+
print(isinstance(bboxes[0], list), len(bboxes[0]))
|
544 |
+
raise TypeError(
|
545 |
+
"`bboxes` must be a list of lists of list with four floats or a list of tuples with four floats."
|
546 |
+
)
|
547 |
+
else:
|
548 |
+
print(isinstance(bboxes[0], list), len(bboxes[0]))
|
549 |
+
raise TypeError(
|
550 |
+
"`bboxes` must be a list of lists of list with four floats or a list of tuples with four floats."
|
551 |
+
)
|
552 |
+
|
553 |
+
if prompt is not None and isinstance(prompt, str):
|
554 |
+
prompt_batch_size = 1
|
555 |
+
elif prompt is not None and isinstance(prompt, list):
|
556 |
+
prompt_batch_size = len(prompt)
|
557 |
+
elif prompt_embeds is not None:
|
558 |
+
prompt_batch_size = prompt_embeds.shape[0]
|
559 |
+
|
560 |
+
if token_indices_batch_size != prompt_batch_size:
|
561 |
+
raise ValueError(
|
562 |
+
f"token indices batch size must be same as prompt batch size. token indices batch size: {token_indices_batch_size}, prompt batch size: {prompt_batch_size}"
|
563 |
+
)
|
564 |
+
|
565 |
+
if bboxes_batch_size != prompt_batch_size:
|
566 |
+
raise ValueError(
|
567 |
+
f"bbox batch size must be same as prompt batch size. bbox batch size: {bboxes_batch_size}, prompt batch size: {prompt_batch_size}"
|
568 |
+
)
|
569 |
+
|
570 |
+
def get_indices(self, prompt: str) -> Dict[str, int]:
|
571 |
+
"""Utility function to list the indices of the tokens you wish to alte"""
|
572 |
+
ids = self.tokenizer(prompt).input_ids
|
573 |
+
indices = {
|
574 |
+
i: tok
|
575 |
+
for tok, i in zip(
|
576 |
+
self.tokenizer.convert_ids_to_tokens(ids), range(len(ids))
|
577 |
+
)
|
578 |
+
}
|
579 |
+
return indices
|
580 |
+
|
581 |
+
@staticmethod
|
582 |
+
def draw_box(pil_img: Image, bboxes: List[List[float]]) -> Image:
|
583 |
+
"""Utility function to draw bbox on the image"""
|
584 |
+
width, height = pil_img.size
|
585 |
+
draw = ImageDraw.Draw(pil_img)
|
586 |
+
|
587 |
+
for obj_box in bboxes:
|
588 |
+
x_min, y_min, x_max, y_max = (
|
589 |
+
obj_box[0] * width,
|
590 |
+
obj_box[1] * height,
|
591 |
+
obj_box[2] * width,
|
592 |
+
obj_box[3] * height,
|
593 |
+
)
|
594 |
+
draw.rectangle(
|
595 |
+
[int(x_min), int(y_min), int(x_max), int(y_max)],
|
596 |
+
outline="red",
|
597 |
+
width=4,
|
598 |
+
)
|
599 |
+
|
600 |
+
return pil_img
|
601 |
+
|
602 |
+
|
603 |
+
@staticmethod
|
604 |
+
def draw_box_desc(pil_img: Image, bboxes: List[List[float]], prompt: List[str]) -> Image:
|
605 |
+
"""Utility function to draw bbox on the image"""
|
606 |
+
color_list = ['red', 'blue', 'yellow', 'purple', 'green', 'black', 'brown', 'orange', 'white', 'gray']
|
607 |
+
width, height = pil_img.size
|
608 |
+
draw = ImageDraw.Draw(pil_img)
|
609 |
+
font_folder = os.path.dirname(os.path.dirname(__file__))
|
610 |
+
font_path = os.path.join(font_folder, 'Rainbow-Party-2.ttf')
|
611 |
+
font = ImageFont.truetype(font_path, 30)
|
612 |
+
|
613 |
+
for box_id in range(len(bboxes)):
|
614 |
+
obj_box = bboxes[box_id]
|
615 |
+
text = prompt[box_id]
|
616 |
+
fill = 'black'
|
617 |
+
for color in prompt[box_id].split(' '):
|
618 |
+
if color in color_list:
|
619 |
+
fill = color
|
620 |
+
text = text.split(',')[0]
|
621 |
+
x_min, y_min, x_max, y_max = (
|
622 |
+
obj_box[0] * width,
|
623 |
+
obj_box[1] * height,
|
624 |
+
obj_box[2] * width,
|
625 |
+
obj_box[3] * height,
|
626 |
+
)
|
627 |
+
draw.rectangle(
|
628 |
+
[int(x_min), int(y_min), int(x_max), int(y_max)],
|
629 |
+
outline=fill,
|
630 |
+
width=4,
|
631 |
+
)
|
632 |
+
draw.text((int(x_min), int(y_min)), text, fill=fill, font=font)
|
633 |
+
|
634 |
+
return pil_img
|
635 |
+
|
636 |
+
|
637 |
+
@torch.no_grad()
|
638 |
+
def __call__(
|
639 |
+
self,
|
640 |
+
prompt: List[List[str]] = None,
|
641 |
+
bboxes: List[List[List[float]]] = None,
|
642 |
+
height: Optional[int] = None,
|
643 |
+
width: Optional[int] = None,
|
644 |
+
num_inference_steps: int = 50,
|
645 |
+
guidance_scale: float = 7.5,
|
646 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
647 |
+
num_images_per_prompt: Optional[int] = 1,
|
648 |
+
eta: float = 0.0,
|
649 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
650 |
+
latents: Optional[torch.FloatTensor] = None,
|
651 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
652 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
653 |
+
output_type: Optional[str] = "pil",
|
654 |
+
return_dict: bool = True,
|
655 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
656 |
+
callback_steps: int = 1,
|
657 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
658 |
+
MIGCsteps=20,
|
659 |
+
NaiveFuserSteps=-1,
|
660 |
+
ca_scale=None,
|
661 |
+
ea_scale=None,
|
662 |
+
sac_scale=None,
|
663 |
+
aug_phase_with_and=False,
|
664 |
+
sa_preserve=False,
|
665 |
+
use_sa_preserve=False,
|
666 |
+
clear_set=False,
|
667 |
+
GUI_progress=None
|
668 |
+
):
|
669 |
+
r"""
|
670 |
+
Function invoked when calling the pipeline for generation.
|
671 |
+
|
672 |
+
Args:
|
673 |
+
prompt (`str` or `List[str]`, *optional*):
|
674 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
675 |
+
instead.
|
676 |
+
token_indices (Union[List[List[List[int]]], List[List[int]]], optional):
|
677 |
+
The list of the indexes in the prompt to layout. Defaults to None.
|
678 |
+
bboxes (Union[List[List[List[float]]], List[List[float]]], optional):
|
679 |
+
The bounding boxes of the indexes to maintain layout in the image. Defaults to None.
|
680 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
681 |
+
The height in pixels of the generated image.
|
682 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
683 |
+
The width in pixels of the generated image.
|
684 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
685 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
686 |
+
expense of slower inference.
|
687 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
688 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
689 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
690 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
691 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
692 |
+
usually at the expense of lower image quality.
|
693 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
694 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
695 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
696 |
+
less than `1`).
|
697 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
698 |
+
The number of images to generate per prompt.
|
699 |
+
eta (`float`, *optional*, defaults to 0.0):
|
700 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
701 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
702 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
703 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
704 |
+
to make generation deterministic.
|
705 |
+
latents (`torch.FloatTensor`, *optional*):
|
706 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
707 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
708 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
709 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
710 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
711 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
712 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
713 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
714 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
715 |
+
argument.
|
716 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
717 |
+
The output format of the generate image. Choose between
|
718 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
719 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
720 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
721 |
+
plain tuple.
|
722 |
+
callback (`Callable`, *optional*):
|
723 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
724 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
725 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
726 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
727 |
+
called at every step.
|
728 |
+
cross_attention_kwargs (`dict`, *optional*):
|
729 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
730 |
+
`self.processor` in
|
731 |
+
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
732 |
+
max_guidance_iter (`int`, *optional*, defaults to `10`):
|
733 |
+
The maximum number of iterations for the layout guidance on attention maps in diffusion mode.
|
734 |
+
max_guidance_iter_per_step (`int`, *optional*, defaults to `5`):
|
735 |
+
The maximum number of iterations to run during each time step for layout guidance.
|
736 |
+
scale_factor (`int`, *optional*, defaults to `50`):
|
737 |
+
The scale factor used to update the latents during optimization.
|
738 |
+
|
739 |
+
Examples:
|
740 |
+
|
741 |
+
Returns:
|
742 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
743 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
744 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
745 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
746 |
+
(nsfw) content, according to the `safety_checker`.
|
747 |
+
"""
|
748 |
+
def aug_phase_with_and_function(phase, instance_num):
|
749 |
+
instance_num = min(instance_num, 7)
|
750 |
+
copy_phase = [phase] * instance_num
|
751 |
+
phase = ', and '.join(copy_phase)
|
752 |
+
return phase
|
753 |
+
|
754 |
+
if aug_phase_with_and:
|
755 |
+
instance_num = len(prompt[0]) - 1
|
756 |
+
for i in range(1, len(prompt[0])):
|
757 |
+
prompt[0][i] = aug_phase_with_and_function(prompt[0][i],
|
758 |
+
instance_num)
|
759 |
+
# 0. Default height and width to unet
|
760 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
761 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
762 |
+
|
763 |
+
# 2. Define call parameters
|
764 |
+
if prompt is not None and isinstance(prompt, str):
|
765 |
+
batch_size = 1
|
766 |
+
elif prompt is not None and isinstance(prompt, list):
|
767 |
+
batch_size = len(prompt)
|
768 |
+
else:
|
769 |
+
batch_size = prompt_embeds.shape[0]
|
770 |
+
|
771 |
+
prompt_nums = [0] * len(prompt)
|
772 |
+
for i, _ in enumerate(prompt):
|
773 |
+
prompt_nums[i] = len(_)
|
774 |
+
|
775 |
+
device = self._execution_device
|
776 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
777 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
778 |
+
# corresponds to doing no classifier free guidance.
|
779 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
780 |
+
|
781 |
+
# 3. Encode input prompt
|
782 |
+
prompt_embeds, cond_prompt_embeds, embeds_pooler = self._encode_prompt(
|
783 |
+
prompt,
|
784 |
+
device,
|
785 |
+
num_images_per_prompt,
|
786 |
+
do_classifier_free_guidance,
|
787 |
+
negative_prompt,
|
788 |
+
prompt_embeds=prompt_embeds,
|
789 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
790 |
+
)
|
791 |
+
# print(prompt_embeds.shape) 3 77 768
|
792 |
+
|
793 |
+
# 4. Prepare timesteps
|
794 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
795 |
+
timesteps = self.scheduler.timesteps
|
796 |
+
|
797 |
+
# 5. Prepare latent variables
|
798 |
+
num_channels_latents = self.unet.config.in_channels
|
799 |
+
latents = self.prepare_latents(
|
800 |
+
batch_size * num_images_per_prompt,
|
801 |
+
num_channels_latents,
|
802 |
+
height,
|
803 |
+
width,
|
804 |
+
prompt_embeds.dtype,
|
805 |
+
device,
|
806 |
+
generator,
|
807 |
+
latents,
|
808 |
+
)
|
809 |
+
|
810 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
811 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
812 |
+
|
813 |
+
# 7. Denoising loop
|
814 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
815 |
+
|
816 |
+
if clear_set:
|
817 |
+
self.instance_set = set()
|
818 |
+
self.embedding = {}
|
819 |
+
|
820 |
+
now_set = set()
|
821 |
+
for i in range(len(bboxes[0])):
|
822 |
+
now_set.add((tuple(bboxes[0][i]), prompt[0][i + 1]))
|
823 |
+
|
824 |
+
mask_set = (now_set | self.instance_set) - (now_set & self.instance_set)
|
825 |
+
self.instance_set = now_set
|
826 |
+
|
827 |
+
guidance_mask = np.full((4, height // 8, width // 8), 1.0)
|
828 |
+
|
829 |
+
for bbox, _ in mask_set:
|
830 |
+
w_min = max(0, int(width * bbox[0] // 8) - 5)
|
831 |
+
w_max = min(width, int(width * bbox[2] // 8) + 5)
|
832 |
+
h_min = max(0, int(height * bbox[1] // 8) - 5)
|
833 |
+
h_max = min(height, int(height * bbox[3] // 8) + 5)
|
834 |
+
guidance_mask[:, h_min:h_max, w_min:w_max] = 0
|
835 |
+
|
836 |
+
kernal_size = 5
|
837 |
+
guidance_mask = uniform_filter(
|
838 |
+
guidance_mask, axes = (1, 2), size = kernal_size
|
839 |
+
)
|
840 |
+
|
841 |
+
guidance_mask = torch.from_numpy(guidance_mask).to(self.device).unsqueeze(0)
|
842 |
+
|
843 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
844 |
+
for i, t in enumerate(timesteps):
|
845 |
+
if GUI_progress is not None:
|
846 |
+
GUI_progress[0] = int((i + 1) / len(timesteps) * 100)
|
847 |
+
# expand the latents if we are doing classifier free guidance
|
848 |
+
latent_model_input = (
|
849 |
+
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
850 |
+
)
|
851 |
+
latent_model_input = self.scheduler.scale_model_input(
|
852 |
+
latent_model_input, t
|
853 |
+
)
|
854 |
+
|
855 |
+
# predict the noise residual
|
856 |
+
cross_attention_kwargs = {'prompt_nums': prompt_nums,
|
857 |
+
'bboxes': bboxes,
|
858 |
+
'ith': i,
|
859 |
+
'embeds_pooler': embeds_pooler,
|
860 |
+
'timestep': t,
|
861 |
+
'height': height,
|
862 |
+
'width': width,
|
863 |
+
'MIGCsteps': MIGCsteps,
|
864 |
+
'NaiveFuserSteps': NaiveFuserSteps,
|
865 |
+
'ca_scale': ca_scale,
|
866 |
+
'ea_scale': ea_scale,
|
867 |
+
'sac_scale': sac_scale,
|
868 |
+
'sa_preserve': sa_preserve,
|
869 |
+
'use_sa_preserve': use_sa_preserve}
|
870 |
+
|
871 |
+
self.unet.eval()
|
872 |
+
noise_pred = self.unet(
|
873 |
+
latent_model_input,
|
874 |
+
t,
|
875 |
+
encoder_hidden_states=prompt_embeds,
|
876 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
877 |
+
).sample
|
878 |
+
|
879 |
+
# perform guidance
|
880 |
+
if do_classifier_free_guidance:
|
881 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
882 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
883 |
+
noise_pred_text - noise_pred_uncond
|
884 |
+
)
|
885 |
+
|
886 |
+
step_output = self.scheduler.step(
|
887 |
+
noise_pred, t, latents, **extra_step_kwargs
|
888 |
+
)
|
889 |
+
latents = step_output.prev_sample
|
890 |
+
|
891 |
+
ori_input = latents.detach().clone()
|
892 |
+
if use_sa_preserve and i in self.embedding:
|
893 |
+
latents = (
|
894 |
+
latents * (1.0 - guidance_mask)
|
895 |
+
+ torch.from_numpy(self.embedding[i]).to(latents.device) * guidance_mask
|
896 |
+
).float()
|
897 |
+
|
898 |
+
if sa_preserve:
|
899 |
+
self.embedding[i] = ori_input.cpu().numpy()
|
900 |
+
|
901 |
+
# call the callback, if provided
|
902 |
+
if i == len(timesteps) - 1 or (
|
903 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
904 |
+
):
|
905 |
+
progress_bar.update()
|
906 |
+
if callback is not None and i % callback_steps == 0:
|
907 |
+
callback(i, t, latents)
|
908 |
+
|
909 |
+
if output_type == "latent":
|
910 |
+
image = latents
|
911 |
+
elif output_type == "pil":
|
912 |
+
# 8. Post-processing
|
913 |
+
image = self.decode_latents(latents)
|
914 |
+
image = self.numpy_to_pil(image)
|
915 |
+
else:
|
916 |
+
# 8. Post-processing
|
917 |
+
image = self.decode_latents(latents)
|
918 |
+
|
919 |
+
# Offload last model to CPU
|
920 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
921 |
+
self.final_offload_hook.offload()
|
922 |
+
|
923 |
+
if not return_dict:
|
924 |
+
return (image, None)
|
925 |
+
|
926 |
+
return StableDiffusionPipelineOutput(
|
927 |
+
images=image, nsfw_content_detected=None
|
928 |
+
)
|
model_bbox/MIGC/migc/migc_utils.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
import yaml
|
6 |
+
import random
|
7 |
+
from diffusers.utils.import_utils import is_accelerate_available
|
8 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
9 |
+
from migc.migc_pipeline import StableDiffusionMIGCPipeline, MIGCProcessor, AttentionStore
|
10 |
+
from diffusers import EulerDiscreteScheduler
|
11 |
+
if is_accelerate_available():
|
12 |
+
from accelerate import init_empty_weights
|
13 |
+
from contextlib import nullcontext
|
14 |
+
|
15 |
+
|
16 |
+
def seed_everything(seed):
|
17 |
+
# np.random.seed(seed)
|
18 |
+
torch.manual_seed(seed)
|
19 |
+
torch.cuda.manual_seed_all(seed)
|
20 |
+
random.seed(seed)
|
21 |
+
|
22 |
+
|
23 |
+
import torch
|
24 |
+
from typing import Callable, Dict, List, Optional, Union
|
25 |
+
from collections import defaultdict
|
26 |
+
|
27 |
+
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
28 |
+
|
29 |
+
# We need to set Attention Processors for the following keys.
|
30 |
+
all_processor_keys = [
|
31 |
+
'down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor', 'down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor',
|
32 |
+
'down_blocks.0.attentions.1.transformer_blocks.0.attn1.processor', 'down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor',
|
33 |
+
'down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor', 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor',
|
34 |
+
'down_blocks.1.attentions.1.transformer_blocks.0.attn1.processor', 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor',
|
35 |
+
'down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor', 'down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor',
|
36 |
+
'down_blocks.2.attentions.1.transformer_blocks.0.attn1.processor', 'down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor',
|
37 |
+
'up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor', 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor',
|
38 |
+
'up_blocks.1.attentions.1.transformer_blocks.0.attn1.processor', 'up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor',
|
39 |
+
'up_blocks.1.attentions.2.transformer_blocks.0.attn1.processor', 'up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor',
|
40 |
+
'up_blocks.2.attentions.0.transformer_blocks.0.attn1.processor', 'up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor',
|
41 |
+
'up_blocks.2.attentions.1.transformer_blocks.0.attn1.processor', 'up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor',
|
42 |
+
'up_blocks.2.attentions.2.transformer_blocks.0.attn1.processor', 'up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor',
|
43 |
+
'up_blocks.3.attentions.0.transformer_blocks.0.attn1.processor', 'up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor',
|
44 |
+
'up_blocks.3.attentions.1.transformer_blocks.0.attn1.processor', 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor',
|
45 |
+
'up_blocks.3.attentions.2.transformer_blocks.0.attn1.processor', 'up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor',
|
46 |
+
'mid_block.attentions.0.transformer_blocks.0.attn1.processor', 'mid_block.attentions.0.transformer_blocks.0.attn2.processor'
|
47 |
+
]
|
48 |
+
|
49 |
+
def load_migc(unet, attention_store, pretrained_MIGC_path: Union[str, Dict[str, torch.Tensor]], attn_processor,
|
50 |
+
**kwargs):
|
51 |
+
|
52 |
+
state_dict = torch.load(pretrained_MIGC_path, map_location="cpu")
|
53 |
+
|
54 |
+
# fill attn processors
|
55 |
+
attn_processors = {}
|
56 |
+
state_dict = state_dict['state_dict']
|
57 |
+
|
58 |
+
|
59 |
+
adapter_grouped_dict = defaultdict(dict)
|
60 |
+
|
61 |
+
# change the key of MIGC.ckpt as the form of diffusers unet
|
62 |
+
for key, value in state_dict.items():
|
63 |
+
key_list = key.split(".")
|
64 |
+
assert 'migc' in key_list
|
65 |
+
if 'input_blocks' in key_list:
|
66 |
+
model_type = 'down_blocks'
|
67 |
+
elif 'middle_block' in key_list:
|
68 |
+
model_type = 'mid_block'
|
69 |
+
else:
|
70 |
+
model_type = 'up_blocks'
|
71 |
+
index_number = int(key_list[3])
|
72 |
+
if model_type == 'down_blocks':
|
73 |
+
input_num1 = str(index_number//3)
|
74 |
+
input_num2 = str((index_number%3)-1)
|
75 |
+
elif model_type == 'mid_block':
|
76 |
+
input_num1 = '0'
|
77 |
+
input_num2 = '0'
|
78 |
+
else:
|
79 |
+
input_num1 = str(index_number//3)
|
80 |
+
input_num2 = str(index_number%3)
|
81 |
+
attn_key_list = [model_type,input_num1,'attentions',input_num2,'transformer_blocks','0']
|
82 |
+
if model_type == 'mid_block':
|
83 |
+
attn_key_list = [model_type,'attentions',input_num2,'transformer_blocks','0']
|
84 |
+
attn_processor_key = '.'.join(attn_key_list)
|
85 |
+
sub_key = '.'.join(key_list[key_list.index('migc'):])
|
86 |
+
adapter_grouped_dict[attn_processor_key][sub_key] = value
|
87 |
+
|
88 |
+
# Create MIGC Processor
|
89 |
+
config = {'not_use_migc': False}
|
90 |
+
for key, value_dict in adapter_grouped_dict.items():
|
91 |
+
dim = value_dict['migc.norm.bias'].shape[0]
|
92 |
+
config['C'] = dim
|
93 |
+
key_final = key + '.attn2.processor'
|
94 |
+
if key_final.startswith("mid_block"):
|
95 |
+
place_in_unet = "mid"
|
96 |
+
elif key_final.startswith("up_blocks"):
|
97 |
+
place_in_unet = "up"
|
98 |
+
elif key_final.startswith("down_blocks"):
|
99 |
+
place_in_unet = "down"
|
100 |
+
|
101 |
+
attn_processors[key_final] = attn_processor(config, attention_store, place_in_unet)
|
102 |
+
attn_processors[key_final].load_state_dict(value_dict)
|
103 |
+
attn_processors[key_final].to(device=unet.device, dtype=unet.dtype)
|
104 |
+
|
105 |
+
# Create CrossAttention/SelfAttention Processor
|
106 |
+
config = {'not_use_migc': True}
|
107 |
+
for key in all_processor_keys:
|
108 |
+
if key not in attn_processors.keys():
|
109 |
+
if key.startswith("mid_block"):
|
110 |
+
place_in_unet = "mid"
|
111 |
+
elif key.startswith("up_blocks"):
|
112 |
+
place_in_unet = "up"
|
113 |
+
elif key.startswith("down_blocks"):
|
114 |
+
place_in_unet = "down"
|
115 |
+
attn_processors[key] = attn_processor(config, attention_store, place_in_unet)
|
116 |
+
unet.set_attn_processor(attn_processors)
|
117 |
+
attention_store.num_att_layers = 32
|
118 |
+
|
119 |
+
|
120 |
+
def offlinePipelineSetupWithSafeTensor(sd_safetensors_path):
|
121 |
+
project_dir = os.path.dirname(os.path.dirname(__file__))
|
122 |
+
migc_ckpt_path = os.path.join(project_dir, 'pretrained_weights/MIGC_SD14.ckpt')
|
123 |
+
clip_model_path = os.path.join(project_dir, 'migc_gui_weights/clip/text_encoder')
|
124 |
+
clip_tokenizer_path = os.path.join(project_dir, 'migc_gui_weights/clip/tokenizer')
|
125 |
+
original_config_file = os.path.join(project_dir, 'migc_gui_weights/v1-inference.yaml')
|
126 |
+
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
127 |
+
with ctx():
|
128 |
+
# text_encoder = CLIPTextModel(config)
|
129 |
+
text_encoder = CLIPTextModel.from_pretrained(clip_model_path)
|
130 |
+
tokenizer = CLIPTokenizer.from_pretrained(clip_tokenizer_path)
|
131 |
+
pipe = StableDiffusionMIGCPipeline.from_single_file(sd_safetensors_path,
|
132 |
+
original_config_file=original_config_file,
|
133 |
+
text_encoder=text_encoder,
|
134 |
+
tokenizer=tokenizer,
|
135 |
+
load_safety_checker=False)
|
136 |
+
print('Initializing pipeline')
|
137 |
+
pipe.attention_store = AttentionStore()
|
138 |
+
from migc.migc_utils import load_migc
|
139 |
+
load_migc(pipe.unet , pipe.attention_store,
|
140 |
+
migc_ckpt_path, attn_processor=MIGCProcessor)
|
141 |
+
|
142 |
+
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
143 |
+
return pipe
|
model_bbox/MIGC/pretrained_weights/MIGC_SD14.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:81756dd19f7c75f9bba1ead1e6f8fcdfb00030cabb01dc46edd85d950236884c
|
3 |
+
size 229514282
|
model_bbox/MIGC/pretrained_weights/PUT_MIGC_CKPT_HERE
ADDED
File without changes
|