From 0146f45f4aae2b2ca982d97f108e15604a286863 Mon Sep 17 00:00:00 2001 From: Emre Kuru Date: Mon, 6 Nov 2023 11:52:26 +0300 Subject: [PATCH] fix: explainee and recommender behavouir abstract classes --- pyxmas/protocol/explainee.py | 23 +++++++++++++++-------- pyxmas/protocol/recommender.py | 18 +++++++++++------- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/pyxmas/protocol/explainee.py b/pyxmas/protocol/explainee.py index 96329ff..b86b109 100644 --- a/pyxmas/protocol/explainee.py +++ b/pyxmas/protocol/explainee.py @@ -38,6 +38,7 @@ def __init__(self, query: data.Query, recipient: str, thread: str = None, impl: self._recipient = recipient def setup(self) -> None: + self.add_state(StateInit, initial=True) for state in _get_all_state_classes(): self.add_transitions(state, state.reachable_states(), error=StateError) @@ -87,7 +88,7 @@ def reachable_states(cls) -> List[type]: async def run(self): try: - self.log(f"Entering state {self.name()}") + self.log(msg = f"Entering state {self.name()}") await self.action() except Exception as e: self.log(msg=f"Error in state {self.name()} | {str(e)}") @@ -95,9 +96,9 @@ async def run(self): self.set_next_state(StateError) finally: if self.next_state is None: - self.log(f"Moving out from state {self.name()}, no next state") + self.log(msg = f"Moving out from state {self.name()}, no next state") else: - self.log(f"Moving from state {self.name()} to state {self.next_state.name()}") + self.log(msg = f"Moving from state {self.name()} to state {self.next_state}") async def action(self): ... @@ -131,13 +132,16 @@ def reachable_states(cls) -> Iterable[type]: class StateInit(ExplaineeState): async def action(self): self.log(msg=f"Issuing query ${self.parent.query}") - message = messages.QueryMessage( + + message = messages.QueryMessage.create( query=self.parent.query, to=self.parent.recipient ) + await self.send(message) self.memory['history'].append(message) self.log(msg=f"Sent ${message}") + self.set_next_state(StateAwaitingRecommendation) @classmethod @@ -156,11 +160,12 @@ async def action(self): if not isinstance(message, messages.RecommendationMessage): raise RuntimeError("Last message is not a recommendation message") self.log(msg=f"Computing answer for recommendation: {message.recommendation}") - answer = self.parent.handle_recommendation(message) + answer = await self.parent.handle_recommendation(message) self.log(msg=f"Computed answer for recommendation {message.recommendation}: ${answer}") await self.send(answer) self.log(msg=f"Sent answer for recommendation {message.recommendation}: ${answer}") self.memory['history'].append(answer) + if isinstance(answer, messages.DisapproveMessage) or isinstance(answer, messages.CollisionMessage): self.set_next_state(StateAwaitingRecommendation) elif isinstance(answer, messages.WhyMessage): @@ -179,6 +184,7 @@ def reachable_states(cls) -> Iterable[type]: class StateAwaitingDetails(ExplaineeState): + async def action(self): message = await self.receive() if message is None: @@ -189,11 +195,12 @@ async def action(self): if not isinstance(message, messages.MoreDetailsMessage): raise RuntimeError("Last message is not an explanation message") self.log(msg=f"Computing answer for explanation: {message.explanation}") - answer = self.parent.handle_details(message) + answer = await self.parent.handle_details(message) self.log(msg=f"Computed answer for explanation {message.explanation}: ${answer}") await self.send(answer) self.log(msg=f"Sent answer for explanation {message.explanation}: ${answer}") self.memory['history'].append(answer) + if isinstance(answer, messages.DisapproveMessage) or isinstance(answer, messages.CollisionMessage): self.set_next_state(StateAwaitingRecommendation) elif isinstance(answer, messages.UnclearExplanationMessage): @@ -239,7 +246,7 @@ async def action(self): if not isinstance(message, messages.ComparisonMessage): raise RuntimeError("Last message is not a comparison message") self.log(msg=f"Computing answer for message: {message}") - answer = self.parent.handle_comparison(message) + answer = await self.parent.handle_comparison(message) self.log(msg=f"Computed answer for comparison: ${answer}") await self.send(answer) self.log(msg=f"Sent answer for comparison: ${answer}") @@ -263,7 +270,7 @@ async def action(self): if not isinstance(message, messages.InvalidAlternativeMessage): raise RuntimeError("Last message is not an invalid alternative message") self.log(msg=f"Computing answer for message: {message}") - answer = self.parent.handle_invalid_alternative(message) + answer = await self.parent.handle_invalid_alternative(message) self.log(msg=f"Computed answer for comparison: ${answer}") await self.send(answer) self.log(msg=f"Sent answer for comparison: ${answer}") diff --git a/pyxmas/protocol/recommender.py b/pyxmas/protocol/recommender.py index 6a655da..aaa3bd2 100644 --- a/pyxmas/protocol/recommender.py +++ b/pyxmas/protocol/recommender.py @@ -5,6 +5,7 @@ import pyxmas.protocol as protocol import pyxmas.protocol.messages as messages import pyxmas.protocol.data as data +import asyncio __all__ = [ @@ -40,6 +41,7 @@ def __init__(self, thread: str = None, impl: data.Types = None): super().__init__(thread, impl) def setup(self) -> None: + self.add_state(state = StateIdle, initial=True) for state in _get_all_state_classes(): self.add_transitions(state, state.reachable_states(), error=StateError) @@ -118,7 +120,7 @@ def reachable_states(cls) -> List[type]: async def run(self): try: - self.log(f"Entering state {self.name()}") + self.log(msg = f"Entering state {self.name()}") await self.action() except Exception as e: self.log(msg=f"Error in state {self.name()} | {str(e)}") @@ -126,9 +128,9 @@ async def run(self): self.set_next_state(StateError) finally: if self.next_state is None: - self.log(f"Moving out from state {self.name()}, no next state") + self.log(msg = f"Moving out from state {self.name()}, no next state") else: - self.log(f"Moving from state {self.name()} to state {self.next_state.name()}") + self.log(msg = f"Moving from state {self.name()} to state {self.next_state}") async def action(self): ... @@ -231,7 +233,9 @@ async def action(self): self.log(msg="Store last message in history") for message_type, handler in self._handlers.items(): if isinstance(message, message_type): - handler(message) + result = handler(message) + if asyncio.iscoroutine(result): + await result return del self.memory["history"][-1] self.log(msg="Remove last message from history") @@ -284,10 +288,10 @@ async def action(self): self.log(msg=f"Computed contrastive explanation: {explanation}") if is_valid: reply = message.make_comparison_reply(explanation) - self.log(f"Sending comparative reply {reply}") + self.log(msg = f"Sending comparative reply {reply}") else: reply = message.make_invalid_alternative_reply(explanation) - self.log(f"Sending invalid reply {reply}") + self.log(msg = f"Sending invalid reply {reply}") await self.send(reply.delegate) self.log(msg=f"Reply sent") self.set_next_state(StateWaitingComparisonFeedback if is_valid else StateWaitingInvalidFeedback) @@ -309,7 +313,7 @@ async def action(self): explanation = await self.parent.compute_explanation(message.query, message.recommendation) self.log(msg=f"Computed explanation: {explanation}") reply = message.make_more_details_reply(explanation) - self.log(f"Sending more details reply {reply}") + self.log(msg = f"Sending more details reply {reply}") await self.send(reply.delegate) self.log(msg=f"Reply sent") self.set_next_state(StateWaitingExplanationFeedback)