RxnIM / mllm /models /builder /builder.py
CYF200127's picture
Upload 235 files
3e1d9f3 verified
raw
history blame contribute delete
390 Bytes
from typing import Dict, Any, Tuple
from torch import nn
from .build_shikra import load_pretrained_shikra
PREPROCESSOR = Dict[str, Any]
# TODO: Registry
def load_pretrained(model_args, training_args) -> Tuple[nn.Module, PREPROCESSOR]:
type_ = model_args.type
if type_ == 'shikra':
return load_pretrained_shikra(model_args, training_args)
else:
assert False