from ..root import ( DATASETS, QUESTION_PLACEHOLDER, IMAGE_PLACEHOLDER, BOXES_PLACEHOLDER, ) from ..utils import MInstrDataset from ..utils.flickr30k_entities_utils import PHRASE_ST_PLACEHOLDER, PHRASE_ED_PLACEHOLDER @DATASETS.register_module() class GPT4Gen(MInstrDataset): def __init__(self, *args, version, **kwargs): super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER)) self.version = version assert version in ['a', 'c', 'bc'] def __getitem__(self, item): raw = self.get_raw_item(item) # image = self.get_image(raw['img_path']) # boxes = raw['boxes'] # question = raw['question'] question = question.replace(PHRASE_ST_PLACEHOLDER, '').replace(PHRASE_ED_PLACEHOLDER, BOXES_PLACEHOLDER) final_question = self.get_template().replace(QUESTION_PLACEHOLDER, question) query_boxes_seq = raw['question_boxes_seq'] if self.version == 'a': final_answer = raw['answer'] answer_boxes_seq = None elif self.version == 'c': final_answer = raw['cot_with_ans'].replace(PHRASE_ST_PLACEHOLDER, '').replace(PHRASE_ED_PLACEHOLDER, '') answer_boxes_seq = None elif self.version == 'bc': final_answer = raw['cot_with_ans'].replace(PHRASE_ST_PLACEHOLDER, '').replace(PHRASE_ED_PLACEHOLDER, BOXES_PLACEHOLDER) answer_boxes_seq = raw['answer_boxes_seq'] else: assert False ret = { 'image': image, 'target': {'boxes': boxes}, 'conversations': [ { 'from': 'human', 'value': final_question, 'boxes_seq': query_boxes_seq, }, { 'from': 'gpt', 'value': final_answer, 'boxes_seq': answer_boxes_seq, } ] } return ret