Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
muchvo committed Apr 22, 2024
1 parent 8f8b459 commit 5e1b3e3
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 11 deletions.
37 changes: 37 additions & 0 deletions omnisafe/common/control_barrier_function/crabs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def __init__(
auto_squeeze=True,
output_activation=None,
) -> None:
"""Initialize the multi-layer perceptron."""
layers = [] # type: ignore
for in_features, out_features in zip(n_units[:-1], n_units[1:]):
if layers:
Expand Down Expand Up @@ -216,6 +217,11 @@ def __init__(self, dim_state, normalizer, n_units, cfgs, *, name='') -> None:
self.automatic_optimization = False

def init_cfgs(self, cfgs):
"""Initialize the configuration.
Args:
cfgs: The configurations.
"""
self.batch_size = cfgs.batch_size
self.weight_decay = cfgs.weight_decay
self.lr = cfgs.lr
Expand Down Expand Up @@ -336,6 +342,11 @@ def __init__(self, h, model: EnsembleModel, policy: ConstraintActorQCritic, cfgs
self.init_cfgs(cfgs)

def init_cfgs(self, cfgs):
"""Initialize the configuration.
Args:
cfgs: The configurations.
"""
self.eps = cfgs.obj.eps
self.neg_coef = cfgs.obj.neg_coef

Expand All @@ -356,6 +367,14 @@ def u(self, states, actions=None):
return all_nh.max(dim=0).values

def obj_eval(self, s):
"""Short cut for barrier function.
Args:
s: The states.
Returns:
dict: The results of the barrier function.
"""
h = self.h(s)
u = self.u(s)

Expand All @@ -378,6 +397,7 @@ class GatedTransitionModel(TransitionModel):
"""Gated transition model for dynamics."""

def __init__(self, *args, **kwargs) -> None:
"""Initialize the gated transition model."""
super().__init__(*args, **kwargs)
self.gate_net = MultiLayerPerceptron(
[self.dim_state + self.dim_action, 256, 256, self.dim_state * 2],
Expand Down Expand Up @@ -410,6 +430,11 @@ class BasePolicy(abc.ABC):

@abc.abstractmethod
def get_actions(self, states):
"""Sample the actions.
Args:
states (torch.Tensor): The states.
"""
pass


Expand All @@ -422,6 +447,7 @@ class ExplorationPolicy(nn.Module, BasePolicy):
"""

def __init__(self, policy, core: CrabsCore) -> None:
"""Initialize the exploration policy."""
super().__init__()
self.policy = policy
self.crabs = core
Expand Down Expand Up @@ -482,6 +508,14 @@ def step(self, obs: torch.Tensor, deterministic: bool = False) -> torch.Tensor:
return self.forward(obs)

def get_actions(self, states):
"""Sample the actions.
Args:
states (torch.Tensor): The states.
Returns:
torch.Tensor: The sampled actions.
"""
return self(states)


Expand Down Expand Up @@ -523,6 +557,7 @@ class MeanPolicy(DetNetPolicy):
"""

def __init__(self, policy) -> None:
"""Initialize the mean policy."""
super().__init__()
self.policy = policy

Expand Down Expand Up @@ -562,6 +597,7 @@ class AddGaussianNoise(NetPolicy):
"""

def __init__(self, policy: NetPolicy, mean, std) -> None:
"""Initialize the policy with Gaussian noise."""
super().__init__()
self.policy = policy
self.mean = mean
Expand Down Expand Up @@ -591,6 +627,7 @@ class UniformPolicy(NetPolicy):
"""

def __init__(self, dim_action) -> None:
"""Initialize the uniform policy."""
super().__init__()
self.dim_action = dim_action

Expand Down
37 changes: 26 additions & 11 deletions omnisafe/common/control_barrier_function/crabs/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class Barrier(nn.Module):
"""

def __init__(self, net, env_barrier_fn, s0, cfgs) -> None:
"""Initialize the barrier function."""
super().__init__()
self.net = net
self.env_barrier_fn = env_barrier_fn
Expand All @@ -63,20 +64,19 @@ def forward(self, states: torch.Tensor) -> torch.Tensor:


class StateBox:
"""State box for the environment.
Args:
shape (Tuple): Shape of the state box.
s0 (torch.Tensor): Initial state.
device (torch.device): Device to run the state box.
expansion (float, optional): Expansion factor for the state box. Defaults to 1.5.
logger ([type], optional): Logger for the state box. Defaults to None.
"""
INF = 1e10

def __init__(self, shape, s0, device, expansion=1.5, logger=None) -> None:
"""State box for the environment.
This class find a box such that the given barrier function is negative inside the box.
Args:
shape (Tuple): Shape of the state box.
s0 (torch.Tensor): Initial state.
device (torch.device): Device to run the state box.
expansion (float, optional): Expansion factor for the state box. Defaults to 1.5.
logger ([type], optional): Logger for the state box. Defaults to None.
"""
"""Initialize the state box."""
self._max = torch.full(shape, -self.INF, device=device)
self._min = torch.full(shape, +self.INF, device=device)
self.center = None
Expand Down Expand Up @@ -162,6 +162,7 @@ class SLangevinOptimizer(nn.Module):
"""

def __init__(self, core: CrabsCore, state_box: StateBox, device, cfgs, logger) -> None:
"""Initialize the optimizer."""
super().__init__()
self.core = core
self.state_box = state_box
Expand Down Expand Up @@ -198,6 +199,11 @@ def __init__(self, core: CrabsCore, state_box: StateBox, device, cfgs, logger) -
self.reinit()

def init_cfgs(self, cfgs):
"""Initialize the configuration.
Args:
cfgs: Configuration for the optimization.
"""
self.temperature = cfgs.temperature

self.filter = cfgs.filter
Expand Down Expand Up @@ -374,6 +380,7 @@ def __init__(
state_box: StateBox,
logger=None,
) -> None:
"""Initialize the optimizer."""
super().__init__()
self.obj_eval = obj_eval
self.s = nn.Parameter(torch.randn(100_000, *state_box.shape), requires_grad=False)
Expand Down Expand Up @@ -407,6 +414,7 @@ def __init__(
state_box: StateBox,
logger=None,
) -> None:
"""Initialize the optimizer."""
super().__init__()
self.obj_eval = obj_eval
self.z = nn.Parameter(torch.randn(10000, *state_box.shape), requires_grad=True)
Expand Down Expand Up @@ -469,6 +477,7 @@ class PolicyAdvTraining:
"""

def __init__(self, policy, s_opt, obj_eval, cfgs) -> None:
"""Initialize the optimizer."""
self.policy = policy
self.s_opt = s_opt
self.obj_eval = obj_eval
Expand Down Expand Up @@ -549,6 +558,7 @@ def __init__(
cfgs=None,
logger=None,
) -> None:
"""Initialize the optimizer."""
super().__init__()
self.h = h
self.obj_eval = obj_eval
Expand All @@ -573,6 +583,11 @@ def __init__(
self.opt = torch.optim.Adam(self.h.parameters(), lr=self.lr, weight_decay=self.weight_decay)

def init_cfgs(self, cfgs):
"""Initialize the configuration.
Args:
cfgs: Configuration for the optimizer.
"""
self.weight_decay = cfgs.weight_decay
self.lr = cfgs.lr
self.lambda_2 = cfgs.lambda_2
Expand Down
1 change: 1 addition & 0 deletions omnisafe/common/control_barrier_function/crabs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class Normalizer(nn.Module):
"""

def __init__(self, dim, *, clip=10) -> None:
"""Initialize the normalizer."""
super().__init__()
self.register_buffer('mean', torch.zeros(dim))
self.register_buffer('std', torch.ones(dim))
Expand Down
Loading

0 comments on commit 5e1b3e3

Please sign in to comment.