Ren Jiawei commited on
Commit
ac0541e
·
1 Parent(s): 1c55e0d
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -19,14 +19,14 @@ with open('shape_names.txt') as f:
19
 
20
  model_gda = GDANET()
21
  model_gda = nn.DataParallel(model_gda)
22
- # model_gda.load_state_dict(torch.load('./GDANet_WOLFMix.t7', map_location=torch.device('cpu')))
23
- model_gda.load_state_dict(torch.load('/Users/renjiawei/Downloads/pretrained_models/GDANet_WOLFMix.t7', map_location=torch.device('cpu')))
24
  model_gda.eval()
25
 
26
  model_dgcnn = DGCNN()
27
  model_dgcnn = nn.DataParallel(model_dgcnn)
28
- # model_dgcnn.load_state_dict(torch.load('./dgcnn.t7', map_location=torch.device('cpu')))
29
- model_dgcnn.load_state_dict(torch.load('/Users/renjiawei/Downloads/pretrained_models/dgcnn.t7', map_location=torch.device('cpu')))
30
  model_dgcnn.eval()
31
 
32
  def pyplot_draw_point_cloud(points, corruption):
@@ -68,11 +68,11 @@ def load_dataset(corruption_idx, severity):
68
  ]
69
  corruption_type = corruptions[corruption_idx]
70
  if corruption_type == 'clean':
71
- # f = h5py.File(osp.join('modelnet_c', corruption_type + '.h5'))
72
- f = h5py.File(osp.join('/Users/renjiawei/Downloads/modelnet_c', corruption_type + '.h5'))
73
  else:
74
- # f = h5py.File(osp.join('modelnet_c', corruption_type + '_{}'.format(severity-1) + '.h5'))
75
- f = h5py.File(osp.join('/Users/renjiawei/Downloads/modelnet_c', corruption_type + '_{}'.format(severity - 1) + '.h5'))
76
  data = f['data'][:].astype('float32')
77
  label = f['label'][:].astype('int64')
78
  f.close()
 
19
 
20
  model_gda = GDANET()
21
  model_gda = nn.DataParallel(model_gda)
22
+ model_gda.load_state_dict(torch.load('./GDANet_WOLFMix.t7', map_location=torch.device('cpu')))
23
+ # model_gda.load_state_dict(torch.load('/Users/renjiawei/Downloads/pretrained_models/GDANet_WOLFMix.t7', map_location=torch.device('cpu')))
24
  model_gda.eval()
25
 
26
  model_dgcnn = DGCNN()
27
  model_dgcnn = nn.DataParallel(model_dgcnn)
28
+ model_dgcnn.load_state_dict(torch.load('./dgcnn.t7', map_location=torch.device('cpu')))
29
+ # model_dgcnn.load_state_dict(torch.load('/Users/renjiawei/Downloads/pretrained_models/dgcnn.t7', map_location=torch.device('cpu')))
30
  model_dgcnn.eval()
31
 
32
  def pyplot_draw_point_cloud(points, corruption):
 
68
  ]
69
  corruption_type = corruptions[corruption_idx]
70
  if corruption_type == 'clean':
71
+ f = h5py.File(osp.join('modelnet_c', corruption_type + '.h5'))
72
+ # f = h5py.File(osp.join('/Users/renjiawei/Downloads/modelnet_c', corruption_type + '.h5'))
73
  else:
74
+ f = h5py.File(osp.join('modelnet_c', corruption_type + '_{}'.format(severity-1) + '.h5'))
75
+ # f = h5py.File(osp.join('/Users/renjiawei/Downloads/modelnet_c', corruption_type + '_{}'.format(severity - 1) + '.h5'))
76
  data = f['data'][:].astype('float32')
77
  label = f['label'][:].astype('int64')
78
  f.close()