Skip to content

Commit

Permalink
add metadata_gold_id_keys and metadata_prediction_id_keys to BratSeri…
Browse files Browse the repository at this point in the history
…alizer (#172)
  • Loading branch information
ArneBinder authored Jul 27, 2024
1 parent 582845d commit 95d7fdd
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 20 deletions.
98 changes: 81 additions & 17 deletions src/serializer/brat.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,20 +113,29 @@ def serialize_annotations(
indices: Dict[str, int],
annotation2id: Dict[Annotation, str],
label_prefix: Optional[str] = None,
annotation_ids: Optional[List[str]] = None,
) -> Tuple[List[str], Dict[Annotation, str]]:
serialized_annotations = []
new_annotation2id: Dict[Annotation, str] = {}
for annotation in annotations:
for idx, annotation in enumerate(annotations):
annotation_type, serialized_annotation = serialize_annotation(
annotation=annotation,
annotation2id=annotation2id,
label_prefix=label_prefix,
)
idx = indices[annotation_type]
annotation_id = f"{annotation_type}{idx}"
if annotation_ids is not None:
if indices.get(annotation_type, 0) > 0:
raise ValueError(
"Cannot specify annotation IDs for the same type (e.g. T or R) if there are "
"other annotations of the same type without an ID."
)
annotation_id = annotation_ids[idx]
else:
index = indices[annotation_type]
annotation_id = f"{annotation_type}{index}"
indices[annotation_type] += 1
serialized_annotations.append(f"{annotation_id}\t{serialized_annotation}")
new_annotation2id[annotation] = annotation_id
indices[annotation_type] += 1

return serialized_annotations, new_annotation2id

Expand All @@ -135,6 +144,8 @@ def serialize_annotation_layers(
layers: List[Tuple[AnnotationLayer, str]],
gold_label_prefix: Optional[str] = None,
prediction_label_prefix: Optional[str] = None,
gold_annotation_ids: Optional[List[Optional[List[str]]]] = None,
prediction_annotation_ids: Optional[List[Optional[List[str]]]] = None,
) -> List[str]:
"""Serialize annotations from given annotation layers into a list of strings.
Expand All @@ -145,15 +156,20 @@ def serialize_annotation_layers(
Defaults to None.
prediction_label_prefix (Optional[str], optional): Prefix to be added to prediction labels.
Defaults to None.
gold_annotation_ids (Optional[List[Optional[str]]], optional): List of gold annotation IDs.
If provided, the length should match the number of layers. Defaults to None.
prediction_annotation_ids (Optional[List[Optional[str]]], optional): List of prediction
annotation IDs. If provided, the length should match the number of layers. Defaults to None.
Returns:
List[str]: List of serialized annotations.
"""

all_serialized_annotations = []
gold_annotation2id: Dict[Annotation, str] = {}
prediction_annotation2id: Dict[Annotation, str] = {}
indices: Dict[str, int] = defaultdict(int)
for layer, what in layers:
for idx, (layer, what) in enumerate(layers):
if what not in ["gold", "prediction", "both"]:
raise ValueError(
f'Invalid value for what to serialize: "{what}". Expected "gold", "prediction", or "both".'
Expand All @@ -171,23 +187,54 @@ def serialize_annotation_layers(
)
serialized_annotations = []
if what in ["gold", "both"]:
if gold_annotation_ids is not None:
if len(gold_annotation_ids) <= idx:
raise ValueError(
"gold_annotation_ids should have the same length as the number of layers."
)
current_gold_annotation_ids = gold_annotation_ids[idx]
if current_gold_annotation_ids is not None and len(
current_gold_annotation_ids
) != len(layer):
raise ValueError(
"gold_annotation_ids should have the same length as the number of gold annotations."
)
else:
current_gold_annotation_ids = None

serialized_gold_annotations, new_gold_ann2id = serialize_annotations(
annotations=layer,
indices=indices,
# gold annotations can only reference other gold annotations
annotation2id=gold_annotation2id,
label_prefix=gold_label_prefix,
annotation_ids=current_gold_annotation_ids,
)
serialized_annotations.extend(serialized_gold_annotations)
gold_annotation2id.update(new_gold_ann2id)
if what in ["prediction", "both"]:
if prediction_annotation_ids is not None:
if len(prediction_annotation_ids) <= idx:
raise ValueError(
"prediction_annotation_ids should have the same length as the number of layers."
)
current_prediction_annotation_ids = prediction_annotation_ids[idx]
if current_prediction_annotation_ids is not None and len(
current_prediction_annotation_ids
) != len(layer.predictions):
raise ValueError(
"prediction_annotation_ids should have the same length as the number of prediction annotations."
)
else:
current_prediction_annotation_ids = None
serialized_predicted_annotations, new_pred_ann2id = serialize_annotations(
annotations=layer.predictions,
indices=indices,
# Predicted annotations can reference both gold and predicted annotations.
# Note that predictions take precedence over gold annotations.
annotation2id={**gold_annotation2id, **prediction_annotation2id},
label_prefix=prediction_label_prefix,
annotation_ids=current_prediction_annotation_ids,
)
prediction_annotation2id.update(new_pred_ann2id)
serialized_annotations.extend(serialized_predicted_annotations)
Expand All @@ -200,10 +247,6 @@ class BratSerializer(DocumentSerializer):
specify the annotation layers to serialize. For now, it supports layers containing LabeledSpan,
LabeledMultiSpan, and BinaryRelation annotations.
If a gold_label_prefix is provided, the gold annotations are serialized with the given prefix.
Otherwise, only the predicted annotations are serialized. A document_processor can be provided
to process documents before serialization.
Attributes:
layers: A mapping from annotation layer names that should be serialized to what should be
serialized, i.e. "gold", "prediction", or "both".
Expand All @@ -212,21 +255,20 @@ class BratSerializer(DocumentSerializer):
with the given string. Otherwise, only predicted annotations are serialized.
prediction_label_prefix: If provided, labels of predicted annotations are prefixed with the
given string.
default_kwargs: Additional keyword arguments to be used as defaults during serialization.
metadata_gold_id_keys: A dictionary mapping layer names to metadata keys that contain the
gold annotation IDs.
metadata_prediction_id_keys: A dictionary mapping layer names to metadata keys that contain
the prediction annotation IDs.
"""

def __init__(
self,
layers: Dict[str, str],
document_processor=None,
prediction_label_prefix=None,
gold_label_prefix=None,
**kwargs,
):
self.document_processor = document_processor
self.layers = layers
self.prediction_label_prefix = prediction_label_prefix
self.gold_label_prefix = gold_label_prefix
self.default_kwargs = kwargs

def __call__(self, documents: Sequence[Document], **kwargs) -> Dict[str, str]:
Expand All @@ -235,8 +277,6 @@ def __call__(self, documents: Sequence[Document], **kwargs) -> Dict[str, str]:
return self.write_with_defaults(
documents=documents,
layers=self.layers,
prediction_label_prefix=self.prediction_label_prefix,
gold_label_prefix=self.gold_label_prefix,
**kwargs,
)

Expand All @@ -254,6 +294,8 @@ def write(
split: Optional[str] = None,
gold_label_prefix: Optional[str] = None,
prediction_label_prefix: Optional[str] = None,
metadata_gold_id_keys: Optional[Dict[str, str]] = None,
metadata_prediction_id_keys: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:

realpath = os.path.realpath(path)
Expand All @@ -280,10 +322,32 @@ def write(
file_name = f"{doc_id}.ann"
metadata_text[f"{file_name}"] = doc.text
ann_path = os.path.join(realpath, file_name)
layer_names = list(layers)
if metadata_gold_id_keys is not None:
gold_annotation_ids = [
doc.metadata[metadata_gold_id_keys[layer_name]]
if layer_name in metadata_gold_id_keys
else None
for layer_name in layer_names
]
else:
gold_annotation_ids = None

if metadata_prediction_id_keys is not None:
prediction_annotation_ids = [
doc.metadata[metadata_prediction_id_keys[layer_name]]
if layer_name in metadata_prediction_id_keys
else None
for layer_name in layer_names
]
else:
prediction_annotation_ids = None
serialized_annotations = serialize_annotation_layers(
layers=[(doc[layer_name], what) for layer_name, what in layers.items()],
layers=[(doc[layer_name], layers[layer_name]) for layer_name in layer_names],
gold_label_prefix=gold_label_prefix,
prediction_label_prefix=prediction_label_prefix,
gold_annotation_ids=gold_annotation_ids,
prediction_annotation_ids=prediction_annotation_ids,
)
with open(ann_path, "w+") as f:
f.writelines(serialized_annotations)
Expand Down
86 changes: 83 additions & 3 deletions tests/unit/serializer/test_brat.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,22 @@ class TextDocumentWithLabeledMultiSpansAndBinaryRelations(TextBasedDocument):
@pytest.fixture
def document():
document = TextDocumentWithLabeledSpansAndBinaryRelations(
text="Harry lives in Berlin, Germany. He works at DFKI.", id="tmp"
text="Harry lives in Berlin, Germany. He works at DFKI.",
id="tmp",
metadata={
"span_ids": [],
"relation_ids": [],
"prediction_span_ids": [],
"prediction_relation_ids": [],
},
)
document.labeled_spans.predictions.extend(
[
LabeledSpan(start=0, end=5, label="PERSON"),
LabeledSpan(start=44, end=48, label="ORGANIZATION"),
]
)
document.metadata["prediction_span_ids"].extend(["T200", "T201"])

assert str(document.labeled_spans.predictions[0]) == "Harry"
assert str(document.labeled_spans.predictions[1]) == "DFKI"
Expand All @@ -119,6 +127,8 @@ def document():
LabeledSpan(start=44, end=48, label="ORGANIZATION"),
]
)
document.metadata["span_ids"].extend(["T100", "T101", "T102"])

assert str(document.labeled_spans[0]) == "Harry"
assert str(document.labeled_spans[1]) == "Berlin, Germany"
assert str(document.labeled_spans[2]) == "DFKI"
Expand All @@ -132,6 +142,7 @@ def document():
),
]
)
document.metadata["prediction_relation_ids"].extend(["R200"])

document.binary_relations.extend(
[
Expand All @@ -147,6 +158,7 @@ def document():
),
]
)
document.metadata["relation_ids"].extend(["R100", "R101"])

return document

Expand Down Expand Up @@ -192,6 +204,51 @@ def test_serialize_annotations(document, what):
raise ValueError(f"Unexpected value for what: {what}")


@pytest.mark.parametrize(
"what",
["gold", "prediction", "both"],
)
def test_serialize_annotations_with_annotation_ids(document, what):
serialized_annotations = serialize_annotation_layers(
layers=[(document.labeled_spans, what), (document.binary_relations, what)],
gold_label_prefix="GOLD",
prediction_label_prefix="PRED" if what == "both" else None,
gold_annotation_ids=[document.metadata["span_ids"], document.metadata["relation_ids"]],
prediction_annotation_ids=[
document.metadata["prediction_span_ids"],
document.metadata["prediction_relation_ids"],
],
)

if what == "both":
assert serialized_annotations == [
"T100\tGOLD-PERSON 0 5\tHarry\n",
"T101\tGOLD-LOCATION 15 30\tBerlin, Germany\n",
"T102\tGOLD-ORGANIZATION 44 48\tDFKI\n",
"T200\tPRED-PERSON 0 5\tHarry\n",
"T201\tPRED-ORGANIZATION 44 48\tDFKI\n",
"R100\tGOLD-lives_in Arg1:T100 Arg2:T101\n",
"R101\tGOLD-works_at Arg1:T100 Arg2:T102\n",
"R200\tPRED-works_at Arg1:T200 Arg2:T201\n",
]
elif what == "gold":
assert serialized_annotations == [
"T100\tGOLD-PERSON 0 5\tHarry\n",
"T101\tGOLD-LOCATION 15 30\tBerlin, Germany\n",
"T102\tGOLD-ORGANIZATION 44 48\tDFKI\n",
"R100\tGOLD-lives_in Arg1:T100 Arg2:T101\n",
"R101\tGOLD-works_at Arg1:T100 Arg2:T102\n",
]
elif what == "prediction":
assert serialized_annotations == [
"T200\tPERSON 0 5\tHarry\n",
"T201\tORGANIZATION 44 48\tDFKI\n",
"R200\tworks_at Arg1:T200 Arg2:T201\n",
]
else:
raise ValueError(f"Unexpected value for what: {what}")


def test_serialize_annotations_unknown_what(document):
with pytest.raises(ValueError) as e:
serialize_annotation_layers(
Expand All @@ -215,7 +272,7 @@ def test_serialize_annotations_missing_prefix(document):
)


def document_processor(document) -> TextBasedDocument:
def append_empty_span_to_labeled_spans(document) -> TextBasedDocument:
doc = document.copy()
doc["labeled_spans"].append(LabeledSpan(start=0, end=0, label="empty"))
return doc
Expand All @@ -225,7 +282,7 @@ def test_write(tmp_path, document):
path = str(tmp_path)
serializer = BratSerializer(
path=path,
document_processor=document_processor,
document_processor=append_empty_span_to_labeled_spans,
layers={"labeled_spans": "prediction", "binary_relations": "prediction"},
)

Expand All @@ -243,6 +300,29 @@ def test_write(tmp_path, document):
]


def test_write_with_annotation_ids(tmp_path, document):
path = str(tmp_path)
serializer = BratSerializer(
path=path,
layers={"labeled_spans": "gold", "binary_relations": "prediction"},
metadata_gold_id_keys={"labeled_spans": "span_ids"},
)

metadata = serializer(documents=[document])
path = metadata["path"]
ann_file = os.path.join(path, f"{document.id}.ann")

with open(ann_file, "r") as file:
lines = file.readlines()

assert lines == [
"T100\tPERSON 0 5\tHarry\n",
"T101\tLOCATION 15 30\tBerlin, Germany\n",
"T102\tORGANIZATION 44 48\tDFKI\n",
"R0\tworks_at Arg1:T100 Arg2:T102\n",
]


def test_write_with_exceptions_and_warnings(tmp_path, caplog, document):
path = str(tmp_path)
serializer = BratSerializer(
Expand Down

0 comments on commit 95d7fdd

Please sign in to comment.