pemami4911's picture
Upload model
bf0bf67
raw
history blame contribute delete
488 Bytes
from transformers import PreTrainedModel
from .configuration_avGFP import avGFPConfig
from evo_prot_grad.models import OneHotCNN
class avGFPModel(PreTrainedModel):
config_class = avGFPConfig
def __init__(self, config):
super().__init__(config)
self.model = OneHotCNN(
vocab_size=config.vocab_size,
kernel_size=config.kernel_size,
input_size=config.input_size
)
def forward(self, x):
return self.model(x)