Skip to content

Commit

Permalink
fix tars typing
Browse files Browse the repository at this point in the history
  • Loading branch information
helpmefindaname committed Aug 30, 2024
1 parent a66a856 commit 642337b
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 12 deletions.
3 changes: 1 addition & 2 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,8 +1024,7 @@ def __getitem__(self, idx: int) -> Token: ...
def __getitem__(self, s: slice) -> Span: ...

@typing.overload
def __getitem__(self, s: typing.Tuple[Span, Span]) -> Relation:
...
def __getitem__(self, s: typing.Tuple[Span, Span]) -> Relation: ...

def __getitem__(self, subscript):
if isinstance(subscript, tuple):
Expand Down
4 changes: 3 additions & 1 deletion flair/models/relation_classifier_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,9 @@ def _entity_pair_permutations(

# Obtain gold label, if existing
gold_relation = sentence[head.span, tail.span]
gold_label: Optional[str] = gold_relation.get_label(self.label_type, zero_tag_value=None).value
gold_label: Optional[str] = gold_relation.get_label(self.label_type, zero_tag_value="O").value
if gold_label == "O":
gold_label = None
yield head, tail, gold_label

def _encode_sentence(
Expand Down
15 changes: 6 additions & 9 deletions flair/models/tars_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,19 +571,16 @@ def predict(

already_set_indices: List[int] = []

sorted_x = sorted(all_detected.items(), key=operator.itemgetter(1))
sorted_x.reverse()
for tuple in sorted_x:
# get the span and its label
label = tuple[0]

sorted_x = sorted(all_detected.items(), key=operator.itemgetter(1), reverse=True)
for label, _ in sorted_x:
span = typing.cast(Span, label.data_point)
label_length = (
0 if not self.prefix else len(label.value.split(" ")) + len(self.separator.split(" "))
)

# determine whether tokens in this span already have a label
tag_this = True
for token in label.data_point:
for token in span:
corresponding_token = sentence.get_token(token.idx - label_length)
if corresponding_token is None:
tag_this = False
Expand All @@ -595,8 +592,8 @@ def predict(
# only add if all tokens have no label
if tag_this:
# make and add a corresponding predicted span
start_pos = label.data_point.data_point[0].idx - label_length - 1
end_pos = label.data_point.data_point[-1].idx - label_length
start_pos = span.tokens[0].idx - label_length - 1
end_pos = span.tokens[-1].idx - label_length

predicted_span = sentence[start_pos:end_pos]
predicted_span.add_label(label_name, value=label.value, score=label.score)
Expand Down

0 comments on commit 642337b

Please sign in to comment.