Skip to content

Commit

Permalink
Add RecurrentActorCore from dm-acme
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanluoyc committed Oct 15, 2023
1 parent 7ae027c commit 3aafc8d
Showing 1 changed file with 45 additions and 0 deletions.
45 changes: 45 additions & 0 deletions corax/agents/jax/actor_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ class ActorCore(Generic[State, Extras]):
Tuple[networks_lib.Action, types.NestedArray],
]

RecurrentPolicy = Callable[
[networks_lib.Params, PRNGKey, networks_lib.Observation, RecurrentState],
Tuple[networks_lib.Action, RecurrentState],
]

Policy = Union[FeedForwardPolicy, FeedForwardPolicyWithExtra]

Expand Down Expand Up @@ -111,3 +115,44 @@ def unvectorized_select_action(
select_action=unvectorized_select_action,
get_extras=actor_core.get_extras,
)


@chex.dataclass(frozen=True, mappable_dataclass=False)
class SimpleActorCoreRecurrentState(Generic[RecurrentState]):
rng: PRNGKey
recurrent_state: RecurrentState


def batched_recurrent_to_actor_core(
recurrent_policy: RecurrentPolicy, initial_core_state: RecurrentState
) -> ActorCore[
SimpleActorCoreRecurrentState[RecurrentState], Mapping[str, jnp.ndarray]
]:
"""Returns ActorCore for a recurrent policy."""

def select_action(
params: networks_lib.Params,
observation: networks_lib.Observation,
state: SimpleActorCoreRecurrentState[RecurrentState],
):
# TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs.
rng = state.rng
rng, policy_rng = jax.random.split(rng)
observation = utils.add_batch_dim(observation)
recurrent_state = utils.add_batch_dim(state.recurrent_state)
action, new_recurrent_state = utils.squeeze_batch_dim(
recurrent_policy(params, policy_rng, observation, recurrent_state)
)
return action, SimpleActorCoreRecurrentState(rng, new_recurrent_state)

initial_core_state = utils.squeeze_batch_dim(initial_core_state)

def init(rng: PRNGKey) -> SimpleActorCoreRecurrentState[RecurrentState]:
return SimpleActorCoreRecurrentState(rng, initial_core_state)

def get_extras(
state: SimpleActorCoreRecurrentState[RecurrentState],
) -> Mapping[str, jnp.ndarray]:
return {"core_state": state.recurrent_state}

return ActorCore(init=init, select_action=select_action, get_extras=get_extras)

0 comments on commit 3aafc8d

Please sign in to comment.