CYF200127's picture
Upload 235 files
3e1d9f3 verified
raw
history blame
7.8 kB
from ..root import (
DATASETS,
QUESTION_PLACEHOLDER,
IMAGE_PLACEHOLDER,
BOXES_PLACEHOLDER,
)
from ..utils import MInstrDataset
def prepare_sentence(sent):
ret_str = []
ret_box_seq = []
for word in sent:
if isinstance(word, list):
ret_str.append(BOXES_PLACEHOLDER)
ret_box_seq.append(word)
else:
ret_str.append(word)
return " ".join(ret_str), ret_box_seq
def prepare_choice(pack_choices, label_index, *, options='ABCDEFG'):
ret_str = []
ret_box_seq = []
for pack, op in zip(pack_choices, options):
ret_str.append(f"({op}) {pack[0]}")
ret_box_seq.extend(pack[1])
ret_pack = (" ".join(ret_str), ret_box_seq)
label_choice = f"The answer is ({options[label_index]})."
return ret_pack, (label_choice, [])
def merge(packs, *, prefixs, postfixs=None):
if postfixs is None:
postfixs = ['' for _ in range(len(packs))]
assert len(packs) == len(prefixs) == len(postfixs), f"{len(packs)},{len(prefixs)},{len(postfixs)}"
ret_str = []
ret_box_seq = []
for pack, prefix, postfix in zip(packs, prefixs, postfixs):
if prefix:
ret_str.append(prefix)
ret_str.append(pack[0])
if postfix:
ret_str.append(postfix)
ret_box_seq.extend(pack[1])
return " ".join(ret_str), ret_box_seq
@DATASETS.register_module()
class VCRDataset(MInstrDataset):
def __init__(self, *args, version, **kwargs):
super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER))
self.version = version
assert version in [
'q-a', 'q-ra',
'qc-a', 'qc-ra', 'qc-rac', # for evaluation: A
'qa-r', 'q-a-q-r',
'qac-r', 'qc-a-qc-r', # for evaluation: R
]
# for evaluation:
# A: 'qc-a' 'qc-ra' 'qc-rac'
# R: 'qac-r' 'qc-a-qc-r'
def __getitem__(self, index, force_answer_label=None, force_rationale_label=None):
item = self.get_raw_item(index)
image = self.get_image(item['img_fn'])
boxes_with_prob = item['boxes']
boxes = [box[:4] for box in boxes_with_prob]
question = item['question']
answer_choices = item['answer_choices']
rationale_choices = item['rationale_choices']
if force_answer_label is not None:
answer_label = force_answer_label
else:
answer_label = item['answer_label']
if force_rationale_label is not None:
rationale_label = force_rationale_label
else:
rationale_label = item['rationale_label']
question_pack = prepare_sentence(question)
answer_pack_choices = [prepare_sentence(_) for _ in answer_choices]
rationale_pack_choices = [prepare_sentence(_) for _ in rationale_choices]
answer_choices_pack, answer_choice = prepare_choice(answer_pack_choices, answer_label)
rationale_choices_pack, rationale_choice = prepare_choice(rationale_pack_choices, rationale_label)
answer_gold_pack = answer_pack_choices[answer_label]
rationale_gold_pack = rationale_pack_choices[rationale_label]
version = self.version
if version == 'q-a':
final_packs = [
merge([question_pack], prefixs=['QUESTION:'], ),
answer_gold_pack,
]
elif version == 'q-ra':
final_packs = [
merge([question_pack], prefixs=['QUESTION:'], ),
merge([rationale_gold_pack, answer_gold_pack], prefixs=['', '']),
]
elif version == 'qc-a':
final_packs = [
merge([question_pack, answer_choices_pack], prefixs=['QUESTION:', '\nOPTIONS:'], postfixs=['', 'You should decide on the best choice and output the corresponding letter.']),
answer_choice,
]
elif version == 'qc-ra':
final_packs = [
merge([question_pack, answer_choices_pack], prefixs=['QUESTION:', '\nOPTIONS:'], postfixs=['', 'You should decide on the best choice and output the corresponding letter.']),
merge([rationale_gold_pack, answer_choice], prefixs=['', '']),
]
elif version == 'qc-rac':
final_packs = [
merge([question_pack, answer_choices_pack], prefixs=['QUESTION:', '\nOPTIONS:'], postfixs=['', 'You should decide on the best choice and output the corresponding letter.']),
merge([rationale_gold_pack, answer_gold_pack, answer_choice], prefixs=['', '', '']),
]
elif version == 'qa-r':
final_packs = [
merge([question_pack, answer_gold_pack], prefixs=['QUESTION:', '\nANSWER:'], postfixs=['', 'You should explain the reason for the above answer.']),
rationale_gold_pack,
]
elif version == 'qac-r':
final_packs = [
merge([question_pack, answer_gold_pack, rationale_choices_pack], prefixs=['QUESTION:', '\nANSWER:', '\nRATIONALE OPTIONS:'], postfixs=['', '', 'You should decide on the best choice that explains the above answer and output the corresponding letter.']),
rationale_choice,
]
elif version == 'q-a-q-r':
final_packs = [
merge([question_pack], prefixs=['QUESTION:'], ),
answer_gold_pack,
('You should explain the reason for the above answer.', ()),
rationale_gold_pack,
]
elif version == 'qc-a-qc-r':
final_packs = [
merge([question_pack, answer_choices_pack], prefixs=['QUESTION:', '\nOPTIONS:'], postfixs=['', 'You should decide on the best choice and output the corresponding letter.']),
answer_choice,
merge([rationale_choices_pack], prefixs=['RATIONALE OPTIONS:'], postfixs=['You should decide on the best choice that explains the above answer and output the corresponding letter.']),
rationale_choice,
]
else:
assert False
conversations = []
roles = ['human', 'gpt']
for idx, pack in enumerate(final_packs):
conversations.append({
'from': roles[idx % 2],
'value': pack[0],
'boxes_seq': pack[1],
})
conversations[0]['value'] = self.get_template().replace(QUESTION_PLACEHOLDER, conversations[0]['value'])
ret = {
'image': image,
'target': {'boxes': boxes},
'conversations': conversations,
}
return ret
@DATASETS.register_module()
class VCRPredDataset(VCRDataset):
def __init__(self, *args, version, **kwargs):
super().__init__(*args, version=version, **kwargs)
assert version in [
'qc-a', 'qc-ra', 'qc-rac', # for evaluation: A
'qac-r', 'qc-a-qc-r', # for evaluation: R
]
self.is_pred_for_r = version in [
'qac-r', 'qc-a-qc-r', # for evaluation: R
]
def __len__(self):
if self.is_pred_for_r:
return super().__len__() * 4
else:
return super().__len__()
# noinspection PyMethodOverriding
def __getitem__(self, index):
if self.is_pred_for_r:
item_index = index // 4
answer_index = index % 4
ret = super().__getitem__(item_index, force_answer_label=answer_index, force_rationale_label=0)
else:
ret = super().__getitem__(index, force_answer_label=0, force_rationale_label=0)
ret['conversations'][-1]['value'] += "WARNING: answer and rationale here are just placeholders. we have no real anno."
return ret