tianleliphoebe commited on
Commit
a724d8a
·
verified ·
1 Parent(s): 42acef2

Update serve/vote_utils.py

Browse files
Files changed (1) hide show
  1. serve/vote_utils.py +1741 -0
serve/vote_utils.py CHANGED
@@ -0,0 +1,1741 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import time
3
+ import json
4
+ import uuid
5
+ import gradio as gr
6
+ import regex as re
7
+ from pathlib import Path
8
+ from .utils import *
9
+ from .log_utils import build_logger
10
+ from .constants import IMAGE_DIR, VIDEO_DIR
11
+ import imageio
12
+ from diffusers.utils import load_image
13
+ import torch
14
+
15
+ ig_logger = build_logger("gradio_web_server_image_generation", "gr_web_image_generation.log") # ig = image generation, loggers for single model direct chat
16
+ igm_logger = build_logger("gradio_web_server_image_generation_multi", "gr_web_image_generation_multi.log") # igm = image generation multi, loggers for side-by-side and battle
17
+ ie_logger = build_logger("gradio_web_server_image_editing", "gr_web_image_editing.log") # ie = image editing, loggers for single model direct chat
18
+ iem_logger = build_logger("gradio_web_server_image_editing_multi", "gr_web_image_editing_multi.log") # iem = image editing multi, loggers for side-by-side and battle
19
+ vg_logger = build_logger("gradio_web_server_video_generation", "gr_web_video_generation.log") # vg = video generation, loggers for single model direct chat
20
+ vgm_logger = build_logger("gradio_web_server_video_generation_multi", "gr_web_video_generation_multi.log") # vgm = video generation multi, loggers for side-by-side and battle
21
+
22
+ def save_any_image(image_file, file_path):
23
+ if isinstance(image_file, str):
24
+ image = load_image(image_file)
25
+ image.save(file_path, 'JPEG')
26
+ else:
27
+ image_file.save(file_path, 'JPEG')
28
+
29
+ def vote_last_response_ig(state, vote_type, model_selector, request: gr.Request):
30
+ with open(get_conv_log_filename(), "a") as fout:
31
+ data = {
32
+ "tstamp": round(time.time(), 4),
33
+ "type": vote_type,
34
+ "model": model_selector,
35
+ "state": state.dict(),
36
+ "ip": get_ip(request),
37
+ }
38
+ fout.write(json.dumps(data) + "\n")
39
+ append_json_item_on_log_server(data, get_conv_log_filename())
40
+ output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
41
+ with open(output_file, 'w') as f:
42
+ save_any_image(state.output, f)
43
+ save_image_file_on_log_server(output_file)
44
+
45
+ def vote_last_response_igm(states, vote_type, model_selectors, request: gr.Request):
46
+ with open(get_conv_log_filename(), "a") as fout:
47
+ data = {
48
+ "tstamp": round(time.time(), 4),
49
+ "type": vote_type,
50
+ "models": [x for x in model_selectors],
51
+ "states": [x.dict() for x in states],
52
+ "ip": get_ip(request),
53
+ }
54
+ fout.write(json.dumps(data) + "\n")
55
+ append_json_item_on_log_server(data, get_conv_log_filename())
56
+ for state in states:
57
+ output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
58
+ with open(output_file, 'w') as f:
59
+ save_any_image(state.output, f)
60
+ save_image_file_on_log_server(output_file)
61
+
62
+ def vote_last_response_ie(state, vote_type, model_selector, request: gr.Request):
63
+ with open(get_conv_log_filename(), "a") as fout:
64
+ data = {
65
+ "tstamp": round(time.time(), 4),
66
+ "type": vote_type,
67
+ "model": model_selector,
68
+ "state": state.dict(),
69
+ "ip": get_ip(request),
70
+ }
71
+ fout.write(json.dumps(data) + "\n")
72
+ append_json_item_on_log_server(data, get_conv_log_filename())
73
+ output_file = f'{IMAGE_DIR}/edition/{state.conv_id}.jpg'
74
+ source_file = f'{IMAGE_DIR}/edition/{state.conv_id}_source.jpg'
75
+ with open(output_file, 'w') as f:
76
+ save_any_image(state.output, f)
77
+ with open(source_file, 'w') as sf:
78
+ save_any_image(state.source_image, sf)
79
+ save_image_file_on_log_server(output_file)
80
+ save_image_file_on_log_server(source_file)
81
+
82
+ def vote_last_response_iem(states, vote_type, model_selectors, request: gr.Request):
83
+ with open(get_conv_log_filename(), "a") as fout:
84
+ data = {
85
+ "tstamp": round(time.time(), 4),
86
+ "type": vote_type,
87
+ "models": [x for x in model_selectors],
88
+ "states": [x.dict() for x in states],
89
+ "ip": get_ip(request),
90
+ }
91
+ fout.write(json.dumps(data) + "\n")
92
+ append_json_item_on_log_server(data, get_conv_log_filename())
93
+ for state in states:
94
+ output_file = f'{IMAGE_DIR}/edition/{state.conv_id}.jpg'
95
+ source_file = f'{IMAGE_DIR}/edition/{state.conv_id}_source.jpg'
96
+ with open(output_file, 'w') as f:
97
+ save_any_image(state.output, f)
98
+ with open(source_file, 'w') as sf:
99
+ save_any_image(state.source_image, sf)
100
+ save_image_file_on_log_server(output_file)
101
+ save_image_file_on_log_server(source_file)
102
+
103
+
104
+ def vote_last_response_vg(state, vote_type, model_selector, request: gr.Request):
105
+ with open(get_conv_log_filename(), "a") as fout:
106
+ data = {
107
+ "tstamp": round(time.time(), 4),
108
+ "type": vote_type,
109
+ "model": model_selector,
110
+ "state": state.dict(),
111
+ "ip": get_ip(request),
112
+ }
113
+ fout.write(json.dumps(data) + "\n")
114
+ append_json_item_on_log_server(data, get_conv_log_filename())
115
+
116
+ output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
117
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
118
+ if state.model_name.startswith('fal'):
119
+ r = requests.get(state.output)
120
+ with open(output_file, 'wb') as outfile:
121
+ outfile.write(r.content)
122
+ else:
123
+ print("======== video shape: ========")
124
+ print(state.output.shape)
125
+ # Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels]
126
+ if state.output.shape[-1] != 3:
127
+ state.output = state.output.permute(0, 2, 3, 1)
128
+ imageio.mimwrite(output_file, state.output, fps=8, quality=9)
129
+ save_video_file_on_log_server(output_file)
130
+
131
+
132
+
133
+ def vote_last_response_vgm(states, vote_type, model_selectors, request: gr.Request):
134
+ with open(get_conv_log_filename(), "a") as fout:
135
+ data = {
136
+ "tstamp": round(time.time(), 4),
137
+ "type": vote_type,
138
+ "models": [x for x in model_selectors],
139
+ "states": [x.dict() for x in states],
140
+ "ip": get_ip(request),
141
+ }
142
+ fout.write(json.dumps(data) + "\n")
143
+ append_json_item_on_log_server(data, get_conv_log_filename())
144
+ for state in states:
145
+ output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
146
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
147
+ if state.model_name.startswith('fal'):
148
+ r = requests.get(state.output)
149
+ with open(output_file, 'wb') as outfile:
150
+ outfile.write(r.content)
151
+ elif isinstance(state.output, torch.Tensor):
152
+ print("======== video shape: ========")
153
+ print(state.output.shape)
154
+ # Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels]
155
+ if state.output.shape[-1] != 3:
156
+ state.output = state.output.permute(0, 2, 3, 1)
157
+ imageio.mimwrite(output_file, state.output, fps=8, quality=9)
158
+ else:
159
+ r = requests.get(state.output)
160
+ with open(output_file, 'wb') as outfile:
161
+ outfile.write(r.content)
162
+ save_video_file_on_log_server(output_file)
163
+
164
+
165
+ ## Image Generation (IG) Single Model Direct Chat
166
+ def upvote_last_response_ig(state, model_selector, request: gr.Request):
167
+ ip = get_ip(request)
168
+ ig_logger.info(f"upvote. ip: {ip}")
169
+ vote_last_response_ig(state, "upvote", model_selector, request)
170
+ return ("",) + (disable_btn,) * 3
171
+
172
+ def downvote_last_response_ig(state, model_selector, request: gr.Request):
173
+ ip = get_ip(request)
174
+ ig_logger.info(f"downvote. ip: {ip}")
175
+ vote_last_response_ig(state, "downvote", model_selector, request)
176
+ return ("",) + (disable_btn,) * 3
177
+
178
+
179
+ def flag_last_response_ig(state, model_selector, request: gr.Request):
180
+ ip = get_ip(request)
181
+ ig_logger.info(f"flag. ip: {ip}")
182
+ vote_last_response_ig(state, "flag", model_selector, request)
183
+ return ("",) + (disable_btn,) * 3
184
+
185
+ ## Image Generation Multi (IGM) Side-by-Side and Battle
186
+
187
+ def leftvote_last_response_igm(
188
+ state0, state1, model_selector0, model_selector1, request: gr.Request
189
+ ):
190
+ igm_logger.info(f"leftvote (named). ip: {get_ip(request)}")
191
+ vote_last_response_igm(
192
+ [state0, state1], "leftvote", [model_selector0, model_selector1], request
193
+ )
194
+ if model_selector0 == "":
195
+ return ("",) + (disable_btn,) * 4 + (
196
+ gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
197
+ gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
198
+ else:
199
+ return ("",) + (disable_btn,) * 4 + (gr.Markdown(state0.model_name, visible=True),
200
+ gr.Markdown(state1.model_name, visible=True))
201
+
202
+ def rightvote_last_response_igm(
203
+ state0, state1, model_selector0, model_selector1, request: gr.Request
204
+ ):
205
+ igm_logger.info(f"rightvote (named). ip: {get_ip(request)}")
206
+ vote_last_response_igm(
207
+ [state0, state1], "rightvote", [model_selector0, model_selector1], request
208
+ )
209
+ print(model_selector0)
210
+ if model_selector0 == "":
211
+ return ("",) + (disable_btn,) * 4 + (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
212
+ else:
213
+ return ("",) + (disable_btn,) * 4 + (gr.Markdown(state0.model_name, visible=True),
214
+ gr.Markdown(state1.model_name, visible=True))
215
+
216
+
217
+ def tievote_last_response_igm(
218
+ state0, state1, model_selector0, model_selector1, request: gr.Request
219
+ ):
220
+ igm_logger.info(f"tievote (named). ip: {get_ip(request)}")
221
+ vote_last_response_igm(
222
+ [state0, state1], "tievote", [model_selector0, model_selector1], request
223
+ )
224
+ if model_selector0 == "":
225
+ return ("",) + (disable_btn,) * 4 + (
226
+ gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
227
+ gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
228
+ else:
229
+ return ("",) + (disable_btn,) * 4 + (gr.Markdown(state0.model_name, visible=True),
230
+ gr.Markdown(state1.model_name, visible=True))
231
+
232
+
233
+ def bothbad_vote_last_response_igm(
234
+ state0, state1, model_selector0, model_selector1, request: gr.Request
235
+ ):
236
+ igm_logger.info(f"bothbad_vote (named). ip: {get_ip(request)}")
237
+ vote_last_response_igm(
238
+ [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
239
+ )
240
+ if model_selector0 == "":
241
+ return ("",) + (disable_btn,) * 4 + (
242
+ gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
243
+ gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
244
+ else:
245
+ return ("",) + (disable_btn,) * 4 + (gr.Markdown(state0.model_name, visible=True),
246
+ gr.Markdown(state1.model_name, visible=True))
247
+
248
+ ## Image Editing (IE) Single Model Direct Chat
249
+
250
+ def upvote_last_response_ie(state, model_selector, request: gr.Request):
251
+ ip = get_ip(request)
252
+ ie_logger.info(f"upvote. ip: {ip}")
253
+ vote_last_response_ie(state, "upvote", model_selector, request)
254
+ return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3
255
+
256
+ def downvote_last_response_ie(state, model_selector, request: gr.Request):
257
+ ip = get_ip(request)
258
+ ie_logger.info(f"downvote. ip: {ip}")
259
+ vote_last_response_ie(state, "downvote", model_selector, request)
260
+ return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3
261
+
262
+ def flag_last_response_ie(state, model_selector, request: gr.Request):
263
+ ip = get_ip(request)
264
+ ie_logger.info(f"flag. ip: {ip}")
265
+ vote_last_response_ie(state, "flag", model_selector, request)
266
+ return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3
267
+
268
+ ## Image Editing Multi (IEM) Side-by-Side and Battle
269
+ def leftvote_last_response_iem(
270
+ state0, state1, model_selector0, model_selector1, request: gr.Request
271
+ ):
272
+ iem_logger.info(f"leftvote (anony). ip: {get_ip(request)}")
273
+ vote_last_response_iem(
274
+ [state0, state1], "leftvote", [model_selector0, model_selector1], request
275
+ )
276
+ # names = (
277
+ # "### Model A: " + state0.model_name,
278
+ # "### Model B: " + state1.model_name,
279
+ # )
280
+ # names = (state0.model_name, state1.model_name)
281
+ if model_selector0 == "":
282
+ names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
283
+ else:
284
+ names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False))
285
+ return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4
286
+
287
+ def rightvote_last_response_iem(
288
+ state0, state1, model_selector0, model_selector1, request: gr.Request
289
+ ):
290
+ iem_logger.info(f"rightvote (anony). ip: {get_ip(request)}")
291
+ vote_last_response_iem(
292
+ [state0, state1], "rightvote", [model_selector0, model_selector1], request
293
+ )
294
+ # names = (
295
+ # "### Model A: " + state0.model_name,
296
+ # "### Model B: " + state1.model_name,
297
+ # )
298
+ if model_selector0 == "":
299
+ names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
300
+ gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
301
+ else:
302
+ names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False))
303
+ return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4
304
+
305
+ def tievote_last_response_iem(
306
+ state0, state1, model_selector0, model_selector1, request: gr.Request
307
+ ):
308
+ iem_logger.info(f"tievote (anony). ip: {get_ip(request)}")
309
+ vote_last_response_iem(
310
+ [state0, state1], "tievote", [model_selector0, model_selector1], request
311
+ )
312
+ if model_selector0 == "":
313
+ names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
314
+ gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
315
+ else:
316
+ names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False))
317
+ return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4
318
+
319
+ def bothbad_vote_last_response_iem(
320
+ state0, state1, model_selector0, model_selector1, request: gr.Request
321
+ ):
322
+ iem_logger.info(f"bothbad_vote (anony). ip: {get_ip(request)}")
323
+ vote_last_response_iem(
324
+ [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
325
+ )
326
+ if model_selector0 == "":
327
+ names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
328
+ gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
329
+ else:
330
+ names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False))
331
+ return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4
332
+
333
+
334
+ ## Video Generation (VG) Single Model Direct Chat
335
+ def upvote_last_response_vg(state, model_selector, request: gr.Request):
336
+ ip = get_ip(request)
337
+ vg_logger.info(f"upvote. ip: {ip}")
338
+ vote_last_response_vg(state, "upvote", model_selector, request)
339
+ return ("",) + (disable_btn,) * 3
340
+
341
+ def downvote_last_response_vg(state, model_selector, request: gr.Request):
342
+ ip = get_ip(request)
343
+ vg_logger.info(f"downvote. ip: {ip}")
344
+ vote_last_response_vg(state, "downvote", model_selector, request)
345
+ return ("",) + (disable_btn,) * 3
346
+
347
+
348
+ def flag_last_response_vg(state, model_selector, request: gr.Request):
349
+ ip = get_ip(request)
350
+ vg_logger.info(f"flag. ip: {ip}")
351
+ vote_last_response_vg(state, "flag", model_selector, request)
352
+ return ("",) + (disable_btn,) * 3
353
+
354
+ ## Image Generation Multi (IGM) Side-by-Side and Battle
355
+
356
+ def leftvote_last_response_vgm(
357
+ state0, state1, model_selector0, model_selector1, request: gr.Request
358
+ ):
359
+ vgm_logger.info(f"leftvote (named). ip: {get_ip(request)}")
360
+ vote_last_response_vgm(
361
+ [state0, state1], "leftvote", [model_selector0, model_selector1], request
362
+ )
363
+ if model_selector0 == "":
364
+ return ("",) + (disable_btn,) * 4 + (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
365
+ else:
366
+ return ("",) + (disable_btn,) * 4 + (
367
+ gr.Markdown(state0.model_name, visible=False),
368
+ gr.Markdown(state1.model_name, visible=False))
369
+
370
+
371
+ def rightvote_last_response_vgm(
372
+ state0, state1, model_selector0, model_selector1, request: gr.Request
373
+ ):
374
+ vgm_logger.info(f"rightvote (named). ip: {get_ip(request)}")
375
+ vote_last_response_vgm(
376
+ [state0, state1], "rightvote", [model_selector0, model_selector1], request
377
+ )
378
+ if model_selector0 == "":
379
+ return ("",) + (disable_btn,) * 4 + (
380
+ gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
381
+ gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
382
+ else:
383
+ return ("",) + (disable_btn,) * 4 + (
384
+ gr.Markdown(state0.model_name, visible=False),
385
+ gr.Markdown(state1.model_name, visible=False))
386
+
387
+ def tievote_last_response_vgm(
388
+ state0, state1, model_selector0, model_selector1, request: gr.Request
389
+ ):
390
+ vgm_logger.info(f"tievote (named). ip: {get_ip(request)}")
391
+ vote_last_response_vgm(
392
+ [state0, state1], "tievote", [model_selector0, model_selector1], request
393
+ )
394
+ if model_selector0 == "":
395
+ return ("",) + (disable_btn,) * 4 + (
396
+ gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
397
+ gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
398
+ else:
399
+ return ("",) + (disable_btn,) * 4 + (
400
+ gr.Markdown(state0.model_name, visible=False),
401
+ gr.Markdown(state1.model_name, visible=False))
402
+
403
+
404
+ def bothbad_vote_last_response_vgm(
405
+ state0, state1, model_selector0, model_selector1, request: gr.Request
406
+ ):
407
+ vgm_logger.info(f"bothbad_vote (named). ip: {get_ip(request)}")
408
+ vote_last_response_vgm(
409
+ [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
410
+ )
411
+ if model_selector0 == "":
412
+ return ("",) + (disable_btn,) * 4 + (
413
+ gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
414
+ gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
415
+ else:
416
+ return ("",) + (disable_btn,) * 4 + (
417
+ gr.Markdown(state0.model_name, visible=False),
418
+ gr.Markdown(state1.model_name, visible=False))
419
+
420
+ share_js = """
421
+ function (a, b, c, d) {
422
+ const captureElement = document.querySelector('#share-region-named');
423
+ html2canvas(captureElement)
424
+ .then(canvas => {
425
+ canvas.style.display = 'none'
426
+ document.body.appendChild(canvas)
427
+ return canvas
428
+ })
429
+ .then(canvas => {
430
+ const image = canvas.toDataURL('image/png')
431
+ const a = document.createElement('a')
432
+ a.setAttribute('download', 'chatbot-arena.png')
433
+ a.setAttribute('href', image)
434
+ a.click()
435
+ canvas.remove()
436
+ });
437
+ return [a, b, c, d];
438
+ }
439
+ """
440
+ def share_click_igm(state0, state1, model_selector0, model_selector1, request: gr.Request):
441
+ igm_logger.info(f"share (anony). ip: {get_ip(request)}")
442
+ if state0 is not None and state1 is not None:
443
+ vote_last_response_igm(
444
+ [state0, state1], "share", [model_selector0, model_selector1], request
445
+ )
446
+
447
+ def share_click_iem(state0, state1, model_selector0, model_selector1, request: gr.Request):
448
+ iem_logger.info(f"share (anony). ip: {get_ip(request)}")
449
+ if state0 is not None and state1 is not None:
450
+ vote_last_response_iem(
451
+ [state0, state1], "share", [model_selector0, model_selector1], request
452
+ )
453
+
454
+ ## All Generation Gradio Interface
455
+
456
+ class ImageStateIG:
457
+ def __init__(self, model_name):
458
+ self.conv_id = uuid.uuid4().hex
459
+ self.model_name = model_name
460
+ self.prompt = None
461
+ self.output = None
462
+
463
+ def dict(self):
464
+ base = {
465
+ "conv_id": self.conv_id,
466
+ "model_name": self.model_name,
467
+ "prompt": self.prompt
468
+ }
469
+ return base
470
+
471
+ class ImageStateIE:
472
+ def __init__(self, model_name):
473
+ self.conv_id = uuid.uuid4().hex
474
+ self.model_name = model_name
475
+ self.source_prompt = None
476
+ self.target_prompt = None
477
+ self.instruct_prompt = None
478
+ self.source_image = None
479
+ self.output = None
480
+
481
+ def dict(self):
482
+ base = {
483
+ "conv_id": self.conv_id,
484
+ "model_name": self.model_name,
485
+ "source_prompt": self.source_prompt,
486
+ "target_prompt": self.target_prompt,
487
+ "instruct_prompt": self.instruct_prompt
488
+ }
489
+ return base
490
+
491
+ class VideoStateVG:
492
+ def __init__(self, model_name):
493
+ self.conv_id = uuid.uuid4().hex
494
+ self.model_name = model_name
495
+ self.prompt = None
496
+ self.output = None
497
+
498
+ def dict(self):
499
+ base = {
500
+ "conv_id": self.conv_id,
501
+ "model_name": self.model_name,
502
+ "prompt": self.prompt
503
+ }
504
+ return base
505
+
506
+
507
+ def generate_ig(gen_func, state, text, model_name, request: gr.Request):
508
+ if not text:
509
+ raise gr.Warning("Prompt cannot be empty.")
510
+ if not model_name:
511
+ raise gr.Warning("Model name cannot be empty.")
512
+ state = ImageStateIG(model_name)
513
+ ip = get_ip(request)
514
+ ig_logger.info(f"generate. ip: {ip}")
515
+ start_tstamp = time.time()
516
+ generated_image = gen_func(text, model_name)
517
+ state.prompt = text
518
+ state.output = generated_image
519
+ state.model_name = model_name
520
+ if generated_image == '':
521
+ with open(get_conv_log_filename(), "a") as fout:
522
+ data = {
523
+ "type": "chat",
524
+ "model": model_name,
525
+ "gen_params": {},
526
+ "start": round(start_tstamp, 4),
527
+ "state": state.dict(),
528
+ "ip": get_ip(request),
529
+ }
530
+ fout.write(json.dumps(data) + "\n")
531
+ append_json_item_on_log_server(data, get_conv_log_filename())
532
+ raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
533
+
534
+ yield state, generated_image
535
+
536
+ finish_tstamp = time.time()
537
+ # logger.info(f"===output===: {output}")
538
+
539
+ with open(get_conv_log_filename(), "a") as fout:
540
+ data = {
541
+ "tstamp": round(finish_tstamp, 4),
542
+ "type": "chat",
543
+ "model": model_name,
544
+ "gen_params": {},
545
+ "start": round(start_tstamp, 4),
546
+ "finish": round(finish_tstamp, 4),
547
+ "state": state.dict(),
548
+ "ip": get_ip(request),
549
+ }
550
+ fout.write(json.dumps(data) + "\n")
551
+ append_json_item_on_log_server(data, get_conv_log_filename())
552
+
553
+ output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
554
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
555
+ with open(output_file, 'w') as f:
556
+ save_any_image(state.output, f)
557
+ save_image_file_on_log_server(output_file)
558
+
559
+ def generate_ig_museum(gen_func, state, model_name, request: gr.Request):
560
+ if not model_name:
561
+ raise gr.Warning("Model name cannot be empty.")
562
+ state = ImageStateIG(model_name)
563
+ ip = get_ip(request)
564
+ ig_logger.info(f"generate. ip: {ip}")
565
+ start_tstamp = time.time()
566
+ generated_image, text = gen_func(model_name)
567
+ state.prompt = text
568
+ state.output = generated_image
569
+ state.model_name = model_name
570
+
571
+ yield state, generated_image, text
572
+
573
+ finish_tstamp = time.time()
574
+ # logger.info(f"===output===: {output}")
575
+
576
+ with open(get_conv_log_filename(), "a") as fout:
577
+ data = {
578
+ "tstamp": round(finish_tstamp, 4),
579
+ "type": "chat",
580
+ "model": model_name,
581
+ "gen_params": {},
582
+ "start": round(start_tstamp, 4),
583
+ "finish": round(finish_tstamp, 4),
584
+ "state": state.dict(),
585
+ "ip": get_ip(request),
586
+ }
587
+ fout.write(json.dumps(data) + "\n")
588
+ append_json_item_on_log_server(data, get_conv_log_filename())
589
+
590
+ output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
591
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
592
+ with open(output_file, 'w') as f:
593
+ save_any_image(state.output, f)
594
+ save_image_file_on_log_server(output_file)
595
+
596
+ def generate_igm(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request):
597
+ if not text:
598
+ raise gr.Warning("Prompt cannot be empty.")
599
+ if not model_name0:
600
+ raise gr.Warning("Model name A cannot be empty.")
601
+ if not model_name1:
602
+ raise gr.Warning("Model name B cannot be empty.")
603
+ state0 = ImageStateIG(model_name0)
604
+ state1 = ImageStateIG(model_name1)
605
+ ip = get_ip(request)
606
+ igm_logger.info(f"generate. ip: {ip}")
607
+ start_tstamp = time.time()
608
+ # Remove ### Model (A|B): from model name
609
+ model_name0 = re.sub(r"### Model A: ", "", model_name0)
610
+ model_name1 = re.sub(r"### Model B: ", "", model_name1)
611
+ generated_image0, generated_image1 = gen_func(text, model_name0, model_name1)
612
+ state0.prompt = text
613
+ state1.prompt = text
614
+ state0.output = generated_image0
615
+ state1.output = generated_image1
616
+ state0.model_name = model_name0
617
+ state1.model_name = model_name1
618
+ if generated_image0 == '' and generated_image1 == '':
619
+ with open(get_conv_log_filename(), "a") as fout:
620
+ data = {
621
+ "type": "chat",
622
+ "model": model_name0,
623
+ "gen_params": {},
624
+ "start": round(start_tstamp, 4),
625
+ "state": state0.dict(),
626
+ "ip": get_ip(request),
627
+ }
628
+ fout.write(json.dumps(data) + "\n")
629
+ append_json_item_on_log_server(data, get_conv_log_filename())
630
+ data = {
631
+ "type": "chat",
632
+ "model": model_name1,
633
+ "gen_params": {},
634
+ "start": round(start_tstamp, 4),
635
+ "state": state1.dict(),
636
+ "ip": get_ip(request),
637
+ }
638
+ fout.write(json.dumps(data) + "\n")
639
+ append_json_item_on_log_server(data, get_conv_log_filename())
640
+ raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
641
+
642
+ yield state0, state1, generated_image0, generated_image1
643
+
644
+ finish_tstamp = time.time()
645
+ # logger.info(f"===output===: {output}")
646
+
647
+ with open(get_conv_log_filename(), "a") as fout:
648
+ data = {
649
+ "tstamp": round(finish_tstamp, 4),
650
+ "type": "chat",
651
+ "model": model_name0,
652
+ "gen_params": {},
653
+ "start": round(start_tstamp, 4),
654
+ "finish": round(finish_tstamp, 4),
655
+ "state": state0.dict(),
656
+ "ip": get_ip(request),
657
+ }
658
+ fout.write(json.dumps(data) + "\n")
659
+ append_json_item_on_log_server(data, get_conv_log_filename())
660
+ data = {
661
+ "tstamp": round(finish_tstamp, 4),
662
+ "type": "chat",
663
+ "model": model_name1,
664
+ "gen_params": {},
665
+ "start": round(start_tstamp, 4),
666
+ "finish": round(finish_tstamp, 4),
667
+ "state": state1.dict(),
668
+ "ip": get_ip(request),
669
+ }
670
+ fout.write(json.dumps(data) + "\n")
671
+ append_json_item_on_log_server(data, get_conv_log_filename())
672
+
673
+ for i, state in enumerate([state0, state1]):
674
+ output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
675
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
676
+ with open(output_file, 'w') as f:
677
+ save_any_image(state.output, f)
678
+ save_image_file_on_log_server(output_file)
679
+
680
+ def generate_igm_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request):
681
+ if not model_name0:
682
+ raise gr.Warning("Model name A cannot be empty.")
683
+ if not model_name1:
684
+ raise gr.Warning("Model name B cannot be empty.")
685
+ state0 = ImageStateIG(model_name0)
686
+ state1 = ImageStateIG(model_name1)
687
+ ip = get_ip(request)
688
+ igm_logger.info(f"generate. ip: {ip}")
689
+ start_tstamp = time.time()
690
+ # Remove ### Model (A|B): from model name
691
+ model_name0 = re.sub(r"### Model A: ", "", model_name0)
692
+ model_name1 = re.sub(r"### Model B: ", "", model_name1)
693
+ generated_image0, generated_image1, text = gen_func(model_name0, model_name1)
694
+ state0.prompt = text
695
+ state1.prompt = text
696
+ state0.output = generated_image0
697
+ state1.output = generated_image1
698
+ state0.model_name = model_name0
699
+ state1.model_name = model_name1
700
+
701
+ yield state0, state1, generated_image0, generated_image1, text
702
+
703
+ finish_tstamp = time.time()
704
+ # logger.info(f"===output===: {output}")
705
+
706
+ with open(get_conv_log_filename(), "a") as fout:
707
+ data = {
708
+ "tstamp": round(finish_tstamp, 4),
709
+ "type": "chat",
710
+ "model": model_name0,
711
+ "gen_params": {},
712
+ "start": round(start_tstamp, 4),
713
+ "finish": round(finish_tstamp, 4),
714
+ "state": state0.dict(),
715
+ "ip": get_ip(request),
716
+ }
717
+ fout.write(json.dumps(data) + "\n")
718
+ append_json_item_on_log_server(data, get_conv_log_filename())
719
+ data = {
720
+ "tstamp": round(finish_tstamp, 4),
721
+ "type": "chat",
722
+ "model": model_name1,
723
+ "gen_params": {},
724
+ "start": round(start_tstamp, 4),
725
+ "finish": round(finish_tstamp, 4),
726
+ "state": state1.dict(),
727
+ "ip": get_ip(request),
728
+ }
729
+ fout.write(json.dumps(data) + "\n")
730
+ append_json_item_on_log_server(data, get_conv_log_filename())
731
+
732
+ for i, state in enumerate([state0, state1]):
733
+ output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
734
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
735
+ with open(output_file, 'w') as f:
736
+ save_any_image(state.output, f)
737
+ save_image_file_on_log_server(output_file)
738
+
739
+
740
+ def generate_igm_annoy(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request):
741
+ if not text:
742
+ raise gr.Warning("Prompt cannot be empty.")
743
+ state0 = ImageStateIG(model_name0)
744
+ state1 = ImageStateIG(model_name1)
745
+ ip = get_ip(request)
746
+ igm_logger.info(f"generate. ip: {ip}")
747
+ start_tstamp = time.time()
748
+ model_name0 = ""
749
+ model_name1 = ""
750
+ generated_image0, generated_image1, model_name0, model_name1 = gen_func(text, model_name0, model_name1)
751
+ state0.prompt = text
752
+ state1.prompt = text
753
+ state0.output = generated_image0
754
+ state1.output = generated_image1
755
+ state0.model_name = model_name0
756
+ state1.model_name = model_name1
757
+ if generated_image0 == '' and generated_image1 == '':
758
+ with open(get_conv_log_filename(), "a") as fout:
759
+ data = {
760
+ "type": "chat",
761
+ "model": model_name0,
762
+ "gen_params": {},
763
+ "start": round(start_tstamp, 4),
764
+ "state": state0.dict(),
765
+ "ip": get_ip(request),
766
+ }
767
+ fout.write(json.dumps(data) + "\n")
768
+ append_json_item_on_log_server(data, get_conv_log_filename())
769
+ data = {
770
+ "type": "chat",
771
+ "model": model_name1,
772
+ "gen_params": {},
773
+ "start": round(start_tstamp, 4),
774
+ "state": state1.dict(),
775
+ "ip": get_ip(request),
776
+ }
777
+ fout.write(json.dumps(data) + "\n")
778
+ append_json_item_on_log_server(data, get_conv_log_filename())
779
+ raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
780
+
781
+
782
+ yield state0, state1, generated_image0, generated_image1, \
783
+ gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False)
784
+
785
+ finish_tstamp = time.time()
786
+ # logger.info(f"===output===: {output}")
787
+
788
+ with open(get_conv_log_filename(), "a") as fout:
789
+ data = {
790
+ "tstamp": round(finish_tstamp, 4),
791
+ "type": "chat",
792
+ "model": model_name0,
793
+ "gen_params": {},
794
+ "start": round(start_tstamp, 4),
795
+ "finish": round(finish_tstamp, 4),
796
+ "state": state0.dict(),
797
+ "ip": get_ip(request),
798
+ }
799
+ fout.write(json.dumps(data) + "\n")
800
+ append_json_item_on_log_server(data, get_conv_log_filename())
801
+ data = {
802
+ "tstamp": round(finish_tstamp, 4),
803
+ "type": "chat",
804
+ "model": model_name1,
805
+ "gen_params": {},
806
+ "start": round(start_tstamp, 4),
807
+ "finish": round(finish_tstamp, 4),
808
+ "state": state1.dict(),
809
+ "ip": get_ip(request),
810
+ }
811
+ fout.write(json.dumps(data) + "\n")
812
+ append_json_item_on_log_server(data, get_conv_log_filename())
813
+
814
+ for i, state in enumerate([state0, state1]):
815
+ output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
816
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
817
+ with open(output_file, 'w') as f:
818
+ save_any_image(state.output, f)
819
+ save_image_file_on_log_server(output_file)
820
+
821
+ def generate_igm_annoy_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request):
822
+ state0 = ImageStateIG(model_name0)
823
+ state1 = ImageStateIG(model_name1)
824
+ ip = get_ip(request)
825
+ igm_logger.info(f"generate. ip: {ip}")
826
+ start_tstamp = time.time()
827
+ # model_name0 = re.sub(r"### Model A: ", "", model_name0)
828
+ # model_name1 = re.sub(r"### Model B: ", "", model_name1)
829
+ model_name0 = ""
830
+ model_name1 = ""
831
+ generated_image0, generated_image1, model_name0, model_name1, text = gen_func(model_name0, model_name1)
832
+ state0.prompt = text
833
+ state1.prompt = text
834
+ state0.output = generated_image0
835
+ state1.output = generated_image1
836
+ state0.model_name = model_name0
837
+ state1.model_name = model_name1
838
+
839
+ yield state0, state1, generated_image0, generated_image1, text,\
840
+ gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False)
841
+
842
+ finish_tstamp = time.time()
843
+ # logger.info(f"===output===: {output}")
844
+
845
+ with open(get_conv_log_filename(), "a") as fout:
846
+ data = {
847
+ "tstamp": round(finish_tstamp, 4),
848
+ "type": "chat",
849
+ "model": model_name0,
850
+ "gen_params": {},
851
+ "start": round(start_tstamp, 4),
852
+ "finish": round(finish_tstamp, 4),
853
+ "state": state0.dict(),
854
+ "ip": get_ip(request),
855
+ }
856
+ fout.write(json.dumps(data) + "\n")
857
+ append_json_item_on_log_server(data, get_conv_log_filename())
858
+ data = {
859
+ "tstamp": round(finish_tstamp, 4),
860
+ "type": "chat",
861
+ "model": model_name1,
862
+ "gen_params": {},
863
+ "start": round(start_tstamp, 4),
864
+ "finish": round(finish_tstamp, 4),
865
+ "state": state1.dict(),
866
+ "ip": get_ip(request),
867
+ }
868
+ fout.write(json.dumps(data) + "\n")
869
+ append_json_item_on_log_server(data, get_conv_log_filename())
870
+
871
+ for i, state in enumerate([state0, state1]):
872
+ output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
873
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
874
+ with open(output_file, 'w') as f:
875
+ save_any_image(state.output, f)
876
+ save_image_file_on_log_server(output_file)
877
+
878
+ def generate_ie(gen_func, state, source_text, target_text, instruct_text, source_image, model_name, request: gr.Request):
879
+ if not source_text:
880
+ raise gr.Warning("Source prompt cannot be empty.")
881
+ if not target_text:
882
+ raise gr.Warning("Target prompt cannot be empty.")
883
+ if not instruct_text:
884
+ raise gr.Warning("Instruction prompt cannot be empty.")
885
+ if not source_image:
886
+ raise gr.Warning("Source image cannot be empty.")
887
+ if not model_name:
888
+ raise gr.Warning("Model name cannot be empty.")
889
+ state = ImageStateIE(model_name)
890
+ ip = get_ip(request)
891
+ ig_logger.info(f"generate. ip: {ip}")
892
+ start_tstamp = time.time()
893
+ generated_image = gen_func(source_text, target_text, instruct_text, source_image, model_name)
894
+ state.source_prompt = source_text
895
+ state.target_prompt = target_text
896
+ state.instruct_prompt = instruct_text
897
+ state.source_image = source_image
898
+ state.output = generated_image
899
+ state.model_name = model_name
900
+
901
+ if generated_image == '':
902
+ with open(get_conv_log_filename(), "a") as fout:
903
+ data = {
904
+ "type": "chat",
905
+ "model": model_name,
906
+ "gen_params": {},
907
+ "start": round(start_tstamp, 4),
908
+ "state": state.dict(),
909
+ "ip": get_ip(request),
910
+ }
911
+ fout.write(json.dumps(data) + "\n")
912
+ append_json_item_on_log_server(data, get_conv_log_filename())
913
+ raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
914
+
915
+ yield state, generated_image
916
+
917
+ finish_tstamp = time.time()
918
+ # logger.info(f"===output===: {output}")
919
+
920
+ with open(get_conv_log_filename(), "a") as fout:
921
+ data = {
922
+ "tstamp": round(finish_tstamp, 4),
923
+ "type": "chat",
924
+ "model": model_name,
925
+ "gen_params": {},
926
+ "start": round(start_tstamp, 4),
927
+ "finish": round(finish_tstamp, 4),
928
+ "state": state.dict(),
929
+ "ip": get_ip(request),
930
+ }
931
+ fout.write(json.dumps(data) + "\n")
932
+ append_json_item_on_log_server(data, get_conv_log_filename())
933
+
934
+ src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg'
935
+ os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
936
+ with open(src_img_file, 'w') as f:
937
+ save_any_image(state.source_image, f)
938
+ output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg'
939
+ with open(output_file, 'w') as f:
940
+ save_any_image(state.output, f)
941
+ save_image_file_on_log_server(src_img_file)
942
+ save_image_file_on_log_server(output_file)
943
+
944
+ def generate_ie_museum(gen_func, state, model_name, request: gr.Request):
945
+ if not model_name:
946
+ raise gr.Warning("Model name cannot be empty.")
947
+ state = ImageStateIE(model_name)
948
+ ip = get_ip(request)
949
+ ig_logger.info(f"generate. ip: {ip}")
950
+ start_tstamp = time.time()
951
+ source_image, generated_image, source_text, target_text, instruct_text = gen_func(model_name)
952
+ state.source_prompt = source_text
953
+ state.target_prompt = target_text
954
+ state.instruct_prompt = instruct_text
955
+ state.source_image = source_image
956
+ state.output = generated_image
957
+ state.model_name = model_name
958
+
959
+ yield state, generated_image, source_image, source_text, target_text, instruct_text
960
+
961
+ finish_tstamp = time.time()
962
+ # logger.info(f"===output===: {output}")
963
+
964
+ with open(get_conv_log_filename(), "a") as fout:
965
+ data = {
966
+ "tstamp": round(finish_tstamp, 4),
967
+ "type": "chat",
968
+ "model": model_name,
969
+ "gen_params": {},
970
+ "start": round(start_tstamp, 4),
971
+ "finish": round(finish_tstamp, 4),
972
+ "state": state.dict(),
973
+ "ip": get_ip(request),
974
+ }
975
+ fout.write(json.dumps(data) + "\n")
976
+ append_json_item_on_log_server(data, get_conv_log_filename())
977
+
978
+ src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg'
979
+ os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
980
+ with open(src_img_file, 'w') as f:
981
+ save_any_image(state.source_image, f)
982
+ output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg'
983
+ with open(output_file, 'w') as f:
984
+ save_any_image(state.output, f)
985
+ save_image_file_on_log_server(src_img_file)
986
+ save_image_file_on_log_server(output_file)
987
+
988
+
989
+ def generate_iem(gen_func, state0, state1, source_text, target_text, instruct_text, source_image, model_name0, model_name1, request: gr.Request):
990
+ if not source_text:
991
+ raise gr.Warning("Source prompt cannot be empty.")
992
+ if not target_text:
993
+ raise gr.Warning("Target prompt cannot be empty.")
994
+ if not instruct_text:
995
+ raise gr.Warning("Instruction prompt cannot be empty.")
996
+ if not source_image:
997
+ raise gr.Warning("Source image cannot be empty.")
998
+ if not model_name0:
999
+ raise gr.Warning("Model name A cannot be empty.")
1000
+ if not model_name1:
1001
+ raise gr.Warning("Model name B cannot be empty.")
1002
+ state0 = ImageStateIE(model_name0)
1003
+ state1 = ImageStateIE(model_name1)
1004
+ ip = get_ip(request)
1005
+ igm_logger.info(f"generate. ip: {ip}")
1006
+ start_tstamp = time.time()
1007
+ model_name0 = re.sub(r"### Model A: ", "", model_name0)
1008
+ model_name1 = re.sub(r"### Model B: ", "", model_name1)
1009
+ generated_image0, generated_image1 = gen_func(source_text, target_text, instruct_text, source_image, model_name0, model_name1)
1010
+ state0.source_prompt = source_text
1011
+ state0.target_prompt = target_text
1012
+ state0.instruct_prompt = instruct_text
1013
+ state0.source_image = source_image
1014
+ state0.output = generated_image0
1015
+ state0.model_name = model_name0
1016
+ state1.source_prompt = source_text
1017
+ state1.target_prompt = target_text
1018
+ state1.instruct_prompt = instruct_text
1019
+ state1.source_image = source_image
1020
+ state1.output = generated_image1
1021
+ state1.model_name = model_name1
1022
+
1023
+ if generated_image0 == '' and generated_image1 == '':
1024
+ with open(get_conv_log_filename(), "a") as fout:
1025
+ data = {
1026
+ "type": "chat",
1027
+ "model": model_name0,
1028
+ "gen_params": {},
1029
+ "start": round(start_tstamp, 4),
1030
+ "state": state0.dict(),
1031
+ "ip": get_ip(request),
1032
+ }
1033
+ fout.write(json.dumps(data) + "\n")
1034
+ append_json_item_on_log_server(data, get_conv_log_filename())
1035
+ data = {
1036
+ "type": "chat",
1037
+ "model": model_name1,
1038
+ "gen_params": {},
1039
+ "start": round(start_tstamp, 4),
1040
+ "state": state1.dict(),
1041
+ "ip": get_ip(request),
1042
+ }
1043
+ fout.write(json.dumps(data) + "\n")
1044
+ append_json_item_on_log_server(data, get_conv_log_filename())
1045
+ raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
1046
+
1047
+
1048
+ yield state0, state1, generated_image0, generated_image1
1049
+
1050
+ finish_tstamp = time.time()
1051
+ # logger.info(f"===output===: {output}")
1052
+
1053
+ with open(get_conv_log_filename(), "a") as fout:
1054
+ data = {
1055
+ "tstamp": round(finish_tstamp, 4),
1056
+ "type": "chat",
1057
+ "model": model_name0,
1058
+ "gen_params": {},
1059
+ "start": round(start_tstamp, 4),
1060
+ "finish": round(finish_tstamp, 4),
1061
+ "state": state0.dict(),
1062
+ "ip": get_ip(request),
1063
+ }
1064
+ fout.write(json.dumps(data) + "\n")
1065
+ append_json_item_on_log_server(data, get_conv_log_filename())
1066
+ data = {
1067
+ "tstamp": round(finish_tstamp, 4),
1068
+ "type": "chat",
1069
+ "model": model_name1,
1070
+ "gen_params": {},
1071
+ "start": round(start_tstamp, 4),
1072
+ "finish": round(finish_tstamp, 4),
1073
+ "state": state1.dict(),
1074
+ "ip": get_ip(request),
1075
+ }
1076
+ fout.write(json.dumps(data) + "\n")
1077
+ append_json_item_on_log_server(data, get_conv_log_filename())
1078
+
1079
+ for i, state in enumerate([state0, state1]):
1080
+ src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg'
1081
+ os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
1082
+ with open(src_img_file, 'w') as f:
1083
+ save_any_image(state.source_image, f)
1084
+ output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg'
1085
+ with open(output_file, 'w') as f:
1086
+ save_any_image(state.output, f)
1087
+ save_image_file_on_log_server(src_img_file)
1088
+ save_image_file_on_log_server(output_file)
1089
+
1090
+ def generate_iem_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request):
1091
+ if not model_name0:
1092
+ raise gr.Warning("Model name A cannot be empty.")
1093
+ if not model_name1:
1094
+ raise gr.Warning("Model name B cannot be empty.")
1095
+ state0 = ImageStateIE(model_name0)
1096
+ state1 = ImageStateIE(model_name1)
1097
+ ip = get_ip(request)
1098
+ igm_logger.info(f"generate. ip: {ip}")
1099
+ start_tstamp = time.time()
1100
+ model_name0 = re.sub(r"### Model A: ", "", model_name0)
1101
+ model_name1 = re.sub(r"### Model B: ", "", model_name1)
1102
+ source_image, generated_image0, generated_image1, source_text, target_text, instruct_text = gen_func(model_name0, model_name1)
1103
+ state0.source_prompt = source_text
1104
+ state0.target_prompt = target_text
1105
+ state0.instruct_prompt = instruct_text
1106
+ state0.source_image = source_image
1107
+ state0.output = generated_image0
1108
+ state0.model_name = model_name0
1109
+ state1.source_prompt = source_text
1110
+ state1.target_prompt = target_text
1111
+ state1.instruct_prompt = instruct_text
1112
+ state1.source_image = source_image
1113
+ state1.output = generated_image1
1114
+ state1.model_name = model_name1
1115
+
1116
+ yield state0, state1, generated_image0, generated_image1, source_image, source_text, target_text, instruct_text
1117
+
1118
+ finish_tstamp = time.time()
1119
+ # logger.info(f"===output===: {output}")
1120
+
1121
+ with open(get_conv_log_filename(), "a") as fout:
1122
+ data = {
1123
+ "tstamp": round(finish_tstamp, 4),
1124
+ "type": "chat",
1125
+ "model": model_name0,
1126
+ "gen_params": {},
1127
+ "start": round(start_tstamp, 4),
1128
+ "finish": round(finish_tstamp, 4),
1129
+ "state": state0.dict(),
1130
+ "ip": get_ip(request),
1131
+ }
1132
+ fout.write(json.dumps(data) + "\n")
1133
+ append_json_item_on_log_server(data, get_conv_log_filename())
1134
+ data = {
1135
+ "tstamp": round(finish_tstamp, 4),
1136
+ "type": "chat",
1137
+ "model": model_name1,
1138
+ "gen_params": {},
1139
+ "start": round(start_tstamp, 4),
1140
+ "finish": round(finish_tstamp, 4),
1141
+ "state": state1.dict(),
1142
+ "ip": get_ip(request),
1143
+ }
1144
+ fout.write(json.dumps(data) + "\n")
1145
+ append_json_item_on_log_server(data, get_conv_log_filename())
1146
+
1147
+ for i, state in enumerate([state0, state1]):
1148
+ src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg'
1149
+ os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
1150
+ with open(src_img_file, 'w') as f:
1151
+ save_any_image(state.source_image, f)
1152
+ output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg'
1153
+ with open(output_file, 'w') as f:
1154
+ save_any_image(state.output, f)
1155
+ save_image_file_on_log_server(src_img_file)
1156
+ save_image_file_on_log_server(output_file)
1157
+
1158
+
1159
+ def generate_iem_annoy(gen_func, state0, state1, source_text, target_text, instruct_text, source_image, model_name0, model_name1, request: gr.Request):
1160
+ if not source_text:
1161
+ raise gr.Warning("Source prompt cannot be empty.")
1162
+ if not target_text:
1163
+ raise gr.Warning("Target prompt cannot be empty.")
1164
+ if not instruct_text:
1165
+ raise gr.Warning("Instruction prompt cannot be empty.")
1166
+ if not source_image:
1167
+ raise gr.Warning("Source image cannot be empty.")
1168
+ state0 = ImageStateIE(model_name0)
1169
+ state1 = ImageStateIE(model_name1)
1170
+ ip = get_ip(request)
1171
+ igm_logger.info(f"generate. ip: {ip}")
1172
+ start_tstamp = time.time()
1173
+ model_name0 = ""
1174
+ model_name1 = ""
1175
+ generated_image0, generated_image1, model_name0, model_name1 = gen_func(source_text, target_text, instruct_text, source_image, model_name0, model_name1)
1176
+ state0.source_prompt = source_text
1177
+ state0.target_prompt = target_text
1178
+ state0.instruct_prompt = instruct_text
1179
+ state0.source_image = source_image
1180
+ state0.output = generated_image0
1181
+ state0.model_name = model_name0
1182
+ state1.source_prompt = source_text
1183
+ state1.target_prompt = target_text
1184
+ state1.instruct_prompt = instruct_text
1185
+ state1.source_image = source_image
1186
+ state1.output = generated_image1
1187
+ state1.model_name = model_name1
1188
+ if generated_image0 == '' and generated_image1 == '':
1189
+ with open(get_conv_log_filename(), "a") as fout:
1190
+ data = {
1191
+ "type": "chat",
1192
+ "model": model_name0,
1193
+ "gen_params": {},
1194
+ "start": round(start_tstamp, 4),
1195
+ "state": state0.dict(),
1196
+ "ip": get_ip(request),
1197
+ }
1198
+ fout.write(json.dumps(data) + "\n")
1199
+ append_json_item_on_log_server(data, get_conv_log_filename())
1200
+ data = {
1201
+ "type": "chat",
1202
+ "model": model_name1,
1203
+ "gen_params": {},
1204
+ "start": round(start_tstamp, 4),
1205
+ "state": state1.dict(),
1206
+ "ip": get_ip(request),
1207
+ }
1208
+ fout.write(json.dumps(data) + "\n")
1209
+ append_json_item_on_log_server(data, get_conv_log_filename())
1210
+ raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
1211
+
1212
+
1213
+ yield state0, state1, generated_image0, generated_image1, \
1214
+ gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False)
1215
+
1216
+ finish_tstamp = time.time()
1217
+ # logger.info(f"===output===: {output}")
1218
+
1219
+ with open(get_conv_log_filename(), "a") as fout:
1220
+ data = {
1221
+ "tstamp": round(finish_tstamp, 4),
1222
+ "type": "chat",
1223
+ "model": model_name0,
1224
+ "gen_params": {},
1225
+ "start": round(start_tstamp, 4),
1226
+ "finish": round(finish_tstamp, 4),
1227
+ "state": state0.dict(),
1228
+ "ip": get_ip(request),
1229
+ }
1230
+ fout.write(json.dumps(data) + "\n")
1231
+ append_json_item_on_log_server(data, get_conv_log_filename())
1232
+ data = {
1233
+ "tstamp": round(finish_tstamp, 4),
1234
+ "type": "chat",
1235
+ "model": model_name1,
1236
+ "gen_params": {},
1237
+ "start": round(start_tstamp, 4),
1238
+ "finish": round(finish_tstamp, 4),
1239
+ "state": state1.dict(),
1240
+ "ip": get_ip(request),
1241
+ }
1242
+ fout.write(json.dumps(data) + "\n")
1243
+ append_json_item_on_log_server(data, get_conv_log_filename())
1244
+
1245
+ for i, state in enumerate([state0, state1]):
1246
+ src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg'
1247
+ os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
1248
+ with open(src_img_file, 'w') as f:
1249
+ save_any_image(state.source_image, f)
1250
+ output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg'
1251
+ with open(output_file, 'w') as f:
1252
+ save_any_image(state.output, f)
1253
+ save_image_file_on_log_server(src_img_file)
1254
+ save_image_file_on_log_server(output_file)
1255
+
1256
+ def generate_iem_annoy_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request):
1257
+ state0 = ImageStateIE(model_name0)
1258
+ state1 = ImageStateIE(model_name1)
1259
+ ip = get_ip(request)
1260
+ igm_logger.info(f"generate. ip: {ip}")
1261
+ start_tstamp = time.time()
1262
+ model_name0 = ""
1263
+ model_name1 = ""
1264
+ source_image, generated_image0, generated_image1, source_text, target_text, instruct_text, model_name0, model_name1 = gen_func(model_name0, model_name1)
1265
+ state0.source_prompt = source_text
1266
+ state0.target_prompt = target_text
1267
+ state0.instruct_prompt = instruct_text
1268
+ state0.source_image = source_image
1269
+ state0.output = generated_image0
1270
+ state0.model_name = model_name0
1271
+ state1.source_prompt = source_text
1272
+ state1.target_prompt = target_text
1273
+ state1.instruct_prompt = instruct_text
1274
+ state1.source_image = source_image
1275
+ state1.output = generated_image1
1276
+ state1.model_name = model_name1
1277
+
1278
+ yield state0, state1, generated_image0, generated_image1, source_image, source_text, target_text, instruct_text, \
1279
+ gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False)
1280
+
1281
+ finish_tstamp = time.time()
1282
+ # logger.info(f"===output===: {output}")
1283
+
1284
+ with open(get_conv_log_filename(), "a") as fout:
1285
+ data = {
1286
+ "tstamp": round(finish_tstamp, 4),
1287
+ "type": "chat",
1288
+ "model": model_name0,
1289
+ "gen_params": {},
1290
+ "start": round(start_tstamp, 4),
1291
+ "finish": round(finish_tstamp, 4),
1292
+ "state": state0.dict(),
1293
+ "ip": get_ip(request),
1294
+ }
1295
+ fout.write(json.dumps(data) + "\n")
1296
+ append_json_item_on_log_server(data, get_conv_log_filename())
1297
+ data = {
1298
+ "tstamp": round(finish_tstamp, 4),
1299
+ "type": "chat",
1300
+ "model": model_name1,
1301
+ "gen_params": {},
1302
+ "start": round(start_tstamp, 4),
1303
+ "finish": round(finish_tstamp, 4),
1304
+ "state": state1.dict(),
1305
+ "ip": get_ip(request),
1306
+ }
1307
+ fout.write(json.dumps(data) + "\n")
1308
+ append_json_item_on_log_server(data, get_conv_log_filename())
1309
+
1310
+ for i, state in enumerate([state0, state1]):
1311
+ src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg'
1312
+ os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
1313
+ with open(src_img_file, 'w') as f:
1314
+ save_any_image(state.source_image, f)
1315
+ output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg'
1316
+ with open(output_file, 'w') as f:
1317
+ save_any_image(state.output, f)
1318
+ save_image_file_on_log_server(src_img_file)
1319
+ save_image_file_on_log_server(output_file)
1320
+
1321
+ def generate_vg(gen_func, state, text, model_name, request: gr.Request):
1322
+ if not text:
1323
+ raise gr.Warning("Prompt cannot be empty.")
1324
+ if not model_name:
1325
+ raise gr.Warning("Model name cannot be empty.")
1326
+ state = VideoStateVG(model_name)
1327
+ ip = get_ip(request)
1328
+ vg_logger.info(f"generate. ip: {ip}")
1329
+ start_tstamp = time.time()
1330
+ generated_video = gen_func(text, model_name)
1331
+ state.prompt = text
1332
+ state.output = generated_video
1333
+ state.model_name = model_name
1334
+ if generated_video == '':
1335
+ with open(get_conv_log_filename(), "a") as fout:
1336
+ data = {
1337
+ "type": "chat",
1338
+ "model": model_name,
1339
+ "gen_params": {},
1340
+ "start": round(start_tstamp, 4),
1341
+ "state": state.dict(),
1342
+ "ip": get_ip(request),
1343
+ }
1344
+ fout.write(json.dumps(data) + "\n")
1345
+ append_json_item_on_log_server(data, get_conv_log_filename())
1346
+ raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
1347
+
1348
+
1349
+ # yield state, generated_video
1350
+
1351
+ finish_tstamp = time.time()
1352
+
1353
+ with open(get_conv_log_filename(), "a") as fout:
1354
+ data = {
1355
+ "tstamp": round(finish_tstamp, 4),
1356
+ "type": "chat",
1357
+ "model": model_name,
1358
+ "gen_params": {},
1359
+ "start": round(start_tstamp, 4),
1360
+ "finish": round(finish_tstamp, 4),
1361
+ "state": state.dict(),
1362
+ "ip": get_ip(request),
1363
+ }
1364
+ fout.write(json.dumps(data) + "\n")
1365
+ append_json_item_on_log_server(data, get_conv_log_filename())
1366
+
1367
+ output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
1368
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
1369
+ if model_name.startswith('fal'):
1370
+ r = requests.get(state.output)
1371
+ with open(output_file, 'wb') as outfile:
1372
+ outfile.write(r.content)
1373
+ else:
1374
+ print("======== video shape: ========")
1375
+ print(state.output.shape)
1376
+ # Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels]
1377
+ if state.output.shape[-1] != 3:
1378
+ state.output = state.output.permute(0, 2, 3, 1)
1379
+ imageio.mimwrite(output_file, state.output, fps=8, quality=9)
1380
+
1381
+ save_video_file_on_log_server(output_file)
1382
+ yield state, output_file
1383
+
1384
+ def generate_vg_museum(gen_func, state, model_name, request: gr.Request):
1385
+ state = VideoStateVG(model_name)
1386
+ ip = get_ip(request)
1387
+ vg_logger.info(f"generate. ip: {ip}")
1388
+ start_tstamp = time.time()
1389
+ generated_video, text = gen_func(model_name)
1390
+ state.prompt = text
1391
+ state.output = generated_video
1392
+ state.model_name = model_name
1393
+
1394
+ # yield state, generated_video
1395
+
1396
+ finish_tstamp = time.time()
1397
+
1398
+ with open(get_conv_log_filename(), "a") as fout:
1399
+ data = {
1400
+ "tstamp": round(finish_tstamp, 4),
1401
+ "type": "chat",
1402
+ "model": model_name,
1403
+ "gen_params": {},
1404
+ "start": round(start_tstamp, 4),
1405
+ "finish": round(finish_tstamp, 4),
1406
+ "state": state.dict(),
1407
+ "ip": get_ip(request),
1408
+ }
1409
+ fout.write(json.dumps(data) + "\n")
1410
+ append_json_item_on_log_server(data, get_conv_log_filename())
1411
+
1412
+ output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
1413
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
1414
+
1415
+ r = requests.get(state.output)
1416
+ with open(output_file, 'wb') as outfile:
1417
+ outfile.write(r.content)
1418
+
1419
+ save_video_file_on_log_server(output_file)
1420
+ yield state, output_file, text
1421
+
1422
+
1423
+ def generate_vgm(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request):
1424
+ if not text:
1425
+ raise gr.Warning("Prompt cannot be empty.")
1426
+ if not model_name0:
1427
+ raise gr.Warning("Model name A cannot be empty.")
1428
+ if not model_name1:
1429
+ raise gr.Warning("Model name B cannot be empty.")
1430
+ state0 = VideoStateVG(model_name0)
1431
+ state1 = VideoStateVG(model_name1)
1432
+ ip = get_ip(request)
1433
+ igm_logger.info(f"generate. ip: {ip}")
1434
+ start_tstamp = time.time()
1435
+ # Remove ### Model (A|B): from model name
1436
+ model_name0 = re.sub(r"### Model A: ", "", model_name0)
1437
+ model_name1 = re.sub(r"### Model B: ", "", model_name1)
1438
+ generated_video0, generated_video1 = gen_func(text, model_name0, model_name1)
1439
+ state0.prompt = text
1440
+ state1.prompt = text
1441
+ state0.output = generated_video0
1442
+ state1.output = generated_video1
1443
+ state0.model_name = model_name0
1444
+ state1.model_name = model_name1
1445
+ if generated_video0 == '' and generated_video1 == '':
1446
+ with open(get_conv_log_filename(), "a") as fout:
1447
+ data = {
1448
+ "type": "chat",
1449
+ "model": model_name0,
1450
+ "gen_params": {},
1451
+ "start": round(start_tstamp, 4),
1452
+ "state": state0.dict(),
1453
+ "ip": get_ip(request),
1454
+ }
1455
+ fout.write(json.dumps(data) + "\n")
1456
+ append_json_item_on_log_server(data, get_conv_log_filename())
1457
+ data = {
1458
+ "type": "chat",
1459
+ "model": model_name1,
1460
+ "gen_params": {},
1461
+ "start": round(start_tstamp, 4),
1462
+ "state": state1.dict(),
1463
+ "ip": get_ip(request),
1464
+ }
1465
+ fout.write(json.dumps(data) + "\n")
1466
+ append_json_item_on_log_server(data, get_conv_log_filename())
1467
+ raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
1468
+
1469
+ # yield state0, state1, generated_video0, generated_video1
1470
+ print("====== model name =========")
1471
+ print(state0.model_name)
1472
+ print(state1.model_name)
1473
+
1474
+
1475
+ finish_tstamp = time.time()
1476
+
1477
+
1478
+ with open(get_conv_log_filename(), "a") as fout:
1479
+ data = {
1480
+ "tstamp": round(finish_tstamp, 4),
1481
+ "type": "chat",
1482
+ "model": model_name0,
1483
+ "gen_params": {},
1484
+ "start": round(start_tstamp, 4),
1485
+ "finish": round(finish_tstamp, 4),
1486
+ "state": state0.dict(),
1487
+ "ip": get_ip(request),
1488
+ }
1489
+ fout.write(json.dumps(data) + "\n")
1490
+ append_json_item_on_log_server(data, get_conv_log_filename())
1491
+ data = {
1492
+ "tstamp": round(finish_tstamp, 4),
1493
+ "type": "chat",
1494
+ "model": model_name1,
1495
+ "gen_params": {},
1496
+ "start": round(start_tstamp, 4),
1497
+ "finish": round(finish_tstamp, 4),
1498
+ "state": state1.dict(),
1499
+ "ip": get_ip(request),
1500
+ }
1501
+ fout.write(json.dumps(data) + "\n")
1502
+ append_json_item_on_log_server(data, get_conv_log_filename())
1503
+
1504
+ for i, state in enumerate([state0, state1]):
1505
+ output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
1506
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
1507
+ print(state.model_name)
1508
+
1509
+ if state.model_name.startswith('fal'):
1510
+ r = requests.get(state.output)
1511
+ with open(output_file, 'wb') as outfile:
1512
+ outfile.write(r.content)
1513
+ else:
1514
+ print("======== video shape: ========")
1515
+ print(state.output)
1516
+ print(state.output.shape)
1517
+ # Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels]
1518
+ if state.output.shape[-1] != 3:
1519
+ state.output = state.output.permute(0, 2, 3, 1)
1520
+ imageio.mimwrite(output_file, state.output, fps=8, quality=9)
1521
+ save_video_file_on_log_server(output_file)
1522
+ yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4'
1523
+
1524
+ def generate_vgm_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request):
1525
+ state0 = VideoStateVG(model_name0)
1526
+ state1 = VideoStateVG(model_name1)
1527
+ ip = get_ip(request)
1528
+ igm_logger.info(f"generate. ip: {ip}")
1529
+ start_tstamp = time.time()
1530
+ # Remove ### Model (A|B): from model name
1531
+ model_name0 = re.sub(r"### Model A: ", "", model_name0)
1532
+ model_name1 = re.sub(r"### Model B: ", "", model_name1)
1533
+ generated_video0, generated_video1, text = gen_func(model_name0, model_name1)
1534
+ state0.prompt = text
1535
+ state1.prompt = text
1536
+ state0.output = generated_video0
1537
+ state1.output = generated_video1
1538
+ state0.model_name = model_name0
1539
+ state1.model_name = model_name1
1540
+
1541
+ # yield state0, state1, generated_video0, generated_video1
1542
+ print("====== model name =========")
1543
+ print(state0.model_name)
1544
+ print(state1.model_name)
1545
+
1546
+
1547
+ finish_tstamp = time.time()
1548
+
1549
+
1550
+ with open(get_conv_log_filename(), "a") as fout:
1551
+ data = {
1552
+ "tstamp": round(finish_tstamp, 4),
1553
+ "type": "chat",
1554
+ "model": model_name0,
1555
+ "gen_params": {},
1556
+ "start": round(start_tstamp, 4),
1557
+ "finish": round(finish_tstamp, 4),
1558
+ "state": state0.dict(),
1559
+ "ip": get_ip(request),
1560
+ }
1561
+ fout.write(json.dumps(data) + "\n")
1562
+ append_json_item_on_log_server(data, get_conv_log_filename())
1563
+ data = {
1564
+ "tstamp": round(finish_tstamp, 4),
1565
+ "type": "chat",
1566
+ "model": model_name1,
1567
+ "gen_params": {},
1568
+ "start": round(start_tstamp, 4),
1569
+ "finish": round(finish_tstamp, 4),
1570
+ "state": state1.dict(),
1571
+ "ip": get_ip(request),
1572
+ }
1573
+ fout.write(json.dumps(data) + "\n")
1574
+ append_json_item_on_log_server(data, get_conv_log_filename())
1575
+
1576
+ for i, state in enumerate([state0, state1]):
1577
+ output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
1578
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
1579
+ print(state.model_name)
1580
+
1581
+ r = requests.get(state.output)
1582
+ with open(output_file, 'wb') as outfile:
1583
+ outfile.write(r.content)
1584
+
1585
+ save_video_file_on_log_server(output_file)
1586
+ yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4', text
1587
+
1588
+
1589
+ def generate_vgm_annoy(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request):
1590
+ if not text:
1591
+ raise gr.Warning("Prompt cannot be empty.")
1592
+ state0 = VideoStateVG(model_name0)
1593
+ state1 = VideoStateVG(model_name1)
1594
+ ip = get_ip(request)
1595
+ vgm_logger.info(f"generate. ip: {ip}")
1596
+ start_tstamp = time.time()
1597
+ model_name0 = ""
1598
+ model_name1 = ""
1599
+ generated_video0, generated_video1, model_name0, model_name1 = gen_func(text, model_name0, model_name1)
1600
+ state0.prompt = text
1601
+ state1.prompt = text
1602
+ state0.output = generated_video0
1603
+ state1.output = generated_video1
1604
+ state0.model_name = model_name0
1605
+ state1.model_name = model_name1
1606
+ if generated_video0 == '' and generated_video1 == '':
1607
+ with open(get_conv_log_filename(), "a") as fout:
1608
+ data = {
1609
+ "type": "chat",
1610
+ "model": model_name0,
1611
+ "gen_params": {},
1612
+ "start": round(start_tstamp, 4),
1613
+ "state": state0.dict(),
1614
+ "ip": get_ip(request),
1615
+ }
1616
+ fout.write(json.dumps(data) + "\n")
1617
+ append_json_item_on_log_server(data, get_conv_log_filename())
1618
+ data = {
1619
+ "type": "chat",
1620
+ "model": model_name1,
1621
+ "gen_params": {},
1622
+ "start": round(start_tstamp, 4),
1623
+ "state": state1.dict(),
1624
+ "ip": get_ip(request),
1625
+ }
1626
+ fout.write(json.dumps(data) + "\n")
1627
+ append_json_item_on_log_server(data, get_conv_log_filename())
1628
+ raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!")
1629
+
1630
+
1631
+ # yield state0, state1, generated_video0, generated_video1, \
1632
+ # gr.Markdown(f"### Model A: {model_name0}"), gr.Markdown(f"### Model B: {model_name1}")
1633
+
1634
+ finish_tstamp = time.time()
1635
+ # logger.info(f"===output===: {output}")
1636
+
1637
+ with open(get_conv_log_filename(), "a") as fout:
1638
+ data = {
1639
+ "tstamp": round(finish_tstamp, 4),
1640
+ "type": "chat",
1641
+ "model": model_name0,
1642
+ "gen_params": {},
1643
+ "start": round(start_tstamp, 4),
1644
+ "finish": round(finish_tstamp, 4),
1645
+ "state": state0.dict(),
1646
+ "ip": get_ip(request),
1647
+ }
1648
+ fout.write(json.dumps(data) + "\n")
1649
+ append_json_item_on_log_server(data, get_conv_log_filename())
1650
+ data = {
1651
+ "tstamp": round(finish_tstamp, 4),
1652
+ "type": "chat",
1653
+ "model": model_name1,
1654
+ "gen_params": {},
1655
+ "start": round(start_tstamp, 4),
1656
+ "finish": round(finish_tstamp, 4),
1657
+ "state": state1.dict(),
1658
+ "ip": get_ip(request),
1659
+ }
1660
+ fout.write(json.dumps(data) + "\n")
1661
+ append_json_item_on_log_server(data, get_conv_log_filename())
1662
+
1663
+ for i, state in enumerate([state0, state1]):
1664
+ output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
1665
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
1666
+ if state.model_name.startswith('fal'):
1667
+ r = requests.get(state.output)
1668
+ with open(output_file, 'wb') as outfile:
1669
+ outfile.write(r.content)
1670
+ else:
1671
+ print("======== video shape: ========")
1672
+ print(state.output.shape)
1673
+ # Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels]
1674
+ if state.output.shape[-1] != 3:
1675
+ state.output = state.output.permute(0, 2, 3, 1)
1676
+ imageio.mimwrite(output_file, state.output, fps=8, quality=9)
1677
+ save_video_file_on_log_server(output_file)
1678
+
1679
+ yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4', \
1680
+ gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False)
1681
+
1682
+ def generate_vgm_annoy_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request):
1683
+ state0 = VideoStateVG(model_name0)
1684
+ state1 = VideoStateVG(model_name1)
1685
+ ip = get_ip(request)
1686
+ vgm_logger.info(f"generate. ip: {ip}")
1687
+ start_tstamp = time.time()
1688
+ model_name0 = ""
1689
+ model_name1 = ""
1690
+ generated_video0, generated_video1, model_name0, model_name1, text = gen_func(model_name0, model_name1)
1691
+ state0.prompt = text
1692
+ state1.prompt = text
1693
+ state0.output = generated_video0
1694
+ state1.output = generated_video1
1695
+ state0.model_name = model_name0
1696
+ state1.model_name = model_name1
1697
+
1698
+ # yield state0, state1, generated_video0, generated_video1, \
1699
+ # gr.Markdown(f"### Model A: {model_name0}"), gr.Markdown(f"### Model B: {model_name1}")
1700
+
1701
+ finish_tstamp = time.time()
1702
+ # logger.info(f"===output===: {output}")
1703
+
1704
+ with open(get_conv_log_filename(), "a") as fout:
1705
+ data = {
1706
+ "tstamp": round(finish_tstamp, 4),
1707
+ "type": "chat",
1708
+ "model": model_name0,
1709
+ "gen_params": {},
1710
+ "start": round(start_tstamp, 4),
1711
+ "finish": round(finish_tstamp, 4),
1712
+ "state": state0.dict(),
1713
+ "ip": get_ip(request),
1714
+ }
1715
+ fout.write(json.dumps(data) + "\n")
1716
+ append_json_item_on_log_server(data, get_conv_log_filename())
1717
+ data = {
1718
+ "tstamp": round(finish_tstamp, 4),
1719
+ "type": "chat",
1720
+ "model": model_name1,
1721
+ "gen_params": {},
1722
+ "start": round(start_tstamp, 4),
1723
+ "finish": round(finish_tstamp, 4),
1724
+ "state": state1.dict(),
1725
+ "ip": get_ip(request),
1726
+ }
1727
+ fout.write(json.dumps(data) + "\n")
1728
+ append_json_item_on_log_server(data, get_conv_log_filename())
1729
+
1730
+ for i, state in enumerate([state0, state1]):
1731
+ output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
1732
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
1733
+
1734
+ r = requests.get(state.output)
1735
+ with open(output_file, 'wb') as outfile:
1736
+ outfile.write(r.content)
1737
+
1738
+ save_video_file_on_log_server(output_file)
1739
+
1740
+ yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4', text,\
1741
+ gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False)