File size: 1,706 Bytes
fade19c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from lightning.pytorch import LightningModule
from lightning.pytorch.core.saving import _load_state
from transformers import PreTrainedModel, PretrainedConfig


class GenBioConfig(PretrainedConfig):
    model_type = "genbio"

    def __init__(self, hparams=None, **kwargs):
        self.hparams = hparams
        super().__init__(**kwargs)


class GenBioModel(PreTrainedModel):
    config_class = GenBioConfig

    def __init__(self, config: GenBioConfig, genbio_model=None, **kwargs):
        super().__init__(config, **kwargs)
        # if genbio_model is provided, we don't need to initialize it
        if genbio_model is not None:
            self.genbio_model = genbio_model
            return
        # otherwise, initialize the model from hyperparameters
        cls_path = config.hparams["_class_path"]
        module_path, name = cls_path.rsplit(".", 1)
        genbio_cls = getattr(__import__(module_path, fromlist=[name]), name)
        checkpoint = {
            LightningModule.CHECKPOINT_HYPER_PARAMS_KEY: config.hparams,
            "state_dict": {},
        }
        # TODO: _load_state is a private function and throws a warning for an
        # empty state_dict. We need a fucntion to intialize our model, this
        # is the only choice we have for now.
        self.genbio_model = _load_state(genbio_cls, checkpoint, strict_loading=False)

    @classmethod
    def from_genbio_model(cls, model: LightningModule):
        return cls(GenBioConfig(hparams=model.hparams), genbio_model=model)

    def forward(self, *args, **kwargs):
        return self.genbio_model(*args, **kwargs)


GenBioModel.register_for_auto_class("AutoModel")
GenBioConfig.register_for_auto_class("AutoConfig")