Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#104: Added broadcast collective operation #169

Merged
merged 4 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .github/workflows/integration_tests_with_db.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@ jobs:
runs-on: ubuntu-latest
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we could use a slimmer image like alpine?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a docker image, only GitHub provided images supported and they add all the garbage

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, my bad 😬. sry cache was still occupied by the previous context 😅


steps:

- name: Free Disk Space (Ubuntu)
uses: jlumbroso/[email protected]
with:
tool-cache: false
android: true
dotnet: true
haskell: true
large-packages: true
docker-images: true
swap-storage: false

- uses: actions/checkout@v2

- name: Setup Python & Poetry Environment
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from typing import Optional

import structlog
from structlog.typing import FilteringBoundLogger

from exasol_advanced_analytics_framework.udf_communication import messages
from exasol_advanced_analytics_framework.udf_communication.peer import Peer
from exasol_advanced_analytics_framework.udf_communication.peer_communicator import PeerCommunicator
from exasol_advanced_analytics_framework.udf_communication.serialization import serialize_message, deserialize_message
from exasol_advanced_analytics_framework.udf_communication.socket_factory.abstract import SocketFactory, Frame

_LOGGER: FilteringBoundLogger = structlog.getLogger()

LOCALHOST_LEADER_RANK = 0
MULTI_NODE_LEADER_RANK = 0


class BroadcastOperation:

def __init__(self,
sequence_number: int,
value: Optional[bytes],
localhost_communicator: PeerCommunicator,
multi_node_communicator: PeerCommunicator,
socket_factory: SocketFactory):
self._socket_factory = socket_factory
self._value = value
self._sequence_number = sequence_number
self._multi_node_communicator = multi_node_communicator
self._localhost_communicator = localhost_communicator
self._logger = _LOGGER.bind(
sequence_number=self._sequence_number,
)

def __call__(self) -> bytes:
if self._localhost_communicator.rank > LOCALHOST_LEADER_RANK:
return self._receive_from_localhost_leader()
return self._send_messages_to_local_peers()

def _receive_from_localhost_leader(self) -> bytes:
self._logger.info("_receive_from_localhost_leader")
leader = self._localhost_communicator.leader
frames = self._localhost_communicator.recv(peer=leader)
message = deserialize_message(frames[0].to_bytes(), messages.Message)
specific_message_obj = self._get_and_check_specific_message_obj(message)
self._check_sequence_number(specific_message_obj=specific_message_obj)
return frames[1].to_bytes()

def _send_messages_to_local_peers(self) -> bytes:
if self._multi_node_communicator.rank > 0:
return self._forward_from_multi_node_leader()
return self._send_messages_from_multi_node_leaders()

def _forward_from_multi_node_leader(self) -> bytes:
self._logger.info("_forward_from_multi_node_leader")
value_frame = self.receive_value_frame_from_multi_node_leader()
leader = self._localhost_communicator.leader
peers = [peer for peer in self._localhost_communicator.peers() if peer != leader]

for peer in peers:
frames = self._construct_broadcast_message(
destination=peer,
leader=leader,
value_frame=value_frame
)
self._localhost_communicator.send(peer=peer, message=frames)

return value_frame.to_bytes()

def receive_value_frame_from_multi_node_leader(self) -> Frame:
leader = self._multi_node_communicator.leader
frames = self._multi_node_communicator.recv(leader)
self._logger.info("received")
message = deserialize_message(frames[0].to_bytes(), messages.Message)
specific_message_obj = self._get_and_check_specific_message_obj(message)
self._check_sequence_number(specific_message_obj=specific_message_obj)
return frames[1]

def _send_messages_from_multi_node_leaders(self) -> bytes:
self._send_messages_to_local_leaders()
self._send_messages_to_local_peers_from_multi_node_leaders()
return self._value

def _send_messages_to_local_leaders(self):
if self._multi_node_communicator is None:
return

self._logger.info("_send_messages_to_local_leaders")
leader = self._multi_node_communicator.leader
peers = [peer for peer in self._multi_node_communicator.peers() if peer != leader]

for peer in peers:
value_frame = self._socket_factory.create_frame(self._value)
frames = self._construct_broadcast_message(
destination=peer,
leader=leader,
value_frame=value_frame
)
self._multi_node_communicator.send(peer=peer, message=frames)

def _send_messages_to_local_peers_from_multi_node_leaders(self):
self._logger.info("_send_messages_to_local_peers_from_multi_node_leaders")
leader = self._localhost_communicator.leader
peers = [p for p in self._localhost_communicator.peers() if p != leader]
for peer in peers:
value_frame = self._socket_factory.create_frame(self._value)
frames = self._construct_broadcast_message(
destination=peer,
leader=leader,
value_frame=value_frame
)
self._localhost_communicator.send(peer=peer, message=frames)

def _check_sequence_number(self, specific_message_obj: messages.Broadcast):
if specific_message_obj.sequence_number != self._sequence_number:
raise RuntimeError(
f"Got message with different sequence number. "
f"We expect the sequence number {self._sequence_number} "
f"but we got {self._sequence_number} in message {specific_message_obj}")

def _get_and_check_specific_message_obj(self, message: messages.Message) -> messages.Broadcast:
specific_message_obj = message.__root__
if not isinstance(specific_message_obj, messages.Broadcast):
raise TypeError(f"Received the wrong message type. "
f"Expected {messages.Broadcast.__name__} got {type(message)}. "
f"For message {message}.")
return specific_message_obj

def _construct_broadcast_message(self, destination: Peer, leader: Peer, value_frame: Frame):
message = messages.Broadcast(sequence_number=self._sequence_number,
destination=destination,
source=leader)
serialized_message = serialize_message(message)
frames = [
self._socket_factory.create_frame(serialized_message),
value_frame
]
return frames
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional, List

from exasol_advanced_analytics_framework.udf_communication.broadcast_operation import BroadcastOperation
from exasol_advanced_analytics_framework.udf_communication.discovery import localhost, multi_node
from exasol_advanced_analytics_framework.udf_communication.gather_operation import GatherOperation
from exasol_advanced_analytics_framework.udf_communication.ip_address import Port, IPAddress
Expand All @@ -9,6 +10,7 @@
LOCALHOST_LEADER_RANK = 0
MULTI_NODE_LEADER_RANK = 0


class Communicator:

def __init__(self,
Expand Down Expand Up @@ -95,6 +97,14 @@ def gather(self, value: bytes) -> Optional[List[bytes]]:
number_of_instances_per_node=self._number_of_instances_per_node)
return gather()

def broadcast(self, value: Optional[bytes]) -> bytes:
sequence_number = self._next_sequence_number()
operation = BroadcastOperation(sequence_number=sequence_number, value=value,
localhost_communicator=self._localhost_communicator,
multi_node_communicator=self._multi_node_communicator,
socket_factory=self._socket_factory)
return operation()

def is_multi_node_leader(self):
if self._multi_node_communicator is not None:
return self._multi_node_communicator.rank == MULTI_NODE_LEADER_RANK
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,13 @@ class Gather(BaseMessage, frozen=True):
position: int


class Broadcast(BaseMessage, frozen=True):
message_type: Literal["Broadcast"] = "Broadcast"
source: Peer
destination: Peer
sequence_number: int


class Message(BaseModel, frozen=True):
__root__: Union[
Ping,
Expand All @@ -146,5 +153,6 @@ class Message(BaseModel, frozen=True):
AcknowledgeCloseConnection,
ConnectionIsClosed,
Timeout,
Gather
Gather,
Broadcast
]
143 changes: 143 additions & 0 deletions tests/integration_tests/without_db/udf_communication/test_broadcast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import time
from pathlib import Path
from typing import List, Dict, Tuple

import structlog
import zmq
from structlog import WriteLoggerFactory
from structlog.tracebacks import ExceptionDictTransformer
from structlog.types import FilteringBoundLogger

from exasol_advanced_analytics_framework.udf_communication.communicator import Communicator
from exasol_advanced_analytics_framework.udf_communication.ip_address import Port, IPAddress
from exasol_advanced_analytics_framework.udf_communication.socket_factory.zmq_wrapper import ZMQSocketFactory
from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import \
ConditionalMethodDropper
from tests.integration_tests.without_db.udf_communication.peer_communication.utils import TestProcess, \
BidirectionalQueue, assert_processes_finish, \
CommunicatorTestProcessParameter

structlog.configure(
context_class=dict,
logger_factory=WriteLoggerFactory(file=Path(__file__).with_suffix(".log").open("wt")),
processors=[
structlog.contextvars.merge_contextvars,
ConditionalMethodDropper(method_name="debug"),
ConditionalMethodDropper(method_name="info"),
structlog.processors.add_log_level,
structlog.processors.TimeStamper(),
structlog.processors.ExceptionRenderer(exception_formatter=ExceptionDictTransformer(locals_max_string=320)),
structlog.processors.CallsiteParameterAdder(),
structlog.processors.JSONRenderer()
]
)

LOGGER: FilteringBoundLogger = structlog.get_logger(__name__)


def run(parameter: CommunicatorTestProcessParameter,
queue: BidirectionalQueue):
try:
is_discovery_leader_node = parameter.node_name == "n0"
context = zmq.Context()
socket_factory = ZMQSocketFactory(context)
communicator = Communicator(
multi_node_discovery_port=Port(port=44444),
local_discovery_port=parameter.local_discovery_port,
multi_node_discovery_ip=IPAddress(ip_address="127.0.0.1"),
node_name=parameter.node_name,
instance_name=parameter.instance_name,
listen_ip=IPAddress(ip_address="127.0.0.1"),
group_identifier=parameter.group_identifier,
number_of_nodes=parameter.number_of_nodes,
number_of_instances_per_node=parameter.number_of_instances_per_node,
is_discovery_leader_node=is_discovery_leader_node,
socket_factory=socket_factory
)
value = None
if communicator.is_multi_node_leader():
value = b"Success"
result = communicator.broadcast(value)
LOGGER.info("result", result=result, instance_name=parameter.instance_name, node_name=parameter.node_name)
queue.put(result.decode("utf-8"))
except Exception as e:
LOGGER.exception("Exception during test")
queue.put(f"Failed during test: {e}")


REPETITIONS_FOR_FUNCTIONALITY = 1


def test_functionality_2_1():
run_test_with_repetitions(number_of_nodes=2,
number_of_instances_per_node=1,
repetitions=REPETITIONS_FOR_FUNCTIONALITY)


def test_functionality_1_2():
run_test_with_repetitions(number_of_nodes=1,
number_of_instances_per_node=2,
repetitions=REPETITIONS_FOR_FUNCTIONALITY)


def test_functionality_2_2():
run_test_with_repetitions(number_of_nodes=2,
number_of_instances_per_node=2,
repetitions=REPETITIONS_FOR_FUNCTIONALITY)


def test_functionality_3_3():
run_test_with_repetitions(number_of_nodes=3,
number_of_instances_per_node=3,
repetitions=REPETITIONS_FOR_FUNCTIONALITY)


def run_test_with_repetitions(number_of_nodes: int, number_of_instances_per_node: int, repetitions: int):
for i in range(repetitions):
group = f"{time.monotonic_ns()}"
LOGGER.info(f"Start iteration",
iteration=i + 1,
repetitions=repetitions,
group_identifier=group,
number_of_nodes=number_of_nodes,
number_of_instances_per_node=number_of_instances_per_node)
start_time = time.monotonic()
expected_result_of_threads, actual_result_of_threads = \
run_test(group_identifier=group,
number_of_nodes=number_of_nodes,
number_of_instances_per_node=number_of_instances_per_node)
assert expected_result_of_threads == actual_result_of_threads
end_time = time.monotonic()
LOGGER.info(f"Finish iteration",
iteration=i + 1,
repetitions=repetitions,
group_identifier=group,
number_of_nodes=number_of_nodes,
number_of_instances_per_node=number_of_instances_per_node,
duration=end_time - start_time)


def run_test(group_identifier: str, number_of_nodes: int, number_of_instances_per_node: int):
parameters = [
CommunicatorTestProcessParameter(
node_name=f"n{n}",
instance_name=f"i{i}",
group_identifier=group_identifier,
number_of_nodes=number_of_nodes,
number_of_instances_per_node=number_of_instances_per_node,
local_discovery_port=Port(port=44445 + n),
seed=0)
for n in range(number_of_nodes)
for i in range(number_of_instances_per_node)]
processes: List[TestProcess[CommunicatorTestProcessParameter]] = \
[TestProcess(parameter, run=run) for parameter in parameters]
for process in processes:
process.start()
assert_processes_finish(processes, timeout_in_seconds=180)
actual_result_of_threads: Dict[Tuple[str, str], str] = {}
expected_result_of_threads: Dict[Tuple[str, str], str] = {}
for process in processes:
result_key = (process.parameter.node_name, process.parameter.instance_name)
actual_result_of_threads[result_key] = process.get()
expected_result_of_threads[result_key] = "Success"
return expected_result_of_threads, actual_result_of_threads
Loading
Loading