Skip to content

Commit

Permalink
use SharedMemory for unroller params pulling
Browse files Browse the repository at this point in the history
  • Loading branch information
hnyu committed Dec 12, 2024
1 parent fbf6bc6 commit 7866483
Showing 1 changed file with 35 additions and 17 deletions.
52 changes: 35 additions & 17 deletions alf/algorithms/distributed_off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import torch
import torch.multiprocessing as mp
from multiprocessing.shared_memory import SharedMemory

import alf
from alf.algorithms.off_policy_algorithm import OffPolicyAlgorithm
Expand Down Expand Up @@ -152,6 +153,18 @@ def __init__(self,
self._port = port
self._ddp_rank = max(0, PerProcessContext().ddp_rank)

def _opt_free_state_dict(self) -> dict:
"""Return `self._core_alg` state dict without optimizers.
This dict will be used for param syncing between a trainer and an unroller.
Sometimes optimizers have large state vectors which we want to exclude.
"""
return {
k: v
for k, v in self._core_alg.state_dict().items()
if '_optimizers.' not in k
}

###############################
######### Forward calls #######
###############################
Expand Down Expand Up @@ -238,18 +251,20 @@ def receive_experience_data(replay_buffer: ReplayBuffer,
time.sleep(0.1)


def pull_params_from_trainer(shared_dict: dict, unroller_id: str):
""" Once new params arrive, we put it in the shared dict and mark the
``params_updated`` as True. Later after the current unroll finishes,
the unroller can load the new params.
def pull_params_from_trainer(memory_name: str, unroller_id: str):
""" Once new params arrive, we put it in the shared memory and mark updated.
Later after the current unroll finishes, the unroller can load the
new params.
"""
socket, _ = create_zmq_socket(
zmq.DEALER, _trainer_addr_config.ip,
_trainer_addr_config.port + _params_port_offset,
unroller_id + "_params")
params = SharedMemory(name=memory_name)
while True:
shared_dict['params'] = socket.recv()
shared_dict['params_updated'] = True
data = socket.recv()
params.buf[:1] = b'1'
params.buf[1:] = data
socket.send_string(UnrollerMessage.OK)


Expand Down Expand Up @@ -342,7 +357,7 @@ def _send_params_to_unroller(self, unroller_id: str) -> bool:
"""
# Get all parameters/buffers in a state dict and send them out
buffer = io.BytesIO()
torch.save(self._core_alg.state_dict(), buffer)
torch.save(self._opt_free_state_dict(), buffer)
self._params_socket.send_multipart(
[unroller_id + b'_params',
buffer.getvalue()])
Expand Down Expand Up @@ -571,17 +586,18 @@ def _register_to_trainer(self):
def _create_pull_params_subprocess(self):
# Compute the total size of the params
buffer = io.BytesIO()
torch.save(self._core_alg.state_dict(), buffer)
torch.save(self._opt_free_state_dict(), buffer)
size = len(buffer.getvalue())
# Create a shared dict
self._shared_dict = mp.Manager().dict()
self._shared_dict['params_updated'] = False
self._shared_dict['params'] = bytes(size)
# Create a shared memory object to store the new params
# The first char indicates whether the params have been updated
self._params = SharedMemory(create=True, size=size + 1, name='params')
# Initialize the update char to False (not updated)
self._params.buf[:1] = b'0'

mp.set_start_method('fork', force=True)
process = mp.Process(
target=pull_params_from_trainer,
args=(self._shared_dict, self._id),
args=(self._params.name, self._id),
daemon=True)
process.start()

Expand Down Expand Up @@ -628,12 +644,14 @@ def _check_paramss_update(self) -> bool:
"""Returns True if params have been updated.
"""
# Check if the params have been updated
if self._shared_dict['params_updated']:
buffer = io.BytesIO(self._shared_dict['params'])
if bytes(self._params.buf[:1]) == b'1':
params = bytes(self._params.buf[1:])
buffer = io.BytesIO(params)
state_dict = torch.load(buffer, map_location='cpu')
self._core_alg.load_state_dict(state_dict)
# We might only update part of the params
self._core_alg.load_state_dict(state_dict, strict=False)
logging.debug("Params updated from the trainer.")
self._shared_dict['params_updated'] = False
self._params.buf[:1] = b'0'
return True
return False

Expand Down

0 comments on commit 7866483

Please sign in to comment.