Spaces:
Running
on
Zero
Running
on
Zero
import datetime | |
import time | |
import json | |
import uuid | |
import gradio as gr | |
import regex as re | |
from pathlib import Path | |
from .utils import * | |
from .log_utils import build_logger | |
from .constants import IMAGE_DIR, VIDEO_DIR | |
import imageio | |
from diffusers.utils import load_image | |
import torch | |
ig_logger = build_logger("gradio_web_server_image_generation", "gr_web_image_generation.log") # ig = image generation, loggers for single model direct chat | |
igm_logger = build_logger("gradio_web_server_image_generation_multi", "gr_web_image_generation_multi.log") # igm = image generation multi, loggers for side-by-side and battle | |
ie_logger = build_logger("gradio_web_server_image_editing", "gr_web_image_editing.log") # ie = image editing, loggers for single model direct chat | |
iem_logger = build_logger("gradio_web_server_image_editing_multi", "gr_web_image_editing_multi.log") # iem = image editing multi, loggers for side-by-side and battle | |
vg_logger = build_logger("gradio_web_server_video_generation", "gr_web_video_generation.log") # vg = video generation, loggers for single model direct chat | |
vgm_logger = build_logger("gradio_web_server_video_generation_multi", "gr_web_video_generation_multi.log") # vgm = video generation multi, loggers for side-by-side and battle | |
def save_any_image(image_file, file_path): | |
if isinstance(image_file, str): | |
image = load_image(image_file) | |
image.save(file_path, 'JPEG') | |
else: | |
image_file.save(file_path, 'JPEG') | |
def vote_last_response_ig(state, vote_type, model_selector, request: gr.Request): | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"tstamp": round(time.time(), 4), | |
"type": vote_type, | |
"model": model_selector, | |
"state": state.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg' | |
with open(output_file, 'w') as f: | |
save_any_image(state.output, f) | |
save_image_file_on_log_server(output_file) | |
def vote_last_response_igm(states, vote_type, model_selectors, request: gr.Request): | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"tstamp": round(time.time(), 4), | |
"type": vote_type, | |
"models": [x for x in model_selectors], | |
"states": [x.dict() for x in states], | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
for state in states: | |
output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg' | |
with open(output_file, 'w') as f: | |
save_any_image(state.output, f) | |
save_image_file_on_log_server(output_file) | |
def vote_last_response_ie(state, vote_type, model_selector, request: gr.Request): | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"tstamp": round(time.time(), 4), | |
"type": vote_type, | |
"model": model_selector, | |
"state": state.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
output_file = f'{IMAGE_DIR}/edition/{state.conv_id}.jpg' | |
source_file = f'{IMAGE_DIR}/edition/{state.conv_id}_source.jpg' | |
with open(output_file, 'w') as f: | |
save_any_image(state.output, f) | |
with open(source_file, 'w') as sf: | |
save_any_image(state.source_image, sf) | |
save_image_file_on_log_server(output_file) | |
save_image_file_on_log_server(source_file) | |
def vote_last_response_iem(states, vote_type, model_selectors, request: gr.Request): | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"tstamp": round(time.time(), 4), | |
"type": vote_type, | |
"models": [x for x in model_selectors], | |
"states": [x.dict() for x in states], | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
for state in states: | |
output_file = f'{IMAGE_DIR}/edition/{state.conv_id}.jpg' | |
source_file = f'{IMAGE_DIR}/edition/{state.conv_id}_source.jpg' | |
with open(output_file, 'w') as f: | |
save_any_image(state.output, f) | |
with open(source_file, 'w') as sf: | |
save_any_image(state.source_image, sf) | |
save_image_file_on_log_server(output_file) | |
save_image_file_on_log_server(source_file) | |
def vote_last_response_vg(state, vote_type, model_selector, request: gr.Request): | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"tstamp": round(time.time(), 4), | |
"type": vote_type, | |
"model": model_selector, | |
"state": state.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4' | |
os.makedirs(os.path.dirname(output_file), exist_ok=True) | |
if state.model_name.startswith('fal'): | |
r = requests.get(state.output) | |
with open(output_file, 'wb') as outfile: | |
outfile.write(r.content) | |
else: | |
print("======== video shape: ========") | |
print(state.output.shape) | |
# Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels] | |
if state.output.shape[-1] != 3: | |
state.output = state.output.permute(0, 2, 3, 1) | |
imageio.mimwrite(output_file, state.output, fps=8, quality=9) | |
save_video_file_on_log_server(output_file) | |
def vote_last_response_vgm(states, vote_type, model_selectors, request: gr.Request): | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"tstamp": round(time.time(), 4), | |
"type": vote_type, | |
"models": [x for x in model_selectors], | |
"states": [x.dict() for x in states], | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
for state in states: | |
output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4' | |
os.makedirs(os.path.dirname(output_file), exist_ok=True) | |
if state.model_name.startswith('fal'): | |
r = requests.get(state.output) | |
with open(output_file, 'wb') as outfile: | |
outfile.write(r.content) | |
elif isinstance(state.output, torch.Tensor): | |
print("======== video shape: ========") | |
print(state.output.shape) | |
# Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels] | |
if state.output.shape[-1] != 3: | |
state.output = state.output.permute(0, 2, 3, 1) | |
imageio.mimwrite(output_file, state.output, fps=8, quality=9) | |
else: | |
r = requests.get(state.output) | |
with open(output_file, 'wb') as outfile: | |
outfile.write(r.content) | |
save_video_file_on_log_server(output_file) | |
## Image Generation (IG) Single Model Direct Chat | |
def upvote_last_response_ig(state, model_selector, request: gr.Request): | |
ip = get_ip(request) | |
ig_logger.info(f"upvote. ip: {ip}") | |
vote_last_response_ig(state, "upvote", model_selector, request) | |
return ("",) + (disable_btn,) * 3 | |
def downvote_last_response_ig(state, model_selector, request: gr.Request): | |
ip = get_ip(request) | |
ig_logger.info(f"downvote. ip: {ip}") | |
vote_last_response_ig(state, "downvote", model_selector, request) | |
return ("",) + (disable_btn,) * 3 | |
def flag_last_response_ig(state, model_selector, request: gr.Request): | |
ip = get_ip(request) | |
ig_logger.info(f"flag. ip: {ip}") | |
vote_last_response_ig(state, "flag", model_selector, request) | |
return ("",) + (disable_btn,) * 3 | |
## Image Generation Multi (IGM) Side-by-Side and Battle | |
def leftvote_last_response_igm( | |
state0, state1, model_selector0, model_selector1, request: gr.Request | |
): | |
igm_logger.info(f"leftvote (named). ip: {get_ip(request)}") | |
vote_last_response_igm( | |
[state0, state1], "leftvote", [model_selector0, model_selector1], request | |
) | |
if model_selector0 == "": | |
return ("",) + (disable_btn,) * 4 + ( | |
gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), | |
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) | |
else: | |
return ("",) + (disable_btn,) * 4 + (gr.Markdown(state0.model_name, visible=True), | |
gr.Markdown(state1.model_name, visible=True)) | |
def rightvote_last_response_igm( | |
state0, state1, model_selector0, model_selector1, request: gr.Request | |
): | |
igm_logger.info(f"rightvote (named). ip: {get_ip(request)}") | |
vote_last_response_igm( | |
[state0, state1], "rightvote", [model_selector0, model_selector1], request | |
) | |
print(model_selector0) | |
if model_selector0 == "": | |
return ("",) + (disable_btn,) * 4 + (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) | |
else: | |
return ("",) + (disable_btn,) * 4 + (gr.Markdown(state0.model_name, visible=True), | |
gr.Markdown(state1.model_name, visible=True)) | |
def tievote_last_response_igm( | |
state0, state1, model_selector0, model_selector1, request: gr.Request | |
): | |
igm_logger.info(f"tievote (named). ip: {get_ip(request)}") | |
vote_last_response_igm( | |
[state0, state1], "tievote", [model_selector0, model_selector1], request | |
) | |
if model_selector0 == "": | |
return ("",) + (disable_btn,) * 4 + ( | |
gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), | |
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) | |
else: | |
return ("",) + (disable_btn,) * 4 + (gr.Markdown(state0.model_name, visible=True), | |
gr.Markdown(state1.model_name, visible=True)) | |
def bothbad_vote_last_response_igm( | |
state0, state1, model_selector0, model_selector1, request: gr.Request | |
): | |
igm_logger.info(f"bothbad_vote (named). ip: {get_ip(request)}") | |
vote_last_response_igm( | |
[state0, state1], "bothbad_vote", [model_selector0, model_selector1], request | |
) | |
if model_selector0 == "": | |
return ("",) + (disable_btn,) * 4 + ( | |
gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), | |
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) | |
else: | |
return ("",) + (disable_btn,) * 4 + (gr.Markdown(state0.model_name, visible=True), | |
gr.Markdown(state1.model_name, visible=True)) | |
## Image Editing (IE) Single Model Direct Chat | |
def upvote_last_response_ie(state, model_selector, request: gr.Request): | |
ip = get_ip(request) | |
ie_logger.info(f"upvote. ip: {ip}") | |
vote_last_response_ie(state, "upvote", model_selector, request) | |
return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3 | |
def downvote_last_response_ie(state, model_selector, request: gr.Request): | |
ip = get_ip(request) | |
ie_logger.info(f"downvote. ip: {ip}") | |
vote_last_response_ie(state, "downvote", model_selector, request) | |
return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3 | |
def flag_last_response_ie(state, model_selector, request: gr.Request): | |
ip = get_ip(request) | |
ie_logger.info(f"flag. ip: {ip}") | |
vote_last_response_ie(state, "flag", model_selector, request) | |
return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3 | |
## Image Editing Multi (IEM) Side-by-Side and Battle | |
def leftvote_last_response_iem( | |
state0, state1, model_selector0, model_selector1, request: gr.Request | |
): | |
iem_logger.info(f"leftvote (anony). ip: {get_ip(request)}") | |
vote_last_response_iem( | |
[state0, state1], "leftvote", [model_selector0, model_selector1], request | |
) | |
# names = ( | |
# "### Model A: " + state0.model_name, | |
# "### Model B: " + state1.model_name, | |
# ) | |
# names = (state0.model_name, state1.model_name) | |
if model_selector0 == "": | |
names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) | |
else: | |
names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False)) | |
return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4 | |
def rightvote_last_response_iem( | |
state0, state1, model_selector0, model_selector1, request: gr.Request | |
): | |
iem_logger.info(f"rightvote (anony). ip: {get_ip(request)}") | |
vote_last_response_iem( | |
[state0, state1], "rightvote", [model_selector0, model_selector1], request | |
) | |
# names = ( | |
# "### Model A: " + state0.model_name, | |
# "### Model B: " + state1.model_name, | |
# ) | |
if model_selector0 == "": | |
names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), | |
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) | |
else: | |
names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False)) | |
return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4 | |
def tievote_last_response_iem( | |
state0, state1, model_selector0, model_selector1, request: gr.Request | |
): | |
iem_logger.info(f"tievote (anony). ip: {get_ip(request)}") | |
vote_last_response_iem( | |
[state0, state1], "tievote", [model_selector0, model_selector1], request | |
) | |
if model_selector0 == "": | |
names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), | |
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) | |
else: | |
names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False)) | |
return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4 | |
def bothbad_vote_last_response_iem( | |
state0, state1, model_selector0, model_selector1, request: gr.Request | |
): | |
iem_logger.info(f"bothbad_vote (anony). ip: {get_ip(request)}") | |
vote_last_response_iem( | |
[state0, state1], "bothbad_vote", [model_selector0, model_selector1], request | |
) | |
if model_selector0 == "": | |
names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), | |
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) | |
else: | |
names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False)) | |
return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4 | |
## Video Generation (VG) Single Model Direct Chat | |
def upvote_last_response_vg(state, model_selector, request: gr.Request): | |
ip = get_ip(request) | |
vg_logger.info(f"upvote. ip: {ip}") | |
vote_last_response_vg(state, "upvote", model_selector, request) | |
return ("",) + (disable_btn,) * 3 | |
def downvote_last_response_vg(state, model_selector, request: gr.Request): | |
ip = get_ip(request) | |
vg_logger.info(f"downvote. ip: {ip}") | |
vote_last_response_vg(state, "downvote", model_selector, request) | |
return ("",) + (disable_btn,) * 3 | |
def flag_last_response_vg(state, model_selector, request: gr.Request): | |
ip = get_ip(request) | |
vg_logger.info(f"flag. ip: {ip}") | |
vote_last_response_vg(state, "flag", model_selector, request) | |
return ("",) + (disable_btn,) * 3 | |
## Image Generation Multi (IGM) Side-by-Side and Battle | |
def leftvote_last_response_vgm( | |
state0, state1, model_selector0, model_selector1, request: gr.Request | |
): | |
vgm_logger.info(f"leftvote (named). ip: {get_ip(request)}") | |
vote_last_response_vgm( | |
[state0, state1], "leftvote", [model_selector0, model_selector1], request | |
) | |
if model_selector0 == "": | |
return ("",) + (disable_btn,) * 4 + (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) | |
else: | |
return ("",) + (disable_btn,) * 4 + ( | |
gr.Markdown(state0.model_name, visible=False), | |
gr.Markdown(state1.model_name, visible=False)) | |
def rightvote_last_response_vgm( | |
state0, state1, model_selector0, model_selector1, request: gr.Request | |
): | |
vgm_logger.info(f"rightvote (named). ip: {get_ip(request)}") | |
vote_last_response_vgm( | |
[state0, state1], "rightvote", [model_selector0, model_selector1], request | |
) | |
if model_selector0 == "": | |
return ("",) + (disable_btn,) * 4 + ( | |
gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), | |
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) | |
else: | |
return ("",) + (disable_btn,) * 4 + ( | |
gr.Markdown(state0.model_name, visible=False), | |
gr.Markdown(state1.model_name, visible=False)) | |
def tievote_last_response_vgm( | |
state0, state1, model_selector0, model_selector1, request: gr.Request | |
): | |
vgm_logger.info(f"tievote (named). ip: {get_ip(request)}") | |
vote_last_response_vgm( | |
[state0, state1], "tievote", [model_selector0, model_selector1], request | |
) | |
if model_selector0 == "": | |
return ("",) + (disable_btn,) * 4 + ( | |
gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), | |
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) | |
else: | |
return ("",) + (disable_btn,) * 4 + ( | |
gr.Markdown(state0.model_name, visible=False), | |
gr.Markdown(state1.model_name, visible=False)) | |
def bothbad_vote_last_response_vgm( | |
state0, state1, model_selector0, model_selector1, request: gr.Request | |
): | |
vgm_logger.info(f"bothbad_vote (named). ip: {get_ip(request)}") | |
vote_last_response_vgm( | |
[state0, state1], "bothbad_vote", [model_selector0, model_selector1], request | |
) | |
if model_selector0 == "": | |
return ("",) + (disable_btn,) * 4 + ( | |
gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), | |
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) | |
else: | |
return ("",) + (disable_btn,) * 4 + ( | |
gr.Markdown(state0.model_name, visible=False), | |
gr.Markdown(state1.model_name, visible=False)) | |
share_js = """ | |
function (a, b, c, d) { | |
const captureElement = document.querySelector('#share-region-named'); | |
html2canvas(captureElement) | |
.then(canvas => { | |
canvas.style.display = 'none' | |
document.body.appendChild(canvas) | |
return canvas | |
}) | |
.then(canvas => { | |
const image = canvas.toDataURL('image/png') | |
const a = document.createElement('a') | |
a.setAttribute('download', 'chatbot-arena.png') | |
a.setAttribute('href', image) | |
a.click() | |
canvas.remove() | |
}); | |
return [a, b, c, d]; | |
} | |
""" | |
def share_click_igm(state0, state1, model_selector0, model_selector1, request: gr.Request): | |
igm_logger.info(f"share (anony). ip: {get_ip(request)}") | |
if state0 is not None and state1 is not None: | |
vote_last_response_igm( | |
[state0, state1], "share", [model_selector0, model_selector1], request | |
) | |
def share_click_iem(state0, state1, model_selector0, model_selector1, request: gr.Request): | |
iem_logger.info(f"share (anony). ip: {get_ip(request)}") | |
if state0 is not None and state1 is not None: | |
vote_last_response_iem( | |
[state0, state1], "share", [model_selector0, model_selector1], request | |
) | |
## All Generation Gradio Interface | |
class ImageStateIG: | |
def __init__(self, model_name): | |
self.conv_id = uuid.uuid4().hex | |
self.model_name = model_name | |
self.prompt = None | |
self.output = None | |
def dict(self): | |
base = { | |
"conv_id": self.conv_id, | |
"model_name": self.model_name, | |
"prompt": self.prompt | |
} | |
return base | |
class ImageStateIE: | |
def __init__(self, model_name): | |
self.conv_id = uuid.uuid4().hex | |
self.model_name = model_name | |
self.source_prompt = None | |
self.target_prompt = None | |
self.instruct_prompt = None | |
self.source_image = None | |
self.output = None | |
def dict(self): | |
base = { | |
"conv_id": self.conv_id, | |
"model_name": self.model_name, | |
"source_prompt": self.source_prompt, | |
"target_prompt": self.target_prompt, | |
"instruct_prompt": self.instruct_prompt | |
} | |
return base | |
class VideoStateVG: | |
def __init__(self, model_name): | |
self.conv_id = uuid.uuid4().hex | |
self.model_name = model_name | |
self.prompt = None | |
self.output = None | |
def dict(self): | |
base = { | |
"conv_id": self.conv_id, | |
"model_name": self.model_name, | |
"prompt": self.prompt | |
} | |
return base | |
def generate_ig(gen_func, state, text, model_name, request: gr.Request): | |
if not text: | |
raise gr.Warning("Prompt cannot be empty.") | |
if not model_name: | |
raise gr.Warning("Model name cannot be empty.") | |
state = ImageStateIG(model_name) | |
ip = get_ip(request) | |
ig_logger.info(f"generate. ip: {ip}") | |
start_tstamp = time.time() | |
generated_image = gen_func(text, model_name) | |
state.prompt = text | |
state.output = generated_image | |
state.model_name = model_name | |
if generated_image == '': | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"type": "chat", | |
"model": model_name, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"state": state.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!") | |
yield state, generated_image | |
finish_tstamp = time.time() | |
# logger.info(f"===output===: {output}") | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg' | |
os.makedirs(os.path.dirname(output_file), exist_ok=True) | |
with open(output_file, 'w') as f: | |
save_any_image(state.output, f) | |
save_image_file_on_log_server(output_file) | |
def generate_ig_museum(gen_func, state, model_name, request: gr.Request): | |
if not model_name: | |
raise gr.Warning("Model name cannot be empty.") | |
state = ImageStateIG(model_name) | |
ip = get_ip(request) | |
ig_logger.info(f"generate. ip: {ip}") | |
start_tstamp = time.time() | |
generated_image, text = gen_func(model_name) | |
state.prompt = text | |
state.output = generated_image | |
state.model_name = model_name | |
yield state, generated_image, text | |
finish_tstamp = time.time() | |
# logger.info(f"===output===: {output}") | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg' | |
os.makedirs(os.path.dirname(output_file), exist_ok=True) | |
with open(output_file, 'w') as f: | |
save_any_image(state.output, f) | |
save_image_file_on_log_server(output_file) | |
def generate_igm(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request): | |
if not text: | |
raise gr.Warning("Prompt cannot be empty.") | |
if not model_name0: | |
raise gr.Warning("Model name A cannot be empty.") | |
if not model_name1: | |
raise gr.Warning("Model name B cannot be empty.") | |
state0 = ImageStateIG(model_name0) | |
state1 = ImageStateIG(model_name1) | |
ip = get_ip(request) | |
igm_logger.info(f"generate. ip: {ip}") | |
start_tstamp = time.time() | |
# Remove ### Model (A|B): from model name | |
model_name0 = re.sub(r"### Model A: ", "", model_name0) | |
model_name1 = re.sub(r"### Model B: ", "", model_name1) | |
generated_image0, generated_image1 = gen_func(text, model_name0, model_name1) | |
state0.prompt = text | |
state1.prompt = text | |
state0.output = generated_image0 | |
state1.output = generated_image1 | |
state0.model_name = model_name0 | |
state1.model_name = model_name1 | |
if generated_image0 == '' and generated_image1 == '': | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"type": "chat", | |
"model": model_name0, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"state": state0.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
data = { | |
"type": "chat", | |
"model": model_name1, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"state": state1.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!") | |
yield state0, state1, generated_image0, generated_image1 | |
finish_tstamp = time.time() | |
# logger.info(f"===output===: {output}") | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name0, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state0.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name1, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state1.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
for i, state in enumerate([state0, state1]): | |
output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg' | |
os.makedirs(os.path.dirname(output_file), exist_ok=True) | |
with open(output_file, 'w') as f: | |
save_any_image(state.output, f) | |
save_image_file_on_log_server(output_file) | |
def generate_igm_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request): | |
if not model_name0: | |
raise gr.Warning("Model name A cannot be empty.") | |
if not model_name1: | |
raise gr.Warning("Model name B cannot be empty.") | |
state0 = ImageStateIG(model_name0) | |
state1 = ImageStateIG(model_name1) | |
ip = get_ip(request) | |
igm_logger.info(f"generate. ip: {ip}") | |
start_tstamp = time.time() | |
# Remove ### Model (A|B): from model name | |
model_name0 = re.sub(r"### Model A: ", "", model_name0) | |
model_name1 = re.sub(r"### Model B: ", "", model_name1) | |
generated_image0, generated_image1, text = gen_func(model_name0, model_name1) | |
state0.prompt = text | |
state1.prompt = text | |
state0.output = generated_image0 | |
state1.output = generated_image1 | |
state0.model_name = model_name0 | |
state1.model_name = model_name1 | |
yield state0, state1, generated_image0, generated_image1, text | |
finish_tstamp = time.time() | |
# logger.info(f"===output===: {output}") | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name0, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state0.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name1, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state1.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
for i, state in enumerate([state0, state1]): | |
output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg' | |
os.makedirs(os.path.dirname(output_file), exist_ok=True) | |
with open(output_file, 'w') as f: | |
save_any_image(state.output, f) | |
save_image_file_on_log_server(output_file) | |
def generate_igm_annoy(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request): | |
if not text: | |
raise gr.Warning("Prompt cannot be empty.") | |
state0 = ImageStateIG(model_name0) | |
state1 = ImageStateIG(model_name1) | |
ip = get_ip(request) | |
igm_logger.info(f"generate. ip: {ip}") | |
start_tstamp = time.time() | |
model_name0 = "" | |
model_name1 = "" | |
generated_image0, generated_image1, model_name0, model_name1 = gen_func(text, model_name0, model_name1) | |
state0.prompt = text | |
state1.prompt = text | |
state0.output = generated_image0 | |
state1.output = generated_image1 | |
state0.model_name = model_name0 | |
state1.model_name = model_name1 | |
if generated_image0 == '' and generated_image1 == '': | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"type": "chat", | |
"model": model_name0, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"state": state0.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
data = { | |
"type": "chat", | |
"model": model_name1, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"state": state1.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!") | |
yield state0, state1, generated_image0, generated_image1, \ | |
gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False) | |
finish_tstamp = time.time() | |
# logger.info(f"===output===: {output}") | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name0, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state0.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name1, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state1.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
for i, state in enumerate([state0, state1]): | |
output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg' | |
os.makedirs(os.path.dirname(output_file), exist_ok=True) | |
with open(output_file, 'w') as f: | |
save_any_image(state.output, f) | |
save_image_file_on_log_server(output_file) | |
def generate_igm_annoy_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request): | |
state0 = ImageStateIG(model_name0) | |
state1 = ImageStateIG(model_name1) | |
ip = get_ip(request) | |
igm_logger.info(f"generate. ip: {ip}") | |
start_tstamp = time.time() | |
# model_name0 = re.sub(r"### Model A: ", "", model_name0) | |
# model_name1 = re.sub(r"### Model B: ", "", model_name1) | |
model_name0 = "" | |
model_name1 = "" | |
generated_image0, generated_image1, model_name0, model_name1, text = gen_func(model_name0, model_name1) | |
state0.prompt = text | |
state1.prompt = text | |
state0.output = generated_image0 | |
state1.output = generated_image1 | |
state0.model_name = model_name0 | |
state1.model_name = model_name1 | |
yield state0, state1, generated_image0, generated_image1, text,\ | |
gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False) | |
finish_tstamp = time.time() | |
# logger.info(f"===output===: {output}") | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name0, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state0.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name1, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state1.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
for i, state in enumerate([state0, state1]): | |
output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg' | |
os.makedirs(os.path.dirname(output_file), exist_ok=True) | |
with open(output_file, 'w') as f: | |
save_any_image(state.output, f) | |
save_image_file_on_log_server(output_file) | |
def generate_ie(gen_func, state, source_text, target_text, instruct_text, source_image, model_name, request: gr.Request): | |
if not source_text: | |
raise gr.Warning("Source prompt cannot be empty.") | |
if not target_text: | |
raise gr.Warning("Target prompt cannot be empty.") | |
if not instruct_text: | |
raise gr.Warning("Instruction prompt cannot be empty.") | |
if not source_image: | |
raise gr.Warning("Source image cannot be empty.") | |
if not model_name: | |
raise gr.Warning("Model name cannot be empty.") | |
state = ImageStateIE(model_name) | |
ip = get_ip(request) | |
ig_logger.info(f"generate. ip: {ip}") | |
start_tstamp = time.time() | |
generated_image = gen_func(source_text, target_text, instruct_text, source_image, model_name) | |
state.source_prompt = source_text | |
state.target_prompt = target_text | |
state.instruct_prompt = instruct_text | |
state.source_image = source_image | |
state.output = generated_image | |
state.model_name = model_name | |
if generated_image == '': | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"type": "chat", | |
"model": model_name, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"state": state.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!") | |
yield state, generated_image | |
finish_tstamp = time.time() | |
# logger.info(f"===output===: {output}") | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg' | |
os.makedirs(os.path.dirname(src_img_file), exist_ok=True) | |
with open(src_img_file, 'w') as f: | |
save_any_image(state.source_image, f) | |
output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg' | |
with open(output_file, 'w') as f: | |
save_any_image(state.output, f) | |
save_image_file_on_log_server(src_img_file) | |
save_image_file_on_log_server(output_file) | |
def generate_ie_museum(gen_func, state, model_name, request: gr.Request): | |
if not model_name: | |
raise gr.Warning("Model name cannot be empty.") | |
state = ImageStateIE(model_name) | |
ip = get_ip(request) | |
ig_logger.info(f"generate. ip: {ip}") | |
start_tstamp = time.time() | |
source_image, generated_image, source_text, target_text, instruct_text = gen_func(model_name) | |
state.source_prompt = source_text | |
state.target_prompt = target_text | |
state.instruct_prompt = instruct_text | |
state.source_image = source_image | |
state.output = generated_image | |
state.model_name = model_name | |
yield state, generated_image, source_image, source_text, target_text, instruct_text | |
finish_tstamp = time.time() | |
# logger.info(f"===output===: {output}") | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg' | |
os.makedirs(os.path.dirname(src_img_file), exist_ok=True) | |
with open(src_img_file, 'w') as f: | |
save_any_image(state.source_image, f) | |
output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg' | |
with open(output_file, 'w') as f: | |
save_any_image(state.output, f) | |
save_image_file_on_log_server(src_img_file) | |
save_image_file_on_log_server(output_file) | |
def generate_iem(gen_func, state0, state1, source_text, target_text, instruct_text, source_image, model_name0, model_name1, request: gr.Request): | |
if not source_text: | |
raise gr.Warning("Source prompt cannot be empty.") | |
if not target_text: | |
raise gr.Warning("Target prompt cannot be empty.") | |
if not instruct_text: | |
raise gr.Warning("Instruction prompt cannot be empty.") | |
if not source_image: | |
raise gr.Warning("Source image cannot be empty.") | |
if not model_name0: | |
raise gr.Warning("Model name A cannot be empty.") | |
if not model_name1: | |
raise gr.Warning("Model name B cannot be empty.") | |
state0 = ImageStateIE(model_name0) | |
state1 = ImageStateIE(model_name1) | |
ip = get_ip(request) | |
igm_logger.info(f"generate. ip: {ip}") | |
start_tstamp = time.time() | |
model_name0 = re.sub(r"### Model A: ", "", model_name0) | |
model_name1 = re.sub(r"### Model B: ", "", model_name1) | |
generated_image0, generated_image1 = gen_func(source_text, target_text, instruct_text, source_image, model_name0, model_name1) | |
state0.source_prompt = source_text | |
state0.target_prompt = target_text | |
state0.instruct_prompt = instruct_text | |
state0.source_image = source_image | |
state0.output = generated_image0 | |
state0.model_name = model_name0 | |
state1.source_prompt = source_text | |
state1.target_prompt = target_text | |
state1.instruct_prompt = instruct_text | |
state1.source_image = source_image | |
state1.output = generated_image1 | |
state1.model_name = model_name1 | |
if generated_image0 == '' and generated_image1 == '': | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"type": "chat", | |
"model": model_name0, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"state": state0.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
data = { | |
"type": "chat", | |
"model": model_name1, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"state": state1.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!") | |
yield state0, state1, generated_image0, generated_image1 | |
finish_tstamp = time.time() | |
# logger.info(f"===output===: {output}") | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name0, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state0.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name1, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state1.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
for i, state in enumerate([state0, state1]): | |
src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg' | |
os.makedirs(os.path.dirname(src_img_file), exist_ok=True) | |
with open(src_img_file, 'w') as f: | |
save_any_image(state.source_image, f) | |
output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg' | |
with open(output_file, 'w') as f: | |
save_any_image(state.output, f) | |
save_image_file_on_log_server(src_img_file) | |
save_image_file_on_log_server(output_file) | |
def generate_iem_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request): | |
if not model_name0: | |
raise gr.Warning("Model name A cannot be empty.") | |
if not model_name1: | |
raise gr.Warning("Model name B cannot be empty.") | |
state0 = ImageStateIE(model_name0) | |
state1 = ImageStateIE(model_name1) | |
ip = get_ip(request) | |
igm_logger.info(f"generate. ip: {ip}") | |
start_tstamp = time.time() | |
model_name0 = re.sub(r"### Model A: ", "", model_name0) | |
model_name1 = re.sub(r"### Model B: ", "", model_name1) | |
source_image, generated_image0, generated_image1, source_text, target_text, instruct_text = gen_func(model_name0, model_name1) | |
state0.source_prompt = source_text | |
state0.target_prompt = target_text | |
state0.instruct_prompt = instruct_text | |
state0.source_image = source_image | |
state0.output = generated_image0 | |
state0.model_name = model_name0 | |
state1.source_prompt = source_text | |
state1.target_prompt = target_text | |
state1.instruct_prompt = instruct_text | |
state1.source_image = source_image | |
state1.output = generated_image1 | |
state1.model_name = model_name1 | |
yield state0, state1, generated_image0, generated_image1, source_image, source_text, target_text, instruct_text | |
finish_tstamp = time.time() | |
# logger.info(f"===output===: {output}") | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name0, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state0.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name1, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state1.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
for i, state in enumerate([state0, state1]): | |
src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg' | |
os.makedirs(os.path.dirname(src_img_file), exist_ok=True) | |
with open(src_img_file, 'w') as f: | |
save_any_image(state.source_image, f) | |
output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg' | |
with open(output_file, 'w') as f: | |
save_any_image(state.output, f) | |
save_image_file_on_log_server(src_img_file) | |
save_image_file_on_log_server(output_file) | |
def generate_iem_annoy(gen_func, state0, state1, source_text, target_text, instruct_text, source_image, model_name0, model_name1, request: gr.Request): | |
if not source_text: | |
raise gr.Warning("Source prompt cannot be empty.") | |
if not target_text: | |
raise gr.Warning("Target prompt cannot be empty.") | |
if not instruct_text: | |
raise gr.Warning("Instruction prompt cannot be empty.") | |
if not source_image: | |
raise gr.Warning("Source image cannot be empty.") | |
state0 = ImageStateIE(model_name0) | |
state1 = ImageStateIE(model_name1) | |
ip = get_ip(request) | |
igm_logger.info(f"generate. ip: {ip}") | |
start_tstamp = time.time() | |
model_name0 = "" | |
model_name1 = "" | |
generated_image0, generated_image1, model_name0, model_name1 = gen_func(source_text, target_text, instruct_text, source_image, model_name0, model_name1) | |
state0.source_prompt = source_text | |
state0.target_prompt = target_text | |
state0.instruct_prompt = instruct_text | |
state0.source_image = source_image | |
state0.output = generated_image0 | |
state0.model_name = model_name0 | |
state1.source_prompt = source_text | |
state1.target_prompt = target_text | |
state1.instruct_prompt = instruct_text | |
state1.source_image = source_image | |
state1.output = generated_image1 | |
state1.model_name = model_name1 | |
if generated_image0 == '' and generated_image1 == '': | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"type": "chat", | |
"model": model_name0, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"state": state0.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
data = { | |
"type": "chat", | |
"model": model_name1, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"state": state1.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!") | |
yield state0, state1, generated_image0, generated_image1, \ | |
gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False) | |
finish_tstamp = time.time() | |
# logger.info(f"===output===: {output}") | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name0, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state0.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name1, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state1.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
for i, state in enumerate([state0, state1]): | |
src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg' | |
os.makedirs(os.path.dirname(src_img_file), exist_ok=True) | |
with open(src_img_file, 'w') as f: | |
save_any_image(state.source_image, f) | |
output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg' | |
with open(output_file, 'w') as f: | |
save_any_image(state.output, f) | |
save_image_file_on_log_server(src_img_file) | |
save_image_file_on_log_server(output_file) | |
def generate_iem_annoy_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request): | |
state0 = ImageStateIE(model_name0) | |
state1 = ImageStateIE(model_name1) | |
ip = get_ip(request) | |
igm_logger.info(f"generate. ip: {ip}") | |
start_tstamp = time.time() | |
model_name0 = "" | |
model_name1 = "" | |
source_image, generated_image0, generated_image1, source_text, target_text, instruct_text, model_name0, model_name1 = gen_func(model_name0, model_name1) | |
state0.source_prompt = source_text | |
state0.target_prompt = target_text | |
state0.instruct_prompt = instruct_text | |
state0.source_image = source_image | |
state0.output = generated_image0 | |
state0.model_name = model_name0 | |
state1.source_prompt = source_text | |
state1.target_prompt = target_text | |
state1.instruct_prompt = instruct_text | |
state1.source_image = source_image | |
state1.output = generated_image1 | |
state1.model_name = model_name1 | |
yield state0, state1, generated_image0, generated_image1, source_image, source_text, target_text, instruct_text, \ | |
gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False) | |
finish_tstamp = time.time() | |
# logger.info(f"===output===: {output}") | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name0, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state0.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name1, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state1.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
for i, state in enumerate([state0, state1]): | |
src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg' | |
os.makedirs(os.path.dirname(src_img_file), exist_ok=True) | |
with open(src_img_file, 'w') as f: | |
save_any_image(state.source_image, f) | |
output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg' | |
with open(output_file, 'w') as f: | |
save_any_image(state.output, f) | |
save_image_file_on_log_server(src_img_file) | |
save_image_file_on_log_server(output_file) | |
def generate_vg(gen_func, state, text, model_name, request: gr.Request): | |
if not text: | |
raise gr.Warning("Prompt cannot be empty.") | |
if not model_name: | |
raise gr.Warning("Model name cannot be empty.") | |
state = VideoStateVG(model_name) | |
ip = get_ip(request) | |
vg_logger.info(f"generate. ip: {ip}") | |
start_tstamp = time.time() | |
generated_video = gen_func(text, model_name) | |
state.prompt = text | |
state.output = generated_video | |
state.model_name = model_name | |
if generated_video == '': | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"type": "chat", | |
"model": model_name, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"state": state.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!") | |
# yield state, generated_video | |
finish_tstamp = time.time() | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4' | |
os.makedirs(os.path.dirname(output_file), exist_ok=True) | |
if model_name.startswith('fal'): | |
r = requests.get(state.output) | |
with open(output_file, 'wb') as outfile: | |
outfile.write(r.content) | |
else: | |
print("======== video shape: ========") | |
print(state.output.shape) | |
# Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels] | |
if state.output.shape[-1] != 3: | |
state.output = state.output.permute(0, 2, 3, 1) | |
imageio.mimwrite(output_file, state.output, fps=8, quality=9) | |
save_video_file_on_log_server(output_file) | |
yield state, output_file | |
def generate_vg_museum(gen_func, state, model_name, request: gr.Request): | |
state = VideoStateVG(model_name) | |
ip = get_ip(request) | |
vg_logger.info(f"generate. ip: {ip}") | |
start_tstamp = time.time() | |
generated_video, text = gen_func(model_name) | |
state.prompt = text | |
state.output = generated_video | |
state.model_name = model_name | |
# yield state, generated_video | |
finish_tstamp = time.time() | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4' | |
os.makedirs(os.path.dirname(output_file), exist_ok=True) | |
r = requests.get(state.output) | |
with open(output_file, 'wb') as outfile: | |
outfile.write(r.content) | |
save_video_file_on_log_server(output_file) | |
yield state, output_file, text | |
def generate_vgm(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request): | |
if not text: | |
raise gr.Warning("Prompt cannot be empty.") | |
if not model_name0: | |
raise gr.Warning("Model name A cannot be empty.") | |
if not model_name1: | |
raise gr.Warning("Model name B cannot be empty.") | |
state0 = VideoStateVG(model_name0) | |
state1 = VideoStateVG(model_name1) | |
ip = get_ip(request) | |
igm_logger.info(f"generate. ip: {ip}") | |
start_tstamp = time.time() | |
# Remove ### Model (A|B): from model name | |
model_name0 = re.sub(r"### Model A: ", "", model_name0) | |
model_name1 = re.sub(r"### Model B: ", "", model_name1) | |
generated_video0, generated_video1 = gen_func(text, model_name0, model_name1) | |
state0.prompt = text | |
state1.prompt = text | |
state0.output = generated_video0 | |
state1.output = generated_video1 | |
state0.model_name = model_name0 | |
state1.model_name = model_name1 | |
if generated_video0 == '' and generated_video1 == '': | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"type": "chat", | |
"model": model_name0, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"state": state0.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
data = { | |
"type": "chat", | |
"model": model_name1, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"state": state1.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!") | |
# yield state0, state1, generated_video0, generated_video1 | |
print("====== model name =========") | |
print(state0.model_name) | |
print(state1.model_name) | |
finish_tstamp = time.time() | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name0, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state0.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name1, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state1.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
for i, state in enumerate([state0, state1]): | |
output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4' | |
os.makedirs(os.path.dirname(output_file), exist_ok=True) | |
print(state.model_name) | |
if state.model_name.startswith('fal'): | |
r = requests.get(state.output) | |
with open(output_file, 'wb') as outfile: | |
outfile.write(r.content) | |
else: | |
print("======== video shape: ========") | |
print(state.output) | |
print(state.output.shape) | |
# Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels] | |
if state.output.shape[-1] != 3: | |
state.output = state.output.permute(0, 2, 3, 1) | |
imageio.mimwrite(output_file, state.output, fps=8, quality=9) | |
save_video_file_on_log_server(output_file) | |
yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4' | |
def generate_vgm_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request): | |
state0 = VideoStateVG(model_name0) | |
state1 = VideoStateVG(model_name1) | |
ip = get_ip(request) | |
igm_logger.info(f"generate. ip: {ip}") | |
start_tstamp = time.time() | |
# Remove ### Model (A|B): from model name | |
model_name0 = re.sub(r"### Model A: ", "", model_name0) | |
model_name1 = re.sub(r"### Model B: ", "", model_name1) | |
generated_video0, generated_video1, text = gen_func(model_name0, model_name1) | |
state0.prompt = text | |
state1.prompt = text | |
state0.output = generated_video0 | |
state1.output = generated_video1 | |
state0.model_name = model_name0 | |
state1.model_name = model_name1 | |
# yield state0, state1, generated_video0, generated_video1 | |
print("====== model name =========") | |
print(state0.model_name) | |
print(state1.model_name) | |
finish_tstamp = time.time() | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name0, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state0.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name1, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state1.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
for i, state in enumerate([state0, state1]): | |
output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4' | |
os.makedirs(os.path.dirname(output_file), exist_ok=True) | |
print(state.model_name) | |
r = requests.get(state.output) | |
with open(output_file, 'wb') as outfile: | |
outfile.write(r.content) | |
save_video_file_on_log_server(output_file) | |
yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4', text | |
def generate_vgm_annoy(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request): | |
if not text: | |
raise gr.Warning("Prompt cannot be empty.") | |
state0 = VideoStateVG(model_name0) | |
state1 = VideoStateVG(model_name1) | |
ip = get_ip(request) | |
vgm_logger.info(f"generate. ip: {ip}") | |
start_tstamp = time.time() | |
model_name0 = "" | |
model_name1 = "" | |
generated_video0, generated_video1, model_name0, model_name1 = gen_func(text, model_name0, model_name1) | |
state0.prompt = text | |
state1.prompt = text | |
state0.output = generated_video0 | |
state1.output = generated_video1 | |
state0.model_name = model_name0 | |
state1.model_name = model_name1 | |
if generated_video0 == '' and generated_video1 == '': | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"type": "chat", | |
"model": model_name0, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"state": state0.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
data = { | |
"type": "chat", | |
"model": model_name1, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"state": state1.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!") | |
# yield state0, state1, generated_video0, generated_video1, \ | |
# gr.Markdown(f"### Model A: {model_name0}"), gr.Markdown(f"### Model B: {model_name1}") | |
finish_tstamp = time.time() | |
# logger.info(f"===output===: {output}") | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name0, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state0.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name1, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state1.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
for i, state in enumerate([state0, state1]): | |
output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4' | |
os.makedirs(os.path.dirname(output_file), exist_ok=True) | |
if state.model_name.startswith('fal'): | |
r = requests.get(state.output) | |
with open(output_file, 'wb') as outfile: | |
outfile.write(r.content) | |
else: | |
print("======== video shape: ========") | |
print(state.output.shape) | |
# Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels] | |
if state.output.shape[-1] != 3: | |
state.output = state.output.permute(0, 2, 3, 1) | |
imageio.mimwrite(output_file, state.output, fps=8, quality=9) | |
save_video_file_on_log_server(output_file) | |
yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4', \ | |
gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False) | |
def generate_vgm_annoy_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request): | |
state0 = VideoStateVG(model_name0) | |
state1 = VideoStateVG(model_name1) | |
ip = get_ip(request) | |
vgm_logger.info(f"generate. ip: {ip}") | |
start_tstamp = time.time() | |
model_name0 = "" | |
model_name1 = "" | |
generated_video0, generated_video1, model_name0, model_name1, text = gen_func(model_name0, model_name1) | |
state0.prompt = text | |
state1.prompt = text | |
state0.output = generated_video0 | |
state1.output = generated_video1 | |
state0.model_name = model_name0 | |
state1.model_name = model_name1 | |
# yield state0, state1, generated_video0, generated_video1, \ | |
# gr.Markdown(f"### Model A: {model_name0}"), gr.Markdown(f"### Model B: {model_name1}") | |
finish_tstamp = time.time() | |
# logger.info(f"===output===: {output}") | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name0, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state0.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name1, | |
"gen_params": {}, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state1.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
for i, state in enumerate([state0, state1]): | |
output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4' | |
os.makedirs(os.path.dirname(output_file), exist_ok=True) | |
r = requests.get(state.output) | |
with open(output_file, 'wb') as outfile: | |
outfile.write(r.content) | |
save_video_file_on_log_server(output_file) | |
yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4', text,\ | |
gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False) |