Spaces:
Runtime error
Runtime error
File size: 9,282 Bytes
9de012e 0b360fe 9de012e 782b3bb 9de012e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 |
import json
import os
import torch
import numpy as np
from leo.model import SequentialGrounder
from leo.utils import LabelConverter, convert_pc_to_box, obj_processing_post, pad_sequence
from torch.utils.data import default_collate
ASSET_DIR = os.path.join(os.getcwd(), 'assets')
CKPT_DIR = os.path.join(os.getcwd(), 'checkpoint/leo')
int2cat = json.load(open(os.path.join(ASSET_DIR, "meta/scannetv2_raw_categories.json"), 'r', encoding="utf-8"))
cat2int = {w: i for i, w in enumerate(int2cat)}
label_converter = LabelConverter(os.path.join(ASSET_DIR, "meta/scannetv2-labels.combined.tsv"))
role_prompt = "You are an AI visual assistant situated in a 3D scene. "\
"You can perceive (1) an ego-view image (accessible when necessary) and (2) the objects (including yourself) in the scene (always accessible). "\
"You should properly respond to the USER's instruction according to the given visual information. "
#role_prompt = " "
egoview_prompt = "Ego-view image:"
objects_prompt = "Objects (including you) in the scene:"
task_prompt = "USER: {instruction} ASSISTANT:"
def get_prompt(instruction):
return {
'prompt_before_obj': role_prompt,
'prompt_middle_1': egoview_prompt,
'prompt_middle_2': objects_prompt,
'prompt_after_obj': task_prompt.format(instruction=instruction),
}
def get_lang(task_item):
task_description = task_item['task_description']
sentence = task_description
data_dict = get_prompt(task_description)
# scan_id = task_item['scan_id']
if 'action_steps' in task_item:
action_steps = task_item['action_steps']
# tgt_object_id = [int(action['target_id']) for action in action_steps]
# tgt_object_name = [action['label'] for action in action_steps]
for action in action_steps:
sentence += ' ' + action['action']
data_dict['output_gt'] = ' '.join([action['action'] + ' <s>' for action in action_steps])
# return scan_id, tgt_object_id, tgt_object_name, sentence, data_dict
return data_dict
def load_data(scan_id):
one_scan = {}
# load scan
pcd_data = torch.load(os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_pcd.pth'))
inst_to_label = torch.load(os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_inst.pth'))
points, colors, instance_labels = pcd_data[0], pcd_data[1], pcd_data[-1]
colors = colors / 127.5 - 1
pcds = np.concatenate([points, colors], 1)
one_scan['pcds'] = pcds
one_scan['instance_labels'] = instance_labels
one_scan['inst_to_label'] = inst_to_label
# convert to gt object
obj_pcds = []
inst_ids = []
inst_labels = []
bg_indices = np.full((points.shape[0], ), 1, dtype=np.bool_)
for inst_id in inst_to_label.keys():
if inst_to_label[inst_id] in cat2int.keys():
mask = instance_labels == inst_id
if np.sum(mask) == 0:
continue
obj_pcds.append(pcds[mask])
inst_ids.append(inst_id)
inst_labels.append(cat2int[inst_to_label[inst_id]])
if inst_to_label[inst_id] not in ['wall', 'floor', 'ceiling']:
bg_indices[mask] = False
one_scan['obj_pcds'] = obj_pcds
one_scan['inst_labels'] = inst_labels
one_scan['inst_ids'] = inst_ids
one_scan['bg_pcds'] = pcds[bg_indices]
# calculate box for matching
obj_center = []
obj_box_size = []
for obj_pcd in obj_pcds:
_c, _b = convert_pc_to_box(obj_pcd)
obj_center.append(_c)
obj_box_size.append(_b)
one_scan['obj_loc'] = obj_center
one_scan['obj_box'] = obj_box_size
# load point feat
feat_pth = os.path.join(ASSET_DIR, f'inputs/{scan_id}', 'obj_feats.pth')
one_scan['obj_feats'] = torch.load(feat_pth, map_location='cpu')
# convert to pq3d input
obj_labels = one_scan['inst_labels'] # N
obj_pcds = one_scan['obj_pcds']
obj_ids = one_scan['inst_ids']
# object filter
excluded_labels = ['wall', 'floor', 'ceiling']
def keep_obj(i, obj_label):
category = int2cat[obj_label]
# filter out background
if category in excluded_labels:
return False
# filter out objects not mentioned in the sentence
return True
selected_obj_idxs = [i for i, obj_label in enumerate(obj_labels) if keep_obj(i, obj_label)]
# crop objects to max_obj_len and reorganize ids ? # TODO
obj_labels = [obj_labels[i] for i in selected_obj_idxs]
obj_pcds = [obj_pcds[i] for i in selected_obj_idxs]
# subsample points
obj_pcds = np.array([obj_pcd[np.random.choice(len(obj_pcd), size=1024,
replace=len(obj_pcd) < 1024)] for obj_pcd in obj_pcds])
obj_fts, obj_locs, obj_boxes, rot_matrix = obj_processing_post(obj_pcds, rot_aug=False)
data_dict = {
"scan_id": scan_id,
"obj_fts": obj_fts.float(),
"obj_locs": obj_locs.float(),
"obj_labels": torch.LongTensor(obj_labels),
"obj_boxes": obj_boxes,
"obj_pad_masks": torch.ones((len(obj_locs)), dtype=torch.bool), # used for padding in collate
"obj_ids": torch.LongTensor([obj_ids[i] for i in selected_obj_idxs])
}
# convert point feature
data_dict['obj_feats'] = one_scan['obj_feats'].squeeze(0)
useful_keys = ['tgt_object_id', 'scan_id', 'obj_labels', 'data_idx',
'obj_fts', 'obj_locs', 'obj_pad_masks', 'obj_ids',
'source', 'prompt_before_obj', 'prompt_middle_1',
'prompt_middle_2', 'prompt_after_obj', 'output_gt', 'obj_feats']
for k in list(data_dict.keys()):
if k not in useful_keys:
del data_dict[k]
# add new keys because of leo
data_dict['img_fts'] = torch.zeros(3, 224, 224)
data_dict['img_masks'] = torch.LongTensor([0]).bool()
data_dict['anchor_locs'] = torch.zeros(3)
data_dict['anchor_orientation'] = torch.zeros(4)
data_dict['anchor_orientation'][-1] = 1 # xyzw
# convert to leo format
data_dict['obj_masks'] = data_dict['obj_pad_masks']
del data_dict['obj_pad_masks']
return data_dict
def form_batch(data_dict):
batch = [data_dict]
new_batch = {}
# pad
padding_keys = ['obj_fts', 'obj_locs', 'obj_masks', 'obj_labels', 'obj_ids']
for k in padding_keys:
tensors = [sample.pop(k) for sample in batch]
padded_tensor = pad_sequence(tensors, pad=0)
new_batch[k] = padded_tensor
# # list
# list_keys = ['tgt_object_id']
# for k in list_keys:
# new_batch[k] = [sample.pop(k) for sample in batch]
# default collate
new_batch.update(default_collate(batch))
return new_batch
def inference(scan_id, task, predict_mode=False):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = 'cpu' # ok for predict_mode=False, and both for Gradio demo local preview
data_dict = load_data(scan_id)
data_dict.update(get_lang(task))
data_dict = form_batch(data_dict)
for key, value in data_dict.items():
if isinstance(value, torch.Tensor):
data_dict[key] = value.to(device)
model = SequentialGrounder(predict_mode)
load_msg = model.load_state_dict(torch.load(os.path.join(CKPT_DIR, 'pytorch_model.bin'), map_location='cpu'), strict=False)
model.to(device)
data_dict = model(data_dict)
if predict_mode == False:
# calculate result id
result_id_list = [data_dict['obj_ids'][0][torch.argmax(data_dict['ground_logits'][i]).item()]
for i in range(len(data_dict['ground_logits']))]
else:
# calculate langauge
# tgt_object_id = data_dict['tgt_object_id']
if data_dict['ground_logits'] == None:
og_pred = []
else:
og_pred = torch.argmax(data_dict['ground_logits'], dim=1)
grd_batch_ind_list = data_dict['grd_batch_ind_list']
response_pred = []
for i in range(1): # len(tgt_object_id)
# target_sequence = list(tgt_object_id[i].cpu().numpy())
predict_sequence = []
if og_pred != None:
for j in range(len(og_pred)):
if grd_batch_ind_list[j] == i:
predict_sequence.append(og_pred[j].item())
obj_ids = data_dict['obj_ids']
response_pred.append({
'predict_object_id' : [obj_ids[i][o].item() for o in predict_sequence],
'predict_object_id': [obj_ids[i][o].item() for o in predict_sequence],
'pred_plan_text': data_dict['output_txt'][i]
})
return result_id_list if predict_mode == False else response_pred
if __name__ == '__main__':
inference("scene0050_00", {
"task_description": "Find the chair and move it to the table.",
"action_steps": [
{
"target_id": "1",
"label": "chair",
"action": "Find the chair."
},
{
"target_id": "2",
"label": "table",
"action": "Move the chair to the table."
}
],
"scan_id": "scene0050_00"
}, predict_mode=True)
|