import json from ..root import DATASETS, IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER, POINTS_PLACEHOLDER from ..utils import MInstrDataset @DATASETS.register_module() class ClevrDataset(MInstrDataset): def __init__(self, *args, scene_graph_file, version, **kwargs): super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER)) self.scene_graph_file = scene_graph_file self.version = version qtype, atype = version.split('-') assert qtype in ['q'] assert atype in ['a', 's', 'bs'] self.qtype = qtype self.atype = atype if scene_graph_file is None: self.scene_graph = None else: self.scene_graph = [line for line in open(scene_graph_file, 'r', encoding='utf8')] def get_raw_item(self, index): question = json.loads(self.data[index]) if self.scene_graph is None: scene = None else: scene = json.loads(self.scene_graph[question['image_index']]) return question, scene def __getitem__(self, index): question, scene = self.get_raw_item(index) img_path = question['image_filename'] image = self.get_image(img_path) if self.atype == 'a': boxes = [] answer = f"The answer is {question['answer']}." answer_boxes_seq = [] elif self.atype == 's': answer, boxes, answer_boxes_seq = clevr_ss_cot(obj=question, scene=scene, add_ref=False) answer += f" The answer is {question['answer']}." elif self.atype == 'bs': answer, boxes, answer_boxes_seq = clevr_ss_cot(obj=question, scene=scene, add_ref=True) answer += f" The answer is {question['answer']}." else: assert False if self.qtype == 'q': query_boxes_seq = [] final_query = self.get_template().replace(QUESTION_PLACEHOLDER, question['question']) else: assert False ret = { 'image': image, 'target': {'points': boxes}, 'conversations': [ { 'from': 'human', 'value': final_query, 'points_seq': query_boxes_seq, }, { 'from': 'gpt', 'value': answer, 'points_seq': answer_boxes_seq, } ] } return ret def get_boxes_idx(boxes_list, refs): def get_idx(boxes_list, box): if box in boxes_list: return boxes_list.index(box) else: boxes_list.append(box) return len(boxes_list) - 1 idx = [get_idx(boxes_list, box) for box in refs] return idx def clevr_ss_cot(obj, scene, add_ref=False): cot = [] boxes = [] seq = [] def can_add_ref(): if p['function'] in ['unique', 'union', 'intersect', 'relate', 'same_size', 'same_shape', 'same_material', 'same_color']: return True if p['function'] in ['scene', 'filter_color', 'filter_material', 'filter_shape', 'filter_size']: if idx + 1 < len(obj['program']) and obj['program'][idx + 1]['function'] in ['exist', 'count']: return True return False for idx, p in enumerate(obj['program']): func = f"{p['function']}:{p['value_inputs'][0]}" if 'value_inputs' in p and p['value_inputs'] else p['function'] inputs = f"[{','.join(map(str, p['inputs']))}]" if p['inputs'] else "" if add_ref and can_add_ref(): if p['ans']: objs = POINTS_PLACEHOLDER idx = get_boxes_idx(boxes_list=boxes, refs=[scene['objects'][_]['pixel_coords'][:2] for _ in p['ans']]) seq.append(idx) else: objs = f" Found no object." else: objs = "" cot.append(f"{func}{inputs}{objs}") ret = " -> ".join(cot) return ret, boxes, seq