Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Environment] MAgent2 #137

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
9 changes: 9 additions & 0 deletions benchmarl/conf/task/magent/adversarial_pursuit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
defaults:
- adversarial_pursuit_config
- _self_

map_size: 45
minimap_mode: False
tag_penalty: -0.2
max_cycles: 500
extra_features: False
3 changes: 2 additions & 1 deletion benchmarl/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from .pettingzoo.common import PettingZooTask
from .smacv2.common import Smacv2Task
from .vmas.common import VmasTask
from .magent.common import MAgentTask

# The enum classes for the environments available.
# This is the only object in this file you need to modify when adding a new environment.
tasks = [VmasTask, Smacv2Task, PettingZooTask, MeltingPotTask]
tasks = [VmasTask, Smacv2Task, PettingZooTask, MeltingPotTask, MAgentTask]

# This is a registry mapping "envname/task_name" to the EnvNameTask.TASK_NAME enum
# It is used by automatically load task enums from yaml files.
Expand Down
Empty file.
16 changes: 16 additions & 0 deletions benchmarl/environments/magent/adversarial_pursuit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

from dataclasses import dataclass, MISSING


@dataclass
class TaskConfig:
map_size: int = MISSING
minimap_mode: bool = MISSING
tag_penalty: float = MISSING
max_cycles: int = MISSING
extra_features: bool = MISSING
130 changes: 130 additions & 0 deletions benchmarl/environments/magent/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

from typing import Callable, Dict, List, Optional

from torchrl.data import CompositeSpec
from torchrl.envs import EnvBase, PettingZooWrapper

from benchmarl.environments.common import Task

from benchmarl.utils import DEVICE_TYPING


class MAgentTask(Task):
"""Enum for MAgent2 tasks."""

ADVERSARIAL_PURSUIT = None
# BATTLE = None
# BATTLEFIELD = None
# COMBINED_ARMS = None
# GATHER = None
# TIGER_DEER = None

def get_env_fun(
self,
num_envs: int,
continuous_actions: bool,
seed: Optional[int],
device: DEVICE_TYPING,
) -> Callable[[], EnvBase]:

return lambda: PettingZooWrapper(
env=self.__get_env(),
return_state=True,
seed=seed,
done_on_any=False,
use_mask=False,
device=device
)

def __get_env(self) -> EnvBase:
try:
from magent2.environments import (
adversarial_pursuit_v4,
# battle_v4,
# battlefield_v5,
# combined_arms_v6,
# gather_v5,
# tiger_deer_v4
)
except ImportError as e:
print("Module 'magent2' not found, install it using `pip install magent2`")
raise e
matteobettini marked this conversation as resolved.
Show resolved Hide resolved

envs = {
"ADVERSARIAL_PURSUIT": adversarial_pursuit_v4,
# "BATTLE": battle_v4,
# "BATTLEFIELD": battlefield_v5,
# "COMBINED_ARMS": combined_arms_v6,
# "GATHER": gather_v5,
# "TIGER_DEER": tiger_deer_v4
}
if self.name not in envs:
raise Exception(f"{self.name} is not an environment of MAgent2")
return envs[self.name].parallel_env(**self.config, render_mode="rgb_array")

def supports_continuous_actions(self) -> bool:
return False

def supports_discrete_actions(self) -> bool:
return True

def has_state(self) -> bool:
return True

def has_render(self, env: EnvBase) -> bool:
return True

def max_steps(self, env: EnvBase) -> int:
return self.config["max_cycles"]

def group_map(self, env: EnvBase) -> Dict[str, List[str]]:
return env.group_map

def state_spec(self, env: EnvBase) -> Optional[CompositeSpec]:
return CompositeSpec({"state": env.observation_spec["state"].clone()})

def action_mask_spec(self, env: EnvBase) -> Optional[CompositeSpec]:
observation_spec = env.observation_spec.clone()
for group in self.group_map(env):
group_obs_spec = observation_spec[group]
for key in list(group_obs_spec.keys()):
if key != "action_mask":
del group_obs_spec[key]
if group_obs_spec.is_empty():
del observation_spec[group]
del observation_spec["state"]
if observation_spec.is_empty():
return None
return observation_spec

def observation_spec(self, env: EnvBase) -> CompositeSpec:
observation_spec = env.observation_spec.clone()
for group in self.group_map(env):
group_obs_spec = observation_spec[group]
for key in list(group_obs_spec.keys()):
if key != "observation":
del group_obs_spec[key]
del observation_spec["state"]
return observation_spec

def info_spec(self, env: EnvBase) -> Optional[CompositeSpec]:
observation_spec = env.observation_spec.clone()
for group in self.group_map(env):
group_obs_spec = observation_spec[group]
for key in list(group_obs_spec.keys()):
if key != "info":
del group_obs_spec[key]
del observation_spec["state"]
return observation_spec

def action_spec(self, env: EnvBase) -> CompositeSpec:
return env.full_action_spec

@staticmethod
def env_name() -> str:
return "magent"