Spaces:
Runtime error
Runtime error
move to(device) outside of collate
Browse files
utils.py
CHANGED
@@ -226,7 +226,7 @@ def collate_fn(examples, tokenizer=None, padding=None, device=None):
|
|
226 |
batch[k].append(v)
|
227 |
|
228 |
return {
|
229 |
-
k: torch.tensor(v, dtype=torch.long
|
230 |
}
|
231 |
|
232 |
@torch.inference_mode()
|
@@ -332,8 +332,8 @@ def batch_embed(
|
|
332 |
drop_last=False,
|
333 |
collate_fn=partial(collate_fn, device=device)
|
334 |
):
|
335 |
-
ids = batch["input_ids"]
|
336 |
-
mask = batch["attention_mask"]
|
337 |
|
338 |
t_ids = torch.zeros_like(ids)
|
339 |
|
|
|
226 |
batch[k].append(v)
|
227 |
|
228 |
return {
|
229 |
+
k: torch.tensor(v, dtype=torch.long) if k in {"attention_mask", "input_ids"} else v for k, v in batch.items()
|
230 |
}
|
231 |
|
232 |
@torch.inference_mode()
|
|
|
332 |
drop_last=False,
|
333 |
collate_fn=partial(collate_fn, device=device)
|
334 |
):
|
335 |
+
ids = batch["input_ids"].to(device)
|
336 |
+
mask = batch["attention_mask"].to(device)
|
337 |
|
338 |
t_ids = torch.zeros_like(ids)
|
339 |
|