File size: 1,706 Bytes
fade19c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
from lightning.pytorch import LightningModule
from lightning.pytorch.core.saving import _load_state
from transformers import PreTrainedModel, PretrainedConfig
class GenBioConfig(PretrainedConfig):
model_type = "genbio"
def __init__(self, hparams=None, **kwargs):
self.hparams = hparams
super().__init__(**kwargs)
class GenBioModel(PreTrainedModel):
config_class = GenBioConfig
def __init__(self, config: GenBioConfig, genbio_model=None, **kwargs):
super().__init__(config, **kwargs)
# if genbio_model is provided, we don't need to initialize it
if genbio_model is not None:
self.genbio_model = genbio_model
return
# otherwise, initialize the model from hyperparameters
cls_path = config.hparams["_class_path"]
module_path, name = cls_path.rsplit(".", 1)
genbio_cls = getattr(__import__(module_path, fromlist=[name]), name)
checkpoint = {
LightningModule.CHECKPOINT_HYPER_PARAMS_KEY: config.hparams,
"state_dict": {},
}
# TODO: _load_state is a private function and throws a warning for an
# empty state_dict. We need a fucntion to intialize our model, this
# is the only choice we have for now.
self.genbio_model = _load_state(genbio_cls, checkpoint, strict_loading=False)
@classmethod
def from_genbio_model(cls, model: LightningModule):
return cls(GenBioConfig(hparams=model.hparams), genbio_model=model)
def forward(self, *args, **kwargs):
return self.genbio_model(*args, **kwargs)
GenBioModel.register_for_auto_class("AutoModel")
GenBioConfig.register_for_auto_class("AutoConfig")
|