Skip to content

Commit

Permalink
Some updates on Traceable controling process, removed redundant eleme…
Browse files Browse the repository at this point in the history
…nt of proposition.activate, etc.
  • Loading branch information
HakiRose committed Feb 26, 2020
1 parent acf2275 commit e4c6812
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 58 deletions.
5 changes: 4 additions & 1 deletion textworld/challenges/spaceship/content_check_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ def make_game(settings: Mapping[str, str], options: Optional[GameOptions] = None

from textworld.challenges.spaceship.maker import test_commands
test_commands(gm, [
'open Blue box',
'open Red box',
'look',
'close Red box',
# 'open Red box',
# 'open Blue box',
Expand Down Expand Up @@ -179,7 +181,8 @@ def quest_design(game):
game._entities['r_0'],
game._entities['s_0'],
game._entities['c_0'],
game._entities['cpu_0'])})
game._entities['cpu_0'])},
output_verb_tense_postcond={'closed': 'has been'})
quests.append(Quest(win_events=[win_quest], fail_events=[], reward=1))

# win_quest1 = EventCondition(conditions={game.new_fact("has_been__closed", game._entities['c_0'])})
Expand Down
65 changes: 34 additions & 31 deletions textworld/generator/game.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,26 +104,23 @@ def set_events(self):
@classmethod
def add_propositions(cls, props: Iterable[Proposition]) -> Iterable[Proposition]:
for prop in props:
if not prop.name.startswith("is__"):
prop.activate[0] = True

if prop.verb == "has been":
prop.activate[1] = 1
if not prop.name.startswith("is__") and (prop.verb == "has been"):
prop.activate = 1

return props

@classmethod
def set_activated(cls, prop: Proposition):
if prop.activate[0] and not prop.activate[1]:
prop.activate[1] = 1
if not prop.activate:
prop.activate = 1

@classmethod
def remove(cls, prop: Proposition, state: State):
if prop.name.startswith('is__'):
if not prop.name.startswith('was__'):
return

if (prop.activate[0] and prop.activate[1]) and (prop in state.get_facts()):
if Proposition(prop.definition, prop.arguments) not in state.get_facts():
if prop.activate and (prop in state.facts):
if Proposition(prop.definition, prop.arguments) not in state.facts:
state.remove_fact(prop)


Expand Down Expand Up @@ -185,6 +182,11 @@ def set_conditions(self, conditions: Iterable[Proposition]) -> Action:
event = PropositionControl(conditions, self.verb_tense)
traceable = event.traceable_propositions
condition = Action("trigger", preconditions=conditions, postconditions=list(conditions) + [event.addon])

# The corresponding traceable(s) should be active in state set to be considered for the event.
if condition.has_traceable():
condition.activate_traceable()

return condition, traceable

def __hash__(self) -> int:
Expand Down Expand Up @@ -946,6 +948,15 @@ def _find_shorter_policy(policy):

return compressed

def will_trigger(self, state: State, action: Action):
if isinstance(self.event, EventCondition):
triggered = self.event.is_triggering(state)

if isinstance(self.event, EventAction):
triggered = self.event.is_triggering(action)

return triggered


class QuestProgression:
""" QuestProgression keeps track of the completion of a quest.
Expand Down Expand Up @@ -1048,20 +1059,7 @@ def __init__(self, game: Game, track_quests: bool = True) -> None:
def valid_actions_gen(self):
potential_actions = list(self.state.all_applicable_actions(self.game.kb.rules.values(),
self.game.kb.types.constants_mapping))
a = []
for act in potential_actions:
k = []
for prop in [list(act.preconditions) + list(act.added)][0]:
if not prop.name.startswith('is__'):
w = [p for p in self.state.get_facts() if not p.name.startswith('is__') and (p.name == prop.name)][0]
k.append(w.activate[0] and (w.activate[1] == 1))
else:
k.append(prop.activate[0] and (prop.activate[1] == 1))

if all(k):
a.append(act)

return a
return [act for act in potential_actions if act.is_valid()]

@property
def done(self) -> bool:
Expand Down Expand Up @@ -1126,14 +1124,19 @@ def winning_policy(self) -> Optional[List[Action]]:
# Discard all "trigger" actions.
return tuple(a for a in master_quest_tree.flatten() if a.name != "trigger")

def add_traceables(self):
def add_traceables(self, action):
s = self.state.facts
for quest_progression in self.quest_progressions:
if quest_progression.quest.reward >= 0:
if not quest_progression.completed and (quest_progression.quest.reward >= 0):
for win_event in quest_progression.win_events:
if win_event.event.traceable:
self.state.add_facts(PropositionControl.add_propositions(win_event.event.traceable))
if win_event.event.traceable and not (win_event.event.traceable in s):
if win_event.will_trigger(self.state, action):
self.state.add_facts(PropositionControl.add_propositions(win_event.event.traceable))

def traceable_manager(self):
if not self.state.has_traceable():
return

for prop in self.state.get_facts():
if not prop.name.startswith('is__'):
PropositionControl.set_activated(prop)
Expand All @@ -1145,9 +1148,9 @@ def update(self, action: Action) -> None:
Args:
action: Action affecting the state of the game.
"""
# Update world facts.
self.state.apply(self.state.state_action_valisate(action))
self.add_traceables()
# Update world facts
self.state.apply(action)
self.add_traceables(action)

# Update all quest progressions given the last action and new state.
for quest_progression in self.quest_progressions:
Expand Down
16 changes: 7 additions & 9 deletions textworld/generator/maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,14 +687,7 @@ def new_fact(self, name: str, *entities: List["WorldEntity"]) -> Proposition:
*entities: A list of entities as arguments to the new fact.
"""
args = [entity.var for entity in entities]
# if name.count('__') == 0:
# verb = 'is'
# definition = name
# name = verb + '__' + definition
# else:
# verb = name[:name.find('__')].replace('_', ' ')
# definition = name[name.find('__')+2:]
# return Proposition(name, arguments=args, verb=verb, definition=definition)

return Proposition(name, args)

def new_rule_fact(self, name: str, *entities: List["WorldEntity"]) -> Union[None, Action]:
Expand All @@ -719,7 +712,12 @@ def new_conditions(conditions, args):
precond = new_conditions(rule.preconditions, args)
postcond = new_conditions(rule.postconditions, args)

return Action(rule.name, precond, postcond)
action = Action(rule.name, precond, postcond)

if action.has_traceable():
action.activate_traceable()

return action

return None

Expand Down
53 changes: 36 additions & 17 deletions textworld/logic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,8 @@ class Proposition:

__slots__ = ("name", "arguments", "signature", "_hash", "verb", "definition", "activate")

def __init__(self, name: str, arguments: Iterable[Variable] = [], verb: str = None, definition: str = None):
def __init__(self, name: str, arguments: Iterable[Variable] = [], verb: str = None, definition: str = None,
activate: int = 0):
"""
Create a Proposition.
Expand Down Expand Up @@ -650,9 +651,9 @@ def __init__(self, name: str, arguments: Iterable[Variable] = [], verb: str = No
self._hash = hash((self.name, self.arguments, self.verb, self.definition))

if self.verb == 'is':
self.activate = [True, 1]
else:
self.activate = [False, 0]
activate = 1

self.activate = activate

@property
def names(self) -> Collection[str]:
Expand All @@ -676,7 +677,8 @@ def __repr__(self):

def __eq__(self, other):
if isinstance(other, Proposition):
return (self.name, self.arguments, self.verb, self.definition) == (other.name, other.arguments, other.verb, other.definition)
return (self.name, self.arguments, self.verb, self.definition, self.activate) == \
(other.name, other.arguments, other.verb, other.definition, other.activate)
else:
return NotImplemented

Expand Down Expand Up @@ -706,7 +708,8 @@ def serialize(self) -> Mapping:
"name": self.name,
"arguments": [var.serialize() for var in self.arguments],
"verb": self.verb,
"definition": self.definition
"definition": self.definition,
"activate": self.activate
}

@classmethod
Expand All @@ -715,7 +718,8 @@ def deserialize(cls, data: Mapping) -> "Proposition":
args = [Variable.deserialize(arg) for arg in data["arguments"]]
verb = data["verb"]
definition = data["definition"]
return cls(name, args, verb, definition)
activate = data["activate"]
return cls(name, args, verb, definition, activate)


@total_ordering
Expand Down Expand Up @@ -1090,6 +1094,20 @@ def inverse(self, name=None) -> "Action":
name = self.name
return Action(name, self.postconditions, self.preconditions)

def has_traceable(self):
for prop in self.all_propositions:
if not prop.name.startswith('is__'):
return True
return False

def activate_traceable(self):
for prop in self.all_propositions:
if not prop.name.startswith('is__'):
prop.activate = 1

def is_valid(self):
return all([prop.activate == 1 for prop in self.all_propositions])


class Rule:
"""
Expand Down Expand Up @@ -1208,7 +1226,10 @@ def instantiate(self, mapping: Mapping[Placeholder, Variable]) -> Action:
"""
pre_inst = [pred.instantiate(mapping) for pred in self.preconditions]
post_inst = [pred.instantiate(mapping) for pred in self.postconditions]
return Action(self.name, pre_inst, post_inst)
action = Action(self.name, pre_inst, post_inst)
if action.has_traceable():
action.activate_traceable()
return action

def match(self, action: Action) -> Optional[Mapping[Placeholder, Variable]]:
"""
Expand Down Expand Up @@ -1600,7 +1621,7 @@ def are_facts(self, props: Iterable[Proposition]) -> bool:
if not self.is_fact(prop):
return False

if not prop.activate[0] or not (prop.activate[1]):
if not prop.activate:
return False

return True
Expand Down Expand Up @@ -1671,14 +1692,6 @@ def is_sequence_applicable(self, actions: Iterable[Action]) -> bool:

return True

def state_action_valisate(self, action: Action):
for prop in action.all_propositions:
if not prop.name.startswith('is__'):
w = [p for p in self.get_facts() if not p.name.startswith('is__') and (p.name == prop.name)][0]
prop.activate[0], prop.activate[1] = w.activate[0], w.activate[1]

return action

def apply(self, action: Action) -> bool:
"""
Apply an action to the state.
Expand Down Expand Up @@ -1964,3 +1977,9 @@ def get_facts(self):
for fact in sorted(facts):
all_facts.append(fact)
return all_facts

def has_traceable(self):
for prop in self.facts:
if not prop.name.startswith('is__'):
return True
return False

0 comments on commit e4c6812

Please sign in to comment.