from transformers import PretrainedConfig

class COMET19_CN_Config(PretrainedConfig):
    def __init__(
        self,
        model: str = "transformer",
        nL: int = 12,
        nH: int = 12,
        hSize: int = 768,
        edpt: float = 0.1,
        adpt: float = 0.1,
        rdpt: float = 0.1,
        odpt: float = 0.1,
        pt: str = "gpt",
        afn: str = "gelu",
        init: str = "pt",
        vSize: int = 40545,
        n_ctx: int = 31,
        n_vocab: int = 40545,
        return_acts: bool = True,
        return_probs: bool = False,
        **kwargs,
    ):
        self.model = model
        self.nL = nL
        self.nH = nH
        self.hSize = hSize
        self.edpt = edpt
        self.adpt = adpt
        self.rdpt = rdpt
        self.odpt = odpt
        self.pt = pt
        self.afn = afn
        self.init = init
        self.vSize = vSize
        self.n_ctx = n_ctx
        self.n_vocab = n_vocab
        self.return_acts = return_acts
        self.return_probs = return_probs
        super().__init__(**kwargs)


def parse_net_config(config):
    return {
        'model': config.model,
        'nL': config.nL,
        'nH': config.nH,
        'hSize': config.hSize,
        'edpt': config.edpt,
        'adpt': config.adpt,
        'rdpt': config.rdpt,
        'odpt': config.odpt,
        'pt': config.pt,
        'afn': config.afn,
        'init': config.init,
        'vSize': config.vSize,
        'n_ctx': config.n_ctx,
    }