K-Sort-Arena / serve /Ksort.py
ksort's picture
Update API
b177a48
raw
history blame
13.1 kB
import gradio as gr
from PIL import Image, ImageDraw, ImageFont, ImageOps
import os
from .constants import KSORT_IMAGE_DIR
from .constants import COLOR1, COLOR2, COLOR3, COLOR4
from .vote_utils import save_any_image
from .utils import disable_btn, enable_btn, invisible_btn
from .upload import create_remote_directory, upload_image, upload_informance, upload_ssh_all
import json
def reset_level(Top_btn):
if Top_btn == "Top 1":
level = 0
elif Top_btn == "Top 2":
level = 1
elif Top_btn == "Top 3":
level = 2
elif Top_btn == "Top 4":
level = 3
return level
def reset_rank(windows, rank, vote_level):
if windows == "Model A":
rank[0] = vote_level
elif windows == "Model B":
rank[1] = vote_level
elif windows == "Model C":
rank[2] = vote_level
elif windows == "Model D":
rank[3] = vote_level
return rank
def reset_btn_rank(windows, rank, btn, vote_level):
if windows == "Model A" and btn == "1":
rank[0] = 0
elif windows == "Model A" and btn == "2":
rank[0] = 1
elif windows == "Model A" and btn == "3":
rank[0] = 2
elif windows == "Model A" and btn == "4":
rank[0] = 3
elif windows == "Model B" and btn == "1":
rank[1] = 0
elif windows == "Model B" and btn == "2":
rank[1] = 1
elif windows == "Model B" and btn == "3":
rank[1] = 2
elif windows == "Model B" and btn == "4":
rank[1] = 3
elif windows == "Model C" and btn == "1":
rank[2] = 0
elif windows == "Model C" and btn == "2":
rank[2] = 1
elif windows == "Model C" and btn == "3":
rank[2] = 2
elif windows == "Model C" and btn == "4":
rank[2] = 3
elif windows == "Model D" and btn == "1":
rank[3] = 0
elif windows == "Model D" and btn == "2":
rank[3] = 1
elif windows == "Model D" and btn == "3":
rank[3] = 2
elif windows == "Model D" and btn == "4":
rank[3] = 3
if btn == "1":
vote_level = 0
elif btn == "2":
vote_level = 1
elif btn == "3":
vote_level = 2
elif btn == "4":
vote_level = 3
return (rank, vote_level)
def reset_vote_text(rank):
rank_str = ""
for i in range(len(rank)):
if rank[i] == None:
rank_str = rank_str + str(rank[i])
else:
rank_str = rank_str + str(rank[i]+1)
rank_str = rank_str + " "
return rank_str
def clear_rank(rank, vote_level):
for i in range(len(rank)):
rank[i] = None
vote_level = 0
return rank, vote_level
def revote_windows(generate_ig0, generate_ig1, generate_ig2, generate_ig3, rank, vote_level):
for i in range(len(rank)):
rank[i] = None
vote_level = 0
return generate_ig0, generate_ig1, generate_ig2, generate_ig3, rank, vote_level
def reset_submit(rank):
for i in range(len(rank)):
if rank[i] == None:
return disable_btn
return enable_btn
def reset_mode(mode):
if mode == "Best":
return (gr.update(visible=False, interactive=False),) * 5 + \
(gr.update(visible=True, interactive=True),) * 16 + \
(gr.update(visible=True, interactive=True),) * 3 + \
(gr.Textbox(value="Rank", visible=False, interactive=False),)
elif mode == "Rank":
return (gr.update(visible=True, interactive=True),) * 5 + \
(gr.update(visible=False, interactive=False),) * 16 + \
(gr.update(visible=True, interactive=False),) * 2 + \
(gr.update(visible=True, interactive=True),) + \
(gr.Textbox(value="Best", visible=False, interactive=False),)
else:
raise ValueError("Undefined mode")
def reset_chatbot(mode, generate_ig0, generate_ig1, generate_ig2, generate_ig3):
return generate_ig0, generate_ig1, generate_ig2, generate_ig3
def get_json_filename(conv_id):
output_dir = f'{KSORT_IMAGE_DIR}/{conv_id}/json/'
if not os.path.exists(output_dir):
os.makedirs(output_dir)
output_file = os.path.join(output_dir, "information.json")
# name = os.path.join(KSORT_IMAGE_DIR, f"{conv_id}/json/information.json")
print(output_file)
return output_file
def get_img_filename(conv_id, i):
output_dir = f'{KSORT_IMAGE_DIR}/{conv_id}/image/'
if not os.path.exists(output_dir):
os.makedirs(output_dir)
output_file = os.path.join(output_dir, f"{i}.jpg")
print(output_file)
return output_file
def vote_submit(states, rank, request: gr.Request):
conv_id = states[0].conv_id
for i in range(len(states)):
output_file = get_img_filename(conv_id, i)
save_any_image(states[i].output, output_file)
with open(get_json_filename(conv_id), "a") as fout:
data = {
"models_name": [x.model_name for x in states],
"img_rank": [x for x in rank],
}
fout.write(json.dumps(data) + "\n")
def vote_ssh_submit(states, rank):
conv_id = states[0].conv_id
output_dir = create_remote_directory(conv_id)
# upload_image(states, output_dir)
data = {
"models_name": [x.model_name for x in states],
"img_rank": [x for x in rank],
}
output_file = os.path.join(output_dir, "result.json")
# upload_informance(data, output_file)
upload_ssh_all(states, output_dir, data, output_file)
from .update_skill import update_skill
update_skill(rank, [x.model_name for x in states])
def submit_response_igm(
state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, rank, request: gr.Request
):
vote_submit([state0, state1, state2, state3], rank, request)
vote_ssh_submit([state0, state1, state2, state3], rank)
if model_selector0 == "":
return (disable_btn,) * 6 + (
gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model C: {state2.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model D: {state3.model_name.split('_')[1]}", visible=True)
) + (disable_btn,)
else:
return (disable_btn,) * 6 + (
gr.Markdown(state0.model_name, visible=True),
gr.Markdown(state1.model_name, visible=True),
gr.Markdown(state2.model_name, visible=True),
gr.Markdown(state3.model_name, visible=True)
) + (disable_btn,)
def submit_response_rank_igm(
state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, rank, right_vote_text, request: gr.Request
):
print(rank)
if right_vote_text == "right":
vote_submit([state0, state1, state2, state3], rank, request)
vote_ssh_submit([state0, state1, state2, state3], rank)
if model_selector0 == "":
return (disable_btn,) * 16 + (disable_btn,) * 3 + ("wrong",) + (
gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model C: {state2.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model D: {state3.model_name.split('_')[1]}", visible=True)
)
else:
return (disable_btn,) * 16 + (disable_btn,) * 3 + ("wrong",) + (
gr.Markdown(state0.model_name, visible=True),
gr.Markdown(state1.model_name, visible=True),
gr.Markdown(state2.model_name, visible=True),
gr.Markdown(state3.model_name, visible=True)
)
else:
return (enable_btn,) * 16 + (enable_btn,) * 3 + ("wrong",) + (gr.Markdown("", visible=False),) * 4
def text_response_rank_igm(generate_ig0, generate_ig1, generate_ig2, generate_ig3, Top1_text, Top2_text, Top3_text, Top4_text, vote_textbox):
rank_list = [char for char in vote_textbox if char.isdigit()]
generate_ig = [generate_ig0, generate_ig1, generate_ig2, generate_ig3]
chatbot = []
rank = [None, None, None, None]
if len(rank_list) != 4:
return generate_ig + ["error rank"] + ["wrong"] + [rank]
for num in range(len(rank_list)):
if rank_list[num] in ['1', '2', '3', '4']:
base_image = Image.fromarray(generate_ig[num]).convert("RGBA")
base_image = base_image.resize((512, 512), Image.ANTIALIAS)
if rank_list[num] == '1':
border_color = COLOR1
elif rank_list[num] == '2':
border_color = COLOR2
elif rank_list[num] == '3':
border_color = COLOR3
elif rank_list[num] == '4':
border_color = COLOR4
border_size = 10 # Size of the border
base_image = ImageOps.expand(base_image, border=border_size, fill=border_color)
draw = ImageDraw.Draw(base_image)
font = ImageFont.truetype("./serve/Arial.ttf", 66)
text_position = (180, 25)
if rank_list[num] == '1':
text_color = COLOR1
draw.text(text_position, Top1_text, font=font, fill=text_color)
elif rank_list[num] == '2':
text_color = COLOR2
draw.text(text_position, Top2_text, font=font, fill=text_color)
elif rank_list[num] == '3':
text_color = COLOR3
draw.text(text_position, Top3_text, font=font, fill=text_color)
elif rank_list[num] == '4':
text_color = COLOR4
draw.text(text_position, Top4_text, font=font, fill=text_color)
base_image = base_image.convert("RGB")
chatbot.append(base_image.copy())
else:
return generate_ig + ["error rank"] + ["wrong"] + [rank]
rank_str = ""
for str_num in rank_list:
rank_str = rank_str + str_num
rank_str = rank_str + " "
rank = [int(x) for x in rank_list]
return chatbot + [rank_str] + ["right"] + [rank]
# def add_foreground(image, vote_level, Top1_text, Top2_text, Top3_text, Top4_text):
# base_image = Image.fromarray(image).convert("RGBA")
# if vote_level == 0:
# txt_layer = Image.new('RGBA', base_image.size, (0, 255, 0, 64))
# elif vote_level == 1:
# txt_layer = Image.new('RGBA', base_image.size, (0, 255, 255, 64))
# elif vote_level == 2:
# txt_layer = Image.new('RGBA', base_image.size, (255, 0, 255, 64))
# elif vote_level == 3:
# txt_layer = Image.new('RGBA', base_image.size, (255, 0, 0, 64))
# draw = ImageDraw.Draw(txt_layer)
# font = ImageFont.truetype("./serve/Arial.ttf", 86)
# text_position = (156, 212)
# if vote_level == 0:
# text_color = (0, 255, 0, 200)
# draw.text(text_position, Top1_text, font=font, fill=text_color)
# elif vote_level == 1:
# text_color = (0, 255, 255, 200)
# draw.text(text_position, Top2_text, font=font, fill=text_color)
# elif vote_level == 2:
# text_color = (255, 0, 255, 200)
# draw.text(text_position, Top3_text, font=font, fill=text_color)
# elif vote_level == 3:
# text_color = (255, 0, 0, 200)
# draw.text(text_position, Top4_text, font=font, fill=text_color)
# combined = Image.alpha_composite(base_image, txt_layer)
# combined = combined.convert("RGB")
# return combined
def add_foreground(image, vote_level, Top1_text, Top2_text, Top3_text, Top4_text):
base_image = Image.fromarray(image).convert("RGBA")
base_image = base_image.resize((512, 512), Image.ANTIALIAS)
if vote_level == 0:
border_color = COLOR1
elif vote_level == 1:
border_color = COLOR2
elif vote_level == 2:
border_color = COLOR3
elif vote_level == 3:
border_color = COLOR4
border_size = 10 # Size of the border
base_image = ImageOps.expand(base_image, border=border_size, fill=border_color)
draw = ImageDraw.Draw(base_image)
font = ImageFont.truetype("./serve/Arial.ttf", 66)
text_position = (180, 25)
if vote_level == 0:
text_color = COLOR1
draw.text(text_position, Top1_text, font=font, fill=text_color)
elif vote_level == 1:
text_color = COLOR2
draw.text(text_position, Top2_text, font=font, fill=text_color)
elif vote_level == 2:
text_color = COLOR3
draw.text(text_position, Top3_text, font=font, fill=text_color)
elif vote_level == 3:
text_color = COLOR4
draw.text(text_position, Top4_text, font=font, fill=text_color)
base_image = base_image.convert("RGB")
return base_image
def add_green_border(image):
border_color = (0, 255, 0) # RGB for green
border_size = 10 # Size of the border
img_with_border = ImageOps.expand(image, border=border_size, fill=border_color)
return img_with_border