Denys Rozumnyi commited on
Commit
49de1ea
·
1 Parent(s): b02e5d5
Files changed (5) hide show
  1. dataset.py +88 -0
  2. geom_solver.py +5 -5
  3. pointnet.py +213 -0
  4. testing.ipynb +0 -0
  5. train_pointnet.py +148 -0
dataset.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ class ShapeNetDataset(data.Dataset):
4
+ def __init__(self,
5
+ root,
6
+ npoints=2500,
7
+ classification=False,
8
+ class_choice=None,
9
+ split='train',
10
+ data_augmentation=True):
11
+ self.npoints = npoints
12
+ self.root = root
13
+ self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')
14
+ self.cat = {}
15
+ self.data_augmentation = data_augmentation
16
+ self.classification = classification
17
+ self.seg_classes = {}
18
+
19
+ with open(self.catfile, 'r') as f:
20
+ for line in f:
21
+ ls = line.strip().split()
22
+ self.cat[ls[0]] = ls[1]
23
+ #print(self.cat)
24
+ if not class_choice is None:
25
+ self.cat = {k: v for k, v in self.cat.items() if k in class_choice}
26
+
27
+ self.id2cat = {v: k for k, v in self.cat.items()}
28
+
29
+ self.meta = {}
30
+ splitfile = os.path.join(self.root, 'train_test_split', 'shuffled_{}_file_list.json'.format(split))
31
+ #from IPython import embed; embed()
32
+ filelist = json.load(open(splitfile, 'r'))
33
+ for item in self.cat:
34
+ self.meta[item] = []
35
+
36
+ for file in filelist:
37
+ _, category, uuid = file.split('/')
38
+ if category in self.cat.values():
39
+ self.meta[self.id2cat[category]].append((os.path.join(self.root, category, 'points', uuid+'.pts'),
40
+ os.path.join(self.root, category, 'points_label', uuid+'.seg')))
41
+
42
+ self.datapath = []
43
+ for item in self.cat:
44
+ for fn in self.meta[item]:
45
+ self.datapath.append((item, fn[0], fn[1]))
46
+
47
+ self.classes = dict(zip(sorted(self.cat), range(len(self.cat))))
48
+ print(self.classes)
49
+ with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/num_seg_classes.txt'), 'r') as f:
50
+ for line in f:
51
+ ls = line.strip().split()
52
+ self.seg_classes[ls[0]] = int(ls[1])
53
+ self.num_seg_classes = self.seg_classes[list(self.cat.keys())[0]]
54
+ print(self.seg_classes, self.num_seg_classes)
55
+
56
+ def __getitem__(self, index):
57
+ fn = self.datapath[index]
58
+ cls = self.classes[self.datapath[index][0]]
59
+ point_set = np.loadtxt(fn[1]).astype(np.float32)
60
+ seg = np.loadtxt(fn[2]).astype(np.int64)
61
+ #print(point_set.shape, seg.shape)
62
+
63
+ choice = np.random.choice(len(seg), self.npoints, replace=True)
64
+ #resample
65
+ point_set = point_set[choice, :]
66
+
67
+ point_set = point_set - np.expand_dims(np.mean(point_set, axis = 0), 0) # center
68
+ dist = np.max(np.sqrt(np.sum(point_set ** 2, axis = 1)),0)
69
+ point_set = point_set / dist #scale
70
+
71
+ if self.data_augmentation:
72
+ theta = np.random.uniform(0,np.pi*2)
73
+ rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)],[np.sin(theta), np.cos(theta)]])
74
+ point_set[:,[0,2]] = point_set[:,[0,2]].dot(rotation_matrix) # random rotation
75
+ point_set += np.random.normal(0, 0.02, size=point_set.shape) # random jitter
76
+
77
+ seg = seg[choice]
78
+ point_set = torch.from_numpy(point_set)
79
+ seg = torch.from_numpy(seg)
80
+ cls = torch.from_numpy(np.array([cls]).astype(np.int64))
81
+
82
+ if self.classification:
83
+ return point_set, cls
84
+ else:
85
+ return point_set, seg
86
+
87
+ def __len__(self):
88
+ return len(self.datapath)
geom_solver.py CHANGED
@@ -17,7 +17,7 @@ class GeomSolver(object):
17
 
18
  def __init__(self):
19
  self.min_vertices = 18
20
- self.kmeans_th = 150
21
  self.clr_th = 2.5
22
  self.device = 'cuda:0'
23
 
@@ -44,9 +44,9 @@ class GeomSolver(object):
44
  vert_mask = (vert_mask > 0).astype(np.uint8)
45
 
46
  dist = cv2.distanceTransform(1-vert_mask, cv2.DIST_L2, 3)
47
- dist[dist > 100] = 100
48
- ndist = np.zeros_like(dist)
49
- ndist = cv2.normalize(dist, ndist, 0, 1.0, cv2.NORM_MINMAX)
50
 
51
  in_this_image = np.array([cki in p.image_ids for p in self.points3D.values()])
52
  uv = torch.round(self.pyt_cameras[ki].transform_points(self.verts)[:, :2]).cpu().numpy().astype(int)
@@ -57,7 +57,7 @@ class GeomSolver(object):
57
  dist_points[uv_inl] += dist[uv[:,1], uv[:,0]]
58
  visible_counts[uv_inl] += 1
59
 
60
- selected_points = (dist_points / (visible_counts + 1e-6)) <= 10
61
  selected_points[visible_counts < 1] = False
62
 
63
  pnts = torch.from_numpy(self.xyz[selected_points].astype(np.float32))[None]
 
17
 
18
  def __init__(self):
19
  self.min_vertices = 18
20
+ self.kmeans_th = 70
21
  self.clr_th = 2.5
22
  self.device = 'cuda:0'
23
 
 
44
  vert_mask = (vert_mask > 0).astype(np.uint8)
45
 
46
  dist = cv2.distanceTransform(1-vert_mask, cv2.DIST_L2, 3)
47
+ # dist[dist > 100] = 100
48
+ # ndist = np.zeros_like(dist)
49
+ # ndist = cv2.normalize(dist, ndist, 0, 1.0, cv2.NORM_MINMAX)
50
 
51
  in_this_image = np.array([cki in p.image_ids for p in self.points3D.values()])
52
  uv = torch.round(self.pyt_cameras[ki].transform_points(self.verts)[:, :2]).cpu().numpy().astype(int)
 
57
  dist_points[uv_inl] += dist[uv[:,1], uv[:,0]]
58
  visible_counts[uv_inl] += 1
59
 
60
+ selected_points = (dist_points / (visible_counts + 1e-6)) <= 15
61
  selected_points[visible_counts < 1] = False
62
 
63
  pnts = torch.from_numpy(self.xyz[selected_points].astype(np.float32))[None]
pointnet.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.parallel
5
+ import torch.utils.data
6
+ from torch.autograd import Variable
7
+ import numpy as np
8
+ import torch.nn.functional as F
9
+
10
+
11
+ class STN3d(nn.Module):
12
+ def __init__(self):
13
+ super(STN3d, self).__init__()
14
+ self.conv1 = torch.nn.Conv1d(3, 64, 1)
15
+ self.conv2 = torch.nn.Conv1d(64, 128, 1)
16
+ self.conv3 = torch.nn.Conv1d(128, 1024, 1)
17
+ self.fc1 = nn.Linear(1024, 512)
18
+ self.fc2 = nn.Linear(512, 256)
19
+ self.fc3 = nn.Linear(256, 9)
20
+ self.relu = nn.ReLU()
21
+
22
+ self.bn1 = nn.BatchNorm1d(64)
23
+ self.bn2 = nn.BatchNorm1d(128)
24
+ self.bn3 = nn.BatchNorm1d(1024)
25
+ self.bn4 = nn.BatchNorm1d(512)
26
+ self.bn5 = nn.BatchNorm1d(256)
27
+
28
+
29
+ def forward(self, x):
30
+ batchsize = x.size()[0]
31
+ x = F.relu(self.bn1(self.conv1(x)))
32
+ x = F.relu(self.bn2(self.conv2(x)))
33
+ x = F.relu(self.bn3(self.conv3(x)))
34
+ x = torch.max(x, 2, keepdim=True)[0]
35
+ x = x.view(-1, 1024)
36
+
37
+ x = F.relu(self.bn4(self.fc1(x)))
38
+ x = F.relu(self.bn5(self.fc2(x)))
39
+ x = self.fc3(x)
40
+
41
+ iden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32))).view(1,9).repeat(batchsize,1)
42
+ if x.is_cuda:
43
+ iden = iden.cuda()
44
+ x = x + iden
45
+ x = x.view(-1, 3, 3)
46
+ return x
47
+
48
+
49
+ class STNkd(nn.Module):
50
+ def __init__(self, k=64):
51
+ super(STNkd, self).__init__()
52
+ self.conv1 = torch.nn.Conv1d(k, 64, 1)
53
+ self.conv2 = torch.nn.Conv1d(64, 128, 1)
54
+ self.conv3 = torch.nn.Conv1d(128, 1024, 1)
55
+ self.fc1 = nn.Linear(1024, 512)
56
+ self.fc2 = nn.Linear(512, 256)
57
+ self.fc3 = nn.Linear(256, k*k)
58
+ self.relu = nn.ReLU()
59
+
60
+ self.bn1 = nn.BatchNorm1d(64)
61
+ self.bn2 = nn.BatchNorm1d(128)
62
+ self.bn3 = nn.BatchNorm1d(1024)
63
+ self.bn4 = nn.BatchNorm1d(512)
64
+ self.bn5 = nn.BatchNorm1d(256)
65
+
66
+ self.k = k
67
+
68
+ def forward(self, x):
69
+ batchsize = x.size()[0]
70
+ x = F.relu(self.bn1(self.conv1(x)))
71
+ x = F.relu(self.bn2(self.conv2(x)))
72
+ x = F.relu(self.bn3(self.conv3(x)))
73
+ x = torch.max(x, 2, keepdim=True)[0]
74
+ x = x.view(-1, 1024)
75
+
76
+ x = F.relu(self.bn4(self.fc1(x)))
77
+ x = F.relu(self.bn5(self.fc2(x)))
78
+ x = self.fc3(x)
79
+
80
+ iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1)
81
+ if x.is_cuda:
82
+ iden = iden.cuda()
83
+ x = x + iden
84
+ x = x.view(-1, self.k, self.k)
85
+ return x
86
+
87
+ class PointNetfeat(nn.Module):
88
+ def __init__(self, global_feat = True, feature_transform = False):
89
+ super(PointNetfeat, self).__init__()
90
+ self.stn = STN3d()
91
+ self.conv1 = torch.nn.Conv1d(3, 64, 1)
92
+ self.conv2 = torch.nn.Conv1d(64, 128, 1)
93
+ self.conv3 = torch.nn.Conv1d(128, 1024, 1)
94
+ self.bn1 = nn.BatchNorm1d(64)
95
+ self.bn2 = nn.BatchNorm1d(128)
96
+ self.bn3 = nn.BatchNorm1d(1024)
97
+ self.global_feat = global_feat
98
+ self.feature_transform = feature_transform
99
+ if self.feature_transform:
100
+ self.fstn = STNkd(k=64)
101
+
102
+ def forward(self, x):
103
+ n_pts = x.size()[2]
104
+ trans = self.stn(x)
105
+ x = x.transpose(2, 1)
106
+ x = torch.bmm(x, trans)
107
+ x = x.transpose(2, 1)
108
+ x = F.relu(self.bn1(self.conv1(x)))
109
+
110
+ if self.feature_transform:
111
+ trans_feat = self.fstn(x)
112
+ x = x.transpose(2,1)
113
+ x = torch.bmm(x, trans_feat)
114
+ x = x.transpose(2,1)
115
+ else:
116
+ trans_feat = None
117
+
118
+ pointfeat = x
119
+ x = F.relu(self.bn2(self.conv2(x)))
120
+ x = self.bn3(self.conv3(x))
121
+ x = torch.max(x, 2, keepdim=True)[0]
122
+ x = x.view(-1, 1024)
123
+ if self.global_feat:
124
+ return x, trans, trans_feat
125
+ else:
126
+ x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
127
+ return torch.cat([x, pointfeat], 1), trans, trans_feat
128
+
129
+ class PointNetCls(nn.Module):
130
+ def __init__(self, k=2, feature_transform=False):
131
+ super(PointNetCls, self).__init__()
132
+ self.feature_transform = feature_transform
133
+ self.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform)
134
+ self.fc1 = nn.Linear(1024, 512)
135
+ self.fc2 = nn.Linear(512, 256)
136
+ self.fc3 = nn.Linear(256, k)
137
+ self.dropout = nn.Dropout(p=0.3)
138
+ self.bn1 = nn.BatchNorm1d(512)
139
+ self.bn2 = nn.BatchNorm1d(256)
140
+ self.relu = nn.ReLU()
141
+
142
+ def forward(self, x):
143
+ x, trans, trans_feat = self.feat(x)
144
+ x = F.relu(self.bn1(self.fc1(x)))
145
+ x = F.relu(self.bn2(self.dropout(self.fc2(x))))
146
+ x = self.fc3(x)
147
+ return F.log_softmax(x, dim=1), trans, trans_feat
148
+
149
+
150
+ class PointNetDenseCls(nn.Module):
151
+ def __init__(self, k = 2, feature_transform=False):
152
+ super(PointNetDenseCls, self).__init__()
153
+ self.k = k
154
+ self.feature_transform=feature_transform
155
+ self.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform)
156
+ self.conv1 = torch.nn.Conv1d(1088, 512, 1)
157
+ self.conv2 = torch.nn.Conv1d(512, 256, 1)
158
+ self.conv3 = torch.nn.Conv1d(256, 128, 1)
159
+ self.conv4 = torch.nn.Conv1d(128, self.k, 1)
160
+ self.bn1 = nn.BatchNorm1d(512)
161
+ self.bn2 = nn.BatchNorm1d(256)
162
+ self.bn3 = nn.BatchNorm1d(128)
163
+
164
+ def forward(self, x):
165
+ batchsize = x.size()[0]
166
+ n_pts = x.size()[2]
167
+ x, trans, trans_feat = self.feat(x)
168
+ x = F.relu(self.bn1(self.conv1(x)))
169
+ x = F.relu(self.bn2(self.conv2(x)))
170
+ x = F.relu(self.bn3(self.conv3(x)))
171
+ x = self.conv4(x)
172
+ x = x.transpose(2,1).contiguous()
173
+ x = F.log_softmax(x.view(-1,self.k), dim=-1)
174
+ x = x.view(batchsize, n_pts, self.k)
175
+ return x, trans, trans_feat
176
+
177
+ def feature_transform_regularizer(trans):
178
+ d = trans.size()[1]
179
+ batchsize = trans.size()[0]
180
+ I = torch.eye(d)[None, :, :]
181
+ if trans.is_cuda:
182
+ I = I.cuda()
183
+ loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2,1)) - I, dim=(1,2)))
184
+ return loss
185
+
186
+ if __name__ == '__main__':
187
+ sim_data = Variable(torch.rand(32,3,2500))
188
+ trans = STN3d()
189
+ out = trans(sim_data)
190
+ print('stn', out.size())
191
+ print('loss', feature_transform_regularizer(out))
192
+
193
+ sim_data_64d = Variable(torch.rand(32, 64, 2500))
194
+ trans = STNkd(k=64)
195
+ out = trans(sim_data_64d)
196
+ print('stn64d', out.size())
197
+ print('loss', feature_transform_regularizer(out))
198
+
199
+ pointfeat = PointNetfeat(global_feat=True)
200
+ out, _, _ = pointfeat(sim_data)
201
+ print('global feat', out.size())
202
+
203
+ pointfeat = PointNetfeat(global_feat=False)
204
+ out, _, _ = pointfeat(sim_data)
205
+ print('point feat', out.size())
206
+
207
+ cls = PointNetCls(k = 5)
208
+ out, _, _ = cls(sim_data)
209
+ print('class', out.size())
210
+
211
+ seg = PointNetDenseCls(k = 3)
212
+ out, _, _ = seg(sim_data)
213
+ print('seg', out.size())
testing.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
train_pointnet.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import argparse
3
+ import os
4
+ import random
5
+ import torch
6
+ import torch.nn.parallel
7
+ import torch.optim as optim
8
+ import torch.utils.data
9
+ from pointnet.dataset import ShapeNetDataset, ModelNetDataset
10
+ from pointnet import PointNetCls, feature_transform_regularizer
11
+ import torch.nn.functional as F
12
+ from tqdm import tqdm
13
+
14
+
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument(
17
+ '--batchSize', type=int, default=32, help='input batch size')
18
+ parser.add_argument(
19
+ '--num_points', type=int, default=2500, help='input batch size')
20
+ parser.add_argument(
21
+ '--workers', type=int, help='number of data loading workers', default=4)
22
+ parser.add_argument(
23
+ '--nepoch', type=int, default=250, help='number of epochs to train for')
24
+ parser.add_argument('--outf', type=str, default='cls', help='output folder')
25
+ parser.add_argument('--model', type=str, default='', help='model path')
26
+ parser.add_argument('--dataset', type=str, required=True, help="dataset path")
27
+ parser.add_argument('--dataset_type', type=str, default='shapenet', help="dataset type shapenet|modelnet40")
28
+ parser.add_argument('--feature_transform', action='store_true', help="use feature transform")
29
+
30
+ opt = parser.parse_args()
31
+ print(opt)
32
+
33
+ blue = lambda x: '\033[94m' + x + '\033[0m'
34
+
35
+ opt.manualSeed = random.randint(1, 10000) # fix seed
36
+ print("Random Seed: ", opt.manualSeed)
37
+ random.seed(opt.manualSeed)
38
+ torch.manual_seed(opt.manualSeed)
39
+
40
+ if opt.dataset_type == 'shapenet':
41
+ dataset = ShapeNetDataset(
42
+ root=opt.dataset,
43
+ classification=True,
44
+ npoints=opt.num_points)
45
+
46
+ test_dataset = ShapeNetDataset(
47
+ root=opt.dataset,
48
+ classification=True,
49
+ split='test',
50
+ npoints=opt.num_points,
51
+ data_augmentation=False)
52
+ elif opt.dataset_type == 'modelnet40':
53
+ dataset = ModelNetDataset(
54
+ root=opt.dataset,
55
+ npoints=opt.num_points,
56
+ split='trainval')
57
+
58
+ test_dataset = ModelNetDataset(
59
+ root=opt.dataset,
60
+ split='test',
61
+ npoints=opt.num_points,
62
+ data_augmentation=False)
63
+ else:
64
+ exit('wrong dataset type')
65
+
66
+
67
+ dataloader = torch.utils.data.DataLoader(
68
+ dataset,
69
+ batch_size=opt.batchSize,
70
+ shuffle=True,
71
+ num_workers=int(opt.workers))
72
+
73
+ testdataloader = torch.utils.data.DataLoader(
74
+ test_dataset,
75
+ batch_size=opt.batchSize,
76
+ shuffle=True,
77
+ num_workers=int(opt.workers))
78
+
79
+ print(len(dataset), len(test_dataset))
80
+ num_classes = len(dataset.classes)
81
+ print('classes', num_classes)
82
+
83
+ try:
84
+ os.makedirs(opt.outf)
85
+ except OSError:
86
+ pass
87
+
88
+ classifier = PointNetCls(k=num_classes, feature_transform=opt.feature_transform)
89
+
90
+ if opt.model != '':
91
+ classifier.load_state_dict(torch.load(opt.model))
92
+
93
+
94
+ optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999))
95
+ scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
96
+ classifier.cuda()
97
+
98
+ num_batch = len(dataset) / opt.batchSize
99
+
100
+ for epoch in range(opt.nepoch):
101
+ scheduler.step()
102
+ for i, data in enumerate(dataloader, 0):
103
+ points, target = data
104
+ target = target[:, 0]
105
+ points = points.transpose(2, 1)
106
+ points, target = points.cuda(), target.cuda()
107
+ optimizer.zero_grad()
108
+ classifier = classifier.train()
109
+ pred, trans, trans_feat = classifier(points)
110
+ loss = F.nll_loss(pred, target)
111
+ if opt.feature_transform:
112
+ loss += feature_transform_regularizer(trans_feat) * 0.001
113
+ loss.backward()
114
+ optimizer.step()
115
+ pred_choice = pred.data.max(1)[1]
116
+ correct = pred_choice.eq(target.data).cpu().sum()
117
+ print('[%d: %d/%d] train loss: %f accuracy: %f' % (epoch, i, num_batch, loss.item(), correct.item() / float(opt.batchSize)))
118
+
119
+ if i % 10 == 0:
120
+ j, data = next(enumerate(testdataloader, 0))
121
+ points, target = data
122
+ target = target[:, 0]
123
+ points = points.transpose(2, 1)
124
+ points, target = points.cuda(), target.cuda()
125
+ classifier = classifier.eval()
126
+ pred, _, _ = classifier(points)
127
+ loss = F.nll_loss(pred, target)
128
+ pred_choice = pred.data.max(1)[1]
129
+ correct = pred_choice.eq(target.data).cpu().sum()
130
+ print('[%d: %d/%d] %s loss: %f accuracy: %f' % (epoch, i, num_batch, blue('test'), loss.item(), correct.item()/float(opt.batchSize)))
131
+
132
+ torch.save(classifier.state_dict(), '%s/cls_model_%d.pth' % (opt.outf, epoch))
133
+
134
+ total_correct = 0
135
+ total_testset = 0
136
+ for i,data in tqdm(enumerate(testdataloader, 0)):
137
+ points, target = data
138
+ target = target[:, 0]
139
+ points = points.transpose(2, 1)
140
+ points, target = points.cuda(), target.cuda()
141
+ classifier = classifier.eval()
142
+ pred, _, _ = classifier(points)
143
+ pred_choice = pred.data.max(1)[1]
144
+ correct = pred_choice.eq(target.data).cpu().sum()
145
+ total_correct += correct.item()
146
+ total_testset += points.size()[0]
147
+
148
+ print("final accuracy {}".format(total_correct / float(total_testset)))