-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
#104: Added broadcast collective operation #169
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,18 @@ jobs: | |
runs-on: ubuntu-latest | ||
|
||
steps: | ||
|
||
- name: Free Disk Space (Ubuntu) | ||
uses: jlumbroso/[email protected] | ||
with: | ||
tool-cache: false | ||
android: true | ||
dotnet: true | ||
haskell: true | ||
large-packages: true | ||
docker-images: true | ||
swap-storage: false | ||
|
||
- uses: actions/checkout@v2 | ||
|
||
- name: Setup Python & Poetry Environment | ||
|
138 changes: 138 additions & 0 deletions
138
exasol_advanced_analytics_framework/udf_communication/broadcast_operation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
from typing import Optional | ||
|
||
import structlog | ||
from structlog.typing import FilteringBoundLogger | ||
|
||
from exasol_advanced_analytics_framework.udf_communication import messages | ||
from exasol_advanced_analytics_framework.udf_communication.peer import Peer | ||
from exasol_advanced_analytics_framework.udf_communication.peer_communicator import PeerCommunicator | ||
from exasol_advanced_analytics_framework.udf_communication.serialization import serialize_message, deserialize_message | ||
from exasol_advanced_analytics_framework.udf_communication.socket_factory.abstract import SocketFactory, Frame | ||
|
||
_LOGGER: FilteringBoundLogger = structlog.getLogger() | ||
|
||
LOCALHOST_LEADER_RANK = 0 | ||
MULTI_NODE_LEADER_RANK = 0 | ||
|
||
|
||
class BroadcastOperation: | ||
|
||
def __init__(self, | ||
sequence_number: int, | ||
value: Optional[bytes], | ||
localhost_communicator: PeerCommunicator, | ||
multi_node_communicator: PeerCommunicator, | ||
socket_factory: SocketFactory): | ||
self._socket_factory = socket_factory | ||
self._value = value | ||
self._sequence_number = sequence_number | ||
self._multi_node_communicator = multi_node_communicator | ||
self._localhost_communicator = localhost_communicator | ||
self._logger = _LOGGER.bind( | ||
sequence_number=self._sequence_number, | ||
) | ||
|
||
def __call__(self) -> bytes: | ||
if self._localhost_communicator.rank > LOCALHOST_LEADER_RANK: | ||
return self._receive_from_localhost_leader() | ||
return self._send_messages_to_local_peers() | ||
|
||
def _receive_from_localhost_leader(self) -> bytes: | ||
self._logger.info("_receive_from_localhost_leader") | ||
leader = self._localhost_communicator.leader | ||
frames = self._localhost_communicator.recv(peer=leader) | ||
message = deserialize_message(frames[0].to_bytes(), messages.Message) | ||
specific_message_obj = self._get_and_check_specific_message_obj(message) | ||
self._check_sequence_number(specific_message_obj=specific_message_obj) | ||
return frames[1].to_bytes() | ||
|
||
def _send_messages_to_local_peers(self) -> bytes: | ||
if self._multi_node_communicator.rank > 0: | ||
return self._forward_from_multi_node_leader() | ||
return self._send_messages_from_multi_node_leaders() | ||
|
||
def _forward_from_multi_node_leader(self) -> bytes: | ||
self._logger.info("_forward_from_multi_node_leader") | ||
value_frame = self.receive_value_frame_from_multi_node_leader() | ||
leader = self._localhost_communicator.leader | ||
peers = [peer for peer in self._localhost_communicator.peers() if peer != leader] | ||
|
||
for peer in peers: | ||
frames = self._construct_broadcast_message( | ||
destination=peer, | ||
leader=leader, | ||
value_frame=value_frame | ||
) | ||
self._localhost_communicator.send(peer=peer, message=frames) | ||
|
||
return value_frame.to_bytes() | ||
|
||
def receive_value_frame_from_multi_node_leader(self) -> Frame: | ||
leader = self._multi_node_communicator.leader | ||
frames = self._multi_node_communicator.recv(leader) | ||
self._logger.info("received") | ||
message = deserialize_message(frames[0].to_bytes(), messages.Message) | ||
specific_message_obj = self._get_and_check_specific_message_obj(message) | ||
self._check_sequence_number(specific_message_obj=specific_message_obj) | ||
return frames[1] | ||
|
||
def _send_messages_from_multi_node_leaders(self) -> bytes: | ||
self._send_messages_to_local_leaders() | ||
self._send_messages_to_local_peers_from_multi_node_leaders() | ||
return self._value | ||
|
||
def _send_messages_to_local_leaders(self): | ||
if self._multi_node_communicator is None: | ||
return | ||
|
||
self._logger.info("_send_messages_to_local_leaders") | ||
leader = self._multi_node_communicator.leader | ||
peers = [peer for peer in self._multi_node_communicator.peers() if peer != leader] | ||
|
||
for peer in peers: | ||
value_frame = self._socket_factory.create_frame(self._value) | ||
frames = self._construct_broadcast_message( | ||
destination=peer, | ||
leader=leader, | ||
value_frame=value_frame | ||
) | ||
self._multi_node_communicator.send(peer=peer, message=frames) | ||
|
||
def _send_messages_to_local_peers_from_multi_node_leaders(self): | ||
self._logger.info("_send_messages_to_local_peers_from_multi_node_leaders") | ||
leader = self._localhost_communicator.leader | ||
peers = [p for p in self._localhost_communicator.peers() if p != leader] | ||
for peer in peers: | ||
value_frame = self._socket_factory.create_frame(self._value) | ||
frames = self._construct_broadcast_message( | ||
destination=peer, | ||
leader=leader, | ||
value_frame=value_frame | ||
) | ||
self._localhost_communicator.send(peer=peer, message=frames) | ||
|
||
def _check_sequence_number(self, specific_message_obj: messages.Broadcast): | ||
if specific_message_obj.sequence_number != self._sequence_number: | ||
raise RuntimeError( | ||
f"Got message with different sequence number. " | ||
f"We expect the sequence number {self._sequence_number} " | ||
f"but we got {self._sequence_number} in message {specific_message_obj}") | ||
|
||
def _get_and_check_specific_message_obj(self, message: messages.Message) -> messages.Broadcast: | ||
specific_message_obj = message.__root__ | ||
if not isinstance(specific_message_obj, messages.Broadcast): | ||
raise TypeError(f"Received the wrong message type. " | ||
f"Expected {messages.Broadcast.__name__} got {type(message)}. " | ||
f"For message {message}.") | ||
return specific_message_obj | ||
|
||
def _construct_broadcast_message(self, destination: Peer, leader: Peer, value_frame: Frame): | ||
message = messages.Broadcast(sequence_number=self._sequence_number, | ||
destination=destination, | ||
source=leader) | ||
serialized_message = serialize_message(message) | ||
frames = [ | ||
self._socket_factory.create_frame(serialized_message), | ||
value_frame | ||
] | ||
return frames |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
143 changes: 143 additions & 0 deletions
143
tests/integration_tests/without_db/udf_communication/test_broadcast.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
import time | ||
from pathlib import Path | ||
from typing import List, Dict, Tuple | ||
|
||
import structlog | ||
import zmq | ||
from structlog import WriteLoggerFactory | ||
from structlog.tracebacks import ExceptionDictTransformer | ||
from structlog.types import FilteringBoundLogger | ||
|
||
from exasol_advanced_analytics_framework.udf_communication.communicator import Communicator | ||
from exasol_advanced_analytics_framework.udf_communication.ip_address import Port, IPAddress | ||
from exasol_advanced_analytics_framework.udf_communication.socket_factory.zmq_wrapper import ZMQSocketFactory | ||
from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import \ | ||
ConditionalMethodDropper | ||
from tests.integration_tests.without_db.udf_communication.peer_communication.utils import TestProcess, \ | ||
BidirectionalQueue, assert_processes_finish, \ | ||
CommunicatorTestProcessParameter | ||
|
||
structlog.configure( | ||
context_class=dict, | ||
logger_factory=WriteLoggerFactory(file=Path(__file__).with_suffix(".log").open("wt")), | ||
processors=[ | ||
structlog.contextvars.merge_contextvars, | ||
ConditionalMethodDropper(method_name="debug"), | ||
ConditionalMethodDropper(method_name="info"), | ||
structlog.processors.add_log_level, | ||
structlog.processors.TimeStamper(), | ||
structlog.processors.ExceptionRenderer(exception_formatter=ExceptionDictTransformer(locals_max_string=320)), | ||
structlog.processors.CallsiteParameterAdder(), | ||
structlog.processors.JSONRenderer() | ||
] | ||
) | ||
|
||
LOGGER: FilteringBoundLogger = structlog.get_logger(__name__) | ||
|
||
|
||
def run(parameter: CommunicatorTestProcessParameter, | ||
queue: BidirectionalQueue): | ||
try: | ||
is_discovery_leader_node = parameter.node_name == "n0" | ||
context = zmq.Context() | ||
socket_factory = ZMQSocketFactory(context) | ||
communicator = Communicator( | ||
multi_node_discovery_port=Port(port=44444), | ||
local_discovery_port=parameter.local_discovery_port, | ||
multi_node_discovery_ip=IPAddress(ip_address="127.0.0.1"), | ||
node_name=parameter.node_name, | ||
instance_name=parameter.instance_name, | ||
listen_ip=IPAddress(ip_address="127.0.0.1"), | ||
group_identifier=parameter.group_identifier, | ||
number_of_nodes=parameter.number_of_nodes, | ||
number_of_instances_per_node=parameter.number_of_instances_per_node, | ||
is_discovery_leader_node=is_discovery_leader_node, | ||
socket_factory=socket_factory | ||
) | ||
value = None | ||
if communicator.is_multi_node_leader(): | ||
value = b"Success" | ||
result = communicator.broadcast(value) | ||
LOGGER.info("result", result=result, instance_name=parameter.instance_name, node_name=parameter.node_name) | ||
queue.put(result.decode("utf-8")) | ||
except Exception as e: | ||
LOGGER.exception("Exception during test") | ||
queue.put(f"Failed during test: {e}") | ||
|
||
|
||
REPETITIONS_FOR_FUNCTIONALITY = 1 | ||
|
||
|
||
def test_functionality_2_1(): | ||
run_test_with_repetitions(number_of_nodes=2, | ||
number_of_instances_per_node=1, | ||
repetitions=REPETITIONS_FOR_FUNCTIONALITY) | ||
|
||
|
||
def test_functionality_1_2(): | ||
run_test_with_repetitions(number_of_nodes=1, | ||
number_of_instances_per_node=2, | ||
repetitions=REPETITIONS_FOR_FUNCTIONALITY) | ||
|
||
|
||
def test_functionality_2_2(): | ||
run_test_with_repetitions(number_of_nodes=2, | ||
number_of_instances_per_node=2, | ||
repetitions=REPETITIONS_FOR_FUNCTIONALITY) | ||
|
||
|
||
def test_functionality_3_3(): | ||
run_test_with_repetitions(number_of_nodes=3, | ||
number_of_instances_per_node=3, | ||
repetitions=REPETITIONS_FOR_FUNCTIONALITY) | ||
|
||
|
||
def run_test_with_repetitions(number_of_nodes: int, number_of_instances_per_node: int, repetitions: int): | ||
for i in range(repetitions): | ||
group = f"{time.monotonic_ns()}" | ||
LOGGER.info(f"Start iteration", | ||
iteration=i + 1, | ||
repetitions=repetitions, | ||
group_identifier=group, | ||
number_of_nodes=number_of_nodes, | ||
number_of_instances_per_node=number_of_instances_per_node) | ||
start_time = time.monotonic() | ||
expected_result_of_threads, actual_result_of_threads = \ | ||
run_test(group_identifier=group, | ||
number_of_nodes=number_of_nodes, | ||
number_of_instances_per_node=number_of_instances_per_node) | ||
assert expected_result_of_threads == actual_result_of_threads | ||
end_time = time.monotonic() | ||
LOGGER.info(f"Finish iteration", | ||
iteration=i + 1, | ||
repetitions=repetitions, | ||
group_identifier=group, | ||
number_of_nodes=number_of_nodes, | ||
number_of_instances_per_node=number_of_instances_per_node, | ||
duration=end_time - start_time) | ||
|
||
|
||
def run_test(group_identifier: str, number_of_nodes: int, number_of_instances_per_node: int): | ||
parameters = [ | ||
CommunicatorTestProcessParameter( | ||
node_name=f"n{n}", | ||
instance_name=f"i{i}", | ||
group_identifier=group_identifier, | ||
number_of_nodes=number_of_nodes, | ||
number_of_instances_per_node=number_of_instances_per_node, | ||
local_discovery_port=Port(port=44445 + n), | ||
seed=0) | ||
for n in range(number_of_nodes) | ||
for i in range(number_of_instances_per_node)] | ||
processes: List[TestProcess[CommunicatorTestProcessParameter]] = \ | ||
[TestProcess(parameter, run=run) for parameter in parameters] | ||
for process in processes: | ||
process.start() | ||
assert_processes_finish(processes, timeout_in_seconds=180) | ||
actual_result_of_threads: Dict[Tuple[str, str], str] = {} | ||
expected_result_of_threads: Dict[Tuple[str, str], str] = {} | ||
for process in processes: | ||
result_key = (process.parameter.node_name, process.parameter.instance_name) | ||
actual_result_of_threads[result_key] = process.get() | ||
expected_result_of_threads[result_key] = "Success" | ||
return expected_result_of_threads, actual_result_of_threads |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we could use a slimmer image like
alpine
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a docker image, only GitHub provided images supported and they add all the garbage
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right, my bad 😬. sry cache was still occupied by the previous context 😅