|
|
|
|
|
class ShapeNetDataset(data.Dataset): |
|
def __init__(self, |
|
root, |
|
npoints=2500, |
|
classification=False, |
|
class_choice=None, |
|
split='train', |
|
data_augmentation=True): |
|
self.npoints = npoints |
|
self.root = root |
|
self.catfile = os.path.join(self.root, 'synsetoffset2category.txt') |
|
self.cat = {} |
|
self.data_augmentation = data_augmentation |
|
self.classification = classification |
|
self.seg_classes = {} |
|
|
|
with open(self.catfile, 'r') as f: |
|
for line in f: |
|
ls = line.strip().split() |
|
self.cat[ls[0]] = ls[1] |
|
|
|
if not class_choice is None: |
|
self.cat = {k: v for k, v in self.cat.items() if k in class_choice} |
|
|
|
self.id2cat = {v: k for k, v in self.cat.items()} |
|
|
|
self.meta = {} |
|
splitfile = os.path.join(self.root, 'train_test_split', 'shuffled_{}_file_list.json'.format(split)) |
|
|
|
filelist = json.load(open(splitfile, 'r')) |
|
for item in self.cat: |
|
self.meta[item] = [] |
|
|
|
for file in filelist: |
|
_, category, uuid = file.split('/') |
|
if category in self.cat.values(): |
|
self.meta[self.id2cat[category]].append((os.path.join(self.root, category, 'points', uuid+'.pts'), |
|
os.path.join(self.root, category, 'points_label', uuid+'.seg'))) |
|
|
|
self.datapath = [] |
|
for item in self.cat: |
|
for fn in self.meta[item]: |
|
self.datapath.append((item, fn[0], fn[1])) |
|
|
|
self.classes = dict(zip(sorted(self.cat), range(len(self.cat)))) |
|
print(self.classes) |
|
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/num_seg_classes.txt'), 'r') as f: |
|
for line in f: |
|
ls = line.strip().split() |
|
self.seg_classes[ls[0]] = int(ls[1]) |
|
self.num_seg_classes = self.seg_classes[list(self.cat.keys())[0]] |
|
print(self.seg_classes, self.num_seg_classes) |
|
|
|
def __getitem__(self, index): |
|
fn = self.datapath[index] |
|
cls = self.classes[self.datapath[index][0]] |
|
point_set = np.loadtxt(fn[1]).astype(np.float32) |
|
seg = np.loadtxt(fn[2]).astype(np.int64) |
|
|
|
|
|
choice = np.random.choice(len(seg), self.npoints, replace=True) |
|
|
|
point_set = point_set[choice, :] |
|
|
|
point_set = point_set - np.expand_dims(np.mean(point_set, axis = 0), 0) |
|
dist = np.max(np.sqrt(np.sum(point_set ** 2, axis = 1)),0) |
|
point_set = point_set / dist |
|
|
|
if self.data_augmentation: |
|
theta = np.random.uniform(0,np.pi*2) |
|
rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)],[np.sin(theta), np.cos(theta)]]) |
|
point_set[:,[0,2]] = point_set[:,[0,2]].dot(rotation_matrix) |
|
point_set += np.random.normal(0, 0.02, size=point_set.shape) |
|
|
|
seg = seg[choice] |
|
point_set = torch.from_numpy(point_set) |
|
seg = torch.from_numpy(seg) |
|
cls = torch.from_numpy(np.array([cls]).astype(np.int64)) |
|
|
|
if self.classification: |
|
return point_set, cls |
|
else: |
|
return point_set, seg |
|
|
|
def __len__(self): |
|
return len(self.datapath) |