Skip to content

Commit

Permalink
use document types from pie-modules (#156)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder authored Mar 6, 2024
1 parent b47507c commit 961ee64
Show file tree
Hide file tree
Showing 11 changed files with 27 additions and 44 deletions.
2 changes: 1 addition & 1 deletion configs/metric/count_entity_labels.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ field: labeled_spans
labels: INFERRED
show_histogram: true
show_as_markdown: true
document_type: pytorch_ie.documents.TextDocumentWithLabeledSpans
document_type: pie_modules.documents.TextDocumentWithLabeledSpans
2 changes: 1 addition & 1 deletion configs/metric/span_lengths_characters.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ title: span lengths (characters)
layer: labeled_spans
show_histogram: true
show_as_markdown: true
document_type: pytorch_ie.documents.TextDocumentWithLabeledSpans
document_type: pie_modules.documents.TextDocumentWithLabeledSpans
2 changes: 1 addition & 1 deletion configs/metric/span_lengths_characters_per_label.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ layer: labeled_spans
labels: INFERRED
show_histogram: true
show_as_markdown: true
document_type: pytorch_ie.documents.TextDocumentWithLabeledSpans
document_type: pie_modules.documents.TextDocumentWithLabeledSpans
4 changes: 2 additions & 2 deletions configs/metric/span_lengths_tokens.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ tokenizer: bert-base-uncased
tokenize_kwargs:
add_special_tokens: false
# strict_span_conversion: false
tokenized_document_type: src.document.types.TokenDocumentWithLabeledSpans
tokenized_document_type: pie_modules.documents.TokenDocumentWithLabeledSpans
show_histogram: true
show_as_markdown: true
document_type: pytorch_ie.documents.TextDocumentWithLabeledSpans
document_type: pie_modules.documents.TextDocumentWithLabeledSpans
4 changes: 2 additions & 2 deletions dataset_builders/pie/conll2003/conll2003.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from pie_datasets import GeneratorBasedBuilder
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.core import AnnotationList, annotation_field
from pytorch_ie.documents import TextDocument, TextDocumentWithLabeledSpans
from pytorch_ie.documents import TextBasedDocument, TextDocumentWithLabeledSpans
from pytorch_ie.utils.span import tokens_and_tags_to_text_and_labeled_spans


@dataclass
class CoNLL2003Document(TextDocument):
class CoNLL2003Document(TextBasedDocument):
entities: AnnotationList[LabeledSpan] = annotation_field(target="text")


Expand Down
21 changes: 2 additions & 19 deletions src/document/types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import dataclasses
from typing import Optional

from pytorch_ie.annotations import BinaryRelation, LabeledSpan
from pie_modules.annotations import LabeledSpan
from pytorch_ie.core import Annotation, AnnotationList, annotation_field
from pytorch_ie.documents import TextBasedDocument, TokenBasedDocument
from pytorch_ie.documents import TextBasedDocument

# =========================== Annotation Types ============================= #

Expand Down Expand Up @@ -40,20 +40,3 @@ def __str__(self) -> str:
class TextDocumentWithLabeledEntitiesAndEntityAttributes(TextBasedDocument):
entities: AnnotationList[LabeledSpan] = annotation_field(target="text")
entity_attributes: AnnotationList[Attribute] = annotation_field(target="entities")


@dataclasses.dataclass
class TokenDocumentWithLabeledSpans(TokenBasedDocument):
labeled_spans: AnnotationList[LabeledSpan] = annotation_field(target="tokens")


@dataclasses.dataclass
class TokenDocumentWithLabeledSpansAndBinaryRelations(TokenDocumentWithLabeledSpans):
binary_relations: AnnotationList[BinaryRelation] = annotation_field(target="labeled_spans")


@dataclasses.dataclass
class TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(
TokenDocumentWithLabeledSpansAndBinaryRelations
):
labeled_partitions: AnnotationList[LabeledSpan] = annotation_field(target="tokens")
18 changes: 9 additions & 9 deletions src/taskmodules/transformer_token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,23 @@

import torch
import torch.nn.functional as F
from pie_modules.annotations import LabeledSpan
from pie_modules.document.processing import token_based_document_to_text_based, tokenize_document
from pytorch_ie import AnnotationLayer, annotation_field
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.core import TaskEncoding, TaskModule
from pytorch_ie.documents import (
TextDocument,
from pie_modules.documents import (
TextDocumentWithLabeledSpans,
TextDocumentWithLabeledSpansAndLabeledPartitions,
TokenBasedDocument,
)
from pytorch_ie import AnnotationLayer, annotation_field
from pytorch_ie.core import TaskEncoding, TaskModule
from pytorch_ie.documents import TextBasedDocument
from pytorch_ie.models.transformer_token_classification import ModelOutputType, ModelStepInputType
from pytorch_ie.utils.span import bio_tags_to_spans
from tokenizers import Encoding
from transformers import AutoTokenizer
from typing_extensions import TypeAlias

DocumentType: TypeAlias = TextDocument
DocumentType: TypeAlias = TextBasedDocument

InputEncodingType: TypeAlias = Encoding
TargetEncodingType: TypeAlias = Sequence[int]
Expand Down Expand Up @@ -152,8 +152,8 @@ def __init__(
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)

@property
def document_type(self) -> Optional[Type[TextDocument]]:
dt: Type[TextDocument]
def document_type(self) -> Optional[Type[TextBasedDocument]]:
dt: Type[TextBasedDocument]
errors = []
if self.span_annotation != "labeled_spans":
errors.append(
Expand Down Expand Up @@ -207,7 +207,7 @@ def _post_prepare(self):

def encode_input(
self,
document: TextDocument,
document: TextBasedDocument,
) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]:
if self.partition_annotation is None:
tokenized_document_type = TokenDocumentWithLabeledSpans
Expand Down
2 changes: 1 addition & 1 deletion tests/dataset_builders/pie/test_conll2003.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import datasets
import pytest
from pie_datasets import DatasetDict
from pie_modules.documents import TextDocumentWithLabeledSpans
from pytorch_ie.core import Document
from pytorch_ie.documents import TextDocumentWithLabeledSpans

from dataset_builders.pie.conll2003.conll2003 import Conll2003
from tests.dataset_builders import PIE_BASE_PATH
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/metrics/test_f1.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass

import pytest
from pytorch_ie.annotations import LabeledSpan
from pie_modules.annotations import LabeledSpan
from pytorch_ie.core import AnnotationLayer, annotation_field
from pytorch_ie.documents import TextBasedDocument

Expand Down
6 changes: 3 additions & 3 deletions tests/unit/serializer/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

import pytest
from pie_datasets import DatasetDict
from pytorch_ie.annotations import BinaryRelation, LabeledSpan
from pie_modules.annotations import BinaryRelation, LabeledSpan
from pytorch_ie.core import AnnotationList, annotation_field
from pytorch_ie.documents import TextDocument
from pytorch_ie.documents import TextBasedDocument

from src.serializer import JsonSerializer


@dataclass
class ExampleDocument(TextDocument):
class ExampleDocument(TextBasedDocument):
entities: AnnotationList[LabeledSpan] = annotation_field(target="text")
relations: AnnotationList[BinaryRelation] = annotation_field(target="entities")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@

import pytest
import torch
from pytorch_ie import AnnotationLayer, annotation_field
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.documents import (
TextBasedDocument,
from pie_modules.annotations import LabeledSpan
from pie_modules.documents import (
TextDocumentWithLabeledSpans,
TextDocumentWithLabeledSpansAndLabeledPartitions,
)
from pytorch_ie import AnnotationLayer, annotation_field
from pytorch_ie.documents import TextBasedDocument
from transformers import BatchEncoding

from src.taskmodules import MyTokenClassificationTaskModule
Expand Down

0 comments on commit 961ee64

Please sign in to comment.