-
Notifications
You must be signed in to change notification settings - Fork 412
/
pretrained_models.py
executable file
·27 lines (26 loc) · 1.28 KB
/
pretrained_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from transformers import (
BertConfig, BertModel, BertTokenizer,
RobertaConfig, RobertaModel, RobertaTokenizer,
AlbertConfig, AlbertModel, AlbertTokenizer,
XLMRobertaConfig, XLMRobertaModel, XLMRobertaTokenizer,
ElectraConfig, ElectraModel, ElectraTokenizer,
T5Config, T5EncoderModel, T5ForConditionalGeneration, T5Tokenizer,
DebertaConfig, DebertaModel, DebertaTokenizer
)
from module.san_model import SanModel
from msrt5.modeling_t5 import MSRT5ForConditionalGeneration, MSRT5EncoderModel
from msrt5.configuration import MSRT5Config
from msrt5.tokenization_t5 import MSRT5Tokenizer
MODEL_CLASSES = {
"bert": (BertConfig, BertModel, BertTokenizer),
"roberta": (RobertaConfig, RobertaModel, RobertaTokenizer),
"albert": (AlbertConfig, AlbertModel, AlbertTokenizer),
"xlm": (XLMRobertaConfig, XLMRobertaModel, XLMRobertaTokenizer),
"san": (BertConfig, SanModel, BertTokenizer),
"electra": (ElectraConfig, ElectraModel, ElectraTokenizer),
"t5": (T5Config, T5EncoderModel, T5Tokenizer),
"deberta": (DebertaConfig, DebertaModel, DebertaTokenizer),
"t5g": (T5Config, T5ForConditionalGeneration, T5Tokenizer),
"msrt5g": (MSRT5Config, MSRT5ForConditionalGeneration, MSRT5Tokenizer),
"msrt5": (MSRT5Config, MSRT5EncoderModel, MSRT5Tokenizer),
}