Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Text only steps #5390

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added a way to remix datasets flexibly
- Added `from_pretrained_transformer_and_instances` constructor to `Vocabulary`
- `TransformerTextField` now supports `__len__`.
- Tango steps for keeping only the strings or text fields from a dataset

### Fixed

Expand Down
2 changes: 2 additions & 0 deletions allennlp/tango/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

from allennlp.tango.training import TrainingStep
from allennlp.tango.evaluation import EvaluationStep
from allennlp.tango.text_fields_only import TextFieldsOnlyDataset
from allennlp.tango.strings_only import StringsOnlyDataset

import warnings

Expand Down
44 changes: 26 additions & 18 deletions allennlp/tango/text_only.py → allennlp/tango/strings_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,27 @@
every time we release a new version.*
"""

import dataclasses
from typing import Set, Optional, Iterable, Any

from allennlp.tango.dataset import DatasetDict
from allennlp.tango.step import Step
from allennlp.common.sqlite_sparse_sequence import SqliteSparseSequence
from allennlp.tango.sqlite_format import SqliteDictFormat
from tqdm import tqdm


@Step.register("text_only")
class TextOnlyDataset(Step):
@Step.register("strings_only")
class StringsOnlyDataset(Step):
"""
This step converts a dataset into another dataset that contains only the strings from the original dataset.
This step converts a dataset into another dataset that contains only strings from the original dataset.

You can specify exactly which fields to keep from the original dataset (default is all of them).
You can specify a minimum length of string to keep, to filter out strings that are too short.
"""

DETERMINISTIC = True
VERSION = "001"
FORMAT = SqliteDictFormat()

def run( # type: ignore
self,
Expand All @@ -39,25 +43,29 @@ def run( # type: ignore
"""

def find_nested_strings(o: Any, prefix: str = "") -> Iterable[str]:
if isinstance(o, list) or isinstance(o, tuple):
if isinstance(o, str):
if fields_to_keep is None or prefix in fields_to_keep:
if min_length is None or len(o) >= min_length:
yield o
elif isinstance(o, list) or isinstance(o, tuple):
for i, item in enumerate(o):
new_prefix = f"{prefix}.{i}"
yield from find_nested_strings(item, new_prefix)
elif isinstance(o, dict):
for name, item in o.items():
new_prefix = f"{prefix}.{name}"
yield from find_nested_strings(item, new_prefix)
elif isinstance(o, str):
if fields_to_keep is None or prefix in fields_to_keep:
if min_length is None or len(o) >= min_length:
yield o

return dataclasses.replace(
input,
splits={
split_name: [
{"text": text} for instance in split for text in find_nested_strings(instance)
]
for split_name, split in input.splits.items()
},
)
splits = {}
for split_name, split in input.splits.items():
sequence_file = self.work_dir() / f"{split_name}.sqlite"
sequence_file.unlink(missing_ok=True)
sequence = SqliteSparseSequence(sequence_file)
sequence.extend(
{"text": string}
for instance in tqdm(split, desc=f"Processing split '{split_name}'")
for string in find_nested_strings(instance)
)
splits[split_name] = sequence

return DatasetDict(splits=splits, vocab=input.vocab, metadata=input.metadata)
76 changes: 76 additions & 0 deletions allennlp/tango/text_fields_only.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""
*AllenNLP Tango is an experimental API and parts of it might change or disappear
every time we release a new version.*
"""

from typing import Set, Optional, Iterable, Any

from allennlp.tango.dataset import DatasetDict
from allennlp.tango.step import Step
from allennlp.common.sqlite_sparse_sequence import SqliteSparseSequence
from allennlp.data.fields import TextField, TransformerTextField
from allennlp.data.instance import Instance
from allennlp.tango.sqlite_format import SqliteDictFormat
from allennlp.data import Field
from tqdm import tqdm


@Step.register("text_fields_only")
class TextFieldsOnlyDataset(Step):
"""
This step converts a dataset into another dataset that contains only text fields from the original dataset.

You can specify exactly which fields to keep from the original dataset (default is all of them).
You can specify a minimum length of string to keep, to filter out strings that are too short.
"""

DETERMINISTIC = True
VERSION = "003"
FORMAT = SqliteDictFormat()

def run( # type: ignore
self,
input: DatasetDict,
*,
fields_to_keep: Optional[Set[str]] = None,
min_length: Optional[int] = None,
) -> DatasetDict:
"""
Turns the `input` dataset into another dataset that contains only the text fields from the
original dataset.

* `fields_to_keep` is an optional list of field names that you want to keep in the result.
If this is `None`, all fields are kept.
* `min_length` specifies the minimum length that a string must have to be part of the
result. If this is `None`, all strings are considered.
"""

def find_nested_fields(o: Any, prefix: str = "") -> Iterable[Field]:
if isinstance(o, list) or isinstance(o, tuple):
for i, item in enumerate(o):
new_prefix = f"{prefix}.{i}"
yield from find_nested_fields(item, new_prefix)
elif isinstance(o, dict):
for name, item in o.items():
new_prefix = f"{prefix}.{name}"
yield from find_nested_fields(item, new_prefix)
elif isinstance(o, Instance):
yield from find_nested_fields(o.fields, prefix)
elif isinstance(o, TextField) or isinstance(o, TransformerTextField):
if fields_to_keep is None or prefix in fields_to_keep:
if min_length is None or len(o) >= min_length:
yield o

splits = {}
for split_name, split in input.splits.items():
sequence_file = self.work_dir() / f"{split_name}.sqlite"
sequence_file.unlink(missing_ok=True)
sequence = SqliteSparseSequence(sequence_file)
sequence.extend(
Instance({"text": text_field})
for instance in tqdm(split, desc=f"Processing split '{split_name}'")
for text_field in find_nested_fields(instance)
)
splits[split_name] = sequence

return DatasetDict(splits=splits, vocab=input.vocab, metadata=input.metadata)
2 changes: 1 addition & 1 deletion tests/commands/tango_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_dry_run(self):
params_as_dict_because_mypy_is_lame = {
"dataset": {"type": "hf_dataset", "dataset_name": "squad"},
"dataset_text_only": {
"type": "text_only",
"type": "text_fields_only",
"input": "dataset",
"fields_to_keep": ["context", "question"],
},
Expand Down
6 changes: 3 additions & 3 deletions tests/tango/steps_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import logging

from allennlp.tango.text_only import TextOnlyDataset
from allennlp.tango.text_fields_only import TextFieldsOnlyDataset

logging.basicConfig(level=logging.INFO)

Expand Down Expand Up @@ -63,7 +63,7 @@ def test_make_step_graph(ordered_ascending: bool):
params_as_dict_because_mypy_is_lame = {
"dataset": {"type": "hf_dataset", "dataset_name": "squad"},
"dataset_text_only": {
"type": "text_only",
"type": "text_fields_only",
"input": {"type": "ref", "ref": "dataset"},
"fields_to_keep": ["context", "question"],
},
Expand All @@ -75,7 +75,7 @@ def test_make_step_graph(ordered_ascending: bool):
step_graph = step_graph_from_params(params.pop("steps"))
assert len(step_graph) == 2
assert isinstance(step_graph["dataset"], HuggingfaceDataset)
assert isinstance(step_graph["dataset_text_only"], TextOnlyDataset)
assert isinstance(step_graph["dataset_text_only"], TextFieldsOnlyDataset)
assert step_graph["dataset_text_only"].kwargs["input"] == step_graph["dataset"]


Expand Down