RxnIM / mllm /engine /builder.py
CYF200127's picture
Upload 235 files
3e1d9f3 verified
raw
history blame contribute delete
882 Bytes
from functools import partial
from typing import Tuple, Dict, Any, Type
from transformers.trainer import DataCollator
from .shikra import ShikraTrainer
from .base_engine import TrainerForMMLLM, Seq2Seq2DataCollatorWithImage
TYPE2TRAINER = {
'shikra': ShikraTrainer,
}
def prepare_trainer_collator(
model_args,
preprocessor: Dict[str, Any],
collator_kwargs: Dict[str, Any]
) -> Tuple[Type[TrainerForMMLLM], Dict[str, DataCollator]]:
type_ = model_args.type
trainer_cls = TYPE2TRAINER[type_]
data_collator_func = partial(
Seq2Seq2DataCollatorWithImage,
preprocessor=preprocessor,
**collator_kwargs,
)
data_collator_dict = {
"train_collator": data_collator_func(inference_mode=False),
"eval_collator": data_collator_func(inference_mode=True),
}
return trainer_cls, data_collator_dict