|
|
|
import torch |
|
|
|
from ..collator_for_classification import DataCollatorForGeneClassification |
|
|
|
""" |
|
Geneformer collator for multi-task cell classification. |
|
""" |
|
|
|
class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification): |
|
class_type = "cell" |
|
|
|
def __init__(self, *args, **kwargs) -> None: |
|
super().__init__(*args, **kwargs) |
|
|
|
def _prepare_batch(self, features): |
|
|
|
batch = self.tokenizer.pad( |
|
features, |
|
class_type=self.class_type, |
|
padding=self.padding, |
|
max_length=self.max_length, |
|
pad_to_multiple_of=self.pad_to_multiple_of, |
|
return_tensors="pt", |
|
) |
|
|
|
|
|
if "label" in features[0]: |
|
|
|
labels = {task: [] for task in features[0]["label"].keys()} |
|
|
|
|
|
for feature in features: |
|
for task, label in feature["label"].items(): |
|
labels[task].append(label) |
|
|
|
|
|
for task in labels: |
|
if isinstance(labels[task][0], (list, torch.Tensor)): |
|
dtype = torch.long |
|
labels[task] = torch.tensor(labels[task], dtype=dtype) |
|
elif isinstance(labels[task][0], dict): |
|
|
|
pass |
|
|
|
|
|
batch["labels"] = labels |
|
else: |
|
|
|
batch["labels"] = {task: torch.tensor([], dtype=torch.long) for task in features[0]["input_ids"].keys()} |
|
|
|
return batch |
|
|
|
def __call__(self, features): |
|
batch = self._prepare_batch(features) |
|
|
|
for k, v in batch.items(): |
|
if torch.is_tensor(v): |
|
batch[k] = v.clone().detach() |
|
elif isinstance(v, dict): |
|
|
|
batch[k] = {task: torch.tensor(labels, dtype=torch.int64) for task, labels in v.items()} |
|
else: |
|
batch[k] = torch.tensor(v, dtype=torch.int64) |
|
|
|
return batch |