import gradio as gr
from response_db import ResponseDb
from response_db import get_code
from create_cache import Game_Cache
import numpy as np
from PIL import Image
import pandas as pd
import torch
import pickle
import uuid
import nltk
nltk.download('punkt')
db = ResponseDb()
css = """
.chatbot {display:flex;flex-direction:column}
.msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
.msg.user {background-color:cornflowerblue;color:white;align-self:self-end}
.msg.bot {background-color:lightgray}
.na_button {background-color:red;color:red}
"""
get_window_url_params = """
function(url_params) {
console.log(url_params);
const params = new URLSearchParams(window.location.search);
url_params = Object.fromEntries(params);
return url_params;
}
"""
quals = {1001:99, 1002:136, 1003:56, 1004:105}
from model.run_question_asking_model import return_modules, return_modules_yn
question_model, response_model_simul, response_model_gtruth, caption_model = return_modules()
question_model_yn, response_model_simul_yn, response_model_gtruth_yn, caption_model_yn = return_modules_yn()
class Game_Session:
def __init__(self, taskid, yn, hard_setting):
self.yn = yn
self.hard_setting = hard_setting
global question_model, response_model_simul, response_model_gtruth, caption_model
global question_model_yn, response_model_simul_yn, response_model_gtruth_yn, caption_model_yn
self.question_model = question_model
self.response_model_simul = response_model_simul
self.response_model_gtruth = response_model_gtruth
self.caption_model = caption_model
self.question_model_yn = question_model_yn
self.response_model_simul_yn = response_model_simul_yn
self.response_model_gtruth_yn = response_model_gtruth_yn
self.caption_model_yn = caption_model_yn
global image_files, images_np, p_y_x, p_r_qy, p_y_xqr, captions, questions, target_questions
self.image_files, self.image_np, self.p_y_x, self.p_r_qy, self.p_y_xqr = None, None, None, None, None
self.captions, self.questions, self.target_questions = None, None, None
self.history = []
self.game_id = str(uuid.uuid4())
self.set_curr_models()
def set_curr_models(self):
if self.yn:
self.curr_question_model, self.curr_caption_model, self.curr_response_model_simul, self.curr_response_model_gtruth = self.question_model_yn, self.caption_model_yn, self.response_model_simul_yn, self.response_model_gtruth_yn
else:
self.curr_question_model, self.curr_caption_model, self.curr_response_model_simul, self.curr_response_model_gtruth = self.question_model, self.caption_model, self.response_model_simul, self.response_model_gtruth
def get_next_question(self):
return self.curr_question_model.select_best_question(self.p_y_x, self.questions, self.images_np, self.captions, self.curr_response_model_simul)
def get_model_gtruth_response(self, question):
return self.response_model_gtruth.get_response(question, self.images_np[0], self.captions[0], self.target_questions, is_a=self.yn)
def ask_a_question(input, taskid, gs):
# input = gs.get_model_gtruth_response(gs.history[-1])
if input not in ["n/a", "yes", "no"] and input not in gs.curr_response_model_simul.model.config.label2id:
html = "
"
for m, msg in enumerate(gs.history):
cls = "bot" if m%2 == 0 else "user"
html += "
{}
".format(cls, msg)
html += "
"
return html, gs, gr.Dropdown.update(visible=True, value=''), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Textbox.update(visible=False), gr.HTML.update(visible=True)
gs.history.append(input)
gs.p_r_qy = gs.curr_response_model_simul.get_p_r_qy(input, gs.history[-2], gs.images_np, gs.captions)
gs.p_y_xqr = gs.p_y_x*gs.p_r_qy
gs.p_y_xqr = gs.p_y_xqr/torch.sum(gs.p_y_xqr)if torch.sum(gs.p_y_xqr) != 0 else torch.zeros_like(gs.p_y_xqr)
gs.p_y_x = gs.p_y_xqr
gs.questions.remove(gs.history[-2])
if taskid not in quals: db.add(gs.game_id, taskid, len(gs.history)//2-1, gs.history[-2], gs.history[-1])
gs.history.append(gs.get_next_question())
top_prob = torch.max(gs.p_y_x).item()
top_pred = torch.argmax(gs.p_y_x).item()
if top_prob > 0.8 or len(gs.history) > 19:
gs.history = gs.history[:-1]
if taskid not in quals: db.add(gs.game_id, taskid, len(gs.history)//2, f"Guess: Image {top_pred}", "")
# write some HTML
html = ""
for m, msg in enumerate(gs.history):
cls = "bot" if m%2 == 0 else "user"
html += "
{}
".format(cls, msg)
html += "
"
### Game finished:
if top_prob > 0.8 or len(gs.history) > 19:
html += f"The model identified Image {top_pred+1} as the image. Please select a new task ID to continue.
"
finish_html = "Congratulations on finishing the game! Please copy the Task Finish Code below to MTurk to complete your task. You can now exit this window.
"
return html, gs, gr.Dropdown.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Textbox.update(value=get_code(taskid, gs.history, top_pred), visible=True), gr.HTML.update(value=finish_html, visible=True)
else:
if not gs.yn:
return html, gs, gr.Dropdown.update(visible=True, value=''), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Textbox.update(visible=False), gr.HTML.update(visible=False)
else:
return html, gs, gr.Dropdown.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=True), gr.Button.update(visible=True), gr.Textbox.update(visible=False), gr.HTML.update(visible=False)
def set_images(taskid):
pilot_study = pd.read_csv("pilot-study.csv")
taskid_original = taskid
if int(taskid) in quals: taskid = quals[int(taskid)]
taskid = pilot_study['mscoco-id'].tolist()[int(taskid)]
with open(f'cache-soft/{int(taskid)}.p', 'rb') as fp:
game_cache = pickle.load(fp)
gs = Game_Session(int(taskid), game_cache.yn, game_cache.hard_setting)
id1 = f"./mscoco-images/val2014/{game_cache.image_files[0]}"
id2 = f"./mscoco-images/val2014/{game_cache.image_files[1]}"
id3 = f"./mscoco-images/val2014/{game_cache.image_files[2]}"
id4 = f"./mscoco-images/val2014/{game_cache.image_files[3]}"
id5 = f"./mscoco-images/val2014/{game_cache.image_files[4]}"
id6 = f"./mscoco-images/val2014/{game_cache.image_files[5]}"
id7 = f"./mscoco-images/val2014/{game_cache.image_files[6]}"
id8 = f"./mscoco-images/val2014/{game_cache.image_files[7]}"
id9 = f"./mscoco-images/val2014/{game_cache.image_files[8]}"
id10 = f"./mscoco-images/val2014/{game_cache.image_files[9]}"
gs.image_files = [id1, id2, id3, id4, id5, id6, id7, id8, id9, id10]
gs.image_files = [x[15:] for x in gs.image_files]
gs.images_np = [np.asarray(Image.open(f"./mscoco-images/{i}")) for i in gs.image_files]
gs.images_np = [np.dstack([i]*3) if len(i.shape)==2 else i for i in gs.images_np]
gs.p_y_x = (torch.ones(10)/10).to(gs.curr_question_model.device)
gs.captions = gs.curr_caption_model.get_captions(gs.image_files)
gs.questions, gs.target_questions = gs.curr_question_model.get_questions(gs.image_files, gs.captions, 0)
gs.curr_question_model.reset_question_bank()
gs.curr_question_model.question_bank = game_cache.question_dict
first_question = gs.curr_question_model.select_best_question(gs.p_y_x, gs.questions, gs.images_np, gs.captions, gs.curr_response_model_simul)
first_question_html = f""
gs.history.append(first_question)
# html = f"Current Task ID: {int(taskid_original)}
"
if not gs.yn:
return id1, id1, id2, id3, id4, id5, id6, id7, id8, id9, id10, gs, first_question_html, gr.Dropdown.update(visible=True, value=''), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False)
else:
return id1, id1, id2, id3, id4, id5, id6, id7, id8, id9, id10, gs, first_question_html, gr.Dropdown.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=True), gr.Button.update(visible=True)
def reset_dropdown():
return gr.Dropdown.update(visible=True, value='')
with gr.Blocks(title="Image Q&A Guessing Game", css=css) as demo:
gr.HTML("Image Q&A Guessing Game
\
\
Imagine you are playing 20-questions with an AI model.
\
The AI model plays the role of the question asker. You play the role of the responder.
\
There are 10 images. Your image is Image 1. The other images are distraction images.\
The model can see all 10 images and all the questions and answers for the current set of images. It will ask a question based on the available information.
\
The goal of the model is to accurately guess the correct image (i.e. Image 1) in as few turns as possible.
\
Your goal is to help the model guess the image by answering as clearly and accurately as possible.
\
(Note: We are testing multiple game settings. In some instances, the game will be open-ended, while in other instances, the answer choices will be limited to yes/no.)
\
Selecting N/A:
\
\
- In some games, there will be an N/A option. Please select N/A only if the question is unanswerable BECAUSE IT DOES NOT APPLY TO THE IMAGE.
\
- Otherwise, please select the closest possible option.
\
- e.g. Q:\"What is the dog doing?\" Please select N/A if there is no dog in the image.\
\
")
with gr.Column():
with gr.Row():
taskid = gr.Number(label="Task ID (Enter a number from 0 to 160)", visible=False)
start_button = gr.Button("Start")
with gr.Column() as img_block:
with gr.Row():
img1 = gr.Image(label="Image 1", show_label=True)
img2 = gr.Image(label="Image 2", show_label=True)
img3 = gr.Image(label="Image 3", show_label=True)
img4 = gr.Image(label="Image 4", show_label=True)
img5 = gr.Image(label="Image 5", show_label=True)
with gr.Row():
img6 = gr.Image(label="Image 6", show_label=True)
img7 = gr.Image(label="Image 7", show_label=True)
img8 = gr.Image(label="Image 8", show_label=True)
img9 = gr.Image(label="Image 9", show_label=True)
img10 = gr.Image(label="Image 10", show_label=True)
conversation = gr.HTML()
game_session_state = gr.State()
# answer = gr.Textbox(placeholder="Insert answer here.", label="Answer the given question.", visible=False)
full_vocab_dict = response_model_simul_yn.model.config.label2id
vocab_list_numbers, vocab_list_letters = [], []
for i in full_vocab_dict:
if i=="None" or i is None: continue
if i[0] in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
vocab_list_numbers.append(i)
else:
vocab_list_letters.append(i)
with gr.Row():
answer = gr.Dropdown(vocab_list_letters+vocab_list_numbers, label="Answer the given question.", \
info="If you cannot find your exact answer, pick the word you feel would be most appropriate. ONLY SELECT N/A IF THE QUESTION DOES NOT APPLY TO THE IMAGE.", visible=False)
clear_box = gr.Button("Reset Selection \n(Use this to clear the dropdown selection.)", visible=False)
with gr.Row():
vocab_warning = gr.HTML("The word you typed in is not a valid word in the model vocabulary. Please clear it and select a valid word from the dropdown menu.
", visible=False)
null_answer = gr.Textbox("n/a", visible=False)
yes_answer = gr.Textbox("yes", visible=False)
no_answer = gr.Textbox("no", visible=False)
with gr.Column():
with gr.Row():
yes_box = gr.Button("Yes", visible=False)
no_box = gr.Button("No", visible=False)
with gr.Column():
with gr.Row():
na_box = gr.Button("N/A", visible=False, elem_classes="na_button")
submit = gr.Button("Submit", visible=False)
with gr.Row():
reward_code = gr.Textbox("", label="Task Finish Code", visible=False)
with gr.Column() as img_block0:
with gr.Row():
img0 = gr.Image(label="Image 1", show_label=True).style(height=700, width=700)
### Button click events
start_button.click(fn=set_images, inputs=taskid, outputs=[img0, img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, game_session_state, conversation, answer, na_box, clear_box, submit, taskid, start_button, yes_box, no_box])
submit.click(fn=ask_a_question, inputs=[answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, clear_box, submit, taskid, start_button, yes_box, no_box, reward_code, vocab_warning])
na_box.click(fn=ask_a_question, inputs=[null_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, clear_box, submit, taskid, start_button, yes_box, no_box, reward_code, vocab_warning])
yes_box.click(fn=ask_a_question, inputs=[yes_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, clear_box, submit, taskid, start_button, yes_box, no_box, reward_code, vocab_warning])
no_box.click(fn=ask_a_question, inputs=[no_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, clear_box, submit, taskid, start_button, yes_box, no_box, reward_code, vocab_warning])
clear_box.click(fn=reset_dropdown, inputs=[], outputs=[answer])
url_params = gr.JSON({}, visible=False, label="URL Params")
demo.load(fn = lambda url_params : gr.Number.update(value=int(url_params['p'])), inputs=[url_params], outputs=taskid, _js=get_window_url_params)
demo.launch()