|
from functools import partial |
|
from typing import Tuple, Dict, Any, Type |
|
|
|
from transformers.trainer import DataCollator |
|
|
|
from .shikra import ShikraTrainer |
|
from .base_engine import TrainerForMMLLM, Seq2Seq2DataCollatorWithImage |
|
|
|
TYPE2TRAINER = { |
|
'shikra': ShikraTrainer, |
|
} |
|
|
|
|
|
def prepare_trainer_collator( |
|
model_args, |
|
preprocessor: Dict[str, Any], |
|
collator_kwargs: Dict[str, Any] |
|
) -> Tuple[Type[TrainerForMMLLM], Dict[str, DataCollator]]: |
|
type_ = model_args.type |
|
trainer_cls = TYPE2TRAINER[type_] |
|
data_collator_func = partial( |
|
Seq2Seq2DataCollatorWithImage, |
|
preprocessor=preprocessor, |
|
**collator_kwargs, |
|
) |
|
data_collator_dict = { |
|
"train_collator": data_collator_func(inference_mode=False), |
|
"eval_collator": data_collator_func(inference_mode=True), |
|
} |
|
return trainer_cls, data_collator_dict |
|
|