File size: 9,282 Bytes
9de012e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b360fe
9de012e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
782b3bb
 
9de012e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import json
import os
import torch
import numpy as np
from leo.model import SequentialGrounder
from leo.utils import LabelConverter, convert_pc_to_box, obj_processing_post, pad_sequence
from torch.utils.data import default_collate


ASSET_DIR = os.path.join(os.getcwd(), 'assets')
CKPT_DIR = os.path.join(os.getcwd(), 'checkpoint/leo')
int2cat = json.load(open(os.path.join(ASSET_DIR, "meta/scannetv2_raw_categories.json"), 'r', encoding="utf-8"))
cat2int = {w: i for i, w in enumerate(int2cat)}
label_converter = LabelConverter(os.path.join(ASSET_DIR, "meta/scannetv2-labels.combined.tsv"))


role_prompt = "You are an AI visual assistant situated in a 3D scene. "\
    "You can perceive (1) an ego-view image (accessible when necessary) and (2) the objects (including yourself) in the scene (always accessible). "\
        "You should properly respond to the USER's instruction according to the given visual information. "
#role_prompt = " " 
egoview_prompt = "Ego-view image:"
objects_prompt = "Objects (including you) in the scene:"
task_prompt = "USER: {instruction} ASSISTANT:"

def get_prompt(instruction):
    return {
            'prompt_before_obj': role_prompt,
            'prompt_middle_1': egoview_prompt,
            'prompt_middle_2': objects_prompt,
            'prompt_after_obj': task_prompt.format(instruction=instruction),
        }

def get_lang(task_item):
    task_description = task_item['task_description']
    sentence = task_description
    data_dict = get_prompt(task_description)

    # scan_id = task_item['scan_id']

    if 'action_steps' in task_item:
        action_steps = task_item['action_steps']
        # tgt_object_id = [int(action['target_id']) for action in action_steps]
        # tgt_object_name = [action['label'] for action in action_steps]
    
        for action in action_steps:
            sentence += ' ' + action['action']

        data_dict['output_gt'] = ' '.join([action['action'] + ' <s>' for action in action_steps])

    # return scan_id, tgt_object_id, tgt_object_name, sentence, data_dict
    return data_dict


def load_data(scan_id):
    one_scan = {}
    # load scan
    pcd_data = torch.load(os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_pcd.pth'))
    inst_to_label = torch.load(os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_inst.pth')) 
    points, colors, instance_labels  = pcd_data[0], pcd_data[1], pcd_data[-1]
    colors = colors / 127.5 - 1
    pcds = np.concatenate([points, colors], 1)
    one_scan['pcds'] = pcds
    one_scan['instance_labels'] = instance_labels
    one_scan['inst_to_label'] = inst_to_label
    # convert to gt object
    obj_pcds = []
    inst_ids = []
    inst_labels = []
    bg_indices = np.full((points.shape[0], ), 1, dtype=np.bool_)
    for inst_id in inst_to_label.keys():
        if inst_to_label[inst_id] in cat2int.keys():
            mask = instance_labels == inst_id
            if np.sum(mask) == 0:
                continue
            obj_pcds.append(pcds[mask])
            inst_ids.append(inst_id)
            inst_labels.append(cat2int[inst_to_label[inst_id]])
            if inst_to_label[inst_id] not in ['wall', 'floor', 'ceiling']:
                bg_indices[mask] = False
    one_scan['obj_pcds'] = obj_pcds
    one_scan['inst_labels'] = inst_labels
    one_scan['inst_ids'] = inst_ids
    one_scan['bg_pcds'] = pcds[bg_indices]
    # calculate box for matching
    obj_center = []
    obj_box_size = []
    for obj_pcd in obj_pcds:
        _c, _b = convert_pc_to_box(obj_pcd)
        obj_center.append(_c)
        obj_box_size.append(_b)
    one_scan['obj_loc'] = obj_center
    one_scan['obj_box'] = obj_box_size
    # load point feat
    feat_pth = os.path.join(ASSET_DIR, f'inputs/{scan_id}', 'obj_feats.pth')
    one_scan['obj_feats'] = torch.load(feat_pth, map_location='cpu')
    #  convert to pq3d input
    obj_labels = one_scan['inst_labels'] # N
    obj_pcds = one_scan['obj_pcds']
    obj_ids = one_scan['inst_ids'] 
    # object filter
    excluded_labels = ['wall', 'floor', 'ceiling']
    def keep_obj(i, obj_label):
        category = int2cat[obj_label]
        # filter out background
        if category in excluded_labels:
            return False
        # filter out objects not mentioned in the sentence
        return True
    selected_obj_idxs = [i for i, obj_label in enumerate(obj_labels) if keep_obj(i, obj_label)]
    # crop objects to max_obj_len and reorganize ids ? # TODO
    obj_labels = [obj_labels[i] for i in selected_obj_idxs]
    obj_pcds = [obj_pcds[i] for i in selected_obj_idxs]
    # subsample points
    obj_pcds = np.array([obj_pcd[np.random.choice(len(obj_pcd), size=1024,
                                        replace=len(obj_pcd) < 1024)] for obj_pcd in obj_pcds])
    obj_fts, obj_locs, obj_boxes, rot_matrix = obj_processing_post(obj_pcds, rot_aug=False)
    data_dict = {
        "scan_id": scan_id,
        "obj_fts": obj_fts.float(),
        "obj_locs": obj_locs.float(),
        "obj_labels": torch.LongTensor(obj_labels),
        "obj_boxes": obj_boxes, 
        "obj_pad_masks": torch.ones((len(obj_locs)), dtype=torch.bool), # used for padding in collate
        "obj_ids": torch.LongTensor([obj_ids[i] for i in selected_obj_idxs])
    }
    # convert point feature
    data_dict['obj_feats'] = one_scan['obj_feats'].squeeze(0)

    useful_keys = ['tgt_object_id', 'scan_id', 'obj_labels', 'data_idx', 
                   'obj_fts', 'obj_locs', 'obj_pad_masks', 'obj_ids', 
                   'source', 'prompt_before_obj', 'prompt_middle_1', 
                   'prompt_middle_2', 'prompt_after_obj', 'output_gt', 'obj_feats']
    for k in list(data_dict.keys()):
            if k not in useful_keys:
                del data_dict[k]
    # add new keys because of leo
    data_dict['img_fts'] = torch.zeros(3, 224, 224)
    data_dict['img_masks'] = torch.LongTensor([0]).bool()
    data_dict['anchor_locs'] = torch.zeros(3)
    data_dict['anchor_orientation'] = torch.zeros(4)
    data_dict['anchor_orientation'][-1] = 1   # xyzw
    # convert to leo format
    data_dict['obj_masks'] = data_dict['obj_pad_masks']
    del data_dict['obj_pad_masks']

    return data_dict

def form_batch(data_dict):
    batch = [data_dict]
    new_batch = {}

    # pad
    padding_keys = ['obj_fts', 'obj_locs', 'obj_masks', 'obj_labels', 'obj_ids']
    for k in padding_keys:
        tensors = [sample.pop(k) for sample in batch]
        padded_tensor = pad_sequence(tensors, pad=0)
        new_batch[k] = padded_tensor
    # # list
    # list_keys = ['tgt_object_id']
    # for k in list_keys:
    #     new_batch[k] = [sample.pop(k) for sample in batch]

    # default collate
    new_batch.update(default_collate(batch))
    return new_batch

    
def inference(scan_id, task, predict_mode=False):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # device = 'cpu' # ok for predict_mode=False, and both for Gradio demo local preview

    data_dict = load_data(scan_id)
    data_dict.update(get_lang(task))
    data_dict = form_batch(data_dict)

    for key, value in data_dict.items():
        if isinstance(value, torch.Tensor):
            data_dict[key] = value.to(device)
    
    model = SequentialGrounder(predict_mode)
    load_msg = model.load_state_dict(torch.load(os.path.join(CKPT_DIR, 'pytorch_model.bin'), map_location='cpu'), strict=False)
    model.to(device)

    data_dict = model(data_dict)

    if predict_mode == False:
        # calculate result id
        result_id_list = [data_dict['obj_ids'][0][torch.argmax(data_dict['ground_logits'][i]).item()] 
                        for i in range(len(data_dict['ground_logits']))]
    else:
        # calculate langauge
        # tgt_object_id = data_dict['tgt_object_id']
        if data_dict['ground_logits'] == None:
            og_pred = []
        else:
            og_pred = torch.argmax(data_dict['ground_logits'], dim=1)
        grd_batch_ind_list = data_dict['grd_batch_ind_list']
        
        response_pred = []
        for i in range(1): # len(tgt_object_id)
            # target_sequence = list(tgt_object_id[i].cpu().numpy())
            predict_sequence = []
            if og_pred != None:
                for j in range(len(og_pred)):
                    if grd_batch_ind_list[j] == i:
                        predict_sequence.append(og_pred[j].item())

            obj_ids = data_dict['obj_ids']
            response_pred.append({
                'predict_object_id' : [obj_ids[i][o].item() for o in predict_sequence],
                'predict_object_id': [obj_ids[i][o].item() for o in predict_sequence],
                'pred_plan_text': data_dict['output_txt'][i]
            })

    return result_id_list if predict_mode == False else response_pred

if __name__ == '__main__':
    inference("scene0050_00", {
        "task_description": "Find the chair and move it to the table.",
        "action_steps": [
            {
                "target_id": "1",
                "label": "chair",
                "action": "Find the chair."
            },
            {
                "target_id": "2",
                "label": "table",
                "action": "Move the chair to the table."
            }
        ],
        "scan_id": "scene0050_00"
    }, predict_mode=True)