Denys Rozumnyi
commited on
Commit
·
49de1ea
1
Parent(s):
b02e5d5
update
Browse files- dataset.py +88 -0
- geom_solver.py +5 -5
- pointnet.py +213 -0
- testing.ipynb +0 -0
- train_pointnet.py +148 -0
dataset.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
class ShapeNetDataset(data.Dataset):
|
4 |
+
def __init__(self,
|
5 |
+
root,
|
6 |
+
npoints=2500,
|
7 |
+
classification=False,
|
8 |
+
class_choice=None,
|
9 |
+
split='train',
|
10 |
+
data_augmentation=True):
|
11 |
+
self.npoints = npoints
|
12 |
+
self.root = root
|
13 |
+
self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')
|
14 |
+
self.cat = {}
|
15 |
+
self.data_augmentation = data_augmentation
|
16 |
+
self.classification = classification
|
17 |
+
self.seg_classes = {}
|
18 |
+
|
19 |
+
with open(self.catfile, 'r') as f:
|
20 |
+
for line in f:
|
21 |
+
ls = line.strip().split()
|
22 |
+
self.cat[ls[0]] = ls[1]
|
23 |
+
#print(self.cat)
|
24 |
+
if not class_choice is None:
|
25 |
+
self.cat = {k: v for k, v in self.cat.items() if k in class_choice}
|
26 |
+
|
27 |
+
self.id2cat = {v: k for k, v in self.cat.items()}
|
28 |
+
|
29 |
+
self.meta = {}
|
30 |
+
splitfile = os.path.join(self.root, 'train_test_split', 'shuffled_{}_file_list.json'.format(split))
|
31 |
+
#from IPython import embed; embed()
|
32 |
+
filelist = json.load(open(splitfile, 'r'))
|
33 |
+
for item in self.cat:
|
34 |
+
self.meta[item] = []
|
35 |
+
|
36 |
+
for file in filelist:
|
37 |
+
_, category, uuid = file.split('/')
|
38 |
+
if category in self.cat.values():
|
39 |
+
self.meta[self.id2cat[category]].append((os.path.join(self.root, category, 'points', uuid+'.pts'),
|
40 |
+
os.path.join(self.root, category, 'points_label', uuid+'.seg')))
|
41 |
+
|
42 |
+
self.datapath = []
|
43 |
+
for item in self.cat:
|
44 |
+
for fn in self.meta[item]:
|
45 |
+
self.datapath.append((item, fn[0], fn[1]))
|
46 |
+
|
47 |
+
self.classes = dict(zip(sorted(self.cat), range(len(self.cat))))
|
48 |
+
print(self.classes)
|
49 |
+
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/num_seg_classes.txt'), 'r') as f:
|
50 |
+
for line in f:
|
51 |
+
ls = line.strip().split()
|
52 |
+
self.seg_classes[ls[0]] = int(ls[1])
|
53 |
+
self.num_seg_classes = self.seg_classes[list(self.cat.keys())[0]]
|
54 |
+
print(self.seg_classes, self.num_seg_classes)
|
55 |
+
|
56 |
+
def __getitem__(self, index):
|
57 |
+
fn = self.datapath[index]
|
58 |
+
cls = self.classes[self.datapath[index][0]]
|
59 |
+
point_set = np.loadtxt(fn[1]).astype(np.float32)
|
60 |
+
seg = np.loadtxt(fn[2]).astype(np.int64)
|
61 |
+
#print(point_set.shape, seg.shape)
|
62 |
+
|
63 |
+
choice = np.random.choice(len(seg), self.npoints, replace=True)
|
64 |
+
#resample
|
65 |
+
point_set = point_set[choice, :]
|
66 |
+
|
67 |
+
point_set = point_set - np.expand_dims(np.mean(point_set, axis = 0), 0) # center
|
68 |
+
dist = np.max(np.sqrt(np.sum(point_set ** 2, axis = 1)),0)
|
69 |
+
point_set = point_set / dist #scale
|
70 |
+
|
71 |
+
if self.data_augmentation:
|
72 |
+
theta = np.random.uniform(0,np.pi*2)
|
73 |
+
rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)],[np.sin(theta), np.cos(theta)]])
|
74 |
+
point_set[:,[0,2]] = point_set[:,[0,2]].dot(rotation_matrix) # random rotation
|
75 |
+
point_set += np.random.normal(0, 0.02, size=point_set.shape) # random jitter
|
76 |
+
|
77 |
+
seg = seg[choice]
|
78 |
+
point_set = torch.from_numpy(point_set)
|
79 |
+
seg = torch.from_numpy(seg)
|
80 |
+
cls = torch.from_numpy(np.array([cls]).astype(np.int64))
|
81 |
+
|
82 |
+
if self.classification:
|
83 |
+
return point_set, cls
|
84 |
+
else:
|
85 |
+
return point_set, seg
|
86 |
+
|
87 |
+
def __len__(self):
|
88 |
+
return len(self.datapath)
|
geom_solver.py
CHANGED
@@ -17,7 +17,7 @@ class GeomSolver(object):
|
|
17 |
|
18 |
def __init__(self):
|
19 |
self.min_vertices = 18
|
20 |
-
self.kmeans_th =
|
21 |
self.clr_th = 2.5
|
22 |
self.device = 'cuda:0'
|
23 |
|
@@ -44,9 +44,9 @@ class GeomSolver(object):
|
|
44 |
vert_mask = (vert_mask > 0).astype(np.uint8)
|
45 |
|
46 |
dist = cv2.distanceTransform(1-vert_mask, cv2.DIST_L2, 3)
|
47 |
-
dist[dist > 100] = 100
|
48 |
-
ndist = np.zeros_like(dist)
|
49 |
-
ndist = cv2.normalize(dist, ndist, 0, 1.0, cv2.NORM_MINMAX)
|
50 |
|
51 |
in_this_image = np.array([cki in p.image_ids for p in self.points3D.values()])
|
52 |
uv = torch.round(self.pyt_cameras[ki].transform_points(self.verts)[:, :2]).cpu().numpy().astype(int)
|
@@ -57,7 +57,7 @@ class GeomSolver(object):
|
|
57 |
dist_points[uv_inl] += dist[uv[:,1], uv[:,0]]
|
58 |
visible_counts[uv_inl] += 1
|
59 |
|
60 |
-
selected_points = (dist_points / (visible_counts + 1e-6)) <=
|
61 |
selected_points[visible_counts < 1] = False
|
62 |
|
63 |
pnts = torch.from_numpy(self.xyz[selected_points].astype(np.float32))[None]
|
|
|
17 |
|
18 |
def __init__(self):
|
19 |
self.min_vertices = 18
|
20 |
+
self.kmeans_th = 70
|
21 |
self.clr_th = 2.5
|
22 |
self.device = 'cuda:0'
|
23 |
|
|
|
44 |
vert_mask = (vert_mask > 0).astype(np.uint8)
|
45 |
|
46 |
dist = cv2.distanceTransform(1-vert_mask, cv2.DIST_L2, 3)
|
47 |
+
# dist[dist > 100] = 100
|
48 |
+
# ndist = np.zeros_like(dist)
|
49 |
+
# ndist = cv2.normalize(dist, ndist, 0, 1.0, cv2.NORM_MINMAX)
|
50 |
|
51 |
in_this_image = np.array([cki in p.image_ids for p in self.points3D.values()])
|
52 |
uv = torch.round(self.pyt_cameras[ki].transform_points(self.verts)[:, :2]).cpu().numpy().astype(int)
|
|
|
57 |
dist_points[uv_inl] += dist[uv[:,1], uv[:,0]]
|
58 |
visible_counts[uv_inl] += 1
|
59 |
|
60 |
+
selected_points = (dist_points / (visible_counts + 1e-6)) <= 15
|
61 |
selected_points[visible_counts < 1] = False
|
62 |
|
63 |
pnts = torch.from_numpy(self.xyz[selected_points].astype(np.float32))[None]
|
pointnet.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.parallel
|
5 |
+
import torch.utils.data
|
6 |
+
from torch.autograd import Variable
|
7 |
+
import numpy as np
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
|
11 |
+
class STN3d(nn.Module):
|
12 |
+
def __init__(self):
|
13 |
+
super(STN3d, self).__init__()
|
14 |
+
self.conv1 = torch.nn.Conv1d(3, 64, 1)
|
15 |
+
self.conv2 = torch.nn.Conv1d(64, 128, 1)
|
16 |
+
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
|
17 |
+
self.fc1 = nn.Linear(1024, 512)
|
18 |
+
self.fc2 = nn.Linear(512, 256)
|
19 |
+
self.fc3 = nn.Linear(256, 9)
|
20 |
+
self.relu = nn.ReLU()
|
21 |
+
|
22 |
+
self.bn1 = nn.BatchNorm1d(64)
|
23 |
+
self.bn2 = nn.BatchNorm1d(128)
|
24 |
+
self.bn3 = nn.BatchNorm1d(1024)
|
25 |
+
self.bn4 = nn.BatchNorm1d(512)
|
26 |
+
self.bn5 = nn.BatchNorm1d(256)
|
27 |
+
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
batchsize = x.size()[0]
|
31 |
+
x = F.relu(self.bn1(self.conv1(x)))
|
32 |
+
x = F.relu(self.bn2(self.conv2(x)))
|
33 |
+
x = F.relu(self.bn3(self.conv3(x)))
|
34 |
+
x = torch.max(x, 2, keepdim=True)[0]
|
35 |
+
x = x.view(-1, 1024)
|
36 |
+
|
37 |
+
x = F.relu(self.bn4(self.fc1(x)))
|
38 |
+
x = F.relu(self.bn5(self.fc2(x)))
|
39 |
+
x = self.fc3(x)
|
40 |
+
|
41 |
+
iden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32))).view(1,9).repeat(batchsize,1)
|
42 |
+
if x.is_cuda:
|
43 |
+
iden = iden.cuda()
|
44 |
+
x = x + iden
|
45 |
+
x = x.view(-1, 3, 3)
|
46 |
+
return x
|
47 |
+
|
48 |
+
|
49 |
+
class STNkd(nn.Module):
|
50 |
+
def __init__(self, k=64):
|
51 |
+
super(STNkd, self).__init__()
|
52 |
+
self.conv1 = torch.nn.Conv1d(k, 64, 1)
|
53 |
+
self.conv2 = torch.nn.Conv1d(64, 128, 1)
|
54 |
+
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
|
55 |
+
self.fc1 = nn.Linear(1024, 512)
|
56 |
+
self.fc2 = nn.Linear(512, 256)
|
57 |
+
self.fc3 = nn.Linear(256, k*k)
|
58 |
+
self.relu = nn.ReLU()
|
59 |
+
|
60 |
+
self.bn1 = nn.BatchNorm1d(64)
|
61 |
+
self.bn2 = nn.BatchNorm1d(128)
|
62 |
+
self.bn3 = nn.BatchNorm1d(1024)
|
63 |
+
self.bn4 = nn.BatchNorm1d(512)
|
64 |
+
self.bn5 = nn.BatchNorm1d(256)
|
65 |
+
|
66 |
+
self.k = k
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
batchsize = x.size()[0]
|
70 |
+
x = F.relu(self.bn1(self.conv1(x)))
|
71 |
+
x = F.relu(self.bn2(self.conv2(x)))
|
72 |
+
x = F.relu(self.bn3(self.conv3(x)))
|
73 |
+
x = torch.max(x, 2, keepdim=True)[0]
|
74 |
+
x = x.view(-1, 1024)
|
75 |
+
|
76 |
+
x = F.relu(self.bn4(self.fc1(x)))
|
77 |
+
x = F.relu(self.bn5(self.fc2(x)))
|
78 |
+
x = self.fc3(x)
|
79 |
+
|
80 |
+
iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1)
|
81 |
+
if x.is_cuda:
|
82 |
+
iden = iden.cuda()
|
83 |
+
x = x + iden
|
84 |
+
x = x.view(-1, self.k, self.k)
|
85 |
+
return x
|
86 |
+
|
87 |
+
class PointNetfeat(nn.Module):
|
88 |
+
def __init__(self, global_feat = True, feature_transform = False):
|
89 |
+
super(PointNetfeat, self).__init__()
|
90 |
+
self.stn = STN3d()
|
91 |
+
self.conv1 = torch.nn.Conv1d(3, 64, 1)
|
92 |
+
self.conv2 = torch.nn.Conv1d(64, 128, 1)
|
93 |
+
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
|
94 |
+
self.bn1 = nn.BatchNorm1d(64)
|
95 |
+
self.bn2 = nn.BatchNorm1d(128)
|
96 |
+
self.bn3 = nn.BatchNorm1d(1024)
|
97 |
+
self.global_feat = global_feat
|
98 |
+
self.feature_transform = feature_transform
|
99 |
+
if self.feature_transform:
|
100 |
+
self.fstn = STNkd(k=64)
|
101 |
+
|
102 |
+
def forward(self, x):
|
103 |
+
n_pts = x.size()[2]
|
104 |
+
trans = self.stn(x)
|
105 |
+
x = x.transpose(2, 1)
|
106 |
+
x = torch.bmm(x, trans)
|
107 |
+
x = x.transpose(2, 1)
|
108 |
+
x = F.relu(self.bn1(self.conv1(x)))
|
109 |
+
|
110 |
+
if self.feature_transform:
|
111 |
+
trans_feat = self.fstn(x)
|
112 |
+
x = x.transpose(2,1)
|
113 |
+
x = torch.bmm(x, trans_feat)
|
114 |
+
x = x.transpose(2,1)
|
115 |
+
else:
|
116 |
+
trans_feat = None
|
117 |
+
|
118 |
+
pointfeat = x
|
119 |
+
x = F.relu(self.bn2(self.conv2(x)))
|
120 |
+
x = self.bn3(self.conv3(x))
|
121 |
+
x = torch.max(x, 2, keepdim=True)[0]
|
122 |
+
x = x.view(-1, 1024)
|
123 |
+
if self.global_feat:
|
124 |
+
return x, trans, trans_feat
|
125 |
+
else:
|
126 |
+
x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
|
127 |
+
return torch.cat([x, pointfeat], 1), trans, trans_feat
|
128 |
+
|
129 |
+
class PointNetCls(nn.Module):
|
130 |
+
def __init__(self, k=2, feature_transform=False):
|
131 |
+
super(PointNetCls, self).__init__()
|
132 |
+
self.feature_transform = feature_transform
|
133 |
+
self.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform)
|
134 |
+
self.fc1 = nn.Linear(1024, 512)
|
135 |
+
self.fc2 = nn.Linear(512, 256)
|
136 |
+
self.fc3 = nn.Linear(256, k)
|
137 |
+
self.dropout = nn.Dropout(p=0.3)
|
138 |
+
self.bn1 = nn.BatchNorm1d(512)
|
139 |
+
self.bn2 = nn.BatchNorm1d(256)
|
140 |
+
self.relu = nn.ReLU()
|
141 |
+
|
142 |
+
def forward(self, x):
|
143 |
+
x, trans, trans_feat = self.feat(x)
|
144 |
+
x = F.relu(self.bn1(self.fc1(x)))
|
145 |
+
x = F.relu(self.bn2(self.dropout(self.fc2(x))))
|
146 |
+
x = self.fc3(x)
|
147 |
+
return F.log_softmax(x, dim=1), trans, trans_feat
|
148 |
+
|
149 |
+
|
150 |
+
class PointNetDenseCls(nn.Module):
|
151 |
+
def __init__(self, k = 2, feature_transform=False):
|
152 |
+
super(PointNetDenseCls, self).__init__()
|
153 |
+
self.k = k
|
154 |
+
self.feature_transform=feature_transform
|
155 |
+
self.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform)
|
156 |
+
self.conv1 = torch.nn.Conv1d(1088, 512, 1)
|
157 |
+
self.conv2 = torch.nn.Conv1d(512, 256, 1)
|
158 |
+
self.conv3 = torch.nn.Conv1d(256, 128, 1)
|
159 |
+
self.conv4 = torch.nn.Conv1d(128, self.k, 1)
|
160 |
+
self.bn1 = nn.BatchNorm1d(512)
|
161 |
+
self.bn2 = nn.BatchNorm1d(256)
|
162 |
+
self.bn3 = nn.BatchNorm1d(128)
|
163 |
+
|
164 |
+
def forward(self, x):
|
165 |
+
batchsize = x.size()[0]
|
166 |
+
n_pts = x.size()[2]
|
167 |
+
x, trans, trans_feat = self.feat(x)
|
168 |
+
x = F.relu(self.bn1(self.conv1(x)))
|
169 |
+
x = F.relu(self.bn2(self.conv2(x)))
|
170 |
+
x = F.relu(self.bn3(self.conv3(x)))
|
171 |
+
x = self.conv4(x)
|
172 |
+
x = x.transpose(2,1).contiguous()
|
173 |
+
x = F.log_softmax(x.view(-1,self.k), dim=-1)
|
174 |
+
x = x.view(batchsize, n_pts, self.k)
|
175 |
+
return x, trans, trans_feat
|
176 |
+
|
177 |
+
def feature_transform_regularizer(trans):
|
178 |
+
d = trans.size()[1]
|
179 |
+
batchsize = trans.size()[0]
|
180 |
+
I = torch.eye(d)[None, :, :]
|
181 |
+
if trans.is_cuda:
|
182 |
+
I = I.cuda()
|
183 |
+
loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2,1)) - I, dim=(1,2)))
|
184 |
+
return loss
|
185 |
+
|
186 |
+
if __name__ == '__main__':
|
187 |
+
sim_data = Variable(torch.rand(32,3,2500))
|
188 |
+
trans = STN3d()
|
189 |
+
out = trans(sim_data)
|
190 |
+
print('stn', out.size())
|
191 |
+
print('loss', feature_transform_regularizer(out))
|
192 |
+
|
193 |
+
sim_data_64d = Variable(torch.rand(32, 64, 2500))
|
194 |
+
trans = STNkd(k=64)
|
195 |
+
out = trans(sim_data_64d)
|
196 |
+
print('stn64d', out.size())
|
197 |
+
print('loss', feature_transform_regularizer(out))
|
198 |
+
|
199 |
+
pointfeat = PointNetfeat(global_feat=True)
|
200 |
+
out, _, _ = pointfeat(sim_data)
|
201 |
+
print('global feat', out.size())
|
202 |
+
|
203 |
+
pointfeat = PointNetfeat(global_feat=False)
|
204 |
+
out, _, _ = pointfeat(sim_data)
|
205 |
+
print('point feat', out.size())
|
206 |
+
|
207 |
+
cls = PointNetCls(k = 5)
|
208 |
+
out, _, _ = cls(sim_data)
|
209 |
+
print('class', out.size())
|
210 |
+
|
211 |
+
seg = PointNetDenseCls(k = 3)
|
212 |
+
out, _, _ = seg(sim_data)
|
213 |
+
print('seg', out.size())
|
testing.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
train_pointnet.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
import torch
|
6 |
+
import torch.nn.parallel
|
7 |
+
import torch.optim as optim
|
8 |
+
import torch.utils.data
|
9 |
+
from pointnet.dataset import ShapeNetDataset, ModelNetDataset
|
10 |
+
from pointnet import PointNetCls, feature_transform_regularizer
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument(
|
17 |
+
'--batchSize', type=int, default=32, help='input batch size')
|
18 |
+
parser.add_argument(
|
19 |
+
'--num_points', type=int, default=2500, help='input batch size')
|
20 |
+
parser.add_argument(
|
21 |
+
'--workers', type=int, help='number of data loading workers', default=4)
|
22 |
+
parser.add_argument(
|
23 |
+
'--nepoch', type=int, default=250, help='number of epochs to train for')
|
24 |
+
parser.add_argument('--outf', type=str, default='cls', help='output folder')
|
25 |
+
parser.add_argument('--model', type=str, default='', help='model path')
|
26 |
+
parser.add_argument('--dataset', type=str, required=True, help="dataset path")
|
27 |
+
parser.add_argument('--dataset_type', type=str, default='shapenet', help="dataset type shapenet|modelnet40")
|
28 |
+
parser.add_argument('--feature_transform', action='store_true', help="use feature transform")
|
29 |
+
|
30 |
+
opt = parser.parse_args()
|
31 |
+
print(opt)
|
32 |
+
|
33 |
+
blue = lambda x: '\033[94m' + x + '\033[0m'
|
34 |
+
|
35 |
+
opt.manualSeed = random.randint(1, 10000) # fix seed
|
36 |
+
print("Random Seed: ", opt.manualSeed)
|
37 |
+
random.seed(opt.manualSeed)
|
38 |
+
torch.manual_seed(opt.manualSeed)
|
39 |
+
|
40 |
+
if opt.dataset_type == 'shapenet':
|
41 |
+
dataset = ShapeNetDataset(
|
42 |
+
root=opt.dataset,
|
43 |
+
classification=True,
|
44 |
+
npoints=opt.num_points)
|
45 |
+
|
46 |
+
test_dataset = ShapeNetDataset(
|
47 |
+
root=opt.dataset,
|
48 |
+
classification=True,
|
49 |
+
split='test',
|
50 |
+
npoints=opt.num_points,
|
51 |
+
data_augmentation=False)
|
52 |
+
elif opt.dataset_type == 'modelnet40':
|
53 |
+
dataset = ModelNetDataset(
|
54 |
+
root=opt.dataset,
|
55 |
+
npoints=opt.num_points,
|
56 |
+
split='trainval')
|
57 |
+
|
58 |
+
test_dataset = ModelNetDataset(
|
59 |
+
root=opt.dataset,
|
60 |
+
split='test',
|
61 |
+
npoints=opt.num_points,
|
62 |
+
data_augmentation=False)
|
63 |
+
else:
|
64 |
+
exit('wrong dataset type')
|
65 |
+
|
66 |
+
|
67 |
+
dataloader = torch.utils.data.DataLoader(
|
68 |
+
dataset,
|
69 |
+
batch_size=opt.batchSize,
|
70 |
+
shuffle=True,
|
71 |
+
num_workers=int(opt.workers))
|
72 |
+
|
73 |
+
testdataloader = torch.utils.data.DataLoader(
|
74 |
+
test_dataset,
|
75 |
+
batch_size=opt.batchSize,
|
76 |
+
shuffle=True,
|
77 |
+
num_workers=int(opt.workers))
|
78 |
+
|
79 |
+
print(len(dataset), len(test_dataset))
|
80 |
+
num_classes = len(dataset.classes)
|
81 |
+
print('classes', num_classes)
|
82 |
+
|
83 |
+
try:
|
84 |
+
os.makedirs(opt.outf)
|
85 |
+
except OSError:
|
86 |
+
pass
|
87 |
+
|
88 |
+
classifier = PointNetCls(k=num_classes, feature_transform=opt.feature_transform)
|
89 |
+
|
90 |
+
if opt.model != '':
|
91 |
+
classifier.load_state_dict(torch.load(opt.model))
|
92 |
+
|
93 |
+
|
94 |
+
optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999))
|
95 |
+
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
|
96 |
+
classifier.cuda()
|
97 |
+
|
98 |
+
num_batch = len(dataset) / opt.batchSize
|
99 |
+
|
100 |
+
for epoch in range(opt.nepoch):
|
101 |
+
scheduler.step()
|
102 |
+
for i, data in enumerate(dataloader, 0):
|
103 |
+
points, target = data
|
104 |
+
target = target[:, 0]
|
105 |
+
points = points.transpose(2, 1)
|
106 |
+
points, target = points.cuda(), target.cuda()
|
107 |
+
optimizer.zero_grad()
|
108 |
+
classifier = classifier.train()
|
109 |
+
pred, trans, trans_feat = classifier(points)
|
110 |
+
loss = F.nll_loss(pred, target)
|
111 |
+
if opt.feature_transform:
|
112 |
+
loss += feature_transform_regularizer(trans_feat) * 0.001
|
113 |
+
loss.backward()
|
114 |
+
optimizer.step()
|
115 |
+
pred_choice = pred.data.max(1)[1]
|
116 |
+
correct = pred_choice.eq(target.data).cpu().sum()
|
117 |
+
print('[%d: %d/%d] train loss: %f accuracy: %f' % (epoch, i, num_batch, loss.item(), correct.item() / float(opt.batchSize)))
|
118 |
+
|
119 |
+
if i % 10 == 0:
|
120 |
+
j, data = next(enumerate(testdataloader, 0))
|
121 |
+
points, target = data
|
122 |
+
target = target[:, 0]
|
123 |
+
points = points.transpose(2, 1)
|
124 |
+
points, target = points.cuda(), target.cuda()
|
125 |
+
classifier = classifier.eval()
|
126 |
+
pred, _, _ = classifier(points)
|
127 |
+
loss = F.nll_loss(pred, target)
|
128 |
+
pred_choice = pred.data.max(1)[1]
|
129 |
+
correct = pred_choice.eq(target.data).cpu().sum()
|
130 |
+
print('[%d: %d/%d] %s loss: %f accuracy: %f' % (epoch, i, num_batch, blue('test'), loss.item(), correct.item()/float(opt.batchSize)))
|
131 |
+
|
132 |
+
torch.save(classifier.state_dict(), '%s/cls_model_%d.pth' % (opt.outf, epoch))
|
133 |
+
|
134 |
+
total_correct = 0
|
135 |
+
total_testset = 0
|
136 |
+
for i,data in tqdm(enumerate(testdataloader, 0)):
|
137 |
+
points, target = data
|
138 |
+
target = target[:, 0]
|
139 |
+
points = points.transpose(2, 1)
|
140 |
+
points, target = points.cuda(), target.cuda()
|
141 |
+
classifier = classifier.eval()
|
142 |
+
pred, _, _ = classifier(points)
|
143 |
+
pred_choice = pred.data.max(1)[1]
|
144 |
+
correct = pred_choice.eq(target.data).cpu().sum()
|
145 |
+
total_correct += correct.item()
|
146 |
+
total_testset += points.size()[0]
|
147 |
+
|
148 |
+
print("final accuracy {}".format(total_correct / float(total_testset)))
|