from functools import partial from typing import Tuple, Dict, Any, Type from transformers.trainer import DataCollator from .shikra import ShikraTrainer from .base_engine import TrainerForMMLLM, Seq2Seq2DataCollatorWithImage TYPE2TRAINER = { 'shikra': ShikraTrainer, } def prepare_trainer_collator( model_args, preprocessor: Dict[str, Any], collator_kwargs: Dict[str, Any] ) -> Tuple[Type[TrainerForMMLLM], Dict[str, DataCollator]]: type_ = model_args.type trainer_cls = TYPE2TRAINER[type_] data_collator_func = partial( Seq2Seq2DataCollatorWithImage, preprocessor=preprocessor, **collator_kwargs, ) data_collator_dict = { "train_collator": data_collator_func(inference_mode=False), "eval_collator": data_collator_func(inference_mode=True), } return trainer_cls, data_collator_dict