nbroad HF staff commited on
Commit
b93ed22
·
1 Parent(s): df70302

move to(device) outside of collate

Browse files
Files changed (1) hide show
  1. utils.py +3 -3
utils.py CHANGED
@@ -226,7 +226,7 @@ def collate_fn(examples, tokenizer=None, padding=None, device=None):
226
  batch[k].append(v)
227
 
228
  return {
229
- k: torch.tensor(v, dtype=torch.long, device=device) if k in {"attention_mask", "input_ids"} else v for k, v in batch.items()
230
  }
231
 
232
  @torch.inference_mode()
@@ -332,8 +332,8 @@ def batch_embed(
332
  drop_last=False,
333
  collate_fn=partial(collate_fn, device=device)
334
  ):
335
- ids = batch["input_ids"]
336
- mask = batch["attention_mask"]
337
 
338
  t_ids = torch.zeros_like(ids)
339
 
 
226
  batch[k].append(v)
227
 
228
  return {
229
+ k: torch.tensor(v, dtype=torch.long) if k in {"attention_mask", "input_ids"} else v for k, v in batch.items()
230
  }
231
 
232
  @torch.inference_mode()
 
332
  drop_last=False,
333
  collate_fn=partial(collate_fn, device=device)
334
  ):
335
+ ids = batch["input_ids"].to(device)
336
+ mask = batch["attention_mask"].to(device)
337
 
338
  t_ids = torch.zeros_like(ids)
339