Skip to content

Commit

Permalink
fix: explainee and recommender behavouir abstract classes
Browse files Browse the repository at this point in the history
  • Loading branch information
emrekuruu committed Nov 6, 2023
1 parent c037962 commit 0146f45
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 15 deletions.
23 changes: 15 additions & 8 deletions pyxmas/protocol/explainee.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, query: data.Query, recipient: str, thread: str = None, impl:
self._recipient = recipient

def setup(self) -> None:
self.add_state(StateInit, initial=True)
for state in _get_all_state_classes():
self.add_transitions(state, state.reachable_states(), error=StateError)

Expand Down Expand Up @@ -87,17 +88,17 @@ def reachable_states(cls) -> List[type]:

async def run(self):
try:
self.log(f"Entering state {self.name()}")
self.log(msg = f"Entering state {self.name()}")
await self.action()
except Exception as e:
self.log(msg=f"Error in state {self.name()} | {str(e)}")
self.memory["last_error"] = e
self.set_next_state(StateError)
finally:
if self.next_state is None:
self.log(f"Moving out from state {self.name()}, no next state")
self.log(msg = f"Moving out from state {self.name()}, no next state")
else:
self.log(f"Moving from state {self.name()} to state {self.next_state.name()}")
self.log(msg = f"Moving from state {self.name()} to state {self.next_state}")

async def action(self):
...
Expand Down Expand Up @@ -131,13 +132,16 @@ def reachable_states(cls) -> Iterable[type]:
class StateInit(ExplaineeState):
async def action(self):
self.log(msg=f"Issuing query ${self.parent.query}")
message = messages.QueryMessage(

message = messages.QueryMessage.create(
query=self.parent.query,
to=self.parent.recipient
)

await self.send(message)
self.memory['history'].append(message)
self.log(msg=f"Sent ${message}")

self.set_next_state(StateAwaitingRecommendation)

@classmethod
Expand All @@ -156,11 +160,12 @@ async def action(self):
if not isinstance(message, messages.RecommendationMessage):
raise RuntimeError("Last message is not a recommendation message")
self.log(msg=f"Computing answer for recommendation: {message.recommendation}")
answer = self.parent.handle_recommendation(message)
answer = await self.parent.handle_recommendation(message)
self.log(msg=f"Computed answer for recommendation {message.recommendation}: ${answer}")
await self.send(answer)
self.log(msg=f"Sent answer for recommendation {message.recommendation}: ${answer}")
self.memory['history'].append(answer)

if isinstance(answer, messages.DisapproveMessage) or isinstance(answer, messages.CollisionMessage):
self.set_next_state(StateAwaitingRecommendation)
elif isinstance(answer, messages.WhyMessage):
Expand All @@ -179,6 +184,7 @@ def reachable_states(cls) -> Iterable[type]:


class StateAwaitingDetails(ExplaineeState):

async def action(self):
message = await self.receive()
if message is None:
Expand All @@ -189,11 +195,12 @@ async def action(self):
if not isinstance(message, messages.MoreDetailsMessage):
raise RuntimeError("Last message is not an explanation message")
self.log(msg=f"Computing answer for explanation: {message.explanation}")
answer = self.parent.handle_details(message)
answer = await self.parent.handle_details(message)
self.log(msg=f"Computed answer for explanation {message.explanation}: ${answer}")
await self.send(answer)
self.log(msg=f"Sent answer for explanation {message.explanation}: ${answer}")
self.memory['history'].append(answer)

if isinstance(answer, messages.DisapproveMessage) or isinstance(answer, messages.CollisionMessage):
self.set_next_state(StateAwaitingRecommendation)
elif isinstance(answer, messages.UnclearExplanationMessage):
Expand Down Expand Up @@ -239,7 +246,7 @@ async def action(self):
if not isinstance(message, messages.ComparisonMessage):
raise RuntimeError("Last message is not a comparison message")
self.log(msg=f"Computing answer for message: {message}")
answer = self.parent.handle_comparison(message)
answer = await self.parent.handle_comparison(message)
self.log(msg=f"Computed answer for comparison: ${answer}")
await self.send(answer)
self.log(msg=f"Sent answer for comparison: ${answer}")
Expand All @@ -263,7 +270,7 @@ async def action(self):
if not isinstance(message, messages.InvalidAlternativeMessage):
raise RuntimeError("Last message is not an invalid alternative message")
self.log(msg=f"Computing answer for message: {message}")
answer = self.parent.handle_invalid_alternative(message)
answer = await self.parent.handle_invalid_alternative(message)
self.log(msg=f"Computed answer for comparison: ${answer}")
await self.send(answer)
self.log(msg=f"Sent answer for comparison: ${answer}")
Expand Down
18 changes: 11 additions & 7 deletions pyxmas/protocol/recommender.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pyxmas.protocol as protocol
import pyxmas.protocol.messages as messages
import pyxmas.protocol.data as data
import asyncio


__all__ = [
Expand Down Expand Up @@ -40,6 +41,7 @@ def __init__(self, thread: str = None, impl: data.Types = None):
super().__init__(thread, impl)

def setup(self) -> None:
self.add_state(state = StateIdle, initial=True)
for state in _get_all_state_classes():
self.add_transitions(state, state.reachable_states(), error=StateError)

Expand Down Expand Up @@ -118,17 +120,17 @@ def reachable_states(cls) -> List[type]:

async def run(self):
try:
self.log(f"Entering state {self.name()}")
self.log(msg = f"Entering state {self.name()}")
await self.action()
except Exception as e:
self.log(msg=f"Error in state {self.name()} | {str(e)}")
self.memory["last_error"] = e
self.set_next_state(StateError)
finally:
if self.next_state is None:
self.log(f"Moving out from state {self.name()}, no next state")
self.log(msg = f"Moving out from state {self.name()}, no next state")
else:
self.log(f"Moving from state {self.name()} to state {self.next_state.name()}")
self.log(msg = f"Moving from state {self.name()} to state {self.next_state}")

async def action(self):
...
Expand Down Expand Up @@ -231,7 +233,9 @@ async def action(self):
self.log(msg="Store last message in history")
for message_type, handler in self._handlers.items():
if isinstance(message, message_type):
handler(message)
result = handler(message)
if asyncio.iscoroutine(result):
await result
return
del self.memory["history"][-1]
self.log(msg="Remove last message from history")
Expand Down Expand Up @@ -284,10 +288,10 @@ async def action(self):
self.log(msg=f"Computed contrastive explanation: {explanation}")
if is_valid:
reply = message.make_comparison_reply(explanation)
self.log(f"Sending comparative reply {reply}")
self.log(msg = f"Sending comparative reply {reply}")
else:
reply = message.make_invalid_alternative_reply(explanation)
self.log(f"Sending invalid reply {reply}")
self.log(msg = f"Sending invalid reply {reply}")
await self.send(reply.delegate)
self.log(msg=f"Reply sent")
self.set_next_state(StateWaitingComparisonFeedback if is_valid else StateWaitingInvalidFeedback)
Expand All @@ -309,7 +313,7 @@ async def action(self):
explanation = await self.parent.compute_explanation(message.query, message.recommendation)
self.log(msg=f"Computed explanation: {explanation}")
reply = message.make_more_details_reply(explanation)
self.log(f"Sending more details reply {reply}")
self.log(msg = f"Sending more details reply {reply}")
await self.send(reply.delegate)
self.log(msg=f"Reply sent")
self.set_next_state(StateWaitingExplanationFeedback)
Expand Down

0 comments on commit 0146f45

Please sign in to comment.