hvaldez commited on
Commit
e0cb1ec
·
verified ·
1 Parent(s): 3a0b984

update to demo.py

Browse files
Files changed (1) hide show
  1. demo.py +7 -5
demo.py CHANGED
@@ -33,6 +33,10 @@ class VideoModel(nn.Module):
33
  super(VideoModel, self).__init__()
34
  self.cfg = load_cfg(config)
35
  self.model = self.build_model()
 
 
 
 
36
  self.templates = ['{}']
37
  self.dataset = self.cfg['data']['dataset']
38
  self.eval()
@@ -156,7 +160,7 @@ class VideoCLSModel(VideoModel):
156
  truncation=True,
157
  max_length=self.model_cfg.max_txt_l.video,
158
  return_tensors="pt",
159
- )
160
  _, class_embeddings = self.model.encode_text(embeddings)
161
  return class_embeddings
162
 
@@ -170,9 +174,7 @@ class VideoCLSModel(VideoModel):
170
  images = values[0]
171
  target = values[1]
172
 
173
- if torch.cuda.is_available():
174
- images = images.cuda(non_blocking=True)
175
- target = target.cuda(non_blocking=True)
176
 
177
  # encode images
178
  images = rearrange(images, 'b c k h w -> b k c h w')
@@ -190,7 +192,7 @@ class VideoCLSModel(VideoModel):
190
  similarity = self.model.get_sim(image_features, self.text_features)[0]
191
 
192
  all_outputs.append(similarity.cpu())
193
- all_targets.append(target.cpu())
194
 
195
  all_outputs = torch.cat(all_outputs)
196
  all_targets = torch.cat(all_targets)
 
33
  super(VideoModel, self).__init__()
34
  self.cfg = load_cfg(config)
35
  self.model = self.build_model()
36
+ use_gpu = torch.cuda.is_available()
37
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+ if use_gpu:
39
+ self.model = self.model.to(self.device)
40
  self.templates = ['{}']
41
  self.dataset = self.cfg['data']['dataset']
42
  self.eval()
 
160
  truncation=True,
161
  max_length=self.model_cfg.max_txt_l.video,
162
  return_tensors="pt",
163
+ ).to(self.device)
164
  _, class_embeddings = self.model.encode_text(embeddings)
165
  return class_embeddings
166
 
 
174
  images = values[0]
175
  target = values[1]
176
 
177
+ images = images.to(self.device)
 
 
178
 
179
  # encode images
180
  images = rearrange(images, 'b c k h w -> b k c h w')
 
192
  similarity = self.model.get_sim(image_features, self.text_features)[0]
193
 
194
  all_outputs.append(similarity.cpu())
195
+ all_targets.append(target)
196
 
197
  all_outputs = torch.cat(all_outputs)
198
  all_targets = torch.cat(all_targets)