-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib; Offline RL] Enable GPU and multi-GPU training for offline algorithms. #47929
base: master
Are you sure you want to change the base?
Changes from 5 commits
fa6f6f8
a7c2335
bd322ba
c612eea
a4a0f77
21039a8
f320c66
5dcd344
ec4d65c
bba1346
ed6c7c1
ed170bf
348589b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -341,6 +341,19 @@ py_test( | |
args = ["--as-test", "--enable-new-api-stack"] | ||
) | ||
|
||
py_test( | ||
name = "learning_tests_cartpole_bc_gpu", | ||
main = "tuned_examples/bc/cartpole_bc.py", | ||
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "gpu"], | ||
size = "medium", | ||
srcs = ["tuned_examples/bc/cartpole_bc.py"], | ||
# Include the offline data files. | ||
data = [ | ||
"tests/data/cartpole/cartpole-v1_large", | ||
], | ||
args = ["--as-test", "--enable-new-api-stack", "--num-gpus=1"] | ||
) | ||
|
||
# CQL | ||
# Pendulum | ||
py_test( | ||
|
@@ -356,6 +369,19 @@ py_test( | |
args = ["--as-test", "--enable-new-api-stack"] | ||
) | ||
|
||
py_test( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same! :D |
||
name = "learning_tests_pendulum_cql_gpu", | ||
main = "tuned_examples/cql/pendulum_cql.py", | ||
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_cartpole", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "gpu"], | ||
size = "large", | ||
srcs = ["tuned_examples/cql/pendulum_cql.py"], | ||
# Include the zipped json data file as well. | ||
data = [ | ||
"tests/data/pendulum/pendulum-v1_enormous", | ||
], | ||
args = ["--as-test", "--enable-new-api-stack", "--num-gpus=1"] | ||
) | ||
|
||
# DQN | ||
# CartPole | ||
py_test( | ||
|
@@ -564,6 +590,19 @@ py_test( | |
args = ["--as-test", "--enable-new-api-stack"] | ||
) | ||
|
||
py_test( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same same :D |
||
name = "learning_tests_cartpole_marwil_gpu", | ||
main = "tuned_examples/marwil/cartpole_marwil.py", | ||
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "gpu"], | ||
size = "large", | ||
srcs = ["tuned_examples/marwil/cartpole_marwil.py"], | ||
# Include the offline data files. | ||
data = [ | ||
"tests/data/cartpole/cartpole-v1_large", | ||
], | ||
args = ["--as-test", "--enable-new-api-stack", "--num-gpus=1"] | ||
) | ||
|
||
# PPO | ||
# CartPole | ||
py_test( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -98,6 +98,7 @@ def build_learner_connector( | |
# Remove unneeded connectors from the MARWIL connector pipeline. | ||
pipeline.remove("AddOneTsToEpisodesAndTruncate") | ||
pipeline.remove("GeneralAdvantageEstimation") | ||
pipeline.remove("NumpyToTensor") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aaahh! So this was one of the problems? That we were already converting everything to torch tensors? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes it was one of them. The major one, was however, that we were passing in a learner that runs on GPU and that one needed to be serialized from ray to send it to the data workers. When deserializing it there it errored out. |
||
|
||
return pipeline | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
from ray.rllib.algorithms.algorithm import Algorithm | ||
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided | ||
from ray.rllib.algorithms.marwil.marwil_catalog import MARWILCatalog | ||
from ray.rllib.connectors.common.tensor_to_numpy import TensorToNumpy | ||
from ray.rllib.connectors.learner import ( | ||
AddObservationsFromEpisodesToBatch, | ||
AddOneTsToEpisodesAndTruncate, | ||
|
@@ -361,6 +362,9 @@ def build_learner_connector( | |
pipeline.append( | ||
GeneralAdvantageEstimation(gamma=self.gamma, lambda_=self.lambda_) | ||
) | ||
pipeline.append( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense. GAE connector outputs tensors already due to it requiring a VF forward pass (with the tensors coming from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. And keep in mind: The data workers will run in parallel and prefetch batches which will actually make the pipeline quite smooth. Another connector piece or one less will not make a big difference, if at all. User usually have enough resources to run multiple data workers in parallel and they should. |
||
TensorToNumpy(), | ||
) | ||
sven1977 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return pipeline | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1382,7 +1382,10 @@ def _update_from_batch_or_episodes( | |
# Convert input batch into a tensor batch (MultiAgentBatch) on the correct | ||
# device (e.g. GPU). We move the batch already here to avoid having to move | ||
# every single minibatch that is created in the `batch_iter` below. | ||
if self._learner_connector is None: | ||
# Note, if we have a learner connector, but a `MultiAgentBatch` is passed in, | ||
# we are in an offline setting. | ||
# TODO (simon, sven): Check, if DreamerV3 has the same setting. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We'll see what tests doing :) DreamerV3 does not use connectors. It passes the batch from the replay buffer directly into |
||
if self._learner_connector is None or episodes is None: | ||
batch = self._convert_batch_type(batch) | ||
batch = self._set_slicing_by_batch_id(batch, value=True) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,9 @@ | ||
import gymnasium as gym | ||
import logging | ||
import numpy as np | ||
import random | ||
from typing import Any, Dict, List, Optional, Union, Tuple, TYPE_CHECKING | ||
|
||
import ray | ||
from ray.actor import ActorHandle | ||
from ray.rllib.core.columns import Columns | ||
from ray.rllib.core.learner import Learner | ||
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec | ||
from ray.rllib.env.single_agent_episode import SingleAgentEpisode | ||
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch | ||
|
@@ -57,23 +53,21 @@ class OfflinePreLearner: | |
This class is an essential part of the new `Offline RL API` of `RLlib`. | ||
It is a callable class that is run in `ray.data.Dataset.map_batches` | ||
when iterating over batches for training. It's basic function is to | ||
convert data in batch from rows to episodes (`SingleAGentEpisode`s | ||
convert data in batch from rows to episodes (`SingleAgentEpisode`s | ||
for now) and to then run the learner connector pipeline to convert | ||
further to trainable batches. These batches are used directly in the | ||
`Learner`'s `update` method. | ||
|
||
The main reason to run these transformations inside of `map_batches` | ||
is for better performance. Batches can be pre-fetched in `ray.data` | ||
and therefore batch trransformation can be run highly parallelized to | ||
and therefore batch transformation can be run highly parallelized to | ||
the `Learner''s `update`. | ||
|
||
This class can be overridden to implement custom logic for transforming | ||
batches and make them 'Learner'-ready. When deriving from this class | ||
the `__call__` method and `_map_to_episodes` can be overridden to induce | ||
custom logic for the complete transformation pipeline (`__call__`) or | ||
for converting to episodes only ('_map_to_episodes`). For an example | ||
how this class can be used to also compute values and advantages see | ||
`rllib.algorithm.marwil.marwil_prelearner.MAWRILOfflinePreLearner`. | ||
for converting to episodes only ('_map_to_episodes`). | ||
|
||
Custom `OfflinePreLearner` classes can be passed into | ||
`AlgorithmConfig.offline`'s `prelearner_class`. The `OfflineData` class | ||
|
@@ -84,46 +78,19 @@ class OfflinePreLearner: | |
def __init__( | ||
self, | ||
config: "AlgorithmConfig", | ||
learner: Union[Learner, list[ActorHandle]], | ||
spaces: Optional[Tuple[gym.Space, gym.Space]] = None, | ||
locality_hints: Optional[list] = None, | ||
module_spec: Optional[MultiRLModuleSpec] = None, | ||
module_state: Optional[Dict[ModuleID, Any]] = None, | ||
): | ||
|
||
self.config = config | ||
self.input_read_episodes = self.config.input_read_episodes | ||
self.input_read_sample_batches = self.config.input_read_sample_batches | ||
# We need this learner to run the learner connector pipeline. | ||
# If it is a `Learner` instance, the `Learner` is local. | ||
if isinstance(learner, Learner): | ||
self._learner = learner | ||
self.learner_is_remote = False | ||
self._module = self._learner._module | ||
# Otherwise we have remote `Learner`s. | ||
else: | ||
# TODO (simon): Check with the data team how to get at | ||
# initialization the data block location. | ||
node_id = ray.get_runtime_context().get_node_id() | ||
# Shuffle indices such that not each data block syncs weights | ||
# with the same learner in case there are multiple learners | ||
# on the same node like the `PreLearner`. | ||
indices = list(range(len(locality_hints))) | ||
random.shuffle(indices) | ||
locality_hints = [locality_hints[i] for i in indices] | ||
learner = [learner[i] for i in indices] | ||
# Choose a learner from the same node. | ||
for i, hint in enumerate(locality_hints): | ||
if hint == node_id: | ||
self._learner = learner[i] | ||
# If no learner has been chosen, there is none on the same node. | ||
if not self._learner: | ||
# Then choose a learner randomly. | ||
self._learner = learner[random.randint(0, len(learner) - 1)] | ||
self.learner_is_remote = True | ||
# Build the module from spec. Note, this will be a MultiRLModule. | ||
self._module = module_spec.build() | ||
self._module.set_state(module_state) | ||
|
||
# Build the module from spec. Note, this will be a MultiRLModule. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice simplification! |
||
# TODO (simon): Check, if this builds automatically on GPU if | ||
# available. | ||
self._module = module_spec.build() | ||
self._module.set_state(module_state) | ||
|
||
# Store the observation and action space if defined, otherwise we | ||
# set them to `None`. Note, if `None` the `convert_from_jsonable` | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesooomeee!!! This is so cool!