DRAK / train_pointnet.py
Denys Rozumnyi
update
49de1ea
raw
history blame
4.99 kB
from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
from pointnet.dataset import ShapeNetDataset, ModelNetDataset
from pointnet import PointNetCls, feature_transform_regularizer
import torch.nn.functional as F
from tqdm import tqdm
parser = argparse.ArgumentParser()
parser.add_argument(
'--batchSize', type=int, default=32, help='input batch size')
parser.add_argument(
'--num_points', type=int, default=2500, help='input batch size')
parser.add_argument(
'--workers', type=int, help='number of data loading workers', default=4)
parser.add_argument(
'--nepoch', type=int, default=250, help='number of epochs to train for')
parser.add_argument('--outf', type=str, default='cls', help='output folder')
parser.add_argument('--model', type=str, default='', help='model path')
parser.add_argument('--dataset', type=str, required=True, help="dataset path")
parser.add_argument('--dataset_type', type=str, default='shapenet', help="dataset type shapenet|modelnet40")
parser.add_argument('--feature_transform', action='store_true', help="use feature transform")
opt = parser.parse_args()
print(opt)
blue = lambda x: '\033[94m' + x + '\033[0m'
opt.manualSeed = random.randint(1, 10000) # fix seed
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
if opt.dataset_type == 'shapenet':
dataset = ShapeNetDataset(
root=opt.dataset,
classification=True,
npoints=opt.num_points)
test_dataset = ShapeNetDataset(
root=opt.dataset,
classification=True,
split='test',
npoints=opt.num_points,
data_augmentation=False)
elif opt.dataset_type == 'modelnet40':
dataset = ModelNetDataset(
root=opt.dataset,
npoints=opt.num_points,
split='trainval')
test_dataset = ModelNetDataset(
root=opt.dataset,
split='test',
npoints=opt.num_points,
data_augmentation=False)
else:
exit('wrong dataset type')
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=opt.batchSize,
shuffle=True,
num_workers=int(opt.workers))
testdataloader = torch.utils.data.DataLoader(
test_dataset,
batch_size=opt.batchSize,
shuffle=True,
num_workers=int(opt.workers))
print(len(dataset), len(test_dataset))
num_classes = len(dataset.classes)
print('classes', num_classes)
try:
os.makedirs(opt.outf)
except OSError:
pass
classifier = PointNetCls(k=num_classes, feature_transform=opt.feature_transform)
if opt.model != '':
classifier.load_state_dict(torch.load(opt.model))
optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999))
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
classifier.cuda()
num_batch = len(dataset) / opt.batchSize
for epoch in range(opt.nepoch):
scheduler.step()
for i, data in enumerate(dataloader, 0):
points, target = data
target = target[:, 0]
points = points.transpose(2, 1)
points, target = points.cuda(), target.cuda()
optimizer.zero_grad()
classifier = classifier.train()
pred, trans, trans_feat = classifier(points)
loss = F.nll_loss(pred, target)
if opt.feature_transform:
loss += feature_transform_regularizer(trans_feat) * 0.001
loss.backward()
optimizer.step()
pred_choice = pred.data.max(1)[1]
correct = pred_choice.eq(target.data).cpu().sum()
print('[%d: %d/%d] train loss: %f accuracy: %f' % (epoch, i, num_batch, loss.item(), correct.item() / float(opt.batchSize)))
if i % 10 == 0:
j, data = next(enumerate(testdataloader, 0))
points, target = data
target = target[:, 0]
points = points.transpose(2, 1)
points, target = points.cuda(), target.cuda()
classifier = classifier.eval()
pred, _, _ = classifier(points)
loss = F.nll_loss(pred, target)
pred_choice = pred.data.max(1)[1]
correct = pred_choice.eq(target.data).cpu().sum()
print('[%d: %d/%d] %s loss: %f accuracy: %f' % (epoch, i, num_batch, blue('test'), loss.item(), correct.item()/float(opt.batchSize)))
torch.save(classifier.state_dict(), '%s/cls_model_%d.pth' % (opt.outf, epoch))
total_correct = 0
total_testset = 0
for i,data in tqdm(enumerate(testdataloader, 0)):
points, target = data
target = target[:, 0]
points = points.transpose(2, 1)
points, target = points.cuda(), target.cuda()
classifier = classifier.eval()
pred, _, _ = classifier(points)
pred_choice = pred.data.max(1)[1]
correct = pred_choice.eq(target.data).cpu().sum()
total_correct += correct.item()
total_testset += points.size()[0]
print("final accuracy {}".format(total_correct / float(total_testset)))