Skip to content

Commit

Permalink
handle dead unrollers
Browse files Browse the repository at this point in the history
  • Loading branch information
hnyu committed Dec 9, 2024
1 parent db924e1 commit fbf6bc6
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 6 deletions.
39 changes: 33 additions & 6 deletions alf/algorithms/distributed_off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
class UnrollerMessage(object):
# unroller indicates end of experience for the current segment
EXP_SEG_END = 'unroller: last_seg_exp'
# confirmation
OK = 'unroller: ok'


def get_local_ip():
Expand Down Expand Up @@ -248,6 +250,7 @@ def pull_params_from_trainer(shared_dict: dict, unroller_id: str):
while True:
shared_dict['params'] = socket.recv()
shared_dict['params_updated'] = True
socket.send_string(UnrollerMessage.OK)


@alf.configurable(whitelist=[
Expand Down Expand Up @@ -328,17 +331,36 @@ def _observe_for_replay(self, exp: Experience):
def is_main_ddp_rank(self):
return self._ddp_rank == 0

def _send_params_to_unroller(self, unroller_id: str):
# Need to first receive a message from the unroller so that
# send_multipart
def _send_params_to_unroller(self, unroller_id: str) -> bool:
"""Send model params to a specified unroller.
Args:
unroller_id: id (bytes str) of the unroller.
Returns:
bool: True if the unroller is still alive.
"""
# Get all parameters/buffers in a state dict and send them out
buffer = io.BytesIO()
torch.save(self._core_alg.state_dict(), buffer)
self._params_socket.send_multipart(
[unroller_id + b'_params',
buffer.getvalue()])
logging.debug(
f"[worker-0] Params sent to unroller {unroller_id.decode()}.")
success = False
for _ in range(100): # 1s in total for acknowledgement
try:
# In case some unrollers might die, we don't want to block forever
_, message = self._params_socket.recv_multipart(
flags=zmq.NOBLOCK)
assert message == UnrollerMessage.OK.encode()
logging.debug(
f"[worker-0] Params sent to unroller {unroller_id.decode()}."
)
success = True
break
except zmq.Again:
time.sleep(0.01)
return success

def _create_unroller_registration_thread(self):
self._new_unroller_ips_and_ports = mp.Queue()
Expand Down Expand Up @@ -445,8 +467,13 @@ def _train_iter_off_policy(self):
if (self.is_main_ddp_rank and alf.summary.get_global_counter() %
self._push_params_every_n_iters == 0):
# Sending params to all the connected unrollers.
dead_unrollers = []
for unroller_id in self._connected_unrollers:
self._send_params_to_unroller(unroller_id)
if not self._send_params_to_unroller(unroller_id):
dead_unrollers.append(unroller_id)
# remove dead unrollers
for unroller_id in dead_unrollers:
self._connected_unrollers.remove(unroller_id)

return steps

Expand Down
1 change: 1 addition & 0 deletions alf/examples/distributed_sac_cartpole_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import alf
import alf.algorithms.distributed_off_policy_algorithm

alf.import_config("sac_cart_pole_conf.py")
# Distributed training only supports a single environment
Expand Down

0 comments on commit fbf6bc6

Please sign in to comment.