Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib; Offline RL] Enable GPU and multi-GPU training for offline algorithms. #47929

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,19 @@ py_test(
args = ["--as-test", "--enable-new-api-stack"]
)

py_test(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesooomeee!!! This is so cool!

name = "learning_tests_cartpole_bc_gpu",
main = "tuned_examples/bc/cartpole_bc.py",
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "gpu"],
size = "medium",
srcs = ["tuned_examples/bc/cartpole_bc.py"],
# Include the offline data files.
data = [
"tests/data/cartpole/cartpole-v1_large",
],
args = ["--as-test", "--enable-new-api-stack", "--num-gpus=1"]
)

# CQL
# Pendulum
py_test(
Expand All @@ -356,6 +369,19 @@ py_test(
args = ["--as-test", "--enable-new-api-stack"]
)

py_test(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same! :D

name = "learning_tests_pendulum_cql_gpu",
main = "tuned_examples/cql/pendulum_cql.py",
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_cartpole", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "gpu"],
size = "large",
srcs = ["tuned_examples/cql/pendulum_cql.py"],
# Include the zipped json data file as well.
data = [
"tests/data/pendulum/pendulum-v1_enormous",
],
args = ["--as-test", "--enable-new-api-stack", "--num-gpus=1"]
)

# DQN
# CartPole
py_test(
Expand Down Expand Up @@ -564,6 +590,19 @@ py_test(
args = ["--as-test", "--enable-new-api-stack"]
)

py_test(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same same :D

name = "learning_tests_cartpole_marwil_gpu",
main = "tuned_examples/marwil/cartpole_marwil.py",
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "gpu"],
size = "large",
srcs = ["tuned_examples/marwil/cartpole_marwil.py"],
# Include the offline data files.
data = [
"tests/data/cartpole/cartpole-v1_large",
],
args = ["--as-test", "--enable-new-api-stack", "--num-gpus=1"]
)

# PPO
# CartPole
py_test(
Expand Down
6 changes: 3 additions & 3 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,13 +847,13 @@ def setup(self, config: AlgorithmConfig) -> None:
# Provide the actor handles for the learners for module
# updating during preprocessing.
self.offline_data.learner_handles = self.learner_group._workers
# Provide the module_spec. Note, in the remote case this is needed
# because the learner module cannot be copied, but must be built.
self.offline_data.module_spec = module_spec
# Otherwise we can simply pass in the local learner.
else:
self.offline_data.learner_handles = [self.learner_group._learner]

# Provide the module_spec. Note, in the remote case this is needed
# because the learner module cannot be copied, but must be built.
self.offline_data.module_spec = module_spec
# Provide the `OfflineData` instance with space information. It might
# need it for reading recorded experiences.
self.offline_data.spaces = self.env_runner_group.get_spaces()
Expand Down
1 change: 1 addition & 0 deletions rllib/algorithms/bc/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def build_learner_connector(
# Remove unneeded connectors from the MARWIL connector pipeline.
pipeline.remove("AddOneTsToEpisodesAndTruncate")
pipeline.remove("GeneralAdvantageEstimation")
pipeline.remove("NumpyToTensor")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aaahh! So this was one of the problems? That we were already converting everything to torch tensors?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it was one of them. The major one, was however, that we were passing in a learner that runs on GPU and that one needed to be serialized from ray to send it to the data workers. When deserializing it there it errored out.


return pipeline

Expand Down
1 change: 1 addition & 0 deletions rllib/algorithms/cql/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def build_learner_connector(
AddObservationsFromEpisodesToBatch,
AddNextObservationsFromEpisodesToTrainBatch(),
)
pipeline.remove("NumpyToTensor")

return pipeline

Expand Down
4 changes: 4 additions & 0 deletions rllib/algorithms/marwil/marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
from ray.rllib.algorithms.marwil.marwil_catalog import MARWILCatalog
from ray.rllib.connectors.common.tensor_to_numpy import TensorToNumpy
from ray.rllib.connectors.learner import (
AddObservationsFromEpisodesToBatch,
AddOneTsToEpisodesAndTruncate,
Expand Down Expand Up @@ -361,6 +362,9 @@ def build_learner_connector(
pipeline.append(
GeneralAdvantageEstimation(gamma=self.gamma, lambda_=self.lambda_)
)
pipeline.append(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. GAE connector outputs tensors already due to it requiring a VF forward pass (with the tensors coming from NumpyToTensor). It does get a little more complicated now in the pipeline, but I feel like it's still ok (not too crazy, connector pieces are named properly, each piece performs a well distinguished task, ...).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. And keep in mind: The data workers will run in parallel and prefetch batches which will actually make the pipeline quite smooth. Another connector piece or one less will not make a big difference, if at all. User usually have enough resources to run multiple data workers in parallel and they should.

TensorToNumpy(),
)
sven1977 marked this conversation as resolved.
Show resolved Hide resolved

return pipeline

Expand Down
30 changes: 20 additions & 10 deletions rllib/algorithms/marwil/tests/test_marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

import ray
import ray.rllib.algorithms.marwil as marwil
from ray.rllib.core import DEFAULT_MODULE_ID
from ray.rllib.core import COMPONENT_RL_MODULE, DEFAULT_MODULE_ID
from ray.rllib.core.columns import Columns
from ray.rllib.core.learner.learner import POLICY_LOSS_KEY, VF_LOSS_KEY
from ray.rllib.env import INPUT_ENV_SPACES
from ray.rllib.offline.offline_prelearner import OfflinePreLearner
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.metrics import (
Expand Down Expand Up @@ -121,7 +122,7 @@ def test_marwil_cont_actions_from_offline_file(self):
evaluation_parallel_to_training=True,
)
.training(
train_batch_size_per_learner=2000,
train_batch_size_per_learner=1024,
)
.offline_data(
# Learn from offline data.
Expand Down Expand Up @@ -167,15 +168,22 @@ def test_marwil_loss_function(self):
# Sample a batch from the offline data.
batch = algo.offline_data.data.take_batch(2000)

# Get the module state from learners.
module_state = algo.learner_group._learner.get_state(
components=COMPONENT_RL_MODULE
)[COMPONENT_RL_MODULE]
# Create the prelearner and compute advantages and values.
offline_prelearner = OfflinePreLearner(config, algo.learner_group._learner)
offline_prelearner = OfflinePreLearner(
config,
spaces=algo.offline_data.spaces[INPUT_ENV_SPACES],
module_spec=algo.offline_data.module_spec,
module_state=module_state,
)
# Note, for `ray.data`'s pipeline everything has to be a dictionary
# therefore the batch is embedded into another dictionary.
batch = offline_prelearner(batch)["batch"][0]
if Columns.LOSS_MASK in batch[DEFAULT_MODULE_ID]:
loss_mask = (
batch[DEFAULT_MODULE_ID][Columns.LOSS_MASK].detach().cpu().numpy()
)
loss_mask = batch[DEFAULT_MODULE_ID][Columns.LOSS_MASK]
num_valid = np.sum(loss_mask)

def possibly_masked_mean(data_):
Expand All @@ -186,13 +194,15 @@ def possibly_masked_mean(data_):

# Calculate our own expected values (to then compare against the
# agent's loss output).
tensor_batch = algo.learner_group._learner._convert_batch_type(batch)
tensor_batch = {k: v for k, v in tensor_batch[DEFAULT_MODULE_ID].items()}
fwd_out = (
algo.learner_group._learner.module[DEFAULT_MODULE_ID]
.unwrapped()
.forward_train({k: v for k, v in batch[DEFAULT_MODULE_ID].items()})
.forward_train(tensor_batch)
)
advantages = (
batch[DEFAULT_MODULE_ID][Columns.VALUE_TARGETS].detach().cpu().numpy()
batch[DEFAULT_MODULE_ID][Columns.VALUE_TARGETS]
- fwd_out["vf_preds"].detach().cpu().numpy()
)
advantages_squared = possibly_masked_mean(np.square(advantages))
Expand All @@ -207,7 +217,7 @@ def possibly_masked_mean(data_):
# Note we need the actual model's logits not the ones from the data set
# stored in `batch[Columns.ACTION_DIST_INPUTS]`.
action_dist = action_dist_cls.from_logits(fwd_out[Columns.ACTION_DIST_INPUTS])
logp = action_dist.logp(batch[DEFAULT_MODULE_ID][Columns.ACTIONS])
logp = action_dist.logp(tensor_batch[Columns.ACTIONS])
logp = logp.detach().cpu().numpy()

# Calculate all expected loss components.
Expand All @@ -219,7 +229,7 @@ def possibly_masked_mean(data_):
# calculation above).
total_loss = algo.learner_group._learner.compute_loss_for_module(
module_id=DEFAULT_MODULE_ID,
batch={k: v for k, v in batch[DEFAULT_MODULE_ID].items()},
batch=tensor_batch,
fwd_out=fwd_out,
config=config,
)
Expand Down
5 changes: 4 additions & 1 deletion rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1382,7 +1382,10 @@ def _update_from_batch_or_episodes(
# Convert input batch into a tensor batch (MultiAgentBatch) on the correct
# device (e.g. GPU). We move the batch already here to avoid having to move
# every single minibatch that is created in the `batch_iter` below.
if self._learner_connector is None:
# Note, if we have a learner connector, but a `MultiAgentBatch` is passed in,
# we are in an offline setting.
# TODO (simon, sven): Check, if DreamerV3 has the same setting.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll see what tests doing :) DreamerV3 does not use connectors. It passes the batch from the replay buffer directly into update_from_batch.

if self._learner_connector is None or episodes is None:
batch = self._convert_batch_type(batch)
batch = self._set_slicing_by_batch_id(batch, value=True)

Expand Down
40 changes: 18 additions & 22 deletions rllib/offline/offline_data.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import logging
from pathlib import Path
import pyarrow.fs
import ray
import time
import types

import ray
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.core import COMPONENT_RL_MODULE
from ray.rllib.core.learner import Learner
from ray.rllib.env import INPUT_ENV_SPACES
from ray.rllib.offline.offline_prelearner import OfflinePreLearner
from ray.rllib.utils.annotations import (
Expand Down Expand Up @@ -124,32 +125,27 @@ def sample(
# (b) Rematerialize the data every couple of iterations. This is
# is costly.
if not self.data_is_mapped:
# Constructor `kwargs` for the `OfflinePreLearner`.
fn_constructor_kwargs = {
"config": self.config,
"learner": self.learner_handles[0],
"spaces": self.spaces[INPUT_ENV_SPACES],
}
# If we have multiple learners, add to the constructor `kwargs`.
if num_shards > 1:
# Call here the learner to get an up-to-date module state.
# TODO (simon): This is a workaround as along as learners cannot
# receive any calls from another actor.
# Call here the learner to get an up-to-date module state.
# TODO (simon): This is a workaround as along as learners cannot
# receive any calls from another actor.
if num_shards > 1 or not isinstance(self.learner_handles[0], Learner):
module_state = ray.get(
self.learner_handles[0].get_state.remote(
component=COMPONENT_RL_MODULE
)
)
# Add constructor `kwargs` when using remote learners.
fn_constructor_kwargs.update(
{
"learner": self.learner_handles,
"locality_hints": self.locality_hints,
"module_spec": self.module_spec,
"module_state": module_state,
}
)
)[COMPONENT_RL_MODULE]
else:
module_state = self.learner_handles[0].get_state(
component=COMPONENT_RL_MODULE
)[COMPONENT_RL_MODULE]

# Constructor `kwargs` for the `OfflinePreLearner`.
fn_constructor_kwargs = {
"config": self.config,
"spaces": self.spaces[INPUT_ENV_SPACES],
"module_spec": self.module_spec,
"module_state": module_state,
}
self.data = self.data.map_batches(
self.prelearner_class,
fn_constructor_kwargs=fn_constructor_kwargs,
Expand Down
51 changes: 9 additions & 42 deletions rllib/offline/offline_prelearner.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import gymnasium as gym
import logging
import numpy as np
import random
from typing import Any, Dict, List, Optional, Union, Tuple, TYPE_CHECKING

import ray
from ray.actor import ActorHandle
from ray.rllib.core.columns import Columns
from ray.rllib.core.learner import Learner
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
Expand Down Expand Up @@ -57,23 +53,21 @@ class OfflinePreLearner:
This class is an essential part of the new `Offline RL API` of `RLlib`.
It is a callable class that is run in `ray.data.Dataset.map_batches`
when iterating over batches for training. It's basic function is to
convert data in batch from rows to episodes (`SingleAGentEpisode`s
convert data in batch from rows to episodes (`SingleAgentEpisode`s
for now) and to then run the learner connector pipeline to convert
further to trainable batches. These batches are used directly in the
`Learner`'s `update` method.

The main reason to run these transformations inside of `map_batches`
is for better performance. Batches can be pre-fetched in `ray.data`
and therefore batch trransformation can be run highly parallelized to
and therefore batch transformation can be run highly parallelized to
the `Learner''s `update`.

This class can be overridden to implement custom logic for transforming
batches and make them 'Learner'-ready. When deriving from this class
the `__call__` method and `_map_to_episodes` can be overridden to induce
custom logic for the complete transformation pipeline (`__call__`) or
for converting to episodes only ('_map_to_episodes`). For an example
how this class can be used to also compute values and advantages see
`rllib.algorithm.marwil.marwil_prelearner.MAWRILOfflinePreLearner`.
for converting to episodes only ('_map_to_episodes`).

Custom `OfflinePreLearner` classes can be passed into
`AlgorithmConfig.offline`'s `prelearner_class`. The `OfflineData` class
Expand All @@ -84,46 +78,19 @@ class OfflinePreLearner:
def __init__(
self,
config: "AlgorithmConfig",
learner: Union[Learner, list[ActorHandle]],
spaces: Optional[Tuple[gym.Space, gym.Space]] = None,
locality_hints: Optional[list] = None,
module_spec: Optional[MultiRLModuleSpec] = None,
module_state: Optional[Dict[ModuleID, Any]] = None,
):

self.config = config
self.input_read_episodes = self.config.input_read_episodes
self.input_read_sample_batches = self.config.input_read_sample_batches
# We need this learner to run the learner connector pipeline.
# If it is a `Learner` instance, the `Learner` is local.
if isinstance(learner, Learner):
self._learner = learner
self.learner_is_remote = False
self._module = self._learner._module
# Otherwise we have remote `Learner`s.
else:
# TODO (simon): Check with the data team how to get at
# initialization the data block location.
node_id = ray.get_runtime_context().get_node_id()
# Shuffle indices such that not each data block syncs weights
# with the same learner in case there are multiple learners
# on the same node like the `PreLearner`.
indices = list(range(len(locality_hints)))
random.shuffle(indices)
locality_hints = [locality_hints[i] for i in indices]
learner = [learner[i] for i in indices]
# Choose a learner from the same node.
for i, hint in enumerate(locality_hints):
if hint == node_id:
self._learner = learner[i]
# If no learner has been chosen, there is none on the same node.
if not self._learner:
# Then choose a learner randomly.
self._learner = learner[random.randint(0, len(learner) - 1)]
self.learner_is_remote = True
# Build the module from spec. Note, this will be a MultiRLModule.
self._module = module_spec.build()
self._module.set_state(module_state)

# Build the module from spec. Note, this will be a MultiRLModule.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice simplification!

# TODO (simon): Check, if this builds automatically on GPU if
# available.
self._module = module_spec.build()
self._module.set_state(module_state)

# Store the observation and action space if defined, otherwise we
# set them to `None`. Note, if `None` the `convert_from_jsonable`
Expand Down
4 changes: 2 additions & 2 deletions rllib/tuned_examples/bc/cartpole_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
# Concurrency defines the number of processes that run the
# `map_batches` transformations. This should be aligned with the
# 'prefetch_batches' argument in 'iter_batches_kwargs'.
map_batches_kwargs={"concurrency": 2, "num_cpus": 2},
map_batches_kwargs={"concurrency": 2, "num_cpus": 2, "num_gpus": 0},
# This data set is small so do not prefetch too many batches and use no
# local shuffle.
iter_batches_kwargs={
Expand All @@ -66,7 +66,7 @@
# mode in a single RLlib training iteration. Leave this to `None` to
# run an entire epoch on the dataset during a single RLlib training
# iteration. For single-learner mode 1 is the only option.
dataset_num_iters_per_learner=1 if args.num_gpus == 0 else None,
dataset_num_iters_per_learner=1 if args.num_gpus < 2 else None,
)
.training(
# To increase learning speed with multiple learners,
Expand Down