Skip to content

Commit

Permalink
#104: Added broadcast collective operation (#169)
Browse files Browse the repository at this point in the history
* Add broadcast operation
* Add jlumbroso/free-disk-space to integration_tests_with_db.yaml to solve sporadic issues while pull docker-db images

---------

Co-authored-by: Nicola Coretti <[email protected]>
  • Loading branch information
tkilias and Nicoretti authored Oct 18, 2023
1 parent 769e874 commit dea65c5
Show file tree
Hide file tree
Showing 6 changed files with 514 additions and 1 deletion.
12 changes: 12 additions & 0 deletions .github/workflows/integration_tests_with_db.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@ jobs:
runs-on: ubuntu-latest

steps:

- name: Free Disk Space (Ubuntu)
uses: jlumbroso/[email protected]
with:
tool-cache: false
android: true
dotnet: true
haskell: true
large-packages: true
docker-images: true
swap-storage: false

- uses: actions/checkout@v2

- name: Setup Python & Poetry Environment
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from typing import Optional

import structlog
from structlog.typing import FilteringBoundLogger

from exasol_advanced_analytics_framework.udf_communication import messages
from exasol_advanced_analytics_framework.udf_communication.peer import Peer
from exasol_advanced_analytics_framework.udf_communication.peer_communicator import PeerCommunicator
from exasol_advanced_analytics_framework.udf_communication.serialization import serialize_message, deserialize_message
from exasol_advanced_analytics_framework.udf_communication.socket_factory.abstract import SocketFactory, Frame

_LOGGER: FilteringBoundLogger = structlog.getLogger()

LOCALHOST_LEADER_RANK = 0
MULTI_NODE_LEADER_RANK = 0


class BroadcastOperation:

def __init__(self,
sequence_number: int,
value: Optional[bytes],
localhost_communicator: PeerCommunicator,
multi_node_communicator: PeerCommunicator,
socket_factory: SocketFactory):
self._socket_factory = socket_factory
self._value = value
self._sequence_number = sequence_number
self._multi_node_communicator = multi_node_communicator
self._localhost_communicator = localhost_communicator
self._logger = _LOGGER.bind(
sequence_number=self._sequence_number,
)

def __call__(self) -> bytes:
if self._localhost_communicator.rank > LOCALHOST_LEADER_RANK:
return self._receive_from_localhost_leader()
return self._send_messages_to_local_peers()

def _receive_from_localhost_leader(self) -> bytes:
self._logger.info("_receive_from_localhost_leader")
leader = self._localhost_communicator.leader
frames = self._localhost_communicator.recv(peer=leader)
message = deserialize_message(frames[0].to_bytes(), messages.Message)
specific_message_obj = self._get_and_check_specific_message_obj(message)
self._check_sequence_number(specific_message_obj=specific_message_obj)
return frames[1].to_bytes()

def _send_messages_to_local_peers(self) -> bytes:
if self._multi_node_communicator.rank > 0:
return self._forward_from_multi_node_leader()
return self._send_messages_from_multi_node_leaders()

def _forward_from_multi_node_leader(self) -> bytes:
self._logger.info("_forward_from_multi_node_leader")
value_frame = self.receive_value_frame_from_multi_node_leader()
leader = self._localhost_communicator.leader
peers = [peer for peer in self._localhost_communicator.peers() if peer != leader]

for peer in peers:
frames = self._construct_broadcast_message(
destination=peer,
leader=leader,
value_frame=value_frame
)
self._localhost_communicator.send(peer=peer, message=frames)

return value_frame.to_bytes()

def receive_value_frame_from_multi_node_leader(self) -> Frame:
leader = self._multi_node_communicator.leader
frames = self._multi_node_communicator.recv(leader)
self._logger.info("received")
message = deserialize_message(frames[0].to_bytes(), messages.Message)
specific_message_obj = self._get_and_check_specific_message_obj(message)
self._check_sequence_number(specific_message_obj=specific_message_obj)
return frames[1]

def _send_messages_from_multi_node_leaders(self) -> bytes:
self._send_messages_to_local_leaders()
self._send_messages_to_local_peers_from_multi_node_leaders()
return self._value

def _send_messages_to_local_leaders(self):
if self._multi_node_communicator is None:
return

self._logger.info("_send_messages_to_local_leaders")
leader = self._multi_node_communicator.leader
peers = [peer for peer in self._multi_node_communicator.peers() if peer != leader]

for peer in peers:
value_frame = self._socket_factory.create_frame(self._value)
frames = self._construct_broadcast_message(
destination=peer,
leader=leader,
value_frame=value_frame
)
self._multi_node_communicator.send(peer=peer, message=frames)

def _send_messages_to_local_peers_from_multi_node_leaders(self):
self._logger.info("_send_messages_to_local_peers_from_multi_node_leaders")
leader = self._localhost_communicator.leader
peers = [p for p in self._localhost_communicator.peers() if p != leader]
for peer in peers:
value_frame = self._socket_factory.create_frame(self._value)
frames = self._construct_broadcast_message(
destination=peer,
leader=leader,
value_frame=value_frame
)
self._localhost_communicator.send(peer=peer, message=frames)

def _check_sequence_number(self, specific_message_obj: messages.Broadcast):
if specific_message_obj.sequence_number != self._sequence_number:
raise RuntimeError(
f"Got message with different sequence number. "
f"We expect the sequence number {self._sequence_number} "
f"but we got {self._sequence_number} in message {specific_message_obj}")

def _get_and_check_specific_message_obj(self, message: messages.Message) -> messages.Broadcast:
specific_message_obj = message.__root__
if not isinstance(specific_message_obj, messages.Broadcast):
raise TypeError(f"Received the wrong message type. "
f"Expected {messages.Broadcast.__name__} got {type(message)}. "
f"For message {message}.")
return specific_message_obj

def _construct_broadcast_message(self, destination: Peer, leader: Peer, value_frame: Frame):
message = messages.Broadcast(sequence_number=self._sequence_number,
destination=destination,
source=leader)
serialized_message = serialize_message(message)
frames = [
self._socket_factory.create_frame(serialized_message),
value_frame
]
return frames
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional, List

from exasol_advanced_analytics_framework.udf_communication.broadcast_operation import BroadcastOperation
from exasol_advanced_analytics_framework.udf_communication.discovery import localhost, multi_node
from exasol_advanced_analytics_framework.udf_communication.gather_operation import GatherOperation
from exasol_advanced_analytics_framework.udf_communication.ip_address import Port, IPAddress
Expand All @@ -9,6 +10,7 @@
LOCALHOST_LEADER_RANK = 0
MULTI_NODE_LEADER_RANK = 0


class Communicator:

def __init__(self,
Expand Down Expand Up @@ -95,6 +97,14 @@ def gather(self, value: bytes) -> Optional[List[bytes]]:
number_of_instances_per_node=self._number_of_instances_per_node)
return gather()

def broadcast(self, value: Optional[bytes]) -> bytes:
sequence_number = self._next_sequence_number()
operation = BroadcastOperation(sequence_number=sequence_number, value=value,
localhost_communicator=self._localhost_communicator,
multi_node_communicator=self._multi_node_communicator,
socket_factory=self._socket_factory)
return operation()

def is_multi_node_leader(self):
if self._multi_node_communicator is not None:
return self._multi_node_communicator.rank == MULTI_NODE_LEADER_RANK
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,13 @@ class Gather(BaseMessage, frozen=True):
position: int


class Broadcast(BaseMessage, frozen=True):
message_type: Literal["Broadcast"] = "Broadcast"
source: Peer
destination: Peer
sequence_number: int


class Message(BaseModel, frozen=True):
__root__: Union[
Ping,
Expand All @@ -146,5 +153,6 @@ class Message(BaseModel, frozen=True):
AcknowledgeCloseConnection,
ConnectionIsClosed,
Timeout,
Gather
Gather,
Broadcast
]
143 changes: 143 additions & 0 deletions tests/integration_tests/without_db/udf_communication/test_broadcast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import time
from pathlib import Path
from typing import List, Dict, Tuple

import structlog
import zmq
from structlog import WriteLoggerFactory
from structlog.tracebacks import ExceptionDictTransformer
from structlog.types import FilteringBoundLogger

from exasol_advanced_analytics_framework.udf_communication.communicator import Communicator
from exasol_advanced_analytics_framework.udf_communication.ip_address import Port, IPAddress
from exasol_advanced_analytics_framework.udf_communication.socket_factory.zmq_wrapper import ZMQSocketFactory
from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import \
ConditionalMethodDropper
from tests.integration_tests.without_db.udf_communication.peer_communication.utils import TestProcess, \
BidirectionalQueue, assert_processes_finish, \
CommunicatorTestProcessParameter

structlog.configure(
context_class=dict,
logger_factory=WriteLoggerFactory(file=Path(__file__).with_suffix(".log").open("wt")),
processors=[
structlog.contextvars.merge_contextvars,
ConditionalMethodDropper(method_name="debug"),
ConditionalMethodDropper(method_name="info"),
structlog.processors.add_log_level,
structlog.processors.TimeStamper(),
structlog.processors.ExceptionRenderer(exception_formatter=ExceptionDictTransformer(locals_max_string=320)),
structlog.processors.CallsiteParameterAdder(),
structlog.processors.JSONRenderer()
]
)

LOGGER: FilteringBoundLogger = structlog.get_logger(__name__)


def run(parameter: CommunicatorTestProcessParameter,
queue: BidirectionalQueue):
try:
is_discovery_leader_node = parameter.node_name == "n0"
context = zmq.Context()
socket_factory = ZMQSocketFactory(context)
communicator = Communicator(
multi_node_discovery_port=Port(port=44444),
local_discovery_port=parameter.local_discovery_port,
multi_node_discovery_ip=IPAddress(ip_address="127.0.0.1"),
node_name=parameter.node_name,
instance_name=parameter.instance_name,
listen_ip=IPAddress(ip_address="127.0.0.1"),
group_identifier=parameter.group_identifier,
number_of_nodes=parameter.number_of_nodes,
number_of_instances_per_node=parameter.number_of_instances_per_node,
is_discovery_leader_node=is_discovery_leader_node,
socket_factory=socket_factory
)
value = None
if communicator.is_multi_node_leader():
value = b"Success"
result = communicator.broadcast(value)
LOGGER.info("result", result=result, instance_name=parameter.instance_name, node_name=parameter.node_name)
queue.put(result.decode("utf-8"))
except Exception as e:
LOGGER.exception("Exception during test")
queue.put(f"Failed during test: {e}")


REPETITIONS_FOR_FUNCTIONALITY = 1


def test_functionality_2_1():
run_test_with_repetitions(number_of_nodes=2,
number_of_instances_per_node=1,
repetitions=REPETITIONS_FOR_FUNCTIONALITY)


def test_functionality_1_2():
run_test_with_repetitions(number_of_nodes=1,
number_of_instances_per_node=2,
repetitions=REPETITIONS_FOR_FUNCTIONALITY)


def test_functionality_2_2():
run_test_with_repetitions(number_of_nodes=2,
number_of_instances_per_node=2,
repetitions=REPETITIONS_FOR_FUNCTIONALITY)


def test_functionality_3_3():
run_test_with_repetitions(number_of_nodes=3,
number_of_instances_per_node=3,
repetitions=REPETITIONS_FOR_FUNCTIONALITY)


def run_test_with_repetitions(number_of_nodes: int, number_of_instances_per_node: int, repetitions: int):
for i in range(repetitions):
group = f"{time.monotonic_ns()}"
LOGGER.info(f"Start iteration",
iteration=i + 1,
repetitions=repetitions,
group_identifier=group,
number_of_nodes=number_of_nodes,
number_of_instances_per_node=number_of_instances_per_node)
start_time = time.monotonic()
expected_result_of_threads, actual_result_of_threads = \
run_test(group_identifier=group,
number_of_nodes=number_of_nodes,
number_of_instances_per_node=number_of_instances_per_node)
assert expected_result_of_threads == actual_result_of_threads
end_time = time.monotonic()
LOGGER.info(f"Finish iteration",
iteration=i + 1,
repetitions=repetitions,
group_identifier=group,
number_of_nodes=number_of_nodes,
number_of_instances_per_node=number_of_instances_per_node,
duration=end_time - start_time)


def run_test(group_identifier: str, number_of_nodes: int, number_of_instances_per_node: int):
parameters = [
CommunicatorTestProcessParameter(
node_name=f"n{n}",
instance_name=f"i{i}",
group_identifier=group_identifier,
number_of_nodes=number_of_nodes,
number_of_instances_per_node=number_of_instances_per_node,
local_discovery_port=Port(port=44445 + n),
seed=0)
for n in range(number_of_nodes)
for i in range(number_of_instances_per_node)]
processes: List[TestProcess[CommunicatorTestProcessParameter]] = \
[TestProcess(parameter, run=run) for parameter in parameters]
for process in processes:
process.start()
assert_processes_finish(processes, timeout_in_seconds=180)
actual_result_of_threads: Dict[Tuple[str, str], str] = {}
expected_result_of_threads: Dict[Tuple[str, str], str] = {}
for process in processes:
result_key = (process.parameter.node_name, process.parameter.instance_name)
actual_result_of_threads[result_key] = process.get()
expected_result_of_threads[result_key] = "Success"
return expected_result_of_threads, actual_result_of_threads
Loading

0 comments on commit dea65c5

Please sign in to comment.