File size: 407 Bytes
44c3947 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
from transformers import PretrainedConfig
class ConNetConfig(PretrainedConfig):
model_type = "convnet"
def __init__(
self,
num_classes=10,
**kwargs,
):
self.num_classes = num_classes
super().__init__(**kwargs)
if __name__=="__main__":
convnet_config = ConNetConfig(num_classes=10)
convnet_config.save_pretrained("custom-convnet")
pass |