diff --git a/sumo_rl/__init__.py b/sumo_rl/__init__.py index fcad983d..804c0204 100755 --- a/sumo_rl/__init__.py +++ b/sumo_rl/__init__.py @@ -1,3 +1,3 @@ -from sumo_rl.environment.env import SumoEnvironment +from sumo_rl.environment.env import SumoEnvironment, TrafficSignal from sumo_rl.environment.env import env, parallel_env from sumo_rl.environment.resco_envs import grid4x4, arterial4x4, ingolstadt1, ingolstadt7, ingolstadt21, cologne1, cologne3, cologne8 \ No newline at end of file diff --git a/sumo_rl/environment/env.py b/sumo_rl/environment/env.py index fd32ac7b..5f895572 100755 --- a/sumo_rl/environment/env.py +++ b/sumo_rl/environment/env.py @@ -50,6 +50,7 @@ class SumoEnvironment(gym.Env): :param max_green: (int) Max green time in a phase :single_agent: (bool) If true, it behaves like a regular gym.Env. Else, it behaves like a MultiagentEnv (https://github.com/ray-project/ray/blob/master/python/ray/rllib/env/multi_agent_env.py) :reward_fn: (str/function/dict) String with the name of the reward function used by the agents, a reward function, or dictionary with reward functions assigned to individual traffic lights by their keys + :observation_fn: (str/function) String with the name of the observation function or a callable observation function itself :add_system_info: (bool) If true, it computes system metrics (total queue, total waiting time, average speed) in the info dictionary :add_per_agent_info: (bool) If true, it computes per-agent (per-traffic signal) metrics (average accumulated waiting time, average queue) in the info dictionary :sumo_seed: (int/string) Random seed for sumo. If 'random' it uses a randomly chosen seed. @@ -82,6 +83,7 @@ def __init__( max_green: int = 50, single_agent: bool = False, reward_fn: Union[str,Callable,dict] = 'diff-waiting-time', + observation_fn: Union[str,Callable] = 'default', add_system_info: bool = True, add_per_agent_info: bool = True, sumo_seed: Union[str,int] = 'random', @@ -133,33 +135,28 @@ def __init__( traci.start([sumolib.checkBinary('sumo'), '-n', self._net], label='init_connection'+self.label) conn = traci.getConnection('init_connection'+self.label) self.ts_ids = list(conn.trafficlight.getIDList()) + self.observation_fn = observation_fn if isinstance(self.reward_fn, dict): - self.traffic_signals = dict() - for key, reward_fn_value in self.reward_fn.items(): - self.traffic_signals[key] = TrafficSignal( - self, - key, - self.delta_time, - self.yellow_time, - self.min_green, - self.max_green, - self.begin_time, - reward_fn_value, - conn - ) + self.traffic_signals = {ts: TrafficSignal(self, + ts, + self.delta_time, + self.yellow_time, + self.min_green, + self.max_green, + self.begin_time, + self.reward_fn[ts], + conn) for ts in self.reward_fn.keys()} else: - self.traffic_signals = { - ts: TrafficSignal(self, - ts, - self.delta_time, - self.yellow_time, - self.min_green, - self.max_green, - self.begin_time, - self.reward_fn, - conn) for ts in self.ts_ids - } + self.traffic_signals = {ts: TrafficSignal(self, + ts, + self.delta_time, + self.yellow_time, + self.min_green, + self.max_green, + self.begin_time, + self.reward_fn, + conn) for ts in self.ts_ids} conn.close() @@ -223,19 +220,15 @@ def reset(self, seed: Optional[int] = None, **kwargs): self._start_simulation() if isinstance(self.reward_fn, dict): - self.traffic_signals = dict() - for key, reward_fn_value in self.reward_fn.items(): - self.traffic_signals[key] = TrafficSignal( - self, - key, - self.delta_time, - self.yellow_time, - self.min_green, - self.max_green, - self.begin_time, - reward_fn_value, - self.sumo - ) + self.traffic_signals = {ts: TrafficSignal(self, + ts, + self.delta_time, + self.yellow_time, + self.min_green, + self.max_green, + self.begin_time, + self.reward_fn[ts], + self.sumo) for ts in self.reward_fn.keys()} else: self.traffic_signals = {ts: TrafficSignal(self, ts, diff --git a/sumo_rl/environment/traffic_signal.py b/sumo_rl/environment/traffic_signal.py index def983dc..60d2fab1 100755 --- a/sumo_rl/environment/traffic_signal.py +++ b/sumo_rl/environment/traffic_signal.py @@ -54,6 +54,20 @@ def __init__(self, self.reward_fn = reward_fn self.sumo = sumo + if type(self.reward_fn) is str: + if self.reward_fn in TrafficSignal.reward_fns.keys(): + self.reward_fn = TrafficSignal.reward_fns[self.reward_fn] + else: + raise NotImplementedError(f'Reward function {self.reward_fn} not implemented') + + if isinstance(self.env.observation_fn, Callable): + self.observation_fn = self.env.observation_fn + else: + if self.env.observation_fn in TrafficSignal.observation_fns.keys(): + self.observation_fn = TrafficSignal.observation_fns[self.env.observation_fn] + else: + raise NotImplementedError(f'Observation function {self.env.observation_fn} not implemented') + self.build_phases() self.lanes = list(dict.fromkeys(self.sumo.trafficlight.getControlledLanes(self.id))) # Remove duplicates and keep order @@ -134,27 +148,10 @@ def set_next_phase(self, new_phase): self.time_since_last_phase_change = 0 def compute_observation(self): - phase_id = [1 if self.green_phase == i else 0 for i in range(self.num_green_phases)] # one-hot encoding - min_green = [0 if self.time_since_last_phase_change < self.min_green + self.yellow_time else 1] - density = self.get_lanes_density() - queue = self.get_lanes_queue() - observation = np.array(phase_id + min_green + density + queue, dtype=np.float32) - return observation + return self.observation_fn(self) def compute_reward(self): - if type(self.reward_fn) is str: - if self.reward_fn == 'diff-waiting-time': - self.last_reward = self._diff_waiting_time_reward() - elif self.reward_fn == 'average-speed': - self.last_reward = self._average_speed_reward() - elif self.reward_fn == 'queue': - self.last_reward = self._queue_reward() - elif self.reward_fn == 'pressure': - self.last_reward = self._pressure_reward() - else: - raise NotImplementedError(f'Reward function {self.reward_fn} not implemented') - else: - self.last_reward = self.reward_fn(self) + self.last_reward = self.reward_fn(self) return self.last_reward def _pressure_reward(self): @@ -172,6 +169,14 @@ def _diff_waiting_time_reward(self): self.last_measure = ts_wait return reward + def _observation_fn_default(self): + phase_id = [1 if self.green_phase == i else 0 for i in range(self.num_green_phases)] # one-hot encoding + min_green = [0 if self.time_since_last_phase_change < self.min_green + self.yellow_time else 1] + density = self.get_lanes_density() + queue = self.get_lanes_queue() + observation = np.array(phase_id + min_green + density + queue, dtype=np.float32) + return observation + def get_accumulated_waiting_time_per_lane(self): wait_time_per_lane = [] for lane in self.lanes: @@ -220,3 +225,28 @@ def _get_veh_list(self): for lane in self.lanes: veh_list += self.sumo.lane.getLastStepVehicleIDs(lane) return veh_list + + @classmethod + def register_reward_fn(cls, fn): + if fn.__name__ in cls.reward_fns.keys(): + raise KeyError(f'Reward function {fn.__name__} already exists') + + cls.reward_fns[fn.__name__] = fn + + @classmethod + def register_observation_fn(cls, fn): + if fn.__name__ in cls.observation_fns.keys(): + raise KeyError(f'Observation function {fn.__name__} already exists') + + cls.observation_fns[fn.__name__] = fn + + reward_fns = { + 'diff-waiting-time': _diff_waiting_time_reward, + 'average-speed': _average_speed_reward, + 'queue': _queue_reward, + 'pressure': _pressure_reward + } + + observation_fns = { + 'default': _observation_fn_default + } \ No newline at end of file