Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# @Time : 2021/12/6 3:35 下午 | |
# @Author : JianingWang | |
# @File : __init__.py | |
# from models.chid_mlm import BertForChidMLM | |
from models.multiple_choice.duma import BertDUMAForMultipleChoice, AlbertDUMAForMultipleChoice, MegatronDumaForMultipleChoice | |
from models.span_extraction.global_pointer import BertForEffiGlobalPointer, RobertaForEffiGlobalPointer, RoformerForEffiGlobalPointer, MegatronForEffiGlobalPointer | |
from transformers import AutoModelForTokenClassification, AutoModelForSequenceClassification, AutoModelForMaskedLM, AutoModelForMultipleChoice, BertTokenizer, \ | |
AutoModelForQuestionAnswering, AutoModelForCausalLM | |
from transformers import AutoTokenizer | |
from transformers.models.roformer import RoFormerTokenizer | |
from transformers.models.bert import BertTokenizerFast, BertForTokenClassification, BertTokenizer | |
from transformers.models.roberta.tokenization_roberta import RobertaTokenizer | |
from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast | |
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer | |
from transformers.models.bart.tokenization_bart import BartTokenizer | |
from transformers.models.t5.tokenization_t5 import T5Tokenizer | |
from transformers.models.plbart.tokenization_plbart import PLBartTokenizer | |
# from models.deberta import DebertaV2ForMultipleChoice, DebertaForMultipleChoice | |
# from models.fengshen.models.longformer import LongformerForMultipleChoice | |
from models.kg import BertForPretrainWithKG, BertForPretrainWithKGV2 | |
from models.language_modeling.mlm import BertForMaskedLM, RobertaForMaskedLM, AlbertForMaskedLM, RoFormerForMaskedLM | |
# from models.sequence_classification.classification import build_cls_model | |
from models.multiple_choice.multiple_choice_tag import BertForTagMultipleChoice, RoFormerForTagMultipleChoice, MegatronBertForTagMultipleChoice | |
from models.multiple_choice.multiple_choice import MegatronBertForMultipleChoice, MegatronBertRDropForMultipleChoice | |
from models.semeval7 import DebertaV2ForSemEval7MultiTask | |
from models.sequence_matching.fusion_siamese import BertForFusionSiamese, BertForWSC | |
# from roformer import RoFormerForTokenClassification, RoFormerForSequenceClassification | |
from models.fewshot_learning.span_proto import SpanProto | |
from models.fewshot_learning.token_proto import TokenProto | |
from models.sequence_labeling.head_token_cls import ( | |
BertSoftmaxForSequenceLabeling, BertCrfForSequenceLabeling, | |
RobertaSoftmaxForSequenceLabeling, RobertaCrfForSequenceLabeling, | |
AlbertSoftmaxForSequenceLabeling, AlbertCrfForSequenceLabeling, | |
MegatronBertSoftmaxForSequenceLabeling, MegatronBertCrfForSequenceLabeling, | |
) | |
from models.span_extraction.span_for_ner import BertSpanForNer, RobertaSpanForNer, AlbertSpanForNer, MegatronBertSpanForNer | |
from models.language_modeling.mlm import BertForMaskedLM | |
from models.language_modeling.kpplm import BertForWikiKGPLM, RoBertaKPPLMForProcessedWikiKGPLM, DeBertaKPPLMForProcessedWikiKGPLM | |
from models.language_modeling.causal_lm import GPT2ForCausalLM | |
from models.sequence_classification.head_cls import ( | |
BertForSequenceClassification, BertPrefixForSequenceClassification, | |
BertPtuningForSequenceClassification, BertAdapterForSequenceClassification, | |
RobertaForSequenceClassification, RobertaPrefixForSequenceClassification, | |
RobertaPtuningForSequenceClassification,RobertaAdapterForSequenceClassification, | |
BartForSequenceClassification, GPT2ForSequenceClassification | |
) | |
from models.sequence_classification.masked_prompt_cls import ( | |
PromptBertForSequenceClassification, PromptBertPtuningForSequenceClassification, | |
PromptBertPrefixForSequenceClassification, PromptBertAdapterForSequenceClassification, | |
PromptRobertaForSequenceClassification, PromptRobertaPtuningForSequenceClassification, | |
PromptRobertaPrefixForSequenceClassification, PromptRobertaAdapterForSequenceClassification | |
) | |
from models.sequence_classification.causal_prompt_cls import PromptGPT2ForSequenceClassification | |
from models.code.code_classification import ( | |
RobertaForCodeClassification, CodeBERTForCodeClassification, | |
GraphCodeBERTForCodeClassification, PLBARTForCodeClassification, CodeT5ForCodeClassification | |
) | |
from models.code.code_generation import ( | |
PLBARTForCodeGeneration | |
) | |
from models.reinforcement_learning.actor import CausalActor | |
from models.reinforcement_learning.critic import AutoModelCritic | |
from models.reinforcement_learning.reward_model import ( | |
RobertaForReward, GPT2ForReward | |
) | |
# Models for pre-training | |
PRETRAIN_MODEL_CLASSES = { | |
"mlm": { | |
"bert": BertForMaskedLM, | |
"roberta": RobertaForMaskedLM, | |
"albert": AlbertForMaskedLM, | |
"roformer": RoFormerForMaskedLM, | |
}, | |
"auto_mlm": AutoModelForMaskedLM, | |
"causal_lm": { | |
"gpt2": GPT2ForCausalLM, | |
"bart": None, | |
"t5": None, | |
"llama": None | |
}, | |
"auto_causal_lm": AutoModelForCausalLM | |
} | |
CLASSIFICATION_MODEL_CLASSES = { | |
"auto_cls": AutoModelForSequenceClassification, # huggingface cls | |
"classification": AutoModelForSequenceClassification, # huggingface cls | |
"head_cls": { | |
"bert": BertForSequenceClassification, | |
"roberta": RobertaForSequenceClassification, | |
"bart": BartForSequenceClassification, | |
"gpt2": GPT2ForSequenceClassification | |
}, # use standard fine-tuning head for cls, e.g., bert+mlp | |
"head_prefix_cls": { | |
"bert": BertPrefixForSequenceClassification, | |
"roberta": RobertaPrefixForSequenceClassification, | |
}, # use standard fine-tuning head with prefix-tuning technique for cls, e.g., bert+mlp | |
"head_ptuning_cls": { | |
"bert": BertPtuningForSequenceClassification, | |
"roberta": RobertaPtuningForSequenceClassification, | |
}, # use standard fine-tuning head with p-tuning technique for cls, e.g., bert+mlp | |
"head_adapter_cls": { | |
"bert": BertAdapterForSequenceClassification, | |
"roberta": RobertaAdapterForSequenceClassification, | |
}, # use standard fine-tuning head with adapter-tuning technique for cls, e.g., bert+mlp | |
"masked_prompt_cls": { | |
"bert": PromptBertForSequenceClassification, | |
"roberta": PromptRobertaForSequenceClassification, | |
# "deberta": PromptDebertaForSequenceClassification, | |
# "deberta-v2": PromptDebertav2ForSequenceClassification, | |
}, # use masked lm head technique for prompt-based cls, e.g., bert+mlm | |
"masked_prompt_prefix_cls": { | |
"bert": PromptBertPrefixForSequenceClassification, | |
"roberta": PromptRobertaPrefixForSequenceClassification, | |
# "deberta": PromptDebertaPrefixForSequenceClassification, | |
# "deberta-v2": PromptDebertav2PrefixForSequenceClassification, | |
}, # use masked lm head with prefix-tuning technique for prompt-based cls, e.g., bert+mlm | |
"masked_prompt_ptuning_cls": { | |
"bert": PromptBertPtuningForSequenceClassification, | |
"roberta": PromptRobertaPtuningForSequenceClassification, | |
# "deberta": PromptDebertaPtuningForSequenceClassification, | |
# "deberta-v2": PromptDebertav2PtuningForSequenceClassification, | |
}, # use masked lm head with p-tuning technique for prompt-based cls, e.g., bert+mlm | |
"masked_prompt_adapter_cls": { | |
"bert": PromptBertAdapterForSequenceClassification, | |
"roberta": PromptRobertaAdapterForSequenceClassification, | |
}, # use masked lm head with adapter-tuning technique for prompt-based cls, e.g., bert+mlm | |
"causal_prompt_cls": { | |
"gpt2": PromptGPT2ForSequenceClassification, | |
"bart": None, | |
"t5": None, | |
}, # use causal lm head for prompt-tuning, e.g., gpt2+lm | |
} | |
TOKEN_CLASSIFICATION_MODEL_CLASSES = { | |
"auto_token_cls": AutoModelForTokenClassification, | |
"head_softmax_token_cls": { | |
"bert": BertSoftmaxForSequenceLabeling, | |
"roberta": RobertaSoftmaxForSequenceLabeling, | |
"albert": AlbertSoftmaxForSequenceLabeling, | |
"megatron": MegatronBertSoftmaxForSequenceLabeling, | |
}, | |
"head_crf_token_cls": { | |
"bert": BertCrfForSequenceLabeling, | |
"roberta": RobertaCrfForSequenceLabeling, | |
"albert": AlbertCrfForSequenceLabeling, | |
"megatron": MegatronBertCrfForSequenceLabeling, | |
} | |
} | |
SPAN_EXTRACTION_MODEL_CLASSES = { | |
"global_pointer": { | |
"bert": BertForEffiGlobalPointer, | |
"roberta": RobertaForEffiGlobalPointer, | |
"roformer": RoformerForEffiGlobalPointer, | |
"megatronbert": MegatronForEffiGlobalPointer | |
}, | |
} | |
FEWSHOT_MODEL_CLASSES = { | |
"sequence_proto": None, | |
"span_proto": SpanProto, | |
"token_proto": TokenProto, | |
} | |
CODE_MODEL_CLASSES = { | |
"code_cls": { | |
"roberta": RobertaForCodeClassification, | |
"codebert": CodeBERTForCodeClassification, | |
"graphcodebert": GraphCodeBERTForCodeClassification, | |
"codet5": CodeT5ForCodeClassification, | |
"plbart": PLBARTForCodeClassification, | |
}, | |
"code_generation": { | |
# "roberta": RobertaForCodeGeneration, | |
# "codebert": BertForCodeGeneration, | |
# "graphcodebert": BertForCodeGeneration, | |
# "codet5": T5ForCodeGeneration, | |
"plbart": PLBARTForCodeGeneration, | |
}, | |
} | |
REINFORCEMENT_MODEL_CLASSES = { | |
"causal_actor": CausalActor, | |
"auto_critic": AutoModelCritic, | |
"rl_reward": { | |
"roberta": RobertaForReward, | |
"gpt2": GPT2ForReward, | |
"gpt-neo": None, | |
"opt": None, | |
"llama": None, | |
} | |
} | |
# task_type 负责对应model类型 | |
OTHER_MODEL_CLASSES = { | |
# sequence labeling | |
"bert_span_ner": BertSpanForNer, | |
"roberta_span_ner": RobertaSpanForNer, | |
"albert_span_ner": AlbertSpanForNer, | |
"megatronbert_span_ner": MegatronBertSpanForNer, | |
# sequence matching | |
"fusion_siamese": BertForFusionSiamese, | |
# multiple choice | |
"multi_choice": AutoModelForMultipleChoice, | |
"multi_choice_megatron": MegatronBertForMultipleChoice, | |
"multi_choice_megatron_rdrop": MegatronBertRDropForMultipleChoice, | |
"megatron_multi_choice_tag": MegatronBertForTagMultipleChoice, | |
"roformer_multi_choice_tag": RoFormerForTagMultipleChoice, | |
"multi_choice_tag": BertForTagMultipleChoice, | |
"duma": BertDUMAForMultipleChoice, | |
"duma_albert": AlbertDUMAForMultipleChoice, | |
"duma_megatron": MegatronDumaForMultipleChoice, | |
# language modeling | |
# "bert_mlm_acc": BertForMaskedLMWithACC, | |
# "roformer_mlm_acc": RoFormerForMaskedLMWithACC, | |
"bert_pretrain_kg": BertForPretrainWithKG, | |
"bert_pretrain_kg_v2": BertForPretrainWithKGV2, | |
"kpplm_roberta": RoBertaKPPLMForProcessedWikiKGPLM, | |
"kpplm_deberta": DeBertaKPPLMForProcessedWikiKGPLM, | |
# other | |
"clue_wsc": BertForWSC, | |
"semeval7multitask": DebertaV2ForSemEval7MultiTask, | |
# "debertav2_multi_choice": DebertaV2ForMultipleChoice, | |
# "deberta_multi_choice": DebertaForMultipleChoice, | |
# "qa": AutoModelForQuestionAnswering, | |
# "roformer_cls": RoFormerForSequenceClassification, | |
# "roformer_ner": RoFormerForTokenClassification, | |
# "fensheng_multi_choice": LongformerForMultipleChoice, | |
# "chid_mlm": BertForChidMLM, | |
} | |
# MODEL_CLASSES = dict(list(PRETRAIN_MODEL_CLASSES.items()) + list(OTHER_MODEL_CLASSES.items())) | |
MODEL_CLASSES_LIST = [ | |
PRETRAIN_MODEL_CLASSES, | |
CLASSIFICATION_MODEL_CLASSES, | |
TOKEN_CLASSIFICATION_MODEL_CLASSES, | |
SPAN_EXTRACTION_MODEL_CLASSES, | |
FEWSHOT_MODEL_CLASSES, | |
CODE_MODEL_CLASSES, | |
REINFORCEMENT_MODEL_CLASSES, | |
OTHER_MODEL_CLASSES, | |
] | |
MODEL_CLASSES = dict() | |
for model_class in MODEL_CLASSES_LIST: | |
MODEL_CLASSES = dict(list(MODEL_CLASSES.items()) + list(model_class.items())) | |
# model_type 负责对应tokenizer | |
TOKENIZER_CLASSES = { | |
# for natural language processing | |
"auto": AutoTokenizer, | |
"bert": BertTokenizerFast, | |
"roberta": RobertaTokenizer, | |
"wobert": RoFormerTokenizer, | |
"roformer": RoFormerTokenizer, | |
"bigbird": BertTokenizerFast, | |
"erlangshen": BertTokenizerFast, | |
"deberta": BertTokenizer, | |
"roformer_v2": BertTokenizerFast, | |
"gpt2": GPT2Tokenizer, | |
"megatronbert": BertTokenizerFast, | |
"bart": BartTokenizer, | |
"t5": T5Tokenizer, | |
# for programming language processing | |
"codebert": RobertaTokenizer, | |
"graphcodebert": RobertaTokenizer, | |
"codet5": RobertaTokenizer, | |
"plbart": PLBartTokenizer | |
} | |