Skip to content

Commit

Permalink
improve code style
Browse files Browse the repository at this point in the history
  • Loading branch information
muchvo committed May 5, 2024
1 parent 960ef39 commit fd48922
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 5 deletions.
4 changes: 3 additions & 1 deletion omnisafe/adapter/crabs_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from omnisafe.adapter.offpolicy_adapter import OffPolicyAdapter
from omnisafe.common.buffer import VectorOffPolicyBuffer
from omnisafe.common.control_barrier_function.crabs.models import MeanPolicy
from omnisafe.common.logger import Logger
from omnisafe.envs.crabs_env import CRABSEnv
from omnisafe.models.actor_critic.constraint_actor_q_critic import ConstraintActorQCritic
Expand Down Expand Up @@ -55,14 +56,15 @@ def __init__( # pylint: disable=too-many-arguments
"""Initialize a instance of :class:`CRABSAdapter`."""
super().__init__(env_id, num_envs, seed, cfgs)
self._env: CRABSEnv
self._eval_env: CRABSEnv
self.n_expl_episodes = 0
self._max_ep_len = self._env.env.spec.max_episode_steps # type: ignore
self.horizon = self._max_ep_len

def eval_policy( # pylint: disable=too-many-locals
self,
episode: int,
agent: ConstraintActorQCritic,
agent: ConstraintActorQCritic | MeanPolicy,
logger: Logger,
) -> None:
"""Rollout the environment with deterministic agent action.
Expand Down
2 changes: 1 addition & 1 deletion omnisafe/algorithms/off_policy/crabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def learn(self):
eval_start = time.time()
self._env.eval_policy(
episode=self._cfgs.train_cfgs.raw_policy_episodes,
agent=self.mean_policy, # type: ignore
agent=self.mean_policy,
logger=self._logger,
)
eval_time += time.time() - eval_start
Expand Down
5 changes: 4 additions & 1 deletion omnisafe/common/control_barrier_function/crabs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Utils for CRABS."""
# pylint: disable=all
import os
from typing import List, Union

import pytorch_lightning as pl
import requests
Expand Down Expand Up @@ -129,7 +130,7 @@ def get_pretrained_model(model_path, model_url, device):


def create_model_and_trainer(cfgs, dim_state, dim_action, normalizer, device):
def make_model(i, model_type) -> TransitionModel:
def make_model(i, model_type) -> nn.Module:
if model_type == 'GatedTransitionModel':
return GatedTransitionModel(
dim_state,
Expand All @@ -153,6 +154,8 @@ def make_model(i, model_type) -> TransitionModel:

model = EnsembleModel(models).to(device)

devices: Union[List[int], int]

if str(device).startswith('cuda'):
accelerator = 'gpu'
devices = [int(str(device)[-1])]
Expand Down
11 changes: 9 additions & 2 deletions omnisafe/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,12 @@ def __load_model_and_env(

def _init_crabs(self, model_params: dict) -> None:
mean_policy = MeanPolicy(self._actor)
assert self._env is not None, 'The environment must be provided or created.'
assert self._actor is not None, 'The actor must be provided or created.'
assert (
self._env.observation_space.shape is not None
), 'The observation space does not exist.'
assert self._env.action_space.shape is not None, 'The action space does not exist.'
normalizer = CRABSNormalizer(self._env.observation_space.shape[0], clip=1000).to(
torch.device('cpu'),
)
Expand All @@ -327,7 +333,8 @@ def _init_crabs(self, model_params: dict) -> None:
normalizer,
MultiLayerPerceptron([self._env.observation_space.shape[0], 256, 256, 1]),
),
self._env._env.env.barrier_fn, # pylint: disable=protected-access
# pylint: disable-next=protected-access
self._env._env.env.barrier_fn, # type: ignore
s0,
self._cfgs.lyapunov,
).to(torch.device('cpu'))
Expand All @@ -342,7 +349,7 @@ def _init_crabs(self, model_params: dict) -> None:
),
core,
)
self._actor.predict = self._actor.step
self._actor.predict = self._actor.step # type: ignore

# pylint: disable-next=too-many-locals
def load_saved(
Expand Down

0 comments on commit fd48922

Please sign in to comment.