Skip to content

Commit

Permalink
gh-3243: propose a different caching approach for sentence
Browse files Browse the repository at this point in the history
  • Loading branch information
helpmefindaname authored and Benedikt Fuchs committed Jul 17, 2023
1 parent 419f13a commit a053fcc
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 70 deletions.
75 changes: 30 additions & 45 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,25 +547,9 @@ def set_label(self, typename: str, value: str, score: float = 1.0):
class Span(_PartOfSentence):
"""This class represents one textual span consisting of Tokens."""

def __new__(self, tokens: List[Token]):
# check if the span already exists. If so, return it
unlabeled_identifier = self._make_unlabeled_identifier(tokens)
if unlabeled_identifier in tokens[0].sentence._known_spans:
span = tokens[0].sentence._known_spans[unlabeled_identifier]
return span

# else make a new span
else:
span = super().__new__(self)
span.initialized = False
tokens[0].sentence._known_spans[unlabeled_identifier] = span
return span

def __init__(self, tokens: List[Token]) -> None:
if not self.initialized:
super().__init__(tokens[0].sentence)
self.tokens = tokens
self.initialized: bool = True
super().__init__(tokens[0].sentence)
self.tokens = tokens

@property
def start_position(self) -> int:
Expand Down Expand Up @@ -606,26 +590,10 @@ def embedding(self):


class Relation(_PartOfSentence):
def __new__(self, first: Span, second: Span):
# check if the relation already exists. If so, return it
unlabeled_identifier = self._make_unlabeled_identifier(first, second)
if unlabeled_identifier in first.sentence._known_spans:
span = first.sentence._known_spans[unlabeled_identifier]
return span

# else make a new relation
else:
span = super().__new__(self)
span.initialized = False
first.sentence._known_spans[unlabeled_identifier] = span
return span

def __init__(self, first: Span, second: Span) -> None:
if not self.initialized:
super().__init__(sentence=first.sentence)
self.first: Span = first
self.second: Span = second
self.initialized: bool = True
super().__init__(sentence=first.sentence)
self.first: Span = first
self.second: Span = second

def __repr__(self) -> str:
return str(self)
Expand Down Expand Up @@ -692,7 +660,7 @@ def __init__(
self.tokens: List[Token] = []

# private field for all known spans
self._known_spans: Dict[str, _PartOfSentence] = {}
self._known_parts: Dict[str, _PartOfSentence] = {}

self.language_code: Optional[str] = language_code

Expand Down Expand Up @@ -769,7 +737,7 @@ def get_relations(self, type: str) -> List[Relation]:

def get_spans(self, type: str) -> List[Span]:
spans: List[Span] = []
for potential_span in self._known_spans.values():
for potential_span in self._known_parts.values():
if isinstance(potential_span, Span) and potential_span.has_label(type):
spans.append(potential_span)
return sorted(spans)
Expand Down Expand Up @@ -949,8 +917,7 @@ def to_dict(self, tag_type: Optional[str] = None):
return {"text": self.to_original_text(), "all labels": labels}

def get_span(self, start: int, stop: int):
span_slice = slice(start, stop)
return self[span_slice]
return self[start:stop]

@typing.overload
def __getitem__(self, idx: int) -> Token:
Expand All @@ -960,9 +927,27 @@ def __getitem__(self, idx: int) -> Token:
def __getitem__(self, s: slice) -> Span:
...

@typing.overload
def __getitem__(self, s: typing.Tuple[Span, Span]) -> Relation:
...

def __getitem__(self, subscript):
if isinstance(subscript, slice):
return Span(self.tokens[subscript])
if isinstance(subscript, tuple):
first, second = subscript
identifier = ""
if isinstance(first, Span) and isinstance(second, Span):
identifier = Relation._make_unlabeled_identifier(first, second)
if identifier not in self._known_parts:
self._known_parts[identifier] = Relation(first, second)

return self._known_parts[identifier]
elif isinstance(subscript, slice):
identifier = Span._make_unlabeled_identifier(self.tokens[subscript])

if identifier not in self._known_parts:
self._known_parts[identifier] = Span(self.tokens[subscript])

return self._known_parts[identifier]
else:
return self.tokens[subscript]

Expand Down Expand Up @@ -1108,11 +1093,11 @@ def remove_labels(self, typename: str):
token.remove_labels(typename)

# labels also need to be deleted at all known spans
for span in self._known_spans.values():
for span in self._known_parts.values():
span.remove_labels(typename)

# remove spans without labels
self._known_spans = {k: v for k, v in self._known_spans.items() if len(v.labels) > 0}
self._known_parts = {k: v for k, v in self._known_parts.items() if len(v.labels) > 0}

# delete labels at object itself
super().remove_labels(typename)
Expand Down
5 changes: 1 addition & 4 deletions flair/datasets/sequence_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
Corpus,
FlairDataset,
MultiCorpus,
Relation,
Sentence,
Token,
get_spans_from_bio,
Expand Down Expand Up @@ -684,9 +683,7 @@ def _convert_lines_to_sentence(
tail_end = int(indices[3])
label = indices[4]
# head and tail span indices are 1-indexed and end index is inclusive
relation = Relation(
first=sentence[head_start - 1 : head_end], second=sentence[tail_start - 1 : tail_end]
)
relation = sentence[sentence[head_start - 1 : head_end], sentence[tail_start - 1 : tail_end]]
remapped = self._remap_label(label)
if remapped != "O":
relation.add_label(typename="relation", value=remapped)
Expand Down
2 changes: 1 addition & 1 deletion flair/models/regexp_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_token_span(self, span: Tuple[int, int]) -> Span:
"""
span_start: int = self.__tokens_start_pos.index(span[0])
span_end: int = self.__tokens_end_pos.index(span[1])
return Span(self.tokens[span_start : span_end + 1])
return self.sentence[span_start : span_end + 1]


class RegexpTagger:
Expand Down
15 changes: 6 additions & 9 deletions flair/models/relation_classifier_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,11 +377,9 @@ def _entity_pair_permutations(
"""
valid_entities: List[_Entity] = list(self._valid_entities(sentence))

# Use a dictionary to find gold relation annotations for a given entity pair
relation_to_gold_label: Dict[str, str] = {
relation.unlabeled_identifier: relation.get_label(self.label_type, zero_tag_value=self.zero_tag_value).value
for relation in sentence.get_relations(self.label_type)
}
# ensure that all existing relations without label have the label set to zero_tag_value.
for relation in sentence.get_relations(self.label_type):
relation.set_label(self.label_type, relation.get_label(self.label_type, self.zero_tag_value).value)

# Yield head and tail entity pairs from the cross product of all entities
for head, tail in itertools.product(valid_entities, repeat=2):
Expand All @@ -398,9 +396,8 @@ def _entity_pair_permutations(
continue

# Obtain gold label, if existing
original_relation: Relation = Relation(first=head.span, second=tail.span)
gold_label: Optional[str] = relation_to_gold_label.get(original_relation.unlabeled_identifier)

gold_relation = sentence[head.span, tail.span]
gold_label: Optional[str] = gold_relation.get_label(self.label_type, zero_tag_value=None).value
yield head, tail, gold_label

def _encode_sentence(
Expand Down Expand Up @@ -481,7 +478,7 @@ def _encode_sentence_for_inference(
tail=tail,
gold_label=gold_label if gold_label is not None else self.zero_tag_value,
)
original_relation: Relation = Relation(first=head.span, second=tail.span)
original_relation: Relation = sentence[head.span, tail.span]
yield masked_sentence, original_relation

def _encode_sentence_for_training(self, sentence: Sentence) -> Iterator[EncodedSentence]:
Expand Down
2 changes: 1 addition & 1 deletion flair/models/relation_extractor_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _get_data_points_from_sentence(self, sentence: Sentence) -> List[Relation]:
):
continue

relation = Relation(span_1, span_2)
relation = sentence[span_1, span_2]
if self.training and self.train_on_gold_pairs_only and relation.get_label(self.label_type).value == "O":
continue
entity_pairs.append(relation)
Expand Down
14 changes: 7 additions & 7 deletions flair/models/tars_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,10 +404,9 @@ def _get_tars_formatted_sentence(self, label, sentence):

for entity_label in sentence.get_labels(self.label_type):
if entity_label.value == label:
new_span = Span(
[tars_sentence.get_token(token.idx + label_length) for token in entity_label.data_point]
)
new_span.add_label(self.static_label_type, value="entity")
start_pos = entity_label.data_point[0].idx + label_length - 1
end_pos = entity_label.data_point[-1].idx + label_length
tars_sentence[start_pos:end_pos].add_label(self.static_label_type, value="entity")
tars_sentence.copy_context_from_sentence(sentence)
return tars_sentence

Expand Down Expand Up @@ -588,9 +587,10 @@ def predict(
# only add if all tokens have no label
if tag_this:
# make and add a corresponding predicted span
predicted_span = Span(
[sentence.get_token(token.idx - label_length) for token in label.data_point]
)
start_pos = label.data_point.data_point[0].idx - label_length - 1
end_pos = label.data_point.data_point[-1].idx - label_length

predicted_span = sentence[start_pos:end_pos]
predicted_span.add_label(label_name, value=label.value, score=label.score)

# set indices so that no token can be tagged twice
Expand Down
6 changes: 3 additions & 3 deletions tests/test_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,9 @@ def test_relation_tags():
sentence = Sentence("Humboldt Universität zu Berlin is located in Berlin .")

# create two relation label
Relation(sentence[0:4], sentence[7:8]).add_label("rel", "located in")
Relation(sentence[0:2], sentence[3:4]).add_label("rel", "university of")
Relation(sentence[0:2], sentence[3:4]).add_label("syntactic", "apposition")
sentence[sentence[0:4], sentence[7:8]].add_label("rel", "located in")
sentence[sentence[0:2], sentence[3:4]].add_label("rel", "university of")
sentence[sentence[0:2], sentence[3:4]].add_label("syntactic", "apposition")

# there should be two relation labels
labels: List[Label] = sentence.get_labels("rel")
Expand Down
37 changes: 37 additions & 0 deletions tests/test_sentence.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import copy
import pickle

from flair.data import Sentence


Expand Down Expand Up @@ -73,3 +76,37 @@ def test_start_end_position_pretokenized() -> None:
(10, 18),
(19, 20),
]


def test_spans_support_deepcopy() -> None:
sentence = Sentence(["I", "live", "in", "Vienna", "."])
sentence[3:4].add_label("ner", "LOC")

_ = copy.deepcopy(sentence)


def test_spans_support_pickle() -> None:
sentence = Sentence(["I", "live", "in", "Vienna", "."])
sentence[3:4].add_label("ner", "LOC")

pickle_data = pickle.dumps(sentence)
_ = pickle.loads(pickle_data)


def test_relations_support_deepcopy() -> None:
sentence = Sentence(["Vienna", "is", "the", "capital", "of", "Austria"])
sentence[0:1].add_label("ner", "LOC")
sentence[5:6].add_label("ner", "LOC")
sentence[sentence[0:1], sentence[5:6]].add_label("rel", "capital")

_ = copy.deepcopy(sentence)


def test_relations_support_pickle() -> None:
sentence = Sentence(["Vienna", "is", "the", "capital", "of", "Austria"])
sentence[0:1].add_label("ner", "LOC")
sentence[5:6].add_label("ner", "LOC")
sentence[sentence[0:1], sentence[5:6]].add_label("rel", "capital")

pickle_data = pickle.dumps(sentence)
_ = pickle.loads(pickle_data)

0 comments on commit a053fcc

Please sign in to comment.