from transformers import PreTrainedModel | |
from .configuration_avGFP import avGFPConfig | |
from evo_prot_grad.models import OneHotCNN | |
class avGFPModel(PreTrainedModel): | |
config_class = avGFPConfig | |
def __init__(self, config): | |
super().__init__(config) | |
self.model = OneHotCNN( | |
vocab_size=config.vocab_size, | |
kernel_size=config.kernel_size, | |
input_size=config.input_size | |
) | |
def forward(self, x): | |
return self.model(x) | |