Skip to content

Commit

Permalink
[Feature] Replay buffer creation is unified
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <[email protected]>
  • Loading branch information
matteobettini committed Oct 5, 2023
1 parent 87b2f6d commit fd5222d
Show file tree
Hide file tree
Showing 10 changed files with 30 additions and 212 deletions.
31 changes: 12 additions & 19 deletions benchmarl/algorithms/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
from torchrl.data import (
CompositeSpec,
DiscreteTensorSpec,
LazyTensorStorage,
OneHotDiscreteTensorSpec,
ReplayBuffer,
TensorDictReplayBuffer,
)
from torchrl.data.replay_buffers import RandomSampler, SamplerWithoutReplacement
from torchrl.objectives import LossModule
from torchrl.objectives.utils import HardUpdate, SoftUpdate, TargetNetUpdater

Expand Down Expand Up @@ -123,14 +126,15 @@ def get_replay_buffer(
self,
group: str,
) -> ReplayBuffer:
return self._get_replay_buffer(
group=group,
memory_size=self.experiment_config.replay_buffer_memory_size(
self.on_policy
),
sampling_size=self.experiment_config.train_minibatch_size(self.on_policy),
traj_len=self.experiment_config.traj_len,
storing_device=self.device,
memory_size = self.experiment_config.replay_buffer_memory_size(self.on_policy)
sampling_size = self.experiment_config.train_minibatch_size(self.on_policy)
storing_device = self.device
sampler = SamplerWithoutReplacement() if self.on_policy else RandomSampler()

return TensorDictReplayBuffer(
storage=LazyTensorStorage(memory_size, device=storing_device),
sampler=sampler,
batch_size=sampling_size,
)

def get_policy_for_loss(self, group: str) -> TensorDictModule:
Expand Down Expand Up @@ -200,17 +204,6 @@ def _get_policy_for_collection(
) -> TensorDictModule:
raise NotImplementedError

@abstractmethod
def _get_replay_buffer(
self,
group: str,
memory_size: int,
sampling_size: int,
traj_len: int,
storing_device: DEVICE_TYPING,
) -> ReplayBuffer:
raise NotImplementedError

@abstractmethod
def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:
raise NotImplementedError
Expand Down
23 changes: 2 additions & 21 deletions benchmarl/algorithms/iddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,13 @@
import torch
from tensordict import TensorDictBase
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.data import (
CompositeSpec,
ReplayBuffer,
TensorDictReplayBuffer,
UnboundedContinuousTensorSpec,
)
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.modules import AdditiveGaussianWrapper, ProbabilisticActor, TanhDelta
from torchrl.objectives import DDPGLoss, LossModule, ValueEstimators

from benchmarl.algorithms.common import Algorithm, AlgorithmConfig
from benchmarl.models.common import ModelConfig
from benchmarl.utils import DEVICE_TYPING, read_yaml_config
from benchmarl.utils import read_yaml_config


class Iddpg(Algorithm):
Expand All @@ -33,19 +27,6 @@ def __init__(
# Overridden abstract methods
#############################

def _get_replay_buffer(
self,
group: str,
memory_size: int,
sampling_size: int,
traj_len: int,
storing_device: DEVICE_TYPING,
) -> ReplayBuffer:
return TensorDictReplayBuffer(
storage=LazyTensorStorage(memory_size, device=storing_device),
batch_size=sampling_size,
)

def _get_loss(
self, group: str, policy_for_loss: TensorDictModule, continuous: bool
) -> Tuple[LossModule, bool]:
Expand Down
25 changes: 2 additions & 23 deletions benchmarl/algorithms/ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,14 @@
from tensordict.nn import TensorDictModule, TensorDictSequential
from tensordict.nn.distributions import NormalParamExtractor
from torch.distributions import Categorical
from torchrl.data import (
CompositeSpec,
ReplayBuffer,
TensorDictReplayBuffer,
UnboundedContinuousTensorSpec,
)
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.modules import ProbabilisticActor, TanhNormal
from torchrl.modules.distributions import MaskedCategorical
from torchrl.objectives import ClipPPOLoss, LossModule, ValueEstimators

from benchmarl.algorithms.common import Algorithm, AlgorithmConfig
from benchmarl.models.common import ModelConfig
from benchmarl.utils import DEVICE_TYPING, read_yaml_config
from benchmarl.utils import read_yaml_config


class Ippo(Algorithm):
Expand All @@ -47,20 +40,6 @@ def __init__(
# Overridden abstract methods
#############################

def _get_replay_buffer(
self,
group: str,
memory_size: int,
sampling_size: int,
traj_len: int,
storing_device: DEVICE_TYPING,
) -> ReplayBuffer:
return TensorDictReplayBuffer(
storage=LazyTensorStorage(memory_size, device=storing_device),
sampler=SamplerWithoutReplacement(),
batch_size=sampling_size,
)

def _get_loss(
self, group: str, policy_for_loss: TensorDictModule, continuous: bool
) -> Tuple[LossModule, bool]:
Expand Down
23 changes: 2 additions & 21 deletions benchmarl/algorithms/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,13 @@

from tensordict import TensorDictBase
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.data import (
CompositeSpec,
ReplayBuffer,
TensorDictReplayBuffer,
UnboundedContinuousTensorSpec,
)
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.modules import EGreedyModule, QValueModule
from torchrl.objectives import DQNLoss, LossModule, ValueEstimators

from benchmarl.algorithms.common import Algorithm, AlgorithmConfig
from benchmarl.models.common import ModelConfig
from benchmarl.utils import DEVICE_TYPING, read_yaml_config
from benchmarl.utils import read_yaml_config


class Iql(Algorithm):
Expand All @@ -29,19 +23,6 @@ def __init__(self, delay_value: bool, loss_function: str, **kwargs):
# Overridden abstract methods
#############################

def _get_replay_buffer(
self,
group: str,
memory_size: int,
sampling_size: int,
traj_len: int,
storing_device: DEVICE_TYPING,
) -> ReplayBuffer:
return TensorDictReplayBuffer(
storage=LazyTensorStorage(memory_size, device=storing_device),
batch_size=sampling_size,
)

def _get_loss(
self, group: str, policy_for_loss: TensorDictModule, continuous: bool
) -> Tuple[LossModule, bool]:
Expand Down
23 changes: 2 additions & 21 deletions benchmarl/algorithms/isac.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,7 @@
from tensordict import TensorDictBase
from tensordict.nn import NormalParamExtractor, TensorDictModule, TensorDictSequential
from torch.distributions import Categorical
from torchrl.data import (
CompositeSpec,
ReplayBuffer,
TensorDictReplayBuffer,
UnboundedContinuousTensorSpec,
)
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.modules import MaskedCategorical, ProbabilisticActor, TanhNormal
from torchrl.objectives import (
ClipPPOLoss,
Expand All @@ -23,7 +17,7 @@

from benchmarl.algorithms.common import Algorithm, AlgorithmConfig
from benchmarl.models.common import ModelConfig
from benchmarl.utils import DEVICE_TYPING, read_yaml_config
from benchmarl.utils import read_yaml_config


class Isac(Algorithm):
Expand Down Expand Up @@ -56,19 +50,6 @@ def __init__(
# Overridden abstract methods
#############################

def _get_replay_buffer(
self,
group: str,
memory_size: int,
sampling_size: int,
traj_len: int,
storing_device: DEVICE_TYPING,
) -> ReplayBuffer:
return TensorDictReplayBuffer(
storage=LazyTensorStorage(memory_size, device=storing_device),
batch_size=sampling_size,
)

def _get_loss(
self, group: str, policy_for_loss: TensorDictModule, continuous: bool
) -> Tuple[LossModule, bool]:
Expand Down
23 changes: 2 additions & 21 deletions benchmarl/algorithms/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,13 @@
import torch
from tensordict import TensorDictBase
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.data import (
CompositeSpec,
ReplayBuffer,
TensorDictReplayBuffer,
UnboundedContinuousTensorSpec,
)
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.modules import AdditiveGaussianWrapper, ProbabilisticActor, TanhDelta
from torchrl.objectives import DDPGLoss, LossModule, ValueEstimators

from benchmarl.algorithms.common import Algorithm, AlgorithmConfig
from benchmarl.models.common import ModelConfig
from benchmarl.utils import DEVICE_TYPING, read_yaml_config
from benchmarl.utils import read_yaml_config


class Maddpg(Algorithm):
Expand All @@ -33,19 +27,6 @@ def __init__(
# Overridden abstract methods
#############################

def _get_replay_buffer(
self,
group: str,
memory_size: int,
sampling_size: int,
traj_len: int,
storing_device: DEVICE_TYPING,
) -> ReplayBuffer:
return TensorDictReplayBuffer(
storage=LazyTensorStorage(memory_size, device=storing_device),
batch_size=sampling_size,
)

def _get_loss(
self, group: str, policy_for_loss: TensorDictModule, continuous: bool
) -> Tuple[LossModule, bool]:
Expand Down
25 changes: 2 additions & 23 deletions benchmarl/algorithms/mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,13 @@
from tensordict.nn import TensorDictModule, TensorDictSequential
from tensordict.nn.distributions import NormalParamExtractor
from torch.distributions import Categorical
from torchrl.data import (
CompositeSpec,
ReplayBuffer,
TensorDictReplayBuffer,
UnboundedContinuousTensorSpec,
)
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.modules import MaskedCategorical, ProbabilisticActor, TanhNormal
from torchrl.objectives import ClipPPOLoss, LossModule, ValueEstimators

from benchmarl.algorithms.common import Algorithm, AlgorithmConfig
from benchmarl.models.common import ModelConfig
from benchmarl.utils import DEVICE_TYPING, read_yaml_config
from benchmarl.utils import read_yaml_config


class Mappo(Algorithm):
Expand All @@ -46,20 +39,6 @@ def __init__(
# Overridden abstract methods
#############################

def _get_replay_buffer(
self,
group: str,
memory_size: int,
sampling_size: int,
traj_len: int,
storing_device: DEVICE_TYPING,
) -> ReplayBuffer:
return TensorDictReplayBuffer(
storage=LazyTensorStorage(memory_size, device=storing_device),
sampler=SamplerWithoutReplacement(),
batch_size=sampling_size,
)

def _get_loss(
self, group: str, policy_for_loss: TensorDictModule, continuous: bool
) -> Tuple[LossModule, bool]:
Expand Down
23 changes: 2 additions & 21 deletions benchmarl/algorithms/masac.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,13 @@
from tensordict import TensorDictBase
from tensordict.nn import NormalParamExtractor, TensorDictModule, TensorDictSequential
from torch.distributions import Categorical
from torchrl.data import (
CompositeSpec,
ReplayBuffer,
TensorDictReplayBuffer,
UnboundedContinuousTensorSpec,
)
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.modules import MaskedCategorical, ProbabilisticActor, TanhNormal
from torchrl.objectives import DiscreteSACLoss, LossModule, SACLoss, ValueEstimators

from benchmarl.algorithms.common import Algorithm, AlgorithmConfig
from benchmarl.models.common import ModelConfig
from benchmarl.utils import DEVICE_TYPING, read_yaml_config
from benchmarl.utils import read_yaml_config


class Masac(Algorithm):
Expand Down Expand Up @@ -50,19 +44,6 @@ def __init__(
# Overridden abstract methods
#############################

def _get_replay_buffer(
self,
group: str,
memory_size: int,
sampling_size: int,
traj_len: int,
storing_device: DEVICE_TYPING,
) -> ReplayBuffer:
return TensorDictReplayBuffer(
storage=LazyTensorStorage(memory_size, device=storing_device),
batch_size=sampling_size,
)

def _get_loss(
self, group: str, policy_for_loss: TensorDictModule, continuous: bool
) -> Tuple[LossModule, bool]:
Expand Down
23 changes: 2 additions & 21 deletions benchmarl/algorithms/qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,13 @@

from tensordict import TensorDictBase
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.data import (
CompositeSpec,
ReplayBuffer,
TensorDictReplayBuffer,
UnboundedContinuousTensorSpec,
)
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.modules import EGreedyModule, QMixer, QValueModule
from torchrl.objectives import LossModule, QMixerLoss, ValueEstimators

from benchmarl.algorithms.common import Algorithm, AlgorithmConfig
from benchmarl.models.common import ModelConfig
from benchmarl.utils import DEVICE_TYPING, read_yaml_config
from benchmarl.utils import read_yaml_config


class Qmix(Algorithm):
Expand All @@ -32,19 +26,6 @@ def __init__(
# Overridden abstract methods
#############################

def _get_replay_buffer(
self,
group: str,
memory_size: int,
sampling_size: int,
traj_len: int,
storing_device: DEVICE_TYPING,
) -> ReplayBuffer:
return TensorDictReplayBuffer(
storage=LazyTensorStorage(memory_size, device=storing_device),
batch_size=sampling_size,
)

def _get_loss(
self, group: str, policy_for_loss: TensorDictModule, continuous: bool
) -> Tuple[LossModule, bool]:
Expand Down
Loading

0 comments on commit fd5222d

Please sign in to comment.