Skip to content

Commit

Permalink
.npz serialization (#39)
Browse files Browse the repository at this point in the history
* npz serialization

* tests
  • Loading branch information
pmaher86 authored Oct 13, 2024
1 parent 861ae5c commit 33ff4cf
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 20 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ repos:
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
exclude: src/blacksquare/word_list.npz
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.8
hooks:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ local_scheme = "no-local-version"
where = ["src"]

[tool.setuptools.package-data]
blacksquare = ["*.dict"]
blacksquare = ["*.npz"]
Binary file not shown.
72 changes: 53 additions & 19 deletions src/blacksquare/word_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,28 +76,39 @@ def __init__(
"""Representation of a scored word list.
Args:
source (Union[ str, Path, List[str], Dict[str, Union[int, float]], ]): The
source for the word list. Can be a list of strings, a dict of strings
to scores, or a path to a .dict file with words in "word;score" format.
Words will be normalized and scores will be scaled from 0-1.
source: The source for the word list. Can be a list of strings, a dict of
strings to scores, a path to a .dict file with words in "word;score"
format, or a path to a .npz file (produced to `.to_npz`) Words will be
normalized and scores will be scaled from 0-1.
Raises:
ValueError: If input type is not recognized
"""
if (
isinstance(source, str)
or isinstance(source, Path)
or isinstance(source, io.IOBase)
):
df = pd.read_csv(
source,
sep=";",
header=None,
names=["word", "score"],
dtype={"word": str, "score": float},
na_filter=False,
)
raw_words_scores = df.values
if isinstance(source, str) or isinstance(source, Path):
if Path(source).suffix == ".npz":
loaded = np.load(source)
length_keys = {
k.split("_")[0]
for k in loaded.keys()
if k not in ("words", "scores")
}
self._words = loaded["words"]
self._scores = loaded["scores"]
self._word_scores_by_length = {
int(k): (loaded[f"{k}_words"], loaded[f"{k}_scores"])
for k in length_keys
}
return
else:
df = pd.read_csv(
source,
sep=";",
header=None,
names=["word", "score"],
dtype={"word": str, "score": float},
na_filter=False,
)
raw_words_scores = df.values
elif isinstance(source, list):
assert len(source) > 0 and isinstance(source[0], str)
raw_words_scores = [(w, 1) for w in source]
Expand Down Expand Up @@ -210,8 +221,31 @@ def score_filter(self, threshold: float) -> WordList:
return WordList(dict(zip(self._words[score_mask], self._scores[score_mask])))

def filter(self, filter_fn: Callable[[ScoredWord], bool]) -> WordList:
"""Returns a new word list filtered by a custom function.
Args:
filtern_fn: The filtering function. Takes a ScoredWord as an
input and outputs a bool.
Returns:
The resulting WordList
"""
return WordList(dict([w for w in self if filter_fn(w)]))

def to_npz(self, file: str | Path) -> None:
"""Serializes word list to a .npz format that is fast to load from disk.
Args:
file: The output file path.
"""
by_length_arrays = {}
for k, v in self._word_scores_by_length.items():
by_length_arrays[f"{k}_words"] = v[0]
by_length_arrays[f"{k}_scores"] = v[1]
np.savez_compressed(
file, words=self._words, scores=self._scores, **by_length_arrays
)

def __len__(self):
return len(self._words)

Expand Down Expand Up @@ -347,4 +381,4 @@ def _normalize(word: str) -> str:
return word.upper().replace(" ", "")


DEFAULT_WORDLIST = WordList(files("blacksquare").joinpath("xwordlist.dict"))
DEFAULT_WORDLIST = WordList(files("blacksquare").joinpath("word_list.npz"))
18 changes: 18 additions & 0 deletions tests/test_word_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,24 @@ def test_add_word_list(self):
def test_score_filter(self, word_list):
assert len(word_list.score_filter(0.5)) == 6

def test_serialization(self, tmp_path):
words = """
AAA;1.0
BBB;0.99
BB;0.5
C;0.1
"""
with (tmp_path / "list.dict").open("w") as f:
f.write(words)

xw = Crossword(3)
wl = WordList(tmp_path / "list.dict")
dict_matches = wl.find_matches(xw[ACROSS, 1])
wl.to_npz(tmp_path / "list.npz")
wl_from_npz = WordList(tmp_path / "list.npz")
npz_matches = wl_from_npz.find_matches(xw[ACROSS, 1])
assert dict_matches.words == npz_matches.words


class TestMatchWordList:
@pytest.fixture
Expand Down

0 comments on commit 33ff4cf

Please sign in to comment.