diff --git a/.gitignore b/.gitignore index c2aaddf5..faa33005 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,4 @@ tmp/* /wheelhouse docs/build docs/src +*.orig diff --git a/textworld/challenges/tests/test_coin_collector.py b/textworld/challenges/tests/test_coin_collector.py index f705a4cd..facab8b1 100644 --- a/textworld/challenges/tests/test_coin_collector.py +++ b/textworld/challenges/tests/test_coin_collector.py @@ -20,5 +20,5 @@ def test_making_coin_collector(): settings = {"level": level} game = coin_collector.make(settings, options) - assert len(game.quests[0].commands) == expected[level]["quest_length"] + assert len(game.walkthrough) == expected[level]["quest_length"] assert len(game.world.rooms) == expected[level]["nb_rooms"] diff --git a/textworld/challenges/tests/test_treasure_hunter.py b/textworld/challenges/tests/test_treasure_hunter.py index 65dbd9fd..922a70cf 100644 --- a/textworld/challenges/tests/test_treasure_hunter.py +++ b/textworld/challenges/tests/test_treasure_hunter.py @@ -13,5 +13,5 @@ def test_making_treasure_hunter_games(): settings = {"level": level} game = treasure_hunter.make(settings, options) - assert len(game.quests[0].commands) == game.metadata["quest_length"], "Level {}".format(level) + assert len(game.walkthrough) == game.metadata["quest_length"], "Level {}".format(level) assert len(game.world.rooms) == game.metadata["world_size"], "Level {}".format(level) diff --git a/textworld/challenges/tw_cooking/cooking.py b/textworld/challenges/tw_cooking/cooking.py index 4513b1b1..c7b312fd 100644 --- a/textworld/challenges/tw_cooking/cooking.py +++ b/textworld/challenges/tw_cooking/cooking.py @@ -928,7 +928,7 @@ def make(settings: Mapping[str, str], options: Optional[GameOptions] = None) -> start_room = rng_map.choice(M.rooms) M.set_player(start_room) - M.grammar = textworld.generator.make_grammar(options.grammar, rng=rng_grammar) + M.grammar = textworld.generator.make_grammar(options.grammar, rng=rng_grammar, kb=options.kb) # Remove every food preparation with grilled, if there is no BBQ. if M.find_by_name("BBQ") is None: diff --git a/textworld/challenges/tw_treasure_hunter/treasure_hunter.py b/textworld/challenges/tw_treasure_hunter/treasure_hunter.py index fabea384..45a8d57d 100644 --- a/textworld/challenges/tw_treasure_hunter/treasure_hunter.py +++ b/textworld/challenges/tw_treasure_hunter/treasure_hunter.py @@ -233,7 +233,7 @@ def make_game(mode: str, options: GameOptions) -> textworld.Game: quest = Quest(win_events=[event], fail_events=[Event(conditions={Proposition("in", [wrong_obj, world.inventory])})]) - grammar = textworld.generator.make_grammar(options.grammar, rng=rng_grammar) + grammar = textworld.generator.make_grammar(options.grammar, rng=rng_grammar, kb=options.kb) game = textworld.generator.make_game_with(world, [quest], grammar) game.metadata.update(metadata) mode_choice = modes.index(mode) diff --git a/textworld/generator/__init__.py b/textworld/generator/__init__.py index 87f5de50..6dd70764 100644 --- a/textworld/generator/__init__.py +++ b/textworld/generator/__init__.py @@ -17,7 +17,9 @@ from textworld.generator.chaining import ChainingOptions, QuestGenerationError from textworld.generator.chaining import sample_quest from textworld.generator.world import World -from textworld.generator.game import Game, Quest, Event, GameOptions +from textworld.generator.game import Game, Quest, GameOptions +from textworld.generator.game import EventCondition, EventAction, EventAnd, EventOr +from textworld.generator.game import Event # For backward compatibility from textworld.generator.graph_networks import create_map, create_small_map from textworld.generator.text_generation import generate_text_from_grammar @@ -142,19 +144,18 @@ def make_quest(world: Union[World, State], options: Optional[GameOptions] = None for i in range(1, len(chain.nodes)): actions.append(chain.actions[i - 1]) if chain.nodes[i].breadth != chain.nodes[i - 1].breadth: - event = Event(actions) - quests.append(Quest(win_events=[event])) + quests.append(Quest(win_event=EventCondition(actions=actions))) actions.append(chain.actions[-1]) - event = Event(actions) - quests.append(Quest(win_events=[event])) + quests.append(Quest(win_event=EventCondition(actions=actions))) return quests -def make_grammar(options: Mapping = {}, rng: Optional[RandomState] = None) -> Grammar: +def make_grammar(options: Mapping = {}, rng: Optional[RandomState] = None, + kb: Optional[KnowledgeBase] = None) -> Grammar: rng = g_rng.next() if rng is None else rng - grammar = Grammar(options, rng) + grammar = Grammar(options, rng, kb) grammar.check() return grammar diff --git a/textworld/generator/game.py b/textworld/generator/game.py index e7feb9b8..cc97f74a 100644 --- a/textworld/generator/game.py +++ b/textworld/generator/game.py @@ -5,10 +5,11 @@ import copy import json import textwrap +import warnings +import itertools from typing import List, Dict, Optional, Mapping, Any, Iterable, Union, Tuple from collections import OrderedDict -from functools import partial from numpy.random import RandomState @@ -17,7 +18,7 @@ from textworld.generator.data import KnowledgeBase from textworld.generator.text_grammar import Grammar, GrammarOptions from textworld.generator.world import World -from textworld.logic import Action, Proposition, State +from textworld.logic import Action, Proposition, State, Rule, Variable from textworld.generator.graph_networks import DIRECTIONS from textworld.generator.chaining import ChainingOptions @@ -45,53 +46,24 @@ def __init__(self): super().__init__(msg) -def gen_commands_from_actions(actions: Iterable[Action], kb: Optional[KnowledgeBase] = None) -> List[str]: - kb = kb or KnowledgeBase.default() +class TextworldGameVersionWarning(UserWarning): + pass - def _get_name_mapping(action): - mapping = kb.rules[action.name].match(action) - return {ph.name: var.name for ph, var in mapping.items()} - commands = [] - for action in actions: - command = "None" - if action is not None: - command = kb.inform7_commands[action.name] - command = command.format(**_get_name_mapping(action)) +class AbstractEvent: - commands.append(command) + _SERIAL_VERSION = 2 - return commands - - -class Event: - """ - Event happening in TextWorld. - - An event gets triggered when its set of conditions become all statisfied. - - Attributes: - actions: Actions to be performed to trigger this event - commands: Human readable version of the actions. - condition: :py:class:`textworld.logic.Action` that can only be applied - when all conditions are statisfied. - """ - - def __init__(self, actions: Iterable[Action] = (), - conditions: Iterable[Proposition] = (), - commands: Iterable[str] = ()) -> None: + def __init__(self, actions: Iterable[Action] = (), commands: Iterable[str] = (), name: str = "") -> None: """ Args: actions: The actions to be performed to trigger this event. - If an empty list, then `conditions` must be provided. - conditions: Set of propositions which need to - be all true in order for this event - to get triggered. commands: Human readable version of the actions. """ self.actions = actions self.commands = commands - self.condition = self.set_conditions(conditions) + self.name = name + self.is_dnf = False @property def actions(self) -> Iterable[Action]: @@ -109,9 +81,113 @@ def commands(self) -> Iterable[str]: def commands(self, commands: Iterable[str]) -> None: self._commands = tuple(commands) - def is_triggering(self, state: State) -> bool: - """ Check if this event would be triggered in a given state. """ - return state.is_applicable(self.condition) + def __hash__(self) -> int: + return hash((self.actions, self.commands)) + + def __eq__(self, other: Any) -> bool: + return (isinstance(other, AbstractEvent) + and self.actions == other.actions + and self.commands == other.commands + and self.name == other.name + and self.is_dnf == other.is_dnf) + + @classmethod + def deserialize(cls, data: Mapping) -> Union["AbstractEvent", "EventCondition", "EventAction", "EventOr", "EventAnd"]: + """ Creates a `AbstractEvent` (or one of its subtypes) from serialized data. + + Args: + data: Serialized data with the needed information to build a `AbstractEvent` (or one of its subtypes) object. + """ + version = data.get("version", 1) + if version == 1: + data["type"] = "EventCondition" + + if data["type"] == "EventCondition": + obj = EventCondition.deserialize(data) + elif data["type"] == "EventAction": + obj = EventAction.deserialize(data) + elif data["type"] == "EventOr": + obj = EventOr.deserialize(data) + elif data["type"] == "EventAnd": + obj = EventAnd.deserialize(data) + elif data["type"] == "AbstractEvent": + obj = cls() + + obj.actions = [Action.deserialize(d) for d in data["actions"]] + obj.commands = data["commands"] + obj.name = data.get("name", "") + obj.is_dnf = data.get("is_dnf", False) + return obj + + def serialize(self) -> Mapping: + """ Serialize this event. + + Results: + Event's data serialized to be JSON compatible + """ + return { + "version": self._SERIAL_VERSION, + "type": self.__class__.__name__, + "commands": self.commands, + "actions": [action.serialize() for action in self.actions], + "name": self.name, + "is_dnf": self.is_dnf, + } + + def copy(self) -> "AbstractEvent": + """ Copy this event. """ + return self.deserialize(self.serialize()) + + @classmethod + def to_dnf(cls, expr: Optional["AbstractEvent"]) -> Optional["AbstractEvent"]: + """Normalize a boolean expression to its DNF. + + Expr can be an AbstractEvent, it this case it returns EventOr([EventAnd([element])]). + Expr can be an EventOr(...) / EventAnd(...) expressions, + in which cases it returns also a disjunctive normalised form (removing identical elements) + + References: + Code inspired by https://stackoverflow.com/a/58372345 + """ + if expr is None: + return None + + if expr.is_dnf: + return expr # Expression is already in DNF. + + if not isinstance(expr, (EventOr, EventAnd)): + result = EventOr((EventAnd((expr,)),)) + + elif isinstance(expr, EventOr): + result = EventOr(se for e in expr for se in cls.to_dnf(e)) + + elif isinstance(expr, EventAnd): + total = [] + for c in itertools.product(*[cls.to_dnf(e) for e in expr]): + total.append(EventAnd(se for e in c for se in e)) + + result = EventOr(total) + + result.is_dnf = True + return result + + +class EventCondition(AbstractEvent): + + def __init__(self, conditions: Iterable[Proposition] = (), + actions: Iterable[Action] = (), + commands: Iterable[str] = (), + **kwargs) -> None: + """ + Args: + conditions: Set of propositions which need to be all true in order for this event + to get triggered. + actions: The actions to be performed to trigger this event. + If an empty list, then `conditions` must be provided. + commands: Human readable version of the actions. + """ + super(EventCondition, self).__init__(actions, commands, **kwargs) + self.condition = self.set_conditions(conditions) def set_conditions(self, conditions: Iterable[Proposition]) -> Action: """ @@ -138,43 +214,266 @@ def set_conditions(self, conditions: Iterable[Proposition]) -> Action: postconditions=list(conditions) + [event]) return self.condition + def is_triggering(self, state: State, action: Optional[Action] = None, callback: Optional[callable] = None) -> bool: + """ Check if this event would be triggered in a given state. """ + is_triggering = state.is_applicable(self.condition) + if callback and is_triggering: + callback(self) + + return is_triggering + + def __str__(self) -> str: + return str(self.condition) + + def __repr__(self) -> str: + return "EventCondition(Action.parse('{}'), name={})".format(self.condition, self.name) + def __hash__(self) -> int: return hash((self.actions, self.commands, self.condition)) def __eq__(self, other: Any) -> bool: - return (isinstance(other, Event) - and self.actions == other.actions - and self.commands == other.commands + return (isinstance(other, EventCondition) + and super().__eq__(other) and self.condition == other.condition) @classmethod - def deserialize(cls, data: Mapping) -> "Event": - """ Creates an `Event` from serialized data. + def deserialize(cls, data: Mapping) -> "EventCondition": + """ Creates an `EventCondition` from serialized data. Args: - data: Serialized data with the needed information to build a - `Event` object. + data: Serialized data with the needed information to build a `EventCondition` object. """ - actions = [Action.deserialize(d) for d in data["actions"]] condition = Action.deserialize(data["condition"]) - event = cls(actions, condition.preconditions, data["commands"]) - return event + return cls(conditions=condition.preconditions) def serialize(self) -> Mapping: """ Serialize this event. Results: - `Event`'s data serialized to be JSON compatible. + `EventCondition`'s data serialized to be JSON compatible. """ - data = {} - data["commands"] = self.commands - data["actions"] = [action.serialize() for action in self.actions] + data = super().serialize() data["condition"] = self.condition.serialize() return data - def copy(self) -> "Event": - """ Copy this event. """ - return self.deserialize(self.serialize()) + +class Event: # For backward compatibility. + """ + Event happening in TextWorld. + + An event gets triggered when its set of conditions become all statisfied. + + .. warning:: Deprecated in favor of + :py:class:`textworld.generator.EventCondition `. + """ + + def __new__(cls, actions: Iterable[Action] = (), + conditions: Iterable[Proposition] = (), + commands: Iterable[str] = ()): + return EventCondition(actions=actions, conditions=conditions, commands=commands) + + +class EventAction(AbstractEvent): + + def __init__(self, action: Rule, + actions: Iterable[Action] = (), + commands: Iterable[str] = (), + **kwargs) -> None: + """ + Args: + action: The action to be performed to trigger this event. + actions: The actions to be performed to trigger this event. + commands: Human readable version of the actions. + + Notes: + TODO: EventAction are temporal. + """ + super(EventAction, self).__init__(actions, commands, **kwargs) + self.action = action + + def is_triggering(self, state: Optional[State] = None, + action: Optional[Action] = None, + callback: Optional[callable] = None) -> bool: + """ Check if this event would be triggered for a given action. """ + if action is None: + return False + + mapping = self.action.match(action) + if mapping is None: + return False + + is_triggering = all( + ph.name == mapping[ph].name for ph in self.action.placeholders if ph.name != ph.type + ) + if callback and is_triggering: + callback(self) + + return is_triggering + + def __str__(self) -> str: + return str(self.action) + + def __repr__(self) -> str: + return "EventAction(Rule.parse('{}'), name={})".format(self.action, self.name) + + def __hash__(self) -> int: + return hash((self.actions, self.commands, self.action)) + + def __eq__(self, other: Any) -> bool: + return (isinstance(other, EventAction) + and super().__eq__(other) + and self.action == other.action) + + @classmethod + def deserialize(cls, data: Mapping) -> "EventAction": + """ Creates an `EventAction` from serialized data. + + Args: + data: Serialized data with the needed information to build a + `EventAction` object. + """ + action = Rule.deserialize(data["action"]) + return cls(action=action) + + def serialize(self) -> Mapping: + """ Serialize this event. + + Results: + `EventAction`'s data serialized to be JSON compatible. + """ + data = super().serialize() + data["action"] = self.action.serialize() + return data + + +class EventOr(AbstractEvent): + def __init__(self, events: Iterable[AbstractEvent] = ()): + super().__init__() + self.events = events + if len(self.events) == 1: + self.commands = self.events[0].commands + self.actions = self.events[0].actions + + @property + def events(self) -> Tuple[AbstractEvent]: + return self._events + + @events.setter + def events(self, events: Iterable[AbstractEvent]) -> None: + self._events = tuple(events) + + def is_triggering(self, state: Optional[State] = None, + action: Optional[Action] = None, + callback: Optional[callable] = None) -> bool: + """ Check if this event would be triggered for a given state and/or action. """ + is_triggering = any(event.is_triggering(state, action, callback) for event in self.events) + if callback and is_triggering: + callback(self) + + return is_triggering + + def __iter__(self) -> Iterable[AbstractEvent]: + yield from self.events + + def __len__(self) -> int: + return len(self.events) + + def __repr__(self) -> str: + return "EventOr({!r})".format(self.events) + + def __str__(self) -> str: + return "EventOr({})".format(self.events) + + def __hash__(self) -> int: + return hash(self.events) + + def __eq__(self, other: Any) -> bool: + return (isinstance(other, EventOr) + and self.events == other.events) + + def serialize(self) -> Mapping: + """ Serialize this EventOr. + + Results: + EventOr's data serialized to be JSON compatible + """ + data = super().serialize() + data["events"] = [e.serialize() for e in self.events] + return data + + @classmethod + def deserialize(cls, data: Mapping) -> "EventOr": + """ Creates a `EventOr` from serialized data. + + Args: + data: Serialized data with the needed information to build a `EventOr` object. + """ + return cls([AbstractEvent.deserialize(d) for d in data["events"]]) + + +class EventAnd(AbstractEvent): + def __init__(self, events: Iterable[AbstractEvent] = ()): + super().__init__() + self.events = events + if len(self.events) == 1: + self.commands = self.events[0].commands + self.actions = self.events[0].actions + + @property + def events(self) -> Tuple[AbstractEvent]: + return self._events + + @events.setter + def events(self, events: Iterable[AbstractEvent]) -> None: + self._events = tuple(events) + + def is_triggering(self, state: Optional[State] = None, + action: Optional[Action] = None, + callback: Optional[callable] = None) -> bool: + """ Check if this event would be triggered for a given state and/or action. """ + is_triggering = all(event.is_triggering(state, action, callback) for event in self.events) + if callback and is_triggering: + callback(self) + + return is_triggering + + def __iter__(self) -> Iterable[AbstractEvent]: + yield from self.events + + def __len__(self) -> int: + return len(self.events) + + def __repr__(self) -> str: + return "EventAnd({!r})".format(self.events) + + def __str__(self) -> str: + return "EventAnd({})".format(self.events) + + def __hash__(self) -> int: + return hash(self.events) + + def __eq__(self, other: Any) -> bool: + return (isinstance(other, EventAnd) + and self.events == other.events) + + def serialize(self) -> Mapping: + """ Serialize this EventAnd. + + Results: + EventAnd's data serialized to be JSON compatible + """ + data = super().serialize() + data["events"] = [e.serialize() for e in self.events] + return data + + @classmethod + def deserialize(cls, data: Mapping) -> "EventAnd": + """ Creates a `EventAnd` from serialized data. + + Args: + data: Serialized data with the needed information to build a `EventAnd` object. + """ + return cls([AbstractEvent.deserialize(d) for d in data["events"]]) class Quest: @@ -184,10 +483,10 @@ class Quest: a mutually exclusive set of failing events. Attributes: - win_events: Mutually exclusive set of winning events. That is, + win_event: Mutually exclusive set of winning events. That is, only one such event needs to be triggered in order to complete this quest. - fail_events: Mutually exclusive set of failing events. That is, + fail_event: Mutually exclusive set of failing events. That is, only one such event needs to be triggered in order to fail this quest. reward: Reward given for completing this quest. @@ -195,53 +494,64 @@ class Quest: commands: List of text commands leading to this quest completion. """ + _SERIAL_VERSION = 2 + def __init__(self, - win_events: Iterable[Event] = (), - fail_events: Iterable[Event] = (), + win_event: Optional[AbstractEvent] = None, + fail_event: Optional[AbstractEvent] = None, reward: Optional[int] = None, desc: Optional[str] = None, - commands: Iterable[str] = ()) -> None: + commands: Iterable[str] = (), + **kwargs) -> None: r""" Args: - win_events: Mutually exclusive set of winning events. That is, + win_event: Mutually exclusive set of winning events. That is, + only one such event needs to be triggered in order + to complete this quest. + fail_event: Mutually exclusive set of failing events. That is, only one such event needs to be triggered in order - to complete this quest. - fail_events: Mutually exclusive set of failing events. That is, - only one such event needs to be triggered in order - to fail this quest. + to fail this quest. reward: Reward given for completing this quest. By default, reward is set to 1 if there is at least one winning events otherwise it is set to 0. desc: A text description of the quest. commands: List of text commands leading to this quest completion. """ - self.win_events = tuple(win_events) - self.fail_events = tuple(fail_events) + # Backward compatibility: check for old argument names. + if "win_events" in kwargs: + win_event = kwargs["win_events"] + if "fail_events" in kwargs: + fail_event = kwargs["fail_events"] + + # Backward compatibility: convert list of Events to EventOr(events). + if win_event is not None and not isinstance(win_event, AbstractEvent): + win_event = EventOr(win_event) + + if fail_event is not None and not isinstance(fail_event, AbstractEvent): + fail_event = EventOr(fail_event) + + self.win_event = AbstractEvent.to_dnf(win_event) if win_event else None + self.fail_event = AbstractEvent.to_dnf(fail_event) if fail_event else None self.desc = desc self.commands = tuple(commands) # Unless explicitly provided, reward is set to 1 if there is at least # one winning events otherwise it is set to 0. - self.reward = int(len(win_events) > 0) if reward is None else reward + self.reward = reward or int(self.win_event is not None) - if len(self.win_events) == 0 and len(self.fail_events) == 0: + if self.win_event is None and self.fail_event is None: raise UnderspecifiedQuestError() @property - def win_events(self) -> Iterable[Event]: - return self._win_events + def events(self) -> Iterable[EventAnd]: + events = [] + if self.win_event: + events += list(self.win_event) - @win_events.setter - def win_events(self, events: Iterable[Event]) -> None: - self._win_events = tuple(events) + if self.fail_event: + events += list(self.fail_event) - @property - def fail_events(self) -> Iterable[Event]: - return self._fail_events - - @fail_events.setter - def fail_events(self, events: Iterable[Event]) -> None: - self._fail_events = tuple(events) + return events @property def commands(self) -> Iterable[str]: @@ -251,22 +561,21 @@ def commands(self) -> Iterable[str]: def commands(self, commands: Iterable[str]) -> None: self._commands = tuple(commands) - def is_winning(self, state: State) -> bool: - """ Check if this quest is winning in that particular state. """ - return any(event.is_triggering(state) for event in self.win_events) + def is_winning(self, state: Optional[State] = None, action: Optional[Action] = None) -> bool: + """ Check if this quest is winning for a given state and/or after a given action. """ + return self.win_event.is_triggering(state, action) - def is_failing(self, state: State) -> bool: - """ Check if this quest is failing in that particular state. """ - return any(event.is_triggering(state) for event in self.fail_events) + def is_failing(self, state: Optional[State] = None, action: Optional[Action] = None) -> bool: + """ Check if this quest is failing for a given state and/or after a given action. """ + return self.fail_event.is_triggering(state, action) def __hash__(self) -> int: - return hash((self.win_events, self.fail_events, self.reward, - self.desc, self.commands)) + return hash((self.win_event, self.fail_event, self.reward, self.desc, self.commands)) def __eq__(self, other: Any) -> bool: return (isinstance(other, Quest) - and self.win_events == other.win_events - and self.fail_events == other.fail_events + and self.win_event == other.win_event + and self.fail_event == other.fail_event and self.reward == other.reward and self.desc == other.desc and self.commands == other.commands) @@ -279,12 +588,22 @@ def deserialize(cls, data: Mapping) -> "Quest": data: Serialized data with the needed information to build a `Quest` object. """ - win_events = [Event.deserialize(d) for d in data["win_events"]] - fail_events = [Event.deserialize(d) for d in data["fail_events"]] + version = data.get("version", 1) + if version == 1: + win_events = [AbstractEvent.deserialize(event) for event in data["win_events"]] + fail_events = [AbstractEvent.deserialize(event) for event in data["fail_events"]] + commands = data.get("commands", []) + reward = data["reward"] + desc = data["desc"] + quest = cls(win_events, fail_events, reward, desc, commands) + return quest + + win_event = AbstractEvent.deserialize(data["win_event"]) if data["win_event"] else None + fail_event = AbstractEvent.deserialize(data["fail_event"]) if data["fail_event"] else None commands = data.get("commands", []) reward = data["reward"] desc = data["desc"] - return cls(win_events, fail_events, reward, desc, commands) + return cls(win_event, fail_event, reward, desc, commands) def serialize(self) -> Mapping: """ Serialize this quest. @@ -292,13 +611,14 @@ def serialize(self) -> Mapping: Results: Quest's data serialized to be JSON compatible """ - data = {} - data["desc"] = self.desc - data["reward"] = self.reward - data["commands"] = self.commands - data["win_events"] = [event.serialize() for event in self.win_events] - data["fail_events"] = [event.serialize() for event in self.fail_events] - return data + return { + "version": self._SERIAL_VERSION, + "desc": self.desc, + "reward": self.reward, + "commands": self.commands, + "win_event": self.win_event.serialize() if self.win_event else None, + "fail_event": self.fail_event.serialize() if self.fail_event else None + } def copy(self) -> "Quest": """ Copy this quest. """ @@ -371,9 +691,16 @@ class Game: A `Game` is defined by a world and it can have quest(s) or not. Additionally, a grammar can be provided to control the text generation. + + Notes: + ----- + Here's the list of the diffrent `Game` class versions. + - v1: Initial version. + - v2: Games that have been created using the new Event classes. + """ - _SERIAL_VERSION = 1 + _SERIAL_VERSION = 2 def __init__(self, world: World, grammar: Optional[Grammar] = None, quests: Iterable[Quest] = ()) -> None: @@ -417,21 +744,10 @@ def change_grammar(self, grammar: Grammar) -> None: """ Changes the grammar used and regenerate all text. """ self.grammar = grammar - _gen_commands = partial(gen_commands_from_actions, kb=self.kb) if self.grammar: - from textworld.generator.inform7 import Inform7Game from textworld.generator.text_generation import generate_text_from_grammar - inform7 = Inform7Game(self) - _gen_commands = inform7.gen_commands_from_actions generate_text_from_grammar(self, self.grammar) - for quest in self.quests: - for event in quest.win_events: - event.commands = _gen_commands(event.actions) - - if quest.win_events: - quest.commands = quest.win_events[0].commands - # Check if we can derive a global winning policy from the quests. if self.grammar: from textworld.generator.text_generation import describe_event @@ -440,7 +756,7 @@ def change_grammar(self, grammar: Grammar) -> None: mapping = {k: info.name for k, info in self._infos.items()} commands = [a.format_command(mapping) for a in policy] self.metadata["walkthrough"] = commands - self.objective = describe_event(Event(policy), self, self.grammar) + self.objective = describe_event(AbstractEvent(policy), self, self.grammar) def save(self, filename: str) -> None: """ Saves the serialized data of this game to a file. """ @@ -462,8 +778,12 @@ def deserialize(cls, data: Mapping) -> "Game": `Game` object. """ - version = data.get("version", cls._SERIAL_VERSION) - if version != cls._SERIAL_VERSION: + version = data.get("version", 1) + if version == 1: + msg = "Loading TextWorld game format (v{})! Current version is {}.".format(version, cls._SERIAL_VERSION) + warnings.warn(msg, TextworldGameVersionWarning) + + elif version != cls._SERIAL_VERSION: msg = "Cannot deserialize a TextWorld version {} game, expected version {}" raise ValueError(msg.format(version, cls._SERIAL_VERSION)) @@ -570,7 +890,7 @@ def objective(self) -> str: return self._objective # TODO: Find a better way of describing the objective of the game with several quests. - self._objective = "\nAND\n".join(quest.desc for quest in self.quests if quest.desc) + self._objective = "\n The next quest is \n".join(quest.desc for quest in self.quests if quest.desc) return self._objective @@ -722,28 +1042,56 @@ class EventProgression: relevant actions to be performed. """ - def __init__(self, event: Event, kb: KnowledgeBase) -> None: + def __init__(self, event: AbstractEvent, kb: KnowledgeBase) -> None: """ Args: quest: The quest to keep track of its completion. """ self._kb = kb or KnowledgeBase.default() - self.event = event + self.event = event # TODO: convert to dnf just to be safe. self._triggered = False self._untriggerable = False - self._policy = () - - # Build a tree representation of the quest. - self._tree = ActionDependencyTree(kb=self._kb, - element_type=ActionDependencyTreeElement) - - if len(event.actions) > 0: - self._tree.push(event.condition) + self._policy = None + # self._policy = () + + # Build a tree representations for each subevent. + self._trees = [] + for events in self.event: # Assuming self.event is in DNF. + # trees = [] + + # Dummy action that should trigger when all events are triggered. + conditions = set() + + for event in events: + if isinstance(event, EventCondition): + conditions |= set(event.condition.preconditions) + elif isinstance(event, EventAction): + mapping = {ph: Variable(ph.name, ph.type) for ph in event.action.placeholders} + conditions |= set(predicate.instantiate(mapping) for predicate in event.action.postconditions) + else: + raise NotImplementedError() + + variables = sorted(set([v for c in conditions for v in c.arguments])) + event = Proposition("event", arguments=variables) + trigger = Action("trigger", preconditions=conditions, postconditions=list(conditions) + [event]) + + tree = ActionDependencyTree(kb=self._kb, element_type=ActionDependencyTreeElement) + tree.push(trigger) + + if events.actions: + for action in events.actions[::-1]: + tree.push(action) + else: + for event in events: + for action in event.actions[::-1]: + tree.push(action) - for action in event.actions[::-1]: - self._tree.push(action) + # trees.append(tree) - self._policy = event.actions + (event.condition,) + # trees = ActionDependencyTree(kb=self._kb, + # element_type=ActionDependencyTreeElement, + # trees=trees) + self._trees.append(tree) def copy(self) -> "EventProgression": """ Return a soft copy. """ @@ -751,17 +1099,37 @@ def copy(self) -> "EventProgression": ep._triggered = self._triggered ep._untriggerable = self._untriggerable ep._policy = self._policy - ep._tree = self._tree.copy() + ep._trees = [tree.copy() for tree in self._trees] return ep @property - def triggering_policy(self) -> List[Action]: - """ Actions to be performed in order to trigger the event. """ + def triggering_policy(self) -> Optional[List[Action]]: if self.done: return () - # Discard all "trigger" actions. - return tuple(a for a in self._policy if a.name != "trigger") + if self._policy is None or True: # TODO + policies = [] + for trees in self._trees: + # Discard all "trigger" actions. + policies.append(tuple(a for a in trees.flatten() if a.name != "trigger")) + + self._policy = min(policies, key=lambda policy: len(policy)) + + return self._policy + + @property + def _tree(self): + best = None + best_policy = None + for trees in self._trees: + # Discard all "trigger" actions. + policy = tuple(a for a in trees.flatten() if a.name != "trigger") + + if best is None or len(best_policy) > len(policy): + best = trees + best_policy = policy + + return best @property def done(self) -> bool: @@ -776,9 +1144,11 @@ def triggered(self) -> bool: @property def untriggerable(self) -> bool: """ Check whether the event is in an untriggerable state. """ - return self._untriggerable + return len(self._trees) == 0 - def update(self, action: Optional[Action] = None, state: Optional[State] = None) -> None: + def update(self, action: Optional[Action] = None, + state: Optional[State] = None, + callback: Optional[callable] = None) -> None: """ Update event progression given available information. Args: @@ -790,23 +1160,29 @@ def update(self, action: Optional[Action] = None, state: Optional[State] = None) if state is not None: # Check if event is triggered. - self._triggered = self.event.is_triggering(state) + self._triggered = self.event.is_triggering(state, action, callback) + + # Update each dependency trees. + to_delete = [] + for i, trees in enumerate(self._trees): + if self._compress_policy(i, state): + continue # A shorter winning policy has been found. - # Try compressing the winning policy given the new game state. - if self.compress_policy(state): - return # A shorter winning policy has been found. + if action and not trees.empty: + # Determine if we moved away from the goal or closer to it. + changed, reverse_action = trees.remove(action) + if changed and reverse_action is None: # Irreversible action. + to_delete.append(trees) - if action is not None and not self._tree.empty: - # Determine if we moved away from the goal or closer to it. - changed, reverse_action = self._tree.remove(action) - if changed and reverse_action is None: # Irreversible action. - self._untriggerable = True # Can't track quest anymore. + if changed and reverse_action is not None: + # Rebuild policy. + # self._policy = tuple(self._tree.flatten()) + self._policy = None # Will be rebuilt on the next call of triggering_policy. - if changed and reverse_action is not None: - # Rebuild policy. - self._policy = tuple(self._tree.flatten()) + for e in to_delete: + self._trees.remove(e) - def compress_policy(self, state: State) -> bool: + def _compress_policy(self, idx, state: State) -> bool: """ Compress the policy given a game state. Args: @@ -815,26 +1191,26 @@ def compress_policy(self, state: State) -> bool: Returns: Whether the policy was compressed or not. """ + # Make sure the compressed policy has the same roots. + root_actions = [root.element.action for root in self._trees[idx].roots] def _find_shorter_policy(policy): for j in range(0, len(policy)): for i in range(j + 1, len(policy))[::-1]: shorter_policy = policy[:j] + policy[i:] - if state.is_sequence_applicable(shorter_policy): - self._tree = ActionDependencyTree(kb=self._kb, - element_type=ActionDependencyTreeElement) + if state.is_sequence_applicable(shorter_policy) and all(a in shorter_policy for a in root_actions): + self._trees[idx] = ActionDependencyTree(kb=self._kb, element_type=ActionDependencyTreeElement) for action in shorter_policy[::-1]: - self._tree.push(action) + self._trees[idx].push(action, allow_multi_root=True) return shorter_policy return None compressed = False - policy = _find_shorter_policy(tuple(a for a in self._tree.flatten())) + policy = _find_shorter_policy(tuple(a for a in self._trees[idx].flatten())) while policy is not None: compressed = True - self._policy = policy policy = _find_shorter_policy(policy) return compressed @@ -854,41 +1230,32 @@ def __init__(self, quest: Quest, kb: KnowledgeBase) -> None: """ self.quest = quest self.kb = kb - self.win_events = [EventProgression(event, kb) for event in quest.win_events] - self.fail_events = [EventProgression(event, kb) for event in quest.fail_events] + self.win_event = EventProgression(quest.win_event, kb) if quest.win_event is not None else None + self.fail_event = EventProgression(quest.fail_event, kb) if quest.fail_event is not None else None def copy(self) -> "QuestProgression": """ Return a soft copy. """ qp = QuestProgression(self.quest, self.kb) - qp.win_events = [event_progression.copy() for event_progression in self.win_events] - qp.fail_events = [event_progression.copy() for event_progression in self.fail_events] + qp.win_event = self.win_event.copy() if self.win_event is not None else None + qp.fail_event = self.fail_event.copy() if self.fail_event is not None else None return qp @property def _tree(self) -> Optional[List[ActionDependencyTree]]: - events = [event for event in self.win_events if len(event.triggering_policy) > 0] - if len(events) == 0: - return None - - event = min(events, key=lambda event: len(event.triggering_policy)) - return event._tree + return self.win_event._tree @property def winning_policy(self) -> Optional[List[Action]]: """ Actions to be performed in order to complete the quest. """ - if self.done: - return None - - winning_policies = [event.triggering_policy for event in self.win_events if len(event.triggering_policy) > 0] - if len(winning_policies) == 0: + if self.done or self.win_event is None: return None - return min(winning_policies, key=lambda policy: len(policy)) + return self.win_event.triggering_policy @property def completable(self) -> bool: """ Check if the quest has winning events. """ - return len(self.win_events) > 0 + return self.win_event is not None @property def done(self) -> bool: @@ -898,19 +1265,21 @@ def done(self) -> bool: @property def completed(self) -> bool: """ Check whether the quest is completed. """ - return any(event.triggered for event in self.win_events) + return self.win_event is not None and self.win_event.triggered @property def failed(self) -> bool: """ Check whether the quest has failed. """ - return any(event.triggered for event in self.fail_events) + return self.fail_event is not None and self.fail_event.triggered @property def unfinishable(self) -> bool: """ Check whether the quest is in an unfinishable state. """ - return any(event.untriggerable for event in self.win_events) + return self.win_event.untriggerable if self.win_event else False - def update(self, action: Optional[Action] = None, state: Optional[State] = None) -> None: + def update(self, action: Optional[Action] = None, + state: Optional[State] = None, + callback: Optional[callable] = None) -> None: """ Update quest progression given available information. Args: @@ -920,8 +1289,15 @@ def update(self, action: Optional[Action] = None, state: Optional[State] = None) if self.done: return # Nothing to do, the quest is already done. - for event in (self.win_events + self.fail_events): - event.update(action, state) + if self.win_event: + self.win_event.update(action, state, callback) + + # Only update fail_event if the quest is not completed. + if self.completed: + return + + if self.fail_event: + self.fail_event.update(action, state, callback) class GameProgression: @@ -939,6 +1315,7 @@ def __init__(self, game: Game, track_quests: bool = True) -> None: """ self.game = game self.state = game.world.state.copy() + self.callback = None self._valid_actions = list(self.state.all_applicable_actions(self.game.kb.rules.values(), self.game.kb.types.constants_mapping)) @@ -965,7 +1342,7 @@ def done(self) -> bool: @property def completed(self) -> bool: - """ Whether all quests are completed. """ + """ Whether all completable quests are completed. """ if not self.tracking_quests: return False # There is nothing to be "completed". @@ -1021,7 +1398,7 @@ def winning_policy(self) -> Optional[List[Action]]: # Discard all "trigger" actions. return tuple(a for a in master_quest_tree.flatten() if a.name != "trigger") - def update(self, action: Action) -> None: + def update(self, action: Action, callback: Optional[callable] = None) -> None: """ Update the state of the game given the provided action. Args: @@ -1036,7 +1413,7 @@ def update(self, action: Action) -> None: # Update all quest progressions given the last action and new state. for quest_progression in self.quest_progressions: - quest_progression.update(action, self.state) + quest_progression.update(action, self.state, callback or self.callback) class GameOptions: diff --git a/textworld/generator/inform7/tests/test_world2inform7.py b/textworld/generator/inform7/tests/test_world2inform7.py index 795d5b19..09861727 100644 --- a/textworld/generator/inform7/tests/test_world2inform7.py +++ b/textworld/generator/inform7/tests/test_world2inform7.py @@ -3,9 +3,14 @@ import itertools +import unittest +import shutil +import tempfile +from os.path import join as pjoin import textworld from textworld import g_rng +from textworld import testing from textworld.utils import make_temp_directory from textworld.core import EnvInfos @@ -104,9 +109,10 @@ def _rule_to_skip(rule): assert not done assert not game_state.won - game_state, _, done = env.step(event.commands[0]) - assert done - assert game_state.won + for cmd in game.walkthrough: + game_state, _, done = env.step(cmd) + assert done + assert game_state.won def test_quest_with_multiple_winning_and_losing_conditions(): @@ -479,3 +485,104 @@ def test_take_all_and_variants(): assert "blue ball:" in game_state.feedback assert "red ball" in game_state.inventory assert "blue ball" in game_state.inventory + + +class TestInform7Game(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.DATA = testing.build_complex_test_game() + cls.game = cls.DATA["game"] + cls.quest1 = cls.DATA["quest1"] + cls.quest2 = cls.DATA["quest2"] + cls.eating_carrot = cls.DATA["eating_carrot"] + cls.onion_eaten = cls.DATA["onion_eaten"] + cls.closing_chest_without_carrot = cls.DATA["closing_chest_without_carrot"] + + cls.tmpdir = pjoin(tempfile.mkdtemp(prefix="test_inform7_game"), "") + options = textworld.GameOptions() + options.path = cls.tmpdir + options.seeds = 20210512 + options.file_ext = ".z8" + cls.game_file = compile_game(cls.game, options) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tmpdir) + + def _apply_commands(self, env, commands, state=None): + state = state or env.reset() + for cmd in commands: + assert not state.done + state, _, done = env.step(cmd) + + return state + + def test_game_completion(self): + env = textworld.start(self.game_file) + + # Do quest1. + state = self._apply_commands(env, self.quest1.win_event.events[0].commands) + assert state.score == self.quest1.reward + assert not state.done + + # Then quest2. + state = self._apply_commands(env, self.quest2.win_event.events[0].commands, state) + assert state.score == state.max_score + assert state.done + assert state.won and not state.lost + + # Alternative winning strategy for quest1. + state = self._apply_commands(env, self.quest1.win_event.events[1].commands) + assert state.score == self.quest1.reward + assert not state.done + + state = self._apply_commands(env, self.quest2.win_event.events[0].commands, state) + assert state.score == state.max_score + assert state.done + assert state.won and not state.lost + + # Start with quest2, then quest1. + state = self._apply_commands(env, self.quest2.win_event.events[0].commands) + assert state.score == self.quest2.reward + assert not state.done + + state = self._apply_commands(env, self.quest1.win_event.events[0].commands, state) + assert state.score == state.max_score + assert state.done + assert state.won and not state.lost + + # Closing the chest containing the onion while holding the carrot should not fail. + state = self._apply_commands(env, self.eating_carrot.commands[:1]) # Take carrot. + state = self._apply_commands(env, self.quest1.win_event.events[1].commands, state) + assert state.score == self.quest1.reward + assert not state.done + + def test_game_failure(self): + env = textworld.start(self.game_file) + + # onion_eaten -> eating_carrot != eating_carrot -> onion_eaten + # Eating the carrot *after* eating the onion causes the game to be lost. + state = self._apply_commands(env, self.onion_eaten.commands) + state = self._apply_commands(env, self.eating_carrot.commands, state) + assert state.done + assert not state.won and state.lost + + # Eating the carrot *before* eating the onion does not lose the game, + # but the game becomes unfinishable. + state = self._apply_commands(env, self.eating_carrot.commands) + state = self._apply_commands(env, self.onion_eaten.commands, state) + assert not (state.done or state.won or state.lost) # Couldn't detect game is unfinishable. + + env_with_state_tracking = textworld.start(self.game_file, EnvInfos(policy_commands=True)) + state = self._apply_commands(env_with_state_tracking, self.eating_carrot.commands) + state = self._apply_commands(env_with_state_tracking, self.onion_eaten.commands, state) + assert not (state.done or state.won or state.lost) # Still won't tell the game is unfinishable. + assert state.policy_commands == [] # But there isn't no sequence of commands that can complete the game. + + # Closing the chest *while* the carrot is in the inventory. + state = self._apply_commands(env, ["open chest", "close chest"]) + assert not state.lost + state = self._apply_commands(env, self.closing_chest_without_carrot.commands, state) + assert state.done + assert not state.won and state.lost diff --git a/textworld/generator/inform7/world2inform7.py b/textworld/generator/inform7/world2inform7.py index 45669e22..1f996a3d 100644 --- a/textworld/generator/inform7/world2inform7.py +++ b/textworld/generator/inform7/world2inform7.py @@ -16,6 +16,7 @@ from textworld.utils import make_temp_directory, str2bool, chunk from textworld.generator.game import Game +from textworld.generator.game import EventCondition, EventAction from textworld.generator.world import WorldRoom, WorldEntity from textworld.logic import Signature, Proposition, Action, Variable @@ -114,12 +115,21 @@ def gen_source_for_conditions(self, conds: Iterable[Proposition]) -> str: if i7_cond: i7_conds.append(i7_cond) - # HACK: In Inform7 we have to mention a container/door is unlocked AND closed. - for cond in conds: + # HACK: In Inform7 we have to mention a container/door is unlocked AND closed. if cond.name == "closed": i7_conds.append("the {} is unlocked".format(cond.arguments[0].name)) - return " and ".join(i7_conds) + return i7_conds + + def gen_source_for_event_action(self, event: EventAction) -> Optional[str]: + pt = self.kb.inform7_events[event.action.name] + if pt is None: + msg = "Undefined Inform7's action: {}".format(event.action.name) + warnings.warn(msg, TextworldInform7Warning) + return [] + + mapping = {ph.type: ph.name for ph in event.action.placeholders} + return [pt.format(**mapping)] def gen_source_for_objects(self, objects: Iterable[WorldEntity]) -> str: source = "" @@ -321,11 +331,9 @@ def gen_source(self, seed: int = 1234) -> str: """) source += quest_completed.format(quest_id=quest_id) - for event_id, event in enumerate(quest.win_events): - commands = self.gen_commands_from_actions(event.actions) - event.commands = commands - - walkthrough = '\nTest quest{}_{} with "{}"\n\n'.format(quest_id, event_id, " / ".join(commands)) + if quest.win_event: + commands = quest.win_event.commands or self.gen_commands_from_actions(quest.win_event.actions) + walkthrough = '\nTest quest{} with "{}"\n\n'.format(quest_id, " / ".join(commands)) source += walkthrough # Add winning and losing conditions for quest. @@ -342,15 +350,35 @@ def gen_source(self, seed: int = 1234) -> str: increase the score by {reward}; [Quest completed] Now the quest{quest_id} completed is true;""") - for fail_event in quest.fail_events: - conditions = self.gen_source_for_conditions(fail_event.condition.preconditions) - quest_ending_conditions += fail_template.format(conditions=conditions) - - for win_event in quest.win_events: - conditions = self.gen_source_for_conditions(win_event.condition.preconditions) - quest_ending_conditions += win_template.format(conditions=conditions, - reward=quest.reward, - quest_id=quest_id) + if quest.win_event: + # Assuming quest.win_event is in a DNF. + for events in quest.win_event: # Loop over EventOr + i7_conds = [] + for event in events: # Loop over EventAnd + if isinstance(event, EventCondition): + i7_conds += self.gen_source_for_conditions(event.condition.preconditions) + elif isinstance(event, EventAction): + i7_conds += self.gen_source_for_event_action(event) + else: + raise NotImplementedError("Unknown event type: {!r}".format(event)) + + quest_ending_conditions += win_template.format(conditions=" and ".join(i7_conds), + reward=quest.reward, + quest_id=quest_id) + + if quest.fail_event: + # Assuming quest.fail_event is in a DNF. + for events in quest.fail_event: # Loop over EventOr + i7_conds = [] + for event in events: # Loop over EventAnd + if isinstance(event, EventCondition): + i7_conds += self.gen_source_for_conditions(event.condition.preconditions) + elif isinstance(event, EventAction): + i7_conds += self.gen_source_for_event_action(event) + else: + raise NotImplementedError("Unknown event type: {!r}".format(event)) + + quest_ending_conditions += fail_template.format(conditions=" and ".join(i7_conds)) quest_ending = """\ Every turn:\n{conditions} diff --git a/textworld/generator/logger.py b/textworld/generator/logger.py index 88a118d7..358733fa 100644 --- a/textworld/generator/logger.py +++ b/textworld/generator/logger.py @@ -76,8 +76,8 @@ def collect(self, game): # Collect distribution of commands leading to winning events. for quest in game.quests: self.quests.add(quest.desc) - for event in quest.win_events: - actions = event.actions + for events in quest.win_event: + actions = events.actions update_bincount(self.dist_quest_length_count, len(actions)) for action in actions: diff --git a/textworld/generator/maker.py b/textworld/generator/maker.py index 1b89f60c..569fd74e 100644 --- a/textworld/generator/maker.py +++ b/textworld/generator/maker.py @@ -19,9 +19,9 @@ from textworld.generator.graph_networks import direction from textworld.generator.data import KnowledgeBase from textworld.generator.vtypes import get_new -from textworld.logic import State, Variable, Proposition, Action +from textworld.logic import State, Variable, Proposition, Action, Placeholder from textworld.generator.game import GameOptions -from textworld.generator.game import Game, World, Quest, Event, EntityInfo +from textworld.generator.game import Game, World, Quest, EventAnd, EventCondition, EventAction, EntityInfo from textworld.generator.graph_networks import DIRECTIONS from textworld.render import visualize from textworld.envs.wrappers import Recorder @@ -76,6 +76,12 @@ def __init__(self, failed_constraints: List[Action]) -> None: super().__init__(msg) +class UnderspecifiedEventError(NameError): + def __init__(self): + msg = "The event type should be specified. It can be either the action or condition." + super().__init__(msg) + + class WorldEntity: """ Represents an entity in the world. @@ -632,7 +638,7 @@ def record_quest(self) -> Quest: actions = [action for action in recorder.actions if action is not None] # Assume the last action contains all the relevant facts about the winning condition. - event = Event(actions=actions) + event = EventCondition(actions=actions) self.quests.append(Quest(win_events=[event])) # Calling build will generate the description for the quest. self.build() @@ -665,14 +671,14 @@ def set_quest_from_commands(self, commands: List[str]) -> Quest: unrecognized_commands = [c for c, a in zip(commands, recorder.actions) if a is None] raise QuestError("Some of the actions were unrecognized: {}".format(unrecognized_commands)) - event = Event(actions=actions) - self.quests = [Quest(win_events=[event])] + event = EventCondition(actions=actions, commands=commands) + self.quests = [Quest(win_event=event, commands=commands)] # Calling build will generate the description for the quest. self.build() return self.quests[-1] - def new_fact(self, name: str, *entities: List["WorldEntity"]) -> None: + def new_fact(self, name: str, *entities: List["WorldEntity"]) -> Proposition: """ Create new fact. Args: @@ -682,7 +688,28 @@ def new_fact(self, name: str, *entities: List["WorldEntity"]) -> None: args = [entity.var for entity in entities] return Proposition(name, args) - def new_event_using_commands(self, commands: List[str]) -> Event: + def new_action(self, name: str, *entities: List["WorldEntity"]) -> Union[None, Action]: + """ Create new fact about a rule. + + Args: + name: The name of the rule which can be used for the new rule fact as well. + *entities: A list of entities as arguments to the new rule fact. + """ + if name not in self._kb.rules: + raise ValueError("Can't find action: '{}'".format(name)) + + rule = self._kb.rules[name] + mapping = {Placeholder(entity.type): Placeholder(entity.id, entity.type) for entity in entities} + return rule.substitute(mapping) + + def new_quest(self, win_event=None, fail_event=None, reward=None, desc=None, commands=()) -> Quest: + return Quest(win_event=win_event, + fail_event=fail_event, + reward=reward, + desc=desc, + commands=commands) + + def new_event_using_commands(self, commands: List[str]) -> Union[EventCondition, EventAction]: """ Creates a new event using predefined text commands. This launches a `textworld.play` session to execute provided commands. @@ -704,7 +731,7 @@ def new_event_using_commands(self, commands: List[str]) -> Event: # Skip "None" actions. actions, commands = zip(*[(a, c) for a, c in zip(recorder.actions, commands) if a is not None]) - event = Event(actions=actions, commands=commands) + event = EventCondition(actions=actions, commands=commands) return event def new_quest_using_commands(self, commands: List[str]) -> Quest: @@ -719,37 +746,74 @@ def new_quest_using_commands(self, commands: List[str]) -> Quest: The resulting quest. """ event = self.new_event_using_commands(commands) - return Quest(win_events=[event], commands=event.commands) + return Quest(win_event=event, commands=event.commands) + + def set_walkthrough(self, *walkthroughs: List[str]): + # Assuming quest.events return a list of EventAnd. + events = {event: event.copy() for quest in self.quests for event in quest.events} + + actions = [] + cmds_performed = [] + + def _callback(event): + if not isinstance(event, EventAnd): + return + + if event not in events: + assert False + + if event not in events or events[event].commands: + return + + events[event].commands = list(cmds_performed) - def set_walkthrough(self, commands: List[str]): with make_temp_directory() as tmpdir: game_file = self.compile(pjoin(tmpdir, "set_walkthrough.ulx")) env = textworld.start(game_file, infos=EnvInfos(last_action=True, intermediate_reward=True)) - state = env.reset() - events = {event: event.copy() for quest in self.quests for event in quest.win_events} - event_progressions = [ep for qp in state._game_progression.quest_progressions for ep in qp.win_events] + for walkthrough in walkthroughs: + state = env.reset() + state._game_progression.callback = _callback + + done = False + for i, cmd in enumerate(walkthrough): + if done: + msg = "Game has ended before finishing playing all commands." + raise ValueError(msg) + + cmds_performed.append(cmd) + state, score, done = env.step(cmd) + actions.append(state._last_action) + + for k, v in events.items(): + if v.commands and not k.actions: + k.commands = v.commands + k.actions = list(actions[:len(v.commands)]) + + actions.clear() + cmds_performed.clear() + + for quest in self.quests: + if quest.win_event: + quest.commands = quest.win_event.commands + + def get_action_from_commands(self, commands: List[str]): + with make_temp_directory() as tmpdir: + game_file = self.compile(pjoin(tmpdir, "get_actions.ulx")) + env = textworld.start(game_file, infos=EnvInfos(last_action=True)) + state = env.reset() - done = False actions = [] + done = False for i, cmd in enumerate(commands): if done: msg = "Game has ended before finishing playing all commands." raise ValueError(msg) - events_triggered = [ep.triggered for ep in event_progressions] - state, score, done = env.step(cmd) actions.append(state._last_action) - for was_triggered, ep in zip(events_triggered, event_progressions): - if not was_triggered and ep.triggered: - events[ep.event].actions = list(actions) - events[ep.event].commands = commands[:i + 1] - - for k, v in events.items(): - k.actions = v.actions - k.commands = v.commands + return actions def validate(self) -> bool: """ Check if the world is valid and can be compiled. @@ -857,6 +921,7 @@ def render(self, interactive: bool = False): :param filename: filename for screenshot """ game = self.build(validate=False) + game.change_grammar(self.grammar) # Generate missing object names. return visualize(game, interactive=interactive) def import_graph(self, G: nx.Graph) -> List[WorldRoom]: diff --git a/textworld/generator/tests/test_game.py b/textworld/generator/tests/test_game.py index 7cc47d86..f56a8301 100644 --- a/textworld/generator/tests/test_game.py +++ b/textworld/generator/tests/test_game.py @@ -6,29 +6,29 @@ import textwrap from typing import Iterable -import numpy.testing as npt - import textworld from textworld import g_rng from textworld import GameMaker +from textworld import testing from textworld.generator.data import KnowledgeBase from textworld.generator import World from textworld.generator import make_small_map from textworld.generator.chaining import ChainingOptions, sample_quest -from textworld.logic import Action - +from textworld.logic import Action, State, Proposition, Rule from textworld.generator.game import GameOptions -from textworld.generator.game import Quest, Game, Event +from textworld.generator.game import Quest, Game, Event, EventAction, EventCondition, EventOr, EventAnd from textworld.generator.game import QuestProgression, GameProgression, EventProgression -from textworld.generator.game import UnderspecifiedEventError, UnderspecifiedQuestError from textworld.generator.game import ActionDependencyTree, ActionDependencyTreeElement from textworld.generator.inform7 import Inform7Game from textworld.logic import GameLogic +DATA = testing.build_complex_test_game() + + def _find_action(command: str, actions: Iterable[Action], inform7: Inform7Game) -> None: """ Apply a text command to a game_progression object. """ commands = inform7.gen_commands_from_actions(actions) @@ -121,46 +121,66 @@ def test_variable_infos(verbose=False): class TestEvent(unittest.TestCase): - @classmethod - def setUpClass(cls): - M = GameMaker() + def test_init(self): + event = Event(conditions=[Proposition.parse("in(carrot: f, chest: c)")]) + assert type(event) is EventCondition - # The goal - commands = ["take carrot", "insert carrot into chest"] - R1 = M.new_room("room") - M.set_player(R1) +class TestEventCondition(unittest.TestCase): - carrot = M.new(type='f', name='carrot') - R1.add(carrot) + @classmethod + def setUpClass(cls): + cls.condition = {Proposition.parse("in(carrot: f, chest: c)")} + cls.event = EventCondition(conditions=cls.condition) + + def test_is_triggering(self): + state = State(KnowledgeBase.default().logic, [ + Proposition.parse("in(carrot: f, chest: c)"), + Proposition.parse("in(lettuce: f, chest: c)"), + ]) + assert self.event.is_triggering(state=state) + + state = State(KnowledgeBase.default().logic, [ + Proposition.parse("in(carrot: f, I: I)"), + Proposition.parse("in(lettuce: f, chest: c)"), + ]) + assert not self.event.is_triggering(state=state) - # Add a closed chest in R2. - chest = M.new(type='c', name='chest') - chest.add_property("open") - R1.add(chest) + def test_serialization(self): + data = self.event.serialize() + event = EventCondition.deserialize(data) + assert event == self.event - cls.event = M.new_event_using_commands(commands) - cls.actions = cls.event.actions - cls.conditions = {M.new_fact("in", carrot, chest)} + def test_copy(self): + event = self.event.copy() + assert event == self.event + assert id(event) != id(self.event) - def test_init(self): - event = Event(self.actions) - assert event.actions == self.actions - assert event.condition == self.event.condition - assert event.condition.preconditions == self.actions[-1].postconditions - assert set(event.condition.preconditions).issuperset(self.conditions) - event = Event(conditions=self.conditions) - assert len(event.actions) == 0 - assert set(event.condition.preconditions) == set(self.conditions) +class TestEventAction(unittest.TestCase): - npt.assert_raises(UnderspecifiedEventError, Event, actions=[]) - npt.assert_raises(UnderspecifiedEventError, Event, actions=[], conditions=[]) - npt.assert_raises(UnderspecifiedEventError, Event, conditions=[]) + @classmethod + def setUpClass(cls): + cls.rule = Rule.parse("close :: $at(P, r) & $at(chest: c, r) & open(chest: c) -> closed(chest: c)") + cls.action = Action.parse("close :: $at(P, room: r) & $at(chest: c, room: r) & open(chest: c) -> closed(chest: c)") + cls.event = EventAction(action=cls.rule) + + def test_is_triggering(self): + # State should be ignored in a EventAction. + state = State(KnowledgeBase.default().logic, [ + Proposition.parse("open(chest: c)"), + ]) + assert self.event.is_triggering(state=state, action=self.action) + + state = State(KnowledgeBase.default().logic, [ + Proposition.parse("closed(chest: c)"), + ]) + action = Action.parse("close :: open(fridge: c) -> closed(fridge: c)") + assert not self.event.is_triggering(state=state, action=action) def test_serialization(self): data = self.event.serialize() - event = Event.deserialize(data) + event = EventAction.deserialize(data) assert event == self.event def test_copy(self): @@ -169,60 +189,166 @@ def test_copy(self): assert id(event) != id(self.event) -class TestQuest(unittest.TestCase): +class TestEventOr(unittest.TestCase): @classmethod def setUpClass(cls): - M = GameMaker() + cls.event_A_condition = {Proposition.parse("in(carrot: f, chest: c)")} + cls.event_A = EventCondition(conditions=cls.event_A_condition) + + cls.event_B_action = Rule.parse("close :: open(chest: c) -> closed(chest: c)") + cls.event_B = EventAction(action=cls.event_B_action) + + cls.event_A_or_B = EventOr(events=(cls.event_A, cls.event_B)) + + def test_is_triggering(self): + open_chest = Action.parse("open :: closed(chest: c) -> open(chest: c)") + close_chest = Action.parse("close :: open(chest: c) -> closed(chest: c)") + carrot_in_chest = State(KnowledgeBase.default().logic, [ + Proposition.parse("in(carrot: f, chest: c)"), + ]) + carrot_in_inventory = State(KnowledgeBase.default().logic, [ + Proposition.parse("in(carrot: f, I: I)"), + ]) + + # A | B + assert self.event_A.is_triggering(state=carrot_in_chest, action=close_chest) + assert self.event_B.is_triggering(state=carrot_in_chest, action=close_chest) + assert self.event_A_or_B.is_triggering(state=carrot_in_chest, action=close_chest) + + # !A | !B + assert not self.event_A.is_triggering(state=carrot_in_inventory, action=open_chest) + assert not self.event_B.is_triggering(state=carrot_in_inventory, action=open_chest) + assert not self.event_A_or_B.is_triggering(state=carrot_in_inventory, action=open_chest) + + # !A | B + assert not self.event_A.is_triggering(state=carrot_in_inventory, action=close_chest) + assert self.event_B.is_triggering(state=carrot_in_inventory, action=close_chest) + assert self.event_A_or_B.is_triggering(state=carrot_in_inventory, action=close_chest) + + # A | !B + assert self.event_A.is_triggering(state=carrot_in_chest, action=open_chest) + assert not self.event_B.is_triggering(state=carrot_in_chest, action=open_chest) + assert self.event_A_or_B.is_triggering(state=carrot_in_chest, action=open_chest) - # The goal - commands = ["go east", "insert carrot into chest"] + def test_serialization(self): + data = self.event_A_or_B.serialize() + event = EventOr.deserialize(data) + assert event == self.event_A_or_B - # Create a 'bedroom' room. - R1 = M.new_room("bedroom") - R2 = M.new_room("kitchen") - M.set_player(R1) + def test_copy(self): + event = self.event_A_or_B.copy() + assert event == self.event_A_or_B + assert id(event) != id(self.event_A_or_B) - path = M.connect(R1.east, R2.west) - path.door = M.new(type='d', name='wooden door') - path.door.add_property("open") - carrot = M.new(type='f', name='carrot') - M.inventory.add(carrot) +class TestEventAnd(unittest.TestCase): - # Add a closed chest in R2. - chest = M.new(type='c', name='chest') - chest.add_property("open") - R2.add(chest) + @classmethod + def setUpClass(cls): + cls.event_A_condition = {Proposition.parse("in(carrot: f, chest: c)")} + cls.event_A = EventCondition(conditions=cls.event_A_condition) + + cls.event_B_action = Rule.parse("close :: open(chest: c) -> closed(chest: c)") + cls.event_B = EventAction(action=cls.event_B_action) + + cls.event_A_and_B = EventAnd(events=(cls.event_A, cls.event_B)) + + def test_is_triggering(self): + open_chest = Action.parse("open :: closed(chest: c) -> open(chest: c)") + close_chest = Action.parse("close :: open(chest: c) -> closed(chest: c)") + carrot_in_chest = State(KnowledgeBase.default().logic, [ + Proposition.parse("in(carrot: f, chest: c)"), + ]) + carrot_in_inventory = State(KnowledgeBase.default().logic, [ + Proposition.parse("in(carrot: f, I: I)"), + ]) + + # A & B + assert self.event_A.is_triggering(state=carrot_in_chest, action=close_chest) + assert self.event_B.is_triggering(state=carrot_in_chest, action=close_chest) + assert self.event_A_and_B.is_triggering(state=carrot_in_chest, action=close_chest) + + # !A & !B + assert not self.event_A.is_triggering(state=carrot_in_inventory, action=open_chest) + assert not self.event_B.is_triggering(state=carrot_in_inventory, action=open_chest) + assert not self.event_A_and_B.is_triggering(state=carrot_in_inventory, action=open_chest) + + # !A & B + assert not self.event_A.is_triggering(state=carrot_in_inventory, action=close_chest) + assert self.event_B.is_triggering(state=carrot_in_inventory, action=close_chest) + assert not self.event_A_and_B.is_triggering(state=carrot_in_inventory, action=close_chest) + + # A & !B + assert self.event_A.is_triggering(state=carrot_in_chest, action=open_chest) + assert not self.event_B.is_triggering(state=carrot_in_chest, action=open_chest) + assert not self.event_A_and_B.is_triggering(state=carrot_in_chest, action=open_chest) - cls.eventA = M.new_event_using_commands(commands) - cls.eventB = Event(conditions={M.new_fact("at", carrot, R1), - M.new_fact("closed", path.door)}) - cls.eventC = Event(conditions={M.new_fact("eaten", carrot)}) - cls.eventD = Event(conditions={M.new_fact("closed", chest), - M.new_fact("closed", path.door)}) - cls.quest = Quest(win_events=[cls.eventA, cls.eventB], - fail_events=[cls.eventC, cls.eventD], - reward=2) - - M.quests = [cls.quest] - cls.game = M.build() - cls.inform7 = Inform7Game(cls.game) + def test_serialization(self): + data = self.event_A_and_B.serialize() + event = EventAnd.deserialize(data) + assert event == self.event_A_and_B - def test_init(self): - npt.assert_raises(UnderspecifiedQuestError, Quest) + def test_copy(self): + event = self.event_A_and_B.copy() + assert event == self.event_A_and_B + assert id(event) != id(self.event_A_and_B) + + +class TestQuest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.carrot_in_chest = {Proposition.parse("in(carrot: f, chest: c)")} + cls.event_carrot_in_chest = EventCondition(conditions=cls.carrot_in_chest) + + cls.close_chest = Rule.parse("close :: open(chest: c) -> closed(chest: c)") + cls.event_close_chest = EventAction(action=cls.close_chest) + + cls.event_closing_chest_with_carrot = EventAnd(events=(cls.event_carrot_in_chest, cls.event_close_chest)) + + cls.carrot_in_inventory = {Proposition.parse("in(carrot: f, I: I)")} + cls.event_carrot_in_inventory = EventCondition(conditions=cls.carrot_in_inventory) - quest = Quest(win_events=[self.eventA, self.eventB]) - assert len(quest.fail_events) == 0 + cls.event_closing_chest_without_carrot = EventAnd(events=(cls.event_carrot_in_inventory, cls.event_close_chest)) - quest = Quest(fail_events=[self.eventC, self.eventD]) - assert len(quest.win_events) == 0 + cls.eat_carrot = Rule.parse("eat :: in(carrot: f, I: I) -> consumed(carrot: f)") + cls.event_eat_carrot = EventAction(action=cls.eat_carrot) - quest = Quest(win_events=[self.eventA], - fail_events=[self.eventC, self.eventD]) + cls.event_closing_chest_whithout_carrot_or_eating_carrot = \ + EventOr(events=(cls.event_closing_chest_without_carrot, cls.event_eat_carrot)) - assert len(quest.win_events) > 0 - assert len(quest.fail_events) > 0 + cls.quest = Quest(win_event=cls.event_closing_chest_with_carrot, + fail_event=cls.event_closing_chest_whithout_carrot_or_eating_carrot) + + def test_backward_compatiblity(self): + # Backward compatibility tests. + quest = Quest(win_events=[self.event_closing_chest_with_carrot], + fail_events=[self.event_closing_chest_without_carrot, self.event_eat_carrot]) + assert quest == self.quest + + quest = Quest([self.event_closing_chest_with_carrot], + [self.event_closing_chest_without_carrot, self.event_eat_carrot]) + assert quest == self.quest + + def test_is_winning_or_failing(self): + close_chest = Action.parse("close :: open(chest: c) -> closed(chest: c)") + eat_carrot = Action.parse("eat :: in(carrot: f, I: I) -> consumed(carrot: f)") + carrot_in_chest = State(KnowledgeBase.default().logic, [ + Proposition.parse("in(carrot: f, chest: c)"), + ]) + carrot_in_inventory = State(KnowledgeBase.default().logic, [ + Proposition.parse("in(carrot: f, I: I)"), + ]) + + assert self.quest.is_winning(state=carrot_in_chest, action=close_chest) + assert not self.quest.is_failing(state=carrot_in_chest, action=close_chest) + assert self.quest.is_failing(state=carrot_in_inventory, action=close_chest) + assert not self.quest.is_winning(state=carrot_in_inventory, action=close_chest) + assert self.quest.is_failing(state=carrot_in_inventory, action=eat_carrot) + assert not self.quest.is_winning(state=carrot_in_inventory, action=eat_carrot) + assert self.quest.is_failing(state=carrot_in_chest, action=eat_carrot) + assert not self.quest.is_winning(state=carrot_in_chest, action=eat_carrot) def test_serialization(self): data = self.quest.serialize() @@ -270,7 +396,8 @@ def _rule_to_skip(rule): # Build the quest by providing the actions. actions = chain.actions assert len(actions) == max_depth, rule.name - quest = Quest(win_events=[Event(actions)]) + + quest = Quest(win_event=EventCondition(actions=actions)) tmp_world = World.from_facts(chain.initial_state.facts) state = tmp_world.state @@ -281,7 +408,7 @@ def _rule_to_skip(rule): assert quest.is_winning(state) # Build the quest by only providing the winning conditions. - quest = Quest(win_events=[Event(conditions=actions[-1].postconditions)]) + quest = Quest(win_event=EventCondition(conditions=actions[-1].postconditions)) tmp_world = World.from_facts(chain.initial_state.facts) state = tmp_world.state @@ -291,76 +418,6 @@ def _rule_to_skip(rule): assert quest.is_winning(state) - def test_win_actions(self): - state = self.game.world.state.copy() - for action in self.quest.win_events[0].actions: - assert not self.quest.is_winning(state) - state.apply(action) - - assert self.quest.is_winning(state) - - # Test alternative way of winning, - # i.e. dropping the carrot and closing the door. - state = self.game.world.state.copy() - actions = list(state.all_applicable_actions(self.game.kb.rules.values(), - self.game.kb.types.constants_mapping)) - - drop_carrot = _find_action("drop carrot", actions, self.inform7) - close_door = _find_action("close wooden door", actions, self.inform7) - - state = self.game.world.state.copy() - assert state.apply(drop_carrot) - assert not self.quest.is_winning(state) - assert state.apply(close_door) - assert self.quest.is_winning(state) - - # Or the other way around. - state = self.game.world.state.copy() - assert state.apply(close_door) - assert not self.quest.is_winning(state) - assert state.apply(drop_carrot) - assert self.quest.is_winning(state) - - def test_fail_actions(self): - state = self.game.world.state.copy() - assert not self.quest.is_failing(state) - - actions = list(state.all_applicable_actions(self.game.kb.rules.values(), - self.game.kb.types.constants_mapping)) - eat_carrot = _find_action("eat carrot", actions, self.inform7) - go_east = _find_action("go east", actions, self.inform7) - - for action in actions: - state = self.game.world.state.copy() - state.apply(action) - # Only the `eat carrot` should fail. - assert self.quest.is_failing(state) == (action == eat_carrot) - - state = self.game.world.state.copy() - state.apply(go_east) # Move to the kitchen. - actions = list(state.all_applicable_actions(self.game.kb.rules.values(), - self.game.kb.types.constants_mapping)) - close_door = _find_action("close wooden door", actions, self.inform7) - close_chest = _find_action("close chest", actions, self.inform7) - - # Only closing the door doesn't fail the quest. - state_ = state.apply_on_copy(close_door) - assert not self.quest.is_failing(state_) - - # Only closing the chest doesn't fail the quest. - state_ = state.apply_on_copy(close_chest) - assert not self.quest.is_failing(state_) - - # Closing the chest, then the door should fail the quest. - state_ = state.apply_on_copy(close_chest) - state_.apply(close_door) - assert self.quest.is_failing(state_) - - # Closing the door, then the chest should fail the quest. - state_ = state.apply_on_copy(close_door) - state_.apply(close_chest) - assert self.quest.is_failing(state_) - class TestGame(unittest.TestCase): @@ -390,6 +447,7 @@ def setUpClass(cls): M.set_quest_from_commands(commands) cls.game = M.build() + cls.walkthrough = commands def test_directions_names(self): expected = set(["north", "south", "east", "west"]) @@ -413,6 +471,9 @@ def test_verbs(self): "inventory", "examine"} assert set(self.game.verbs) == expected_verbs + def test_walkthrough(self): + assert self.game.walkthrough == self.walkthrough + def test_command_templates(self): expected_templates = { 'close {c}', 'close {d}', 'drop {o}', 'eat {f}', 'examine {d}', @@ -442,36 +503,36 @@ class TestEventProgression(unittest.TestCase): @classmethod def setUpClass(cls): - M = GameMaker() - - # The goal - commands = ["take carrot", "insert carrot into chest"] + cls.game = DATA["game"] + cls.win_event = DATA["quest"].win_event + cls.eating_carrot = DATA["eating_carrot"] + cls.onion_eaten = DATA["onion_eaten"] - R1 = M.new_room("room") - M.set_player(R1) - - carrot = M.new(type='f', name='carrot') - R1.add(carrot) - - # Add a closed chest in R2. - chest = M.new(type='c', name='chest') - chest.add_property("open") - R1.add(chest) + def test_triggering_policy(self): + event = EventProgression(self.win_event, KnowledgeBase.default()) - cls.event = M.new_event_using_commands(commands) - cls.actions = cls.event.actions - cls.conditions = {M.new_fact("in", carrot, chest)} - cls.game = M.build() - commands = ["take carrot", "eat carrot"] - cls.eating_carrot = M.new_event_using_commands(commands) + state = self.game.world.state.copy() + for action in event.triggering_policy: + assert not event.done + assert not event.triggered + assert not event.untriggerable + state.apply(action) + event.update(action=action, state=state) - def test_triggering_policy(self): - event = EventProgression(self.event, KnowledgeBase.default()) + assert event.triggering_policy == () + assert event.done + assert event.triggered + assert not event.untriggerable + event = EventProgression(self.win_event, KnowledgeBase.default()) state = self.game.world.state.copy() - expected_actions = self.event.actions + + expected_actions = self.eating_carrot.actions for i, action in enumerate(expected_actions): - assert event.triggering_policy == expected_actions[i:] + state.apply(action) + event.update(action=action, state=state) + + for action in event.triggering_policy: assert not event.done assert not event.triggered assert not event.untriggerable @@ -484,7 +545,7 @@ def test_triggering_policy(self): assert not event.untriggerable def test_untriggerable(self): - event = EventProgression(self.event, KnowledgeBase.default()) + event = EventProgression(self.win_event, KnowledgeBase.default()) state = self.game.world.state.copy() for action in self.eating_carrot.actions: @@ -495,6 +556,14 @@ def test_untriggerable(self): state.apply(action) event.update(action=action, state=state) + for action in self.onion_eaten.actions: + assert event.triggering_policy != () + assert not event.done + assert not event.triggered + assert not event.untriggerable + state.apply(action) + event.update(action=action, state=state) + assert event.triggering_policy == () assert event.done assert not event.triggered @@ -505,85 +574,71 @@ class TestQuestProgression(unittest.TestCase): @classmethod def setUpClass(cls): - M = GameMaker() - - room = M.new_room("room") - M.set_player(room) - - carrot = M.new(type='f', name='carrot') - lettuce = M.new(type='f', name='lettuce') - room.add(carrot) - room.add(lettuce) - - chest = M.new(type='c', name='chest') - chest.add_property("open") - room.add(chest) - - # The goals - commands = ["take carrot", "insert carrot into chest"] - cls.eventA = M.new_event_using_commands(commands) - - commands = ["take lettuce", "insert lettuce into chest", "close chest"] - event = M.new_event_using_commands(commands) - cls.eventB = Event(actions=event.actions, - conditions={M.new_fact("in", lettuce, chest), - M.new_fact("closed", chest)}) - - cls.fail_eventA = Event(conditions={M.new_fact("eaten", carrot)}) - cls.fail_eventB = Event(conditions={M.new_fact("eaten", lettuce)}) - - cls.quest = Quest(win_events=[cls.eventA, cls.eventB], - fail_events=[cls.fail_eventA, cls.fail_eventB]) - - commands = ["take carrot", "eat carrot"] - cls.eating_carrot = M.new_event_using_commands(commands) - commands = ["take lettuce", "eat lettuce"] - cls.eating_lettuce = M.new_event_using_commands(commands) - - M.quests = [cls.quest] - cls.game = M.build() - - def _apply_actions_to_quest(self, actions, quest): - state = self.game.world.state.copy() + cls.game = DATA["game"] + cls.quest = DATA["quest"] + cls.eating_carrot = DATA["eating_carrot"] + cls.onion_eaten = DATA["onion_eaten"] + cls.closing_chest_without_carrot = DATA["closing_chest_without_carrot"] + + def _apply_actions_to_quest(self, actions, quest, state=None): + state = state or self.game.world.state.copy() for action in actions: assert not quest.done state.apply(action) quest.update(action, state) - assert quest.done - return quest + return state def test_completed(self): quest = QuestProgression(self.quest, KnowledgeBase.default()) - quest = self._apply_actions_to_quest(self.eventA.actions, quest) - assert quest.completed - assert not quest.failed + self._apply_actions_to_quest(self.quest.win_event.events[0].actions, quest) + assert quest.done + assert quest.completed and not quest.failed assert quest.winning_policy is None # Alternative winning strategy. quest = QuestProgression(self.quest, KnowledgeBase.default()) - quest = self._apply_actions_to_quest(self.eventB.actions, quest) - assert quest.completed - assert not quest.failed + self._apply_actions_to_quest(self.quest.win_event.events[1].actions, quest) + assert quest.done + assert quest.completed and not quest.failed + assert quest.winning_policy is None + + # Alternative winning strategy but with carrot in inventory. + quest = QuestProgression(self.quest, KnowledgeBase.default()) + state = self._apply_actions_to_quest(self.eating_carrot.actions[:1], quest) # Take carrot. + self._apply_actions_to_quest(self.quest.win_event.events[1].actions, quest, state) + assert quest.done + assert quest.completed and not quest.failed assert quest.winning_policy is None def test_failed(self): + # onion_eaten -> eating_carrot != eating_carrot -> onion_eaten + # Eating the carrot *after* eating the onion causes the game to be lost. quest = QuestProgression(self.quest, KnowledgeBase.default()) - quest = self._apply_actions_to_quest(self.eating_carrot.actions, quest) - assert not quest.completed - assert quest.failed + state = self._apply_actions_to_quest(self.onion_eaten.actions, quest) + self._apply_actions_to_quest(self.eating_carrot.actions, quest, state) + assert quest.done + assert not quest.completed and quest.failed assert quest.winning_policy is None + # Eating the carrot *before* eating the onion does not lose the game, + # but the game becomes unfinishable. + quest = QuestProgression(self.quest, KnowledgeBase.default()) + state = self._apply_actions_to_quest(self.eating_carrot.actions, quest) + self._apply_actions_to_quest(self.onion_eaten.actions, quest, state) + assert quest.done and quest.unfinishable + assert not quest.completed and not quest.failed + quest = QuestProgression(self.quest, KnowledgeBase.default()) - quest = self._apply_actions_to_quest(self.eating_lettuce.actions, quest) - assert not quest.completed - assert quest.failed + self._apply_actions_to_quest(self.closing_chest_without_carrot.actions, quest) + assert quest.done + assert not quest.completed and quest.failed assert quest.winning_policy is None def test_winning_policy(self): kb = KnowledgeBase.default() quest = QuestProgression(self.quest, kb) - quest = self._apply_actions_to_quest(quest.winning_policy, quest) + self._apply_actions_to_quest(quest.winning_policy, quest) assert quest.completed assert not quest.failed assert quest.winning_policy is None @@ -591,13 +646,15 @@ def test_winning_policy(self): # Winning policy should be the shortest one leading to a winning event. state = self.game.world.state.copy() quest = QuestProgression(self.quest, KnowledgeBase.default()) - for i, action in enumerate(self.eventB.actions): + for i, action in enumerate(self.quest.win_event.events[1].actions): if i < 2: - assert quest.winning_policy == self.eventA.actions + assert set(quest.winning_policy).issubset(set(self.quest.win_event.events[0].actions)) + assert not set(quest.winning_policy).issubset(set(self.quest.win_event.events[1].actions)) else: - # After taking the lettuce and putting it in the chest, - # QuestB becomes the shortest one to complete. - assert quest.winning_policy == self.eventB.actions[i:] + # After opening the chest and taking the onion, + # the alternative winning event becomes the shortest one to complete. + assert quest.winning_policy == self.quest.win_event.events[1].actions[i:] + assert not quest.done state.apply(action) quest.update(action, state) @@ -612,134 +669,54 @@ class TestGameProgression(unittest.TestCase): @classmethod def setUpClass(cls): - M = GameMaker() - - # Create a 'bedroom' room. - R1 = M.new_room("bedroom") - R2 = M.new_room("kitchen") - M.set_player(R2) - - path = M.connect(R1.east, R2.west) - path.door = M.new(type='d', name='wooden door') - path.door.add_property("closed") - - carrot = M.new(type='f', name='carrot') - lettuce = M.new(type='f', name='lettuce') - R1.add(carrot) - R1.add(lettuce) - - tomato = M.new(type='f', name='tomato') - pepper = M.new(type='f', name='pepper') - M.inventory.add(tomato) - M.inventory.add(pepper) - - # Add a closed chest in R2. - chest = M.new(type='c', name='chest') - chest.add_property("open") - R2.add(chest) - - # The goals - commands = ["open wooden door", "go west", "take carrot", "go east", "drop carrot"] - cls.eventA = M.new_event_using_commands(commands) - - commands = ["open wooden door", "go west", "take lettuce", "go east", "insert lettuce into chest"] - cls.eventB = M.new_event_using_commands(commands) - - commands = ["drop pepper"] - cls.eventC = M.new_event_using_commands(commands) - - cls.losing_eventA = Event(conditions={M.new_fact("eaten", carrot)}) - cls.losing_eventB = Event(conditions={M.new_fact("eaten", lettuce)}) - - cls.questA = Quest(win_events=[cls.eventA], fail_events=[cls.losing_eventA]) - cls.questB = Quest(win_events=[cls.eventB], fail_events=[cls.losing_eventB]) - cls.questC = Quest(win_events=[cls.eventC], fail_events=[]) - cls.questD = Quest(win_events=[], fail_events=[cls.losing_eventA, cls.losing_eventB]) - - commands = ["open wooden door", "go west", "take carrot", "eat carrot"] - cls.eating_carrot = M.new_event_using_commands(commands) - commands = ["open wooden door", "go west", "take lettuce", "eat lettuce"] - cls.eating_lettuce = M.new_event_using_commands(commands) - commands = ["eat tomato"] - cls.eating_tomato = M.new_event_using_commands(commands) - commands = ["eat pepper"] - cls.eating_pepper = M.new_event_using_commands(commands) - - M.quests = [cls.questA, cls.questB, cls.questC] - cls.game = M.build() + cls.game = DATA["game"] + cls.quest1 = DATA["quest1"] + cls.quest2 = DATA["quest2"] + cls.eating_carrot = DATA["eating_carrot"] + cls.onion_eaten = DATA["onion_eaten"] + cls.knife_on_counter = DATA["knife_on_counter"] def test_completed(self): game = GameProgression(self.game) - for action in self.eventA.actions + self.eventC.actions: + for action in self.quest1.win_event.events[0].actions + self.quest2.win_event.events[0].actions: assert not game.done game.update(action) - assert not game.done - remaining_actions = self.eventB.actions[1:] # skipping "open door". - assert game.winning_policy == remaining_actions - - for action in self.eventB.actions: - assert not game.done - game.update(action) - - assert game.done - assert game.completed - assert not game.failed - assert game.winning_policy is None - - def test_failed(self): - game = GameProgression(self.game) - action = self.eating_tomato.actions[0] - game.update(action) - assert not game.done - assert not game.completed - assert not game.failed - assert game.winning_policy is not None - - game = GameProgression(self.game) - action = self.eating_pepper.actions[0] - game.update(action) - assert not game.completed - assert game.failed assert game.done + assert game.completed and not game.failed assert game.winning_policy is None + # Alternative quest1 solution game = GameProgression(self.game) - for action in self.eating_carrot.actions: + for action in self.quest1.win_event.events[1].actions + self.quest2.win_event.events[0].actions: assert not game.done game.update(action) assert game.done - assert not game.completed - assert game.failed + assert game.completed and not game.failed assert game.winning_policy is None + def test_failed(self): game = GameProgression(self.game) - for action in self.eating_lettuce.actions: - assert not game.done - game.update(action) - assert game.done - assert not game.completed - assert game.failed - assert game.winning_policy is None - - # Completing QuestA but failing quest B. - game = GameProgression(self.game) - for action in self.eventA.actions: - assert not game.done + # Completing quest2 but failing quest 1. + for action in self.knife_on_counter.actions: game.update(action) + assert not game.quest_progressions[0].done + assert game.quest_progressions[1].done + assert game.quest_progressions[1].completed assert not game.done + assert not game.completed and not game.failed + assert game.winning_policy is not None - game = GameProgression(self.game) - for action in self.eating_lettuce.actions: - assert not game.done + for action in self.onion_eaten.actions + self.eating_carrot.actions: game.update(action) assert game.done - assert not game.completed - assert game.failed + assert game.quest_progressions[0].done + assert game.quest_progressions[0].failed + assert not game.completed and game.failed assert game.winning_policy is None def test_winning_policy(self): @@ -784,6 +761,7 @@ def test_cycle_in_winning_policy(self): commands = ["go north", "take carrot"] M.set_quest_from_commands(commands) + M.set_walkthrough(commands) # TODO: redundant! game = M.build() inform7 = Inform7Game(game) game_progression = GameProgression(game) @@ -808,6 +786,7 @@ def test_cycle_in_winning_policy(self): commands = ["go east", "take apple", "go west", "go north", "drop apple"] M.set_quest_from_commands(commands) + M.set_walkthrough(commands) # TODO: redundant! game = M.build() game_progression = GameProgression(game) @@ -858,14 +837,15 @@ def test_game_with_multiple_quests(self): quest1.desc = "Fetch the carrot and drop it on the kitchen's ground." quest2 = M.new_quest_using_commands(commands[0] + commands[1]) quest2.desc = "Fetch the lettuce and drop it on the kitchen's ground." - quest3 = M.new_quest_using_commands(commands[0] + commands[1] + commands[2]) + # quest3 = M.new_quest_using_commands(commands[0] + commands[1] + commands[2]) winning_facts = [M.new_fact("in", lettuce, chest), M.new_fact("in", carrot, chest), M.new_fact("closed", chest)] - quest3.win_events[0].set_conditions(winning_facts) + quest3 = Quest(win_event=EventCondition(winning_facts)) quest3.desc = "Put the lettuce and the carrot into the chest before closing it." M.quests = [quest1, quest2, quest3] + M.set_walkthrough(commands[0] + commands[1] + commands[2]) assert len(M.quests) == len(commands) game = M.build() diff --git a/textworld/generator/text_generation.py b/textworld/generator/text_generation.py index 3b1cdfca..816fbe97 100644 --- a/textworld/generator/text_generation.py +++ b/textworld/generator/text_generation.py @@ -5,7 +5,8 @@ import re from collections import OrderedDict -from textworld.generator.game import Quest, Event, Game +from textworld.generator.game import Quest, Game +from textworld.generator.game import AbstractEvent from textworld.generator.text_grammar import Grammar from textworld.generator.text_grammar import fix_determinant @@ -381,15 +382,18 @@ def generate_instruction(action, grammar, game, counts): def assign_description_to_quest(quest: Quest, game: Game, grammar: Grammar): + if quest.win_event is None: + return "" + event_descriptions = [] - for event in quest.win_events: + for event in quest.win_event: event_descriptions += [describe_event(event, game, grammar)] quest_desc = " OR ".join(desc for desc in event_descriptions if desc) return quest_desc -def describe_event(event: Event, game: Game, grammar: Grammar) -> str: +def describe_event(event: AbstractEvent, game: Game, grammar: Grammar) -> str: """ Assign a descripton to a quest. """ diff --git a/textworld/testing.py b/textworld/testing.py index ad6a4e0f..46782a69 100644 --- a/textworld/testing.py +++ b/textworld/testing.py @@ -6,12 +6,13 @@ import sys import contextlib -from typing import Tuple +from typing import Tuple, Optional import numpy as np import textworld from textworld.generator.game import Event, Quest, Game +from textworld.generator.game import EventAction, EventCondition, EventOr, EventAnd from textworld.generator.game import GameOptions @@ -38,7 +39,7 @@ def _compile_test_game(game, options: GameOptions) -> str: "instruction_extension": [] } rng_grammar = np.random.RandomState(1234) - grammar = textworld.generator.make_grammar(grammar_flags, rng=rng_grammar) + grammar = textworld.generator.make_grammar(grammar_flags, rng=rng_grammar, kb=options.kb) game.change_grammar(grammar) game_file = textworld.generator.compile_game(game, options) @@ -107,3 +108,102 @@ def build_and_compile_game(options: GameOptions) -> Tuple[Game, str]: game = build_game(options) game_file = _compile_test_game(game, options) return game, game_file + + +def build_complex_test_game(options: Optional[GameOptions] = None): + M = textworld.GameMaker(options) + + # The goal + quest1_cmds1 = ["open chest", "take carrot", "insert carrot into chest", "close chest"] + quest1_cmds2 = ["open chest", "take onion", "insert onion into chest", "close chest"] + quest2_cmds = ["take knife", "put knife on counter"] + + kitchen = M.new_room("kitchen") + M.set_player(kitchen) + + counter = M.new(type='s', name='counter') + chest = M.new(type='c', name='chest') + chest.add_property("closed") + carrot = M.new(type='f', name='carrot') + onion = M.new(type='f', name='onion') + knife = M.new(type='o', name='knife') + kitchen.add(chest, counter, carrot, onion, knife) + + carrot_in_chest = EventCondition(conditions={M.new_fact("in", carrot, chest)}) + onion_in_chest = EventCondition(conditions={M.new_fact("in", onion, chest)}) + closing_chest = EventAction(action=M.new_action("close/c", chest)) + + either_carrot_or_onion_in_chest = EventOr(events=(carrot_in_chest, onion_in_chest)) + closing_chest_with_either_carrot_or_onion = EventAnd(events=(either_carrot_or_onion_in_chest, closing_chest)) + + carrot_in_inventory = EventCondition(conditions={M.new_fact("in", carrot, M.inventory)}) + closing_chest_without_carrot = EventAnd(events=(carrot_in_inventory, closing_chest)) + + eating_carrot = EventAction(action=M.new_action("eat", carrot)) + onion_eaten = EventCondition(conditions={M.new_fact("eaten", onion)}) + + quest1 = Quest( + win_event=closing_chest_with_either_carrot_or_onion, + fail_event=EventOr([ + closing_chest_without_carrot, + EventAnd([ + eating_carrot, + onion_eaten + ]) + ]), + reward=3, + ) + + knife_on_counter = EventCondition(conditions={M.new_fact("on", knife, counter)}) + + quest2 = Quest( + win_event=knife_on_counter, + reward=5, + ) + + carrot_in_chest.name = "carrot_in_chest" + onion_in_chest.name = "onion_in_chest" + closing_chest.name = "closing_chest" + either_carrot_or_onion_in_chest.name = "either_carrot_or_onion_in_chest" + closing_chest_with_either_carrot_or_onion.name = "closing_chest_with_either_carrot_or_onion" + carrot_in_inventory.name = "carrot_in_inventory" + closing_chest_without_carrot.name = "closing_chest_without_carrot" + eating_carrot.name = "eating_carrot" + onion_eaten.name = "onion_eaten" + knife_on_counter.name = "knife_on_counter" + + M.quests = [quest1, quest2] + M.set_walkthrough( + quest1_cmds1, + quest1_cmds2, + quest2_cmds + ) + game = M.build() + + eating_carrot.commands = ["take carrot", "eat carrot"] + eating_carrot.actions = M.get_action_from_commands(eating_carrot.commands) + onion_eaten.commands = ["take onion", "eat onion"] + onion_eaten.actions = M.get_action_from_commands(onion_eaten.commands) + closing_chest_without_carrot.commands = ["take carrot", "open chest", "close chest"] + closing_chest_without_carrot.actions = M.get_action_from_commands(closing_chest_without_carrot.commands) + knife_on_counter.commands = ["take knife", "put knife on counter"] + knife_on_counter.actions = M.get_action_from_commands(knife_on_counter.commands) + + data = { + "game": game, + "quest": quest1, + "quest1": quest1, + "quest2": quest2, + "carrot_in_chest": carrot_in_chest, + "onion_in_chest": onion_in_chest, + "closing_chest": closing_chest, + "either_carrot_or_onion_in_chest": either_carrot_or_onion_in_chest, + "closing_chest_with_either_carrot_or_onion": closing_chest_with_either_carrot_or_onion, + "carrot_in_inventory": carrot_in_inventory, + "closing_chest_without_carrot": closing_chest_without_carrot, + "eating_carrot": eating_carrot, + "onion_eaten": onion_eaten, + "knife_on_counter": knife_on_counter, + } + + return data