Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DONT MERGE] simple text classification #36

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions configs/dataset/imdb.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_target_: datasets.load_dataset

path: ${original_work_dir}/datasets/imdb
44 changes: 44 additions & 0 deletions configs/experiment/imdb.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# @package _global_

# to execute this experiment run:
# python train.py experiment=imdb

defaults:
- override /dataset: imdb
- override /datamodule: default
- override /taskmodule: simple_transformer_text_classification
- override /model: transformer_text_classification
- override /callbacks: default
- override /logger: wandb
- override /trainer: default

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

# name of the run determines folder name in logs
name: "imdb/transformer_text_classification"

seed: 12345

trainer:
min_epochs: 5
max_epochs: 20
# gradient_clip_val: 0.5

taskmodule:
# the texts in imdb are rather short, so we decrease max_length to save resources
max_length: 128

datamodule:
batch_size: 32
# the imdb dataset has no val split, so we use "test" for that
val_split: test

logger:
wandb:
name: "first-run"
tags:
- dataset=imdb
- model=transformer_text_classification
- task=sentiment_classification
# save_dir: models/${name}/debug
4 changes: 4 additions & 0 deletions configs/model/transformer_text_classification.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_target_: pytorch_ie.models.TransformerTextClassificationModel

#model_name_or_path: ${transformer_model} # transformer_model is specified in config_rel.yaml
model_name_or_path: bert-base-uncased
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_target_: pytorch_ie.taskmodules.SimpleTransformerTextClassificationTaskModule

#tokenizer_name_or_path: ${transformer_model} # transformer_model is specified in config_rel.yaml
tokenizer_name_or_path: bert-base-uncased
55 changes: 55 additions & 0 deletions datasets/imdb/imdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from dataclasses import dataclass

import pytorch_ie.data.builder
from pytorch_ie.annotations import Label
from pytorch_ie.core import AnnotationList, annotation_field
from pytorch_ie.documents import TextDocument

import datasets


class ImdbConfig(datasets.BuilderConfig):
"""BuilderConfig for IMDB"""

def __init__(self, **kwargs):
"""BuilderConfig for IMDB.
Args:
**kwargs: keyword arguments forwarded to super.
"""
super().__init__(**kwargs)


@dataclass
class ImdbDocument(TextDocument):
label: AnnotationList[Label] = annotation_field()


class Imdb(pytorch_ie.data.builder.GeneratorBasedBuilder):
DOCUMENT_TYPE = ImdbDocument

BASE_DATASET_PATH = "imdb"

BUILDER_CONFIGS = [
ImdbConfig(
name="plain_text",
version=datasets.Version("1.0.0"),
description="IMDB sentiment classification dataset",
),
]

def _generate_document_kwargs(self, dataset):
return {"int2str": dataset.features["label"].int2str}

def _generate_document(self, example, int2str):

text = example["text"]
document = ImdbDocument(text=text)
label_id = example["label"]
if label_id < 0:
return document

label = int2str(label_id)
label_annotation = Label(label=label)
document.label.append(label_annotation)

return document
7 changes: 4 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# --------- pytorch-ie --------- #
# pytorch-ie>=0.8.0,<0.9
# we need the latest dev version of pytorch-ie where
# the pipeline works without the parameter "predict_field"
# we need the latest dev version of pytorch-ie
# * where the pipeline works without the parameter "predict_field", and
# * to allow for annotations without a target (the Label annotation)
git+https://github.com/ChristophAlt/pytorch-ie.git

# --------- hydra --------- #
Expand All @@ -10,7 +11,7 @@ hydra-colorlog>=1.1.0
hydra-optuna-sweeper>=1.1.0

# --------- loggers --------- #
# wandb
wandb # used in the example experiment
# neptune-client
# mlflow
# comet-ml
Expand Down
7 changes: 6 additions & 1 deletion src/training_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,12 @@ def train(config: DictConfig) -> Optional[float]:
# Init pytorch-ie model
log.info(f"Instantiating model <{config.model._target_}>")
# NOTE: THE FOLLOWING LINE MAY NEED ADAPTATION WHEN YOU DEFINE YOUR OWN MODELS OR TASKMODULES!
additional_model_kwargs: Dict[str, Any] = dict(num_classes=len(taskmodule.label_to_id))
# additional_model_kwargs: Dict[str, Any] = dict(num_classes=len(taskmodule.label_to_id))
additional_model_kwargs: Dict[str, Any] = dict(
num_classes=len(taskmodule.label_to_id),
tokenizer_vocab_size=taskmodule.tokenizer.vocab_size,
t_total=datamodule.num_train * config["trainer"]["max_epochs"],
)
model: PyTorchIEModel = hydra.utils.instantiate(
config.model, _convert_="partial", **additional_model_kwargs
)
Expand Down
9 changes: 8 additions & 1 deletion tests/shell/test_basic_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,14 @@

def test_fast_dev_run_with_evaluation():
"""Test running for 1 train, val and test batch."""
command = ["train.py", "++trainer.fast_dev_run=true", "++test=true"]
command = [
"train.py",
"experiment=imdb",
"logger=wandb",
"logger.wandb.offline=true",
"++trainer.fast_dev_run=true",
"++test=true",
]
run_command(command)


Expand Down