|
import json |
|
import re |
|
|
|
from ..root import DATASETS, IMAGE_PLACEHOLDER, BOXES_PLACEHOLDER, QUESTION_PLACEHOLDER, METRICS |
|
from ..utils.flickr30k_entities_utils import PHRASE_ST_PLACEHOLDER, PHRASE_ED_PLACEHOLDER |
|
from ..utils import MInstrDataset, BaseComputeMetrics |
|
|
|
REFID_PAT = re.compile(r'(\s\((?:(?:\d+(?:,\d+)*)|-)\)\s?)') |
|
ANS_EXTRACT_PAT = re.compile(r'(?:(?:(?:(?:(?:So t)|(?:T)|(?:t))he answer is)|(?:Answer:)) (.+))') |
|
|
|
|
|
@DATASETS.register_module() |
|
class GQADataset(MInstrDataset): |
|
def __init__( |
|
self, |
|
*args, |
|
scene_graph_file, |
|
scene_graph_index, |
|
version, |
|
question_box_prob=0.5, |
|
**kwargs |
|
): |
|
super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER)) |
|
self.scene_graph_file = scene_graph_file |
|
self.scene_graph_index = scene_graph_index |
|
self.version = version |
|
self.question_box_prob = question_box_prob |
|
qtype, atype = version.split('-') |
|
assert qtype in ['q', 'qb', 'qbp'] |
|
assert atype in ['a', 'c', 'bc', 's', 'bs', 'l', 'bl'] |
|
self.qtype = qtype |
|
self.atype = atype |
|
|
|
assert bool(scene_graph_file) == bool(scene_graph_index) |
|
if scene_graph_file is not None and scene_graph_index is not None: |
|
self.scene_graph = [line for line in open(scene_graph_file, 'r', encoding='utf8')] |
|
self.scene_index = json.load(open(scene_graph_index, 'r', encoding='utf8')) |
|
else: |
|
self.scene_graph = None |
|
self.scene_index = None |
|
|
|
def get_raw_item(self, index): |
|
question = json.loads(self.data[index]) |
|
if self.scene_graph is None: |
|
return question, None |
|
scene = json.loads(self.scene_graph[self.scene_index[question['imageId']]]) |
|
return question, scene |
|
|
|
def __getitem__(self, index): |
|
question, scene = self.get_raw_item(index) |
|
img_path = f"{question['imageId']}.jpg" |
|
image = self.get_image(img_path) |
|
|
|
|
|
if self.atype == 'bc': |
|
boxes = question['cot']['boxes'] |
|
answer = question['cot']['value'].replace(PHRASE_ST_PLACEHOLDER, "").replace(PHRASE_ED_PLACEHOLDER, BOXES_PLACEHOLDER) |
|
answer_boxes_seq = question['cot']['seq'] |
|
elif self.atype == 'c': |
|
boxes = [] |
|
answer = question['cot']['value'].replace(PHRASE_ST_PLACEHOLDER, "").replace(PHRASE_ED_PLACEHOLDER, "") |
|
answer_boxes_seq = [] |
|
elif self.atype == 'bs': |
|
boxes, bss, answer_boxes_seq = get_bss_example(question, scene) |
|
answer = f"{bss}. The answer is {question['answer']}." |
|
elif self.atype == 's': |
|
boxes = [] |
|
ss = REFID_PAT.sub('', question['semanticStr']) |
|
answer = f"{ss}. The answer is {question['answer']}." |
|
answer_boxes_seq = [] |
|
elif self.atype == 'bl': |
|
boxes, answer, answer_boxes_seq = get_bl_example(question, scene) |
|
elif self.atype == 'l': |
|
boxes = [] |
|
_, answer, _ = get_bl_example(question, scene) |
|
answer = answer.replace(BOXES_PLACEHOLDER, "") |
|
answer_boxes_seq = [] |
|
elif self.atype == 'a': |
|
boxes = [] |
|
answer = f"The answer is {question['answer']}." |
|
answer_boxes_seq = [] |
|
else: |
|
assert False |
|
|
|
|
|
if self.qtype == 'q': |
|
boxes, query, query_boxes_seq = prepare_query_dummy(boxes, question, scene) |
|
elif self.qtype == 'qb': |
|
boxes, query, query_boxes_seq = prepare_query_box(boxes, question, scene) |
|
elif self.qtype == 'qbp': |
|
if self.rng.uniform() > self.question_box_prob: |
|
boxes, query, query_boxes_seq = prepare_query_dummy(boxes, question, scene) |
|
else: |
|
boxes, query, query_boxes_seq = prepare_query_box(boxes, question, scene) |
|
else: |
|
assert False |
|
|
|
final_query = self.get_template().replace(QUESTION_PLACEHOLDER, query) |
|
|
|
ret = { |
|
'image': image, |
|
'target': {'boxes': boxes}, |
|
'conversations': [ |
|
{ |
|
'from': 'human', |
|
'value': final_query, |
|
'boxes_seq': query_boxes_seq, |
|
}, |
|
{ |
|
'from': 'gpt', |
|
'value': answer, |
|
'boxes_seq': answer_boxes_seq, |
|
} |
|
] |
|
} |
|
return ret |
|
|
|
|
|
def prepare_query_dummy(boxes_list, q, scene): |
|
return boxes_list, q['question'], [] |
|
|
|
|
|
def prepare_query_box(boxes_list, q, scene): |
|
def get_boxes_idx(box): |
|
if box in boxes_list: |
|
return boxes_list.index(box) |
|
else: |
|
boxes_list.append(box) |
|
return len(boxes_list) - 1 |
|
|
|
def add_boxes_by_rids(rids): |
|
def get_box_xyxy(obj): |
|
x, y, w, h = obj['x'], obj['y'], obj['w'], obj['h'] |
|
return x, y, x + w, y + h |
|
|
|
boxes_idx = [] |
|
for rid in rids: |
|
ref = scene['objects'][rid] |
|
ref_box = list(get_box_xyxy(ref)) |
|
boxes_idx.append(get_boxes_idx(ref_box)) |
|
return boxes_idx |
|
|
|
sent = list(q['question'].split()) |
|
query_boxes_seq = [] |
|
for span, rids_str in q['annotations']['question'].items(): |
|
span = tuple(map(int, span.split(':'))) |
|
if len(span) == 1: |
|
span = [span[0], span[0] + 1] |
|
sent[span[1] - 1] = f"{sent[span[1] - 1]}{BOXES_PLACEHOLDER}" |
|
boxes_idx = add_boxes_by_rids(rids_str.split(',')) |
|
query_boxes_seq.append(boxes_idx) |
|
sent_converted = " ".join(sent).strip() |
|
return boxes_list, sent_converted, query_boxes_seq |
|
|
|
|
|
def add_boxes_by_rids(boxes_list, rids, scene): |
|
def get_boxes_idx(boxes_list, box): |
|
if box in boxes_list: |
|
return boxes_list.index(box) |
|
else: |
|
boxes_list.append(box) |
|
return len(boxes_list) - 1 |
|
|
|
def get_box_xyxy(obj): |
|
x, y, w, h = obj['x'], obj['y'], obj['w'], obj['h'] |
|
return x, y, x + w, y + h |
|
|
|
boxes_idx = [] |
|
for rid in rids: |
|
ref = scene['objects'][rid] |
|
ref_box = list(get_box_xyxy(ref)) |
|
boxes_idx.append(get_boxes_idx(boxes_list, ref_box)) |
|
return boxes_idx |
|
|
|
|
|
def get_bss_example(question, scene): |
|
def format_refids(item): |
|
item = item.strip()[1:-1] |
|
return item.split(',') |
|
|
|
s = question['semanticStr'] |
|
print(REFID_PAT.findall(s)) |
|
formats = [] |
|
boxes = [] |
|
seqs = [] |
|
|
|
for item in REFID_PAT.findall(s): |
|
if '-' in item: |
|
formats.append('') |
|
else: |
|
formats.append('<boxes>') |
|
refids = format_refids(item) |
|
idx = add_boxes_by_rids(boxes, refids, scene) |
|
seqs.append(idx) |
|
answer = REFID_PAT.sub('{}', s).format(*formats) |
|
|
|
print(answer) |
|
print(boxes) |
|
print(seqs) |
|
return boxes, answer, seqs |
|
|
|
|
|
def get_bl_example(ann, scene): |
|
boxes = [] |
|
boxes_seq = [] |
|
|
|
origin_sent = ann['fullAnswer'] |
|
origin_sent = re.sub('(?:^Yes,)|(?:^No,)', '', origin_sent).strip() |
|
sent = list(origin_sent.split()) |
|
for span, rids_str in ann['annotations']['fullAnswer'].items(): |
|
span = tuple(map(int, span.split(':'))) |
|
if len(span) == 1: |
|
span = [span[0], span[0] + 1] |
|
sent[span[1] - 1] = f"{sent[span[1] - 1]}{BOXES_PLACEHOLDER}" |
|
rids = rids_str.split(',') |
|
boxes_idx = add_boxes_by_rids(boxes, rids, scene) |
|
boxes_seq.append(boxes_idx) |
|
|
|
answer = "".join(sent) |
|
answer += f"The answer is {ann['answer']}." |
|
return boxes, answer, boxes_seq |
|
|
|
|
|
@METRICS.register_module() |
|
class GQAComputeMetrics(BaseComputeMetrics): |
|
def extract_ans(self, string: str): |
|
try: |
|
found = ANS_EXTRACT_PAT.findall(string.strip()) |
|
if len(found) != 1: |
|
return None |
|
return found[0].strip().rstrip('.').strip() |
|
except (IndexError, AttributeError): |
|
return None |
|
|