zaydzuhri's picture
Training in progress, step 2500
0094a2a verified
raw
history blame contribute delete
557 Bytes
# -*- coding: utf-8 -*-
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from fla.models.transformer.configuration_transformer import TransformerConfig
from fla.models.transformer.modeling_transformer import (
TransformerForCausalLM, TransformerModel)
AutoConfig.register(TransformerConfig.model_type, TransformerConfig)
AutoModel.register(TransformerConfig, TransformerModel)
AutoModelForCausalLM.register(TransformerConfig, TransformerForCausalLM)
__all__ = ['TransformerConfig', 'TransformerForCausalLM', 'TransformerModel']