from transformers import PreTrainedModel import torch.nn as nn from .configuration_simple_model import SimpleNNConfig # Define the model class class SimpleNN(PreTrainedModel): config_class = SimpleNNConfig def __init__(self, config): super().__init__(config) self.dense = nn.Linear(config.input_size, config.num_classes) def forward(self, x): return self.dense(x)