From 6f189e0a332c89a2da98001dd92cd7fcbc8c92b8 Mon Sep 17 00:00:00 2001 From: Thilina Rajapakse Date: Tue, 2 Feb 2021 02:04:00 +0530 Subject: [PATCH] Updated classification model tokenization logic. Added deberta, mpnet, squeezenet for classification --- CHANGELOG.md | 20 ++++- docs/_docs/04-classification-specifics.md | 37 ++++---- examples/t5/mt5_translation/test.py | 2 +- examples/t5/training_on_a_new_task/train.py | 2 +- setup.py | 2 +- .../classification/classification_model.py | 84 +++++++++++++++---- .../classification/classification_utils.py | 6 +- .../transformer_models/longformer_model.py | 73 ++++++++++++++++ .../transformer_models/mobilebert_model.py | 66 +++++++++++++++ simpletransformers/config/model_args.py | 7 +- simpletransformers/conv_ai/conv_ai_model.py | 8 +- simpletransformers/ner/ner_model.py | 4 +- .../question_answering_model.py | 11 ++- simpletransformers/seq2seq/seq2seq_model.py | 2 +- simpletransformers/seq2seq/seq2seq_utils.py | 11 ++- tests/test_classification.py | 53 +++++++++++- 16 files changed, 329 insertions(+), 59 deletions(-) create mode 100755 simpletransformers/classification/transformer_models/longformer_model.py create mode 100755 simpletransformers/classification/transformer_models/mobilebert_model.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 83e531dd..f6e177f8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,20 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.60.0] - 2021-02-02 + +# Added + +- Added class weights support for Longformer classification +- Added new classification models: + - SqueezeBert + - DeBERTa + - MPNet + +# Changed + +- Updated ClassificationModel logic to make it easier to add new models + ## [0.51.16] - 2021-01-29 ## Fixed @@ -1386,7 +1400,11 @@ Model checkpoint is now saved for all epochs again. - This CHANGELOG file to hopefully serve as an evolving example of a standardized open source project CHANGELOG. -[0.51.15]: https://github.com/ThilinaRajapakse/simpletransformers/compare/2af55e9...HEAD +[0.60.0]: https://github.com/ThilinaRajapakse/simpletransformers/compare/5840749...HEAD + +[0.51.16]: https://github.com/ThilinaRajapakse/simpletransformers/compare/b42898e...5840749 + +[0.51.15]: https://github.com/ThilinaRajapakse/simpletransformers/compare/2af55e9...b42898e [0.51.14]: https://github.com/ThilinaRajapakse/simpletransformers/compare/278fca1...2af55e9 diff --git a/docs/_docs/04-classification-specifics.md b/docs/_docs/04-classification-specifics.md index 32919f72..4444cdc5 100644 --- a/docs/_docs/04-classification-specifics.md +++ b/docs/_docs/04-classification-specifics.md @@ -2,7 +2,7 @@ title: Classification Specifics permalink: /docs/classification-specifics/ excerpt: "Specific notes for text classification tasks." -last_modified_at: 2020/12/21 22:13:56 +last_modified_at: 2021/02/02 02:03:09 toc: true --- @@ -32,22 +32,25 @@ The process of performing text classification in Simple Transformers does not de New model types are regularly added to the library. Text classification tasks currently supports the model types given below. -| Model | Model code for `ClassificationModel` | -| ----------- | ------------------------------------ | -| ALBERT | albert | -| BERT | bert | -| BERTweet | bertweet | -| CamemBERT | camembert | -| RoBERTa | roberta | -| DistilBERT | distilbert | -| ELECTRA | electra | -| FlauBERT | flaubert | -| *LayoutLM | layoutlm | -| Longformer | longformer | -| *MobileBERT | mobilebert | -| XLM | xlm | -| XLM-RoBERTa | xlmroberta | -| XLNet | xlnet | +| Model | Model code for `ClassificationModel` | +| ------------ | ------------------------------------ | +| ALBERT | albert | +| BERT | bert | +| BERTweet | bertweet | +| CamemBERT | camembert | +| *DeBERTa | deberta | +| DistilBERT | distilbert | +| ELECTRA | electra | +| FlauBERT | flaubert | +| LayoutLM | layoutlm | +| *Longformer | longformer | +| *MPNet | mpnet | +| MobileBERT | mobilebert | +| RoBERTa | roberta | +| *SqueezeBert | squeezebert | +| XLM | xlm | +| XLM-RoBERTa | xlmroberta | +| XLNet | xlnet | \* *Not available with Multi-label classification* diff --git a/examples/t5/mt5_translation/test.py b/examples/t5/mt5_translation/test.py index 5e282bdc..9eb047f9 100644 --- a/examples/t5/mt5_translation/test.py +++ b/examples/t5/mt5_translation/test.py @@ -33,4 +33,4 @@ english_preds = model.predict(to_english) sin_eng_bleu = sacrebleu.corpus_bleu(english_preds, english_truth) -print("Sinhalese to English: ", sin_eng_bleu.score) \ No newline at end of file +print("Sinhalese to English: ", sin_eng_bleu.score) diff --git a/examples/t5/training_on_a_new_task/train.py b/examples/t5/training_on_a_new_task/train.py index cc2b5d5d..db2cf981 100644 --- a/examples/t5/training_on_a_new_task/train.py +++ b/examples/t5/training_on_a_new_task/train.py @@ -22,6 +22,6 @@ "wandb_project": "Question Generation with T5", } -model = T5Model("t5","t5-large",args=model_args) +model = T5Model("t5", "t5-large", args=model_args) model.train_model(train_df, eval_data=eval_df) diff --git a/setup.py b/setup.py index 5312030b..49c6e90c 100755 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name="simpletransformers", - version="0.51.16", + version="0.60.0", author="Thilina Rajapakse", author_email="chaturangarajapakshe@gmail.com", description="An easy-to-use wrapper library for the Transformers library.", diff --git a/simpletransformers/classification/classification_model.py b/simpletransformers/classification/classification_model.py index 00c5c9d3..8e34e03b 100755 --- a/simpletransformers/classification/classification_model.py +++ b/simpletransformers/classification/classification_model.py @@ -27,7 +27,7 @@ mean_squared_error, roc_curve, auc, - average_precision_score + average_precision_score, ) from tensorboardX import SummaryWriter from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset @@ -46,11 +46,17 @@ from transformers import ( AlbertConfig, AlbertTokenizer, + AutoConfig, + AutoModelForSequenceClassification, + AutoTokenizer, BertConfig, BertTokenizer, BertweetTokenizer, CamembertConfig, CamembertTokenizer, + DebertaConfig, + DebertaForSequenceClassification, + DebertaTokenizer, DistilBertConfig, DistilBertTokenizer, ElectraConfig, @@ -60,15 +66,17 @@ LayoutLMConfig, LayoutLMTokenizer, LongformerConfig, - LongformerForSequenceClassification, LongformerTokenizer, + MPNetConfig, + MPNetForSequenceClassification, + MPNetTokenizer, MobileBertConfig, - MobileBertForSequenceClassification, MobileBertTokenizer, - ReformerConfig, - ReformerTokenizer, RobertaConfig, RobertaTokenizer, + SqueezeBertConfig, + SqueezeBertForSequenceClassification, + SqueezeBertTokenizer, WEIGHTS_NAME, XLMConfig, XLMRobertaConfig, @@ -91,6 +99,8 @@ from simpletransformers.classification.transformer_models.distilbert_model import DistilBertForSequenceClassification from simpletransformers.classification.transformer_models.flaubert_model import FlaubertForSequenceClassification from simpletransformers.classification.transformer_models.layoutlm_model import LayoutLMForSequenceClassification +from simpletransformers.classification.transformer_models.longformer_model import LongformerForSequenceClassification +from simpletransformers.classification.transformer_models.mobilebert_model import MobileBertForSequenceClassification from simpletransformers.classification.transformer_models.roberta_model import RobertaForSequenceClassification from simpletransformers.classification.transformer_models.xlm_model import XLMForSequenceClassification from simpletransformers.classification.transformer_models.xlm_roberta_model import XLMRobertaForSequenceClassification @@ -100,7 +110,6 @@ from simpletransformers.config.utils import sweep_config_to_sweep_values from simpletransformers.custom_models.models import ElectraForSequenceClassification -from transformers.models.reformer import ReformerForSequenceClassification try: import wandb @@ -112,11 +121,22 @@ logger = logging.getLogger(__name__) +MODELS_WITHOUT_CLASS_WEIGHTS_SUPPORT = ["squeezebert", "deberta", "mpnet"] + +MODELS_WITH_EXTRA_SEP_TOKEN = ["roberta", "camembert", "xlmroberta", "longformer", "mpnet"] + +MODELS_WITH_ADD_PREFIX_SPACE = ["roberta", "camembert", "xlmroberta", "longformer", "mpnet"] + +MODELS_WITHOUT_SLIDING_WINDOW_SUPPORT = ["squeezebert"] + + class ClassificationModel: def __init__( self, model_type, model_name, + tokenizer_type=None, + tokenizer_name=None, num_labels=None, weight=None, args=None, @@ -132,6 +152,9 @@ def __init__( Args: model_type: The type of model (bert, xlnet, xlm, roberta, distilbert) model_name: The exact architecture and trained weights to use. This may be a Hugging Face Transformers compatible pre-trained model, a community model, or the path to a directory containing model files. + tokenizer_type: The type of tokenizer (auto, bert, xlnet, xlm, roberta, distilbert, etc.) to use. If a string is passed, Simple Transformers will try to initialize a tokenizer class from the available MODEL_CLASSES. + Alternatively, a Tokenizer class (subclassed from PreTrainedTokenizer) can be passed. + tokenizer_name: The name/path to the tokenizer. If the tokenizer_type is not specified, the model_type will be used to determine the type of the tokenizer. num_labels (optional): The number of labels or classes in the dataset. weight (optional): A list of length num_labels containing the weights to assign to each label for loss calculation. args (optional): Default args will be used if this parameter is not provided. If provided, it should be a dict containing the args that should be changed in the default args. @@ -143,17 +166,20 @@ def __init__( MODEL_CLASSES = { "albert": (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer), + "auto": (AutoConfig, AutoModelForSequenceClassification, AutoTokenizer), "bert": (BertConfig, BertForSequenceClassification, BertTokenizer), "bertweet": (RobertaConfig, RobertaForSequenceClassification, BertweetTokenizer), "camembert": (CamembertConfig, CamembertForSequenceClassification, CamembertTokenizer), + "deberta": (DebertaConfig, DebertaForSequenceClassification, DebertaTokenizer), "distilbert": (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer), "electra": (ElectraConfig, ElectraForSequenceClassification, ElectraTokenizer), "flaubert": (FlaubertConfig, FlaubertForSequenceClassification, FlaubertTokenizer), "layoutlm": (LayoutLMConfig, LayoutLMForSequenceClassification, LayoutLMTokenizer), "longformer": (LongformerConfig, LongformerForSequenceClassification, LongformerTokenizer), "mobilebert": (MobileBertConfig, MobileBertForSequenceClassification, MobileBertTokenizer), - "reformer": (ReformerConfig, ReformerForSequenceClassification, ReformerTokenizer), + "mpnet": (MPNetConfig, MPNetForSequenceClassification, MPNetTokenizer), "roberta": (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer), + "squeezebert": (SqueezeBertConfig, SqueezeBertForSequenceClassification, SqueezeBertTokenizer), "xlm": (XLMConfig, XLMForSequenceClassification, XLMTokenizer), "xlmroberta": (XLMRobertaConfig, XLMRobertaForSequenceClassification, XLMRobertaTokenizer), "xlnet": (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer), @@ -166,6 +192,9 @@ def __init__( elif isinstance(args, ClassificationArgs): self.args = args + if model_type in MODELS_WITHOUT_SLIDING_WINDOW_SUPPORT and self.args.sliding_window: + raise ValueError("{} does not currently support sliding window".format(model_type)) + if self.args.thread_count: torch.set_num_threads(self.args.thread_count) @@ -200,13 +229,24 @@ def __init__( self.args.labels_list = [i for i in range(len_labels_list)] config_class, model_class, tokenizer_class = MODEL_CLASSES[model_type] + + if tokenizer_type is not None: + if isinstance(tokenizer_type, str): + _, _, tokenizer_class = MODEL_CLASSES[tokenizer_type] + else: + tokenizer_class = tokenizer_type + if num_labels: self.config = config_class.from_pretrained(model_name, num_labels=num_labels, **self.args.config) self.num_labels = num_labels else: self.config = config_class.from_pretrained(model_name, **self.args.config) self.num_labels = self.config.num_labels - self.weight = weight + + if model_type in MODELS_WITHOUT_CLASS_WEIGHTS_SUPPORT and weight is not None: + raise ValueError("{} does not currently support class weights".format(model_type)) + else: + self.weight = weight if use_cuda: if torch.cuda.is_available(): @@ -275,17 +315,20 @@ def __init__( except AttributeError: raise AttributeError("fp16 requires Pytorch >= 1.6. Please update Pytorch or turn off fp16.") - if model_name in [ + if tokenizer_name is None: + tokenizer_name = model_name + + if tokenizer_name in [ "vinai/bertweet-base", "vinai/bertweet-covid19-base-cased", "vinai/bertweet-covid19-base-uncased", ]: self.tokenizer = tokenizer_class.from_pretrained( - model_name, do_lower_case=self.args.do_lower_case, normalization=True, **kwargs + tokenizer_name, do_lower_case=self.args.do_lower_case, normalization=True, **kwargs ) else: self.tokenizer = tokenizer_class.from_pretrained( - model_name, do_lower_case=self.args.do_lower_case, **kwargs + tokenizer_name, do_lower_case=self.args.do_lower_case, **kwargs ) if self.args.special_tokens_list: @@ -294,6 +337,8 @@ def __init__( self.args.model_name = model_name self.args.model_type = model_type + self.args.tokenizer_name = tokenizer_name + self.args.tokenizer_type = tokenizer_type if model_type in ["camembert", "xlmroberta"]: warnings.warn( @@ -1184,7 +1229,7 @@ def load_and_cache_examples( sep_token=tokenizer.sep_token, # RoBERTa uses an extra separator b/w pairs of sentences, # cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805 - sep_token_extra=bool(args.model_type in ["roberta", "camembert", "xlmroberta", "longformer"]), + sep_token_extra=args.model_type in MODELS_WITH_EXTRA_SEP_TOKEN, # PAD on the left for XLNet pad_on_left=bool(args.model_type in ["xlnet"]), pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0], @@ -1196,7 +1241,7 @@ def load_and_cache_examples( sliding_window=args.sliding_window, flatten=not evaluate, stride=args.stride, - add_prefix_space=bool(args.model_type in ["roberta", "camembert", "xlmroberta", "longformer"]), + add_prefix_space=args.model_type in MODELS_WITH_ADD_PREFIX_SPACE, # avoid padding in case of single example/online inferencing to decrease execution time pad_to_max_length=bool(len(examples) > 1), args=args, @@ -1236,8 +1281,10 @@ def load_and_cache_examples( else: return dataset else: - train_dataset = ClassificationDataset(examples, self.tokenizer, self.args, mode=mode, multi_label=multi_label, output_mode=output_mode) - return train_dataset + dataset = ClassificationDataset( + examples, self.tokenizer, self.args, mode=mode, multi_label=multi_label, output_mode=output_mode + ) + return dataset def compute_metrics(self, preds, model_outputs, labels, eval_examples=None, multi_label=False, **kwargs): """ @@ -1302,7 +1349,10 @@ def compute_metrics(self, preds, model_outputs, labels, eval_examples=None, mult auroc = auc(fpr, tpr) auprc = average_precision_score(labels, scores) return ( - {**{"mcc": mcc, "tp": tp, "tn": tn, "fp": fp, "fn": fn, "auroc": auroc, "auprc": auprc}, **extra_metrics}, + { + **{"mcc": mcc, "tp": tp, "tn": tn, "fp": fp, "fn": fn, "auroc": auroc, "auprc": auprc}, + **extra_metrics, + }, wrong, ) else: @@ -1575,7 +1625,7 @@ def _move_model_to_device(self): def _get_inputs_dict(self, batch): if isinstance(batch[0], dict): - inputs = {key: value.squeeze().to(self.device) for key, value in batch[0].items()} + inputs = {key: value.squeeze(1).to(self.device) for key, value in batch[0].items()} inputs["labels"] = batch[1].to(self.device) else: batch = tuple(t.to(self.device) for t in batch) diff --git a/simpletransformers/classification/classification_utils.py b/simpletransformers/classification/classification_utils.py index dfb8228b..76d729d8 100755 --- a/simpletransformers/classification/classification_utils.py +++ b/simpletransformers/classification/classification_utils.py @@ -96,7 +96,7 @@ def preprocess_data(data): max_length=args.max_seq_length, truncation=True, padding="max_length", - return_tensors="pt" + return_tensors="pt", ) else: tokenized_example = tokenizer.encode_plus( @@ -104,7 +104,7 @@ def preprocess_data(data): max_length=args.max_seq_length, truncation=True, padding="max_length", - return_tensors="pt" + return_tensors="pt", ) return {**tokenized_example, "label": example.label} @@ -600,7 +600,7 @@ def __init__( self.data = [ dict( json.load(open(os.path.join(data_path, l + self.data_type_extension))), - **{"images": l + image_type_extension} + **{"images": l + image_type_extension}, ) for l in files_list ] diff --git a/simpletransformers/classification/transformer_models/longformer_model.py b/simpletransformers/classification/transformer_models/longformer_model.py new file mode 100755 index 00000000..eb23aa24 --- /dev/null +++ b/simpletransformers/classification/transformer_models/longformer_model.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss, MSELoss +from transformers.models.longformer.modeling_longformer import ( + LongformerModel, + LongformerPreTrainedModel, + LongformerClassificationHead, +) + + +class LongformerForSequenceClassification(LongformerPreTrainedModel): + def __init__(self, config, weight=None): + super(LongformerForSequenceClassification, self).__init__(config) + self.num_labels = config.num_labels + + self.longformer = LongformerModel(config) + self.classifier = LongformerClassificationHead(config) + self.weight = weight + + self.init_weights() + + def forward( + self, + input_ids=None, + attention_mask=None, + global_attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + if global_attention_mask is None: + global_attention_mask = torch.zeros_like(input_ids) + # global attention on cls token + global_attention_mask[:, 0] = 1 + + outputs = self.longformer( + input_ids, + attention_mask=attention_mask, + global_attention_mask=global_attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + sequence_output = outputs[0] + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + if self.num_labels == 1: + # We are doing regression + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) + else: + if self.weight is not None: + weight = self.weight.to(labels.device) + else: + weight = None + loss_fct = CrossEntropyLoss(weight=weight) + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output diff --git a/simpletransformers/classification/transformer_models/mobilebert_model.py b/simpletransformers/classification/transformer_models/mobilebert_model.py new file mode 100755 index 00000000..d1ee3c0c --- /dev/null +++ b/simpletransformers/classification/transformer_models/mobilebert_model.py @@ -0,0 +1,66 @@ +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss, MSELoss +from transformers.models.mobilebert.modeling_mobilebert import MobileBertModel, MobileBertPreTrainedModel + + +class MobileBertForSequenceClassification(MobileBertPreTrainedModel): + def __init__(self, config, weight=None): + super(MobileBertForSequenceClassification, self).__init__(config) + self.num_labels = config.num_labels + + self.mobilebert = MobileBertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, self.num_labels) + self.weight = weight + + self.init_weights() + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + outputs = self.mobilebert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.num_labels == 1: + # We are doing regression + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) + else: + if self.weight is not None: + weight = self.weight.to(labels.device) + else: + weight = None + loss_fct = CrossEntropyLoss(weight=weight) + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output diff --git a/simpletransformers/config/model_args.py b/simpletransformers/config/model_args.py index 4614e3b6..ff25bde9 100644 --- a/simpletransformers/config/model_args.py +++ b/simpletransformers/config/model_args.py @@ -85,6 +85,8 @@ class ModelArgs: skip_special_tokens: bool = True tensorboard_dir: str = None thread_count: int = None + tokenizer_type: str = None + tokenizer_name: str = None train_batch_size: int = 8 train_custom_parameters_only: bool = False use_cached_eval_features: bool = False @@ -110,7 +112,10 @@ def get_args_for_saving(self): def save(self, output_dir): os.makedirs(output_dir, exist_ok=True) with open(os.path.join(output_dir, "model_args.json"), "w") as f: - json.dump(self.get_args_for_saving(), f) + args_dict = self.get_args_for_saving() + if args_dict["tokenizer_type"] is not None and not isinstance(args_dict["tokenizer_type"], str): + args_dict["tokenizer_type"] = type(args_dict["tokenizer_type"]).__name__ + json.dump(args_dict, f) def load(self, input_dir): if input_dir: diff --git a/simpletransformers/conv_ai/conv_ai_model.py b/simpletransformers/conv_ai/conv_ai_model.py index 5281a86d..96c57658 100644 --- a/simpletransformers/conv_ai/conv_ai_model.py +++ b/simpletransformers/conv_ai/conv_ai_model.py @@ -697,14 +697,10 @@ def evaluate(self, eval_file, output_dir, verbose=True, silent=False, **kwargs): if args.fp16: with amp.autocast(): - outputs = model( - input_ids, token_type_ids=token_type_ids, mc_token_ids=mc_token_ids, - ) + outputs = model(input_ids, token_type_ids=token_type_ids, mc_token_ids=mc_token_ids,) lm_logits, mc_logits = outputs[:2] else: - outputs = model( - input_ids, token_type_ids=token_type_ids, mc_token_ids=mc_token_ids, - ) + outputs = model(input_ids, token_type_ids=token_type_ids, mc_token_ids=mc_token_ids,) lm_logits, mc_logits = outputs[:2] # model outputs are always tuple in pytorch-transformers (see doc) diff --git a/simpletransformers/ner/ner_model.py b/simpletransformers/ner/ner_model.py index e5666c3e..7a5e33e4 100755 --- a/simpletransformers/ner/ner_model.py +++ b/simpletransformers/ner/ner_model.py @@ -240,9 +240,7 @@ def __init__( model_name, do_lower_case=self.args.do_lower_case, normalization=True, **kwargs ) elif model_type == "auto": - self.tokenizer = tokenizer_class.from_pretrained( - model_name, **kwargs - ) + self.tokenizer = tokenizer_class.from_pretrained(model_name, **kwargs) else: self.tokenizer = tokenizer_class.from_pretrained( model_name, do_lower_case=self.args.do_lower_case, **kwargs diff --git a/simpletransformers/question_answering/question_answering_model.py b/simpletransformers/question_answering/question_answering_model.py index 984b086a..6a7adc67 100755 --- a/simpletransformers/question_answering/question_answering_model.py +++ b/simpletransformers/question_answering/question_answering_model.py @@ -195,7 +195,9 @@ def __init__(self, model_type, model_name, args=None, use_cuda=True, cuda_device if model_type == "auto": self.tokenizer = tokenizer_class.from_pretrained(model_name, **kwargs) else: - self.tokenizer = tokenizer_class.from_pretrained(model_name, do_lower_case=self.args.do_lower_case, **kwargs) + self.tokenizer = tokenizer_class.from_pretrained( + model_name, do_lower_case=self.args.do_lower_case, **kwargs + ) if self.args.special_tokens_list: self.tokenizer.add_tokens(self.args.special_tokens_list, special_tokens=True) @@ -248,7 +250,12 @@ def load_and_cache_examples(self, examples, evaluate=False, no_cache=False, outp if mode == "dev": all_feature_index = torch.arange(all_input_ids.size(0), dtype=torch.long) dataset = TensorDataset( - all_input_ids, all_attention_masks, all_token_type_ids, all_feature_index, all_cls_index, all_p_mask + all_input_ids, + all_attention_masks, + all_token_type_ids, + all_feature_index, + all_cls_index, + all_p_mask, ) else: all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long) diff --git a/simpletransformers/seq2seq/seq2seq_model.py b/simpletransformers/seq2seq/seq2seq_model.py index 56d6343d..850c6005 100644 --- a/simpletransformers/seq2seq/seq2seq_model.py +++ b/simpletransformers/seq2seq/seq2seq_model.py @@ -79,7 +79,7 @@ except ImportError: wandb_available = False -if transformers.__version__ < '4.2.0': +if transformers.__version__ < "4.2.0": MBartForConditionalGeneration._keys_to_ignore_on_save = [] logger = logging.getLogger(__name__) diff --git a/simpletransformers/seq2seq/seq2seq_utils.py b/simpletransformers/seq2seq/seq2seq_utils.py index 092cc2d6..1533ced4 100644 --- a/simpletransformers/seq2seq/seq2seq_utils.py +++ b/simpletransformers/seq2seq/seq2seq_utils.py @@ -16,11 +16,14 @@ logger = logging.getLogger(__name__) -if transformers.__version__ < '4.2.0': - shift_tokens_right = lambda input_ids, pad_token_id, decoder_start_token_id: _shift_tokens_right(input_ids, pad_token_id) +if transformers.__version__ < "4.2.0": + shift_tokens_right = lambda input_ids, pad_token_id, decoder_start_token_id: _shift_tokens_right( + input_ids, pad_token_id + ) else: shift_tokens_right = _shift_tokens_right + def preprocess_data(data): input_text, target_text, encoder_tokenizer, decoder_tokenizer, args = data @@ -112,7 +115,9 @@ def preprocess_data_mbart(data): ) decoder_input_ids = tokenized_example["labels"].clone() - decoder_input_ids = shift_tokens_right(decoder_input_ids, tokenizer.pad_token_id, tokenizer.lang_code_to_id[args.tgt_lang]) + decoder_input_ids = shift_tokens_right( + decoder_input_ids, tokenizer.pad_token_id, tokenizer.lang_code_to_id[args.tgt_lang] + ) labels = tokenized_example["labels"] labels[labels == tokenizer.pad_token_id] = -100 diff --git a/tests/test_classification.py b/tests/test_classification.py index bf24087b..88ff5b50 100644 --- a/tests/test_classification.py +++ b/tests/test_classification.py @@ -12,6 +12,7 @@ ("electra", "google/electra-small-discriminator"), ("mobilebert", "google/mobilebert-uncased"), ("bertweet", "vinai/bertweet-base"), + ("deberta", "microsoft/deberta-base"), # ("xlnet", "xlnet-base-cased"), # ("xlm", "xlm-mlm-17-1280"), # ("roberta", "roberta-base"), @@ -48,6 +49,7 @@ def test_binary_classification(model_type, model_name): "reprocess_input_data": True, "overwrite_output_dir": True, "scheduler": "constant_schedule", + "max_seq_length": 20, }, ) @@ -57,6 +59,8 @@ def test_binary_classification(model_type, model_name): # Evaluate the model result, model_outputs, wrong_predictions = model.eval_model(eval_df) + predictions, raw_outputs = model.predict(["Some arbitary sentence"]) + @pytest.mark.parametrize( "model_type, model_name", @@ -97,7 +101,7 @@ def test_multiclass_classification(model_type, model_name): model_type, model_name, num_labels=3, - args={"no_save": True, "reprocess_input_data": True, "overwrite_output_dir": True}, + args={"no_save": True, "reprocess_input_data": True, "overwrite_output_dir": True, "max_seq_length": 20}, use_cuda=False, ) @@ -142,7 +146,13 @@ def test_multilabel_classification(model_type, model_name): model_type, model_name, num_labels=6, - args={"no_save": True, "reprocess_input_data": True, "overwrite_output_dir": True, "num_train_epochs": 1}, + args={ + "no_save": True, + "reprocess_input_data": True, + "overwrite_output_dir": True, + "num_train_epochs": 1, + "max_seq_length": 20, + }, use_cuda=False, ) @@ -153,3 +163,42 @@ def test_multilabel_classification(model_type, model_name): result, model_outputs, wrong_predictions = model.eval_model(eval_df) predictions, raw_outputs = model.predict(["This thing is entirely different from the other thing. "]) + + +def test_sliding_window(): + # Train and Evaluation data needs to be in a Pandas Dataframe of two columns. + # The first column is the text with type str, and the second column is the + # label with type int. + train_data = [ + ["Example sentence belonging to class 1" * 10, 1], + ["Example sentence belonging to class 0", 0], + ] + train_df = pd.DataFrame(train_data) + + eval_data = [ + ["Example eval sentence belonging to class 1", 1], + ["Example eval sentence belonging to class 0" * 10, 0], + ] + eval_df = pd.DataFrame(eval_data) + + # Create a ClassificationModel + model = ClassificationModel( + "distilbert", + "distilbert-base-uncased", + use_cuda=False, + args={ + "no_save": True, + "reprocess_input_data": True, + "overwrite_output_dir": True, + "max_seq_length": 20, + "sliding_window": True, + }, + ) + + # Train the model + model.train_model(train_df) + + # Evaluate the model + result, model_outputs, wrong_predictions = model.eval_model(eval_df) + + predictions, raw_outputs = model.predict(["Some arbitary sentence"])