Skip to content

Commit

Permalink
dont allow repeats
Browse files Browse the repository at this point in the history
  • Loading branch information
pmaher86 committed Oct 8, 2024
1 parent 31884d0 commit 3476fdc
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 10 deletions.
15 changes: 12 additions & 3 deletions src/blacksquare/crossword.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,15 +627,22 @@ def fill(
word_list: Optional[WordList] = None,
timeout: Optional[float] = 30.0,
temperature: float = 0,
score_filter: Optional[float] = None,
allow_repeats: bool = False,
) -> Optional[Crossword]:
"""Searches for a possible fill, and returns the result as a new Crossword
object. Uses a modified depth-first-search algorithm.
Args:
timeout (int, optional): The maximum time in seconds to search before
word_list (WordList, optional): An optional word list to use instead of the
default for the crossword.
timeout (float, optional): The maximum time in seconds to search before
returning. Defaults to 30. If None, will search until completion.
temperature (float, optional): A parameter to control randomness. Defaults
to 0 (no randomness). Reasonable values are around 1.
score_filter: A threshold to apply to the word list before filling.
allow_repeats (bool): Whether to allow words that already appear in the
grid. Defaults to false.
Returns:
Optional[Crossword]: The filled Crossword. Returns None if the search is
Expand All @@ -645,6 +652,8 @@ def fill(
subgraphs = self.get_disconnected_open_subgrids()
start_time = time.time()
word_list = word_list if word_list is not None else self.word_list
if score_filter:
word_list = word_list.score_filter(score_filter)
xw = self.copy()

def recurse_subgraph_fill(
Expand All @@ -657,7 +666,7 @@ def recurse_subgraph_fill(
)
noise = np.abs(np.random.normal(scale=num_matches)) * temperature
word_to_match: Word = xw[active_subgraph[np.argmin(num_matches + noise)]]
matches = word_to_match.find_matches(word_list)
matches = word_to_match.find_matches(word_list, allow_repeats=allow_repeats)
if not matches:
dead_end_states.add(xw.hashable_state(active_subgraph))
return False
Expand Down Expand Up @@ -781,7 +790,7 @@ def _grid_html(self, size_px: Optional[int] = None) -> str:
for c in self.itercells():
c.number
cell_number_span = f'<span class="cell-number">{c.number or ""}</span>'
letter_span = f'<span class="letter">{c.value if c!=BLACK else ""}</span>'
letter_span = f'<span class="letter">{c.str if c!=BLACK else ""}</span>'
circle_span = '<span class="circle"></span>'
if c == BLACK:
extra_class = " black"
Expand Down
20 changes: 15 additions & 5 deletions src/blacksquare/word.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ def __getitem__(self, key) -> Cell:

def __setitem__(self, key, value):
self.cells[key].value = value
# grid_indices = self._parent.get_indices(self.index)
# self._parent[grid_indices[key]] = value

@property
def direction(self) -> Direction:
Expand All @@ -69,7 +67,6 @@ def cells(self) -> List[Cell]:
def value(self) -> str:
# TODO: rename to str
"""str: The current fill value of the word"""
# return "".join(self._parent.grid[self._parent._get_word_mask(self.index)])
return "".join([c.str for c in self.cells])

# Todo: array, str?
Expand Down Expand Up @@ -105,19 +102,32 @@ def symmetric_image(self) -> Optional[Union[Word, List[Word]]]:
else:
return self._parent[result]

def find_matches(self, word_list: Optional[WordList] = None) -> MatchWordList:
def find_matches(
self, word_list: Optional[WordList] = None, allow_repeats: bool = False
) -> MatchWordList:
"""Finds matches for the word, ranking matches by how many valid crosses they
allow.
Args:
word_list (Optional[WordList], optional): The word list to use for matching.
If None, the default wordlist of the parent crossword is used..
If None, the default wordlist of the parent crossword is used.
allow_repeats (bool): Whether to include words that are already in the grid.
Defaults to False.
Returns:
MatchWordList: The matching words, scored by compatible crosses.
"""
word_list = self._parent.word_list if word_list is None else word_list
self_len = len(self)
matches = word_list.find_matches(self)
if not allow_repeats:
matches = matches.filter_words(
[
w.value
for w in self._parent.iterwords()
if len(w) == self_len and not w.is_open()
]
)
open_indices = np.argwhere(
np.equal(self.cells, SpecialCellValue.EMPTY)
).squeeze(axis=1)
Expand Down
17 changes: 15 additions & 2 deletions src/blacksquare/word_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import io
import re
from collections import defaultdict
from functools import lru_cache, cached_property
from functools import cached_property, lru_cache
from importlib.resources import files
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Dict, List, NamedTuple, Optional, Union
Expand Down Expand Up @@ -318,6 +318,20 @@ def score_filter(self, threshold: float) -> MatchWordList:
self._word_length, self._words[score_mask], self._scores[score_mask]
)

def filter_words(self, words: List[str]) -> MatchWordList:
"""Returns a new word list with a specific set of words filtered out.
Args:
words: (List[str]): The list of words to filter out.
Returns:
MatchWordList: The new MatchWordlist.
"""
word_mask = ~np.isin(self._words, words)
return MatchWordList(
self._word_length, self._words[word_mask], self._scores[word_mask]
)


def _normalize(word: str) -> str:
"""Sanitizes an input word.
Expand All @@ -331,5 +345,4 @@ def _normalize(word: str) -> str:
return word.upper().replace(" ", "")


# DEFAULT_WORDLIST = WordList(pkg_resources.resource_stream(__name__, "xwordlist.dict"))
DEFAULT_WORDLIST = WordList(files("blacksquare").joinpath("xwordlist.dict"))
15 changes: 15 additions & 0 deletions tests/test_crossword.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,13 @@ def test_subgrids(self, xw: Crossword):
class TestCrosswordFill:
def test_fill(self, xw):
solution = xw.fill()
assert solution[ACROSS, 5].value == "ABB"
assert len(set(w.value for w in solution.iterwords())) == len(
list(solution.iterwords())
)

def test_fill_with_repeats(self, xw):
solution = xw.fill(allow_repeats=True)
assert solution[ACROSS, 5].value == "BBB"

def test_find_solutions_with_custom_dictionary(self, xw):
Expand All @@ -194,6 +201,14 @@ def test_unsolvable(self, xw):
solution = xw.fill()
assert solution is None

def test_unsolvable_if_repeats_needed(self):
xw = Crossword(3)
word_list = WordList({"AAA": 1})
assert xw.fill(word_list) is None

def test_unsolvable_with_score_filter(self, xw):
assert xw.fill(score_filter=1.0) is None

def test_filled_by_setting_letter(self):
xw = Crossword(5, symmetry=None)
for i in range(5):
Expand Down
7 changes: 7 additions & 0 deletions tests/test_word_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def test_add_word_list(self):
assert new_word_list.words == ["AAA", "XYZ", "ABC", "ZZZ"]
assert new_word_list.scores == [1.0, 1.0, 0.9, 0.1]

def test_score_filter(self, word_list):
assert len(word_list.score_filter(0.5)) == 6


class TestMatchWordList:
@pytest.fixture
Expand All @@ -75,3 +78,7 @@ def rescore_fn(word, score):
rescored = matches.rescore(rescore_fn)
assert rescored.get_score("ABC") == 2.0
assert rescored.get_score("BCD") == 0.1

def test_filter_words(self, matches: MatchWordList):
assert len(matches) == 5
assert len(matches.filter_words(["ABC", "ABB"])) == 3

0 comments on commit 3476fdc

Please sign in to comment.