Skip to content

Commit

Permalink
Updated tokenization logic in classification models
Browse files Browse the repository at this point in the history
  • Loading branch information
Thilina Rajapakse committed Jan 31, 2021
1 parent 5840749 commit f56302d
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 88 deletions.
174 changes: 92 additions & 82 deletions simpletransformers/classification/classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@
MobileBertConfig,
MobileBertForSequenceClassification,
MobileBertTokenizer,
ReformerConfig,
ReformerTokenizer,
RobertaConfig,
RobertaTokenizer,
WEIGHTS_NAME,
Expand All @@ -80,6 +82,7 @@
from simpletransformers.classification.classification_utils import (
InputExample,
LazyClassificationDataset,
ClassificationDataset,
convert_examples_to_features,
)
from simpletransformers.classification.transformer_models.albert_model import AlbertForSequenceClassification
Expand All @@ -97,6 +100,8 @@
from simpletransformers.config.utils import sweep_config_to_sweep_values
from simpletransformers.custom_models.models import ElectraForSequenceClassification

from transformers.models.reformer import ReformerForSequenceClassification

try:
import wandb

Expand Down Expand Up @@ -147,6 +152,7 @@ def __init__(
"layoutlm": (LayoutLMConfig, LayoutLMForSequenceClassification, LayoutLMTokenizer),
"longformer": (LongformerConfig, LongformerForSequenceClassification, LongformerTokenizer),
"mobilebert": (MobileBertConfig, MobileBertForSequenceClassification, MobileBertTokenizer),
"reformer": (ReformerConfig, ReformerForSequenceClassification, ReformerTokenizer),
"roberta": (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
"xlm": (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
"xlmroberta": (XLMRobertaConfig, XLMRobertaForSequenceClassification, XLMRobertaTokenizer),
Expand Down Expand Up @@ -1137,97 +1143,101 @@ def load_and_cache_examples(
os.makedirs(self.args.cache_dir, exist_ok=True)

mode = "dev" if evaluate else "train"
cached_features_file = os.path.join(
args.cache_dir,
"cached_{}_{}_{}_{}_{}".format(
mode, args.model_type, args.max_seq_length, self.num_labels, len(examples),
),
)

if os.path.exists(cached_features_file) and (
(not args.reprocess_input_data and not no_cache)
or (mode == "dev" and args.use_cached_eval_features and not no_cache)
):
features = torch.load(cached_features_file)
if verbose:
logger.info(f" Features loaded from cache at {cached_features_file}")
else:
if verbose:
logger.info(" Converting to features started. Cache is not used.")
if args.sliding_window:
logger.info(" Sliding window enabled")

# If labels_map is defined, then labels need to be replaced with ints
if self.args.labels_map and not self.args.regression:
for example in examples:
if multi_label:
example.label = [self.args.labels_map[label] for label in example.label]
else:
example.label = self.args.labels_map[example.label]

features = convert_examples_to_features(
examples,
args.max_seq_length,
tokenizer,
output_mode,
# XLNet has a CLS token at the end
cls_token_at_end=bool(args.model_type in ["xlnet"]),
cls_token=tokenizer.cls_token,
cls_token_segment_id=2 if args.model_type in ["xlnet"] else 0,
sep_token=tokenizer.sep_token,
# RoBERTa uses an extra separator b/w pairs of sentences,
# cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
sep_token_extra=bool(args.model_type in ["roberta", "camembert", "xlmroberta", "longformer"]),
# PAD on the left for XLNet
pad_on_left=bool(args.model_type in ["xlnet"]),
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0,
process_count=process_count,
multi_label=multi_label,
silent=args.silent or silent,
use_multiprocessing=args.use_multiprocessing,
sliding_window=args.sliding_window,
flatten=not evaluate,
stride=args.stride,
add_prefix_space=bool(args.model_type in ["roberta", "camembert", "xlmroberta", "longformer"]),
# avoid padding in case of single example/online inferencing to decrease execution time
pad_to_max_length=bool(len(examples) > 1),
args=args,
if args.sliding_window or self.args.model_type == "layoutlm":
cached_features_file = os.path.join(
args.cache_dir,
"cached_{}_{}_{}_{}_{}".format(
mode, args.model_type, args.max_seq_length, self.num_labels, len(examples),
),
)
if verbose and args.sliding_window:
logger.info(f" {len(features)} features created from {len(examples)} samples.")

if not no_cache:
torch.save(features, cached_features_file)
if os.path.exists(cached_features_file) and (
(not args.reprocess_input_data and not no_cache)
or (mode == "dev" and args.use_cached_eval_features and not no_cache)
):
features = torch.load(cached_features_file)
if verbose:
logger.info(f" Features loaded from cache at {cached_features_file}")
else:
if verbose:
logger.info(" Converting to features started. Cache is not used.")
if args.sliding_window:
logger.info(" Sliding window enabled")

# If labels_map is defined, then labels need to be replaced with ints
if self.args.labels_map and not self.args.regression:
for example in examples:
if multi_label:
example.label = [self.args.labels_map[label] for label in example.label]
else:
example.label = self.args.labels_map[example.label]

features = convert_examples_to_features(
examples,
args.max_seq_length,
tokenizer,
output_mode,
# XLNet has a CLS token at the end
cls_token_at_end=bool(args.model_type in ["xlnet"]),
cls_token=tokenizer.cls_token,
cls_token_segment_id=2 if args.model_type in ["xlnet"] else 0,
sep_token=tokenizer.sep_token,
# RoBERTa uses an extra separator b/w pairs of sentences,
# cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
sep_token_extra=bool(args.model_type in ["roberta", "camembert", "xlmroberta", "longformer"]),
# PAD on the left for XLNet
pad_on_left=bool(args.model_type in ["xlnet"]),
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0,
process_count=process_count,
multi_label=multi_label,
silent=args.silent or silent,
use_multiprocessing=args.use_multiprocessing,
sliding_window=args.sliding_window,
flatten=not evaluate,
stride=args.stride,
add_prefix_space=bool(args.model_type in ["roberta", "camembert", "xlmroberta", "longformer"]),
# avoid padding in case of single example/online inferencing to decrease execution time
pad_to_max_length=bool(len(examples) > 1),
args=args,
)
if verbose and args.sliding_window:
logger.info(f" {len(features)} features created from {len(examples)} samples.")

if not no_cache:
torch.save(features, cached_features_file)

if args.sliding_window and evaluate:
features = [
[feature_set] if not isinstance(feature_set, list) else feature_set for feature_set in features
]
window_counts = [len(sample) for sample in features]
features = [feature for feature_set in features for feature in feature_set]
if args.sliding_window and evaluate:
features = [
[feature_set] if not isinstance(feature_set, list) else feature_set for feature_set in features
]
window_counts = [len(sample) for sample in features]
features = [feature for feature_set in features for feature in feature_set]

all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)

if self.args.model_type == "layoutlm":
all_bboxes = torch.tensor([f.bboxes for f in features], dtype=torch.long)
if self.args.model_type == "layoutlm":
all_bboxes = torch.tensor([f.bboxes for f in features], dtype=torch.long)

if output_mode == "classification":
all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
elif output_mode == "regression":
all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.float)
if output_mode == "classification":
all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
elif output_mode == "regression":
all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.float)

if self.args.model_type == "layoutlm":
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_bboxes)
else:
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if self.args.model_type == "layoutlm":
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_bboxes)
else:
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)

if args.sliding_window and evaluate:
return dataset, window_counts
if args.sliding_window and evaluate:
return dataset, window_counts
else:
return dataset
else:
return dataset
train_dataset = ClassificationDataset(examples, self.tokenizer, self.args, mode=mode, multi_label=multi_label, output_mode=output_mode)
return train_dataset

def compute_metrics(self, preds, model_outputs, labels, eval_examples=None, multi_label=False, **kwargs):
"""
Expand Down
85 changes: 85 additions & 0 deletions simpletransformers/classification/classification_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import csv
import json
import logging
import linecache
import os
import sys
Expand All @@ -44,6 +45,8 @@

csv.field_size_limit(2147483647)

logger = logging.getLogger(__name__)


class InputExample(object):
"""A single training/test example for simple sequence classification."""
Expand Down Expand Up @@ -84,6 +87,88 @@ def __init__(self, input_ids, input_mask, segment_ids, label_id, bboxes=None):
self.bboxes = bboxes


def preprocess_data(data):
example, tokenizer, args = data

if example.text_b:
tokenized_example = tokenizer.encode_plus(
text_pair=[example.text_a, example.text_b],
max_length=args.max_seq_length,
truncation=True,
padding="max_length",
return_tensors="pt"
)
else:
tokenized_example = tokenizer.encode_plus(
text=example.text_a,
max_length=args.max_seq_length,
truncation=True,
padding="max_length",
return_tensors="pt"
)

return {**tokenized_example, "label": example.label}


class ClassificationDataset(Dataset):
def __init__(self, data, tokenizer, args, mode, multi_label, output_mode):
self.tokenizer = tokenizer
self.output_mode = output_mode

cached_features_file = os.path.join(
args.cache_dir,
"cached_{}_{}_{}_{}_{}".format(
mode, args.model_type, args.max_seq_length, len(args.labels_list), len(data),
),
)

if os.path.exists(cached_features_file) and (
(not args.reprocess_input_data and not args.no_cache)
or (mode == "dev" and args.use_cached_eval_features and not args.no_cache)
):
self.examples = torch.load(cached_features_file)
logger.info(f" Features loaded from cache at {cached_features_file}")
else:
logger.info(" Converting to features started. Cache is not used.")

# If labels_map is defined, then labels need to be replaced with ints
if args.labels_map and not args.regression:
for example in data:
if multi_label:
example.label = [args.labels_map[label] for label in example.label]
else:
example.label = args.labels_map[example.label]
data = [(example, tokenizer, args) for example in data]

if args.use_multiprocessing:
with Pool(args.process_count) as p:
self.examples = list(
tqdm(
p.imap(preprocess_data, data, chunksize=args.multiprocessing_chunksize),
total=len(data),
disable=args.silent,
)
)
else:
self.examples = [preprocess_data(d) for d in tqdm(data, disable=args.silent)]

if not args.no_cache:
logger.info(" Saving features into cached file %s", cached_features_file)
torch.save(self.examples, cached_features_file)

def __len__(self):
return len(self.examples)

def __getitem__(self, index):
features = self.examples[index]
label = features.pop("label")
if self.output_mode == "classification":
label = torch.tensor(label, dtype=torch.long)
elif self.output_mode == "regression":
label = torch.tensor(label, dtype=torch.float)
return features, label


def convert_example_to_feature(
example_row,
pad_token=0,
Expand Down

This file was deleted.

2 changes: 1 addition & 1 deletion tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_multiclass_classification(model_type, model_name):
@pytest.mark.parametrize(
"model_type, model_name",
[
# ("bert", "bert-base-uncased"),
("bert", "bert-base-uncased"),
("xlnet", "xlnet-base-cased"),
# ("xlm", "xlm-mlm-17-1280"),
# ("roberta", "roberta-base"),
Expand Down

0 comments on commit f56302d

Please sign in to comment.