|
from ..root import ( |
|
DATASETS, |
|
QUESTION_PLACEHOLDER, |
|
IMAGE_PLACEHOLDER, |
|
BOXES_PLACEHOLDER, |
|
) |
|
from ..utils import MInstrDataset |
|
|
|
|
|
def prepare_sentence(sent): |
|
ret_str = [] |
|
ret_box_seq = [] |
|
for word in sent: |
|
if isinstance(word, list): |
|
ret_str.append(BOXES_PLACEHOLDER) |
|
ret_box_seq.append(word) |
|
else: |
|
ret_str.append(word) |
|
return " ".join(ret_str), ret_box_seq |
|
|
|
|
|
def prepare_choice(pack_choices, label_index, *, options='ABCDEFG'): |
|
ret_str = [] |
|
ret_box_seq = [] |
|
for pack, op in zip(pack_choices, options): |
|
ret_str.append(f"({op}) {pack[0]}") |
|
ret_box_seq.extend(pack[1]) |
|
ret_pack = (" ".join(ret_str), ret_box_seq) |
|
label_choice = f"The answer is ({options[label_index]})." |
|
return ret_pack, (label_choice, []) |
|
|
|
|
|
def merge(packs, *, prefixs, postfixs=None): |
|
if postfixs is None: |
|
postfixs = ['' for _ in range(len(packs))] |
|
assert len(packs) == len(prefixs) == len(postfixs), f"{len(packs)},{len(prefixs)},{len(postfixs)}" |
|
ret_str = [] |
|
ret_box_seq = [] |
|
for pack, prefix, postfix in zip(packs, prefixs, postfixs): |
|
if prefix: |
|
ret_str.append(prefix) |
|
ret_str.append(pack[0]) |
|
if postfix: |
|
ret_str.append(postfix) |
|
ret_box_seq.extend(pack[1]) |
|
return " ".join(ret_str), ret_box_seq |
|
|
|
|
|
@DATASETS.register_module() |
|
class VCRDataset(MInstrDataset): |
|
def __init__(self, *args, version, **kwargs): |
|
super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER)) |
|
self.version = version |
|
assert version in [ |
|
'q-a', 'q-ra', |
|
'qc-a', 'qc-ra', 'qc-rac', |
|
'qa-r', 'q-a-q-r', |
|
'qac-r', 'qc-a-qc-r', |
|
] |
|
|
|
|
|
|
|
|
|
def __getitem__(self, index, force_answer_label=None, force_rationale_label=None): |
|
item = self.get_raw_item(index) |
|
image = self.get_image(item['img_fn']) |
|
|
|
boxes_with_prob = item['boxes'] |
|
boxes = [box[:4] for box in boxes_with_prob] |
|
|
|
question = item['question'] |
|
answer_choices = item['answer_choices'] |
|
rationale_choices = item['rationale_choices'] |
|
if force_answer_label is not None: |
|
answer_label = force_answer_label |
|
else: |
|
answer_label = item['answer_label'] |
|
if force_rationale_label is not None: |
|
rationale_label = force_rationale_label |
|
else: |
|
rationale_label = item['rationale_label'] |
|
|
|
question_pack = prepare_sentence(question) |
|
answer_pack_choices = [prepare_sentence(_) for _ in answer_choices] |
|
rationale_pack_choices = [prepare_sentence(_) for _ in rationale_choices] |
|
|
|
answer_choices_pack, answer_choice = prepare_choice(answer_pack_choices, answer_label) |
|
rationale_choices_pack, rationale_choice = prepare_choice(rationale_pack_choices, rationale_label) |
|
answer_gold_pack = answer_pack_choices[answer_label] |
|
rationale_gold_pack = rationale_pack_choices[rationale_label] |
|
|
|
version = self.version |
|
if version == 'q-a': |
|
final_packs = [ |
|
merge([question_pack], prefixs=['QUESTION:'], ), |
|
answer_gold_pack, |
|
] |
|
elif version == 'q-ra': |
|
final_packs = [ |
|
merge([question_pack], prefixs=['QUESTION:'], ), |
|
merge([rationale_gold_pack, answer_gold_pack], prefixs=['', '']), |
|
] |
|
elif version == 'qc-a': |
|
final_packs = [ |
|
merge([question_pack, answer_choices_pack], prefixs=['QUESTION:', '\nOPTIONS:'], postfixs=['', 'You should decide on the best choice and output the corresponding letter.']), |
|
answer_choice, |
|
] |
|
elif version == 'qc-ra': |
|
final_packs = [ |
|
merge([question_pack, answer_choices_pack], prefixs=['QUESTION:', '\nOPTIONS:'], postfixs=['', 'You should decide on the best choice and output the corresponding letter.']), |
|
merge([rationale_gold_pack, answer_choice], prefixs=['', '']), |
|
] |
|
elif version == 'qc-rac': |
|
final_packs = [ |
|
merge([question_pack, answer_choices_pack], prefixs=['QUESTION:', '\nOPTIONS:'], postfixs=['', 'You should decide on the best choice and output the corresponding letter.']), |
|
merge([rationale_gold_pack, answer_gold_pack, answer_choice], prefixs=['', '', '']), |
|
] |
|
elif version == 'qa-r': |
|
final_packs = [ |
|
merge([question_pack, answer_gold_pack], prefixs=['QUESTION:', '\nANSWER:'], postfixs=['', 'You should explain the reason for the above answer.']), |
|
rationale_gold_pack, |
|
] |
|
elif version == 'qac-r': |
|
final_packs = [ |
|
merge([question_pack, answer_gold_pack, rationale_choices_pack], prefixs=['QUESTION:', '\nANSWER:', '\nRATIONALE OPTIONS:'], postfixs=['', '', 'You should decide on the best choice that explains the above answer and output the corresponding letter.']), |
|
rationale_choice, |
|
] |
|
elif version == 'q-a-q-r': |
|
final_packs = [ |
|
merge([question_pack], prefixs=['QUESTION:'], ), |
|
answer_gold_pack, |
|
('You should explain the reason for the above answer.', ()), |
|
rationale_gold_pack, |
|
] |
|
elif version == 'qc-a-qc-r': |
|
final_packs = [ |
|
merge([question_pack, answer_choices_pack], prefixs=['QUESTION:', '\nOPTIONS:'], postfixs=['', 'You should decide on the best choice and output the corresponding letter.']), |
|
answer_choice, |
|
merge([rationale_choices_pack], prefixs=['RATIONALE OPTIONS:'], postfixs=['You should decide on the best choice that explains the above answer and output the corresponding letter.']), |
|
rationale_choice, |
|
] |
|
else: |
|
assert False |
|
|
|
conversations = [] |
|
roles = ['human', 'gpt'] |
|
for idx, pack in enumerate(final_packs): |
|
conversations.append({ |
|
'from': roles[idx % 2], |
|
'value': pack[0], |
|
'boxes_seq': pack[1], |
|
}) |
|
conversations[0]['value'] = self.get_template().replace(QUESTION_PLACEHOLDER, conversations[0]['value']) |
|
|
|
ret = { |
|
'image': image, |
|
'target': {'boxes': boxes}, |
|
'conversations': conversations, |
|
} |
|
return ret |
|
|
|
|
|
@DATASETS.register_module() |
|
class VCRPredDataset(VCRDataset): |
|
def __init__(self, *args, version, **kwargs): |
|
super().__init__(*args, version=version, **kwargs) |
|
assert version in [ |
|
'qc-a', 'qc-ra', 'qc-rac', |
|
'qac-r', 'qc-a-qc-r', |
|
] |
|
self.is_pred_for_r = version in [ |
|
'qac-r', 'qc-a-qc-r', |
|
] |
|
|
|
def __len__(self): |
|
if self.is_pred_for_r: |
|
return super().__len__() * 4 |
|
else: |
|
return super().__len__() |
|
|
|
|
|
def __getitem__(self, index): |
|
if self.is_pred_for_r: |
|
item_index = index // 4 |
|
answer_index = index % 4 |
|
ret = super().__getitem__(item_index, force_answer_label=answer_index, force_rationale_label=0) |
|
else: |
|
ret = super().__getitem__(index, force_answer_label=0, force_rationale_label=0) |
|
ret['conversations'][-1]['value'] += "WARNING: answer and rationale here are just placeholders. we have no real anno." |
|
return ret |
|
|