Skip to content

Commit

Permalink
Merge branch 'master' into feature/random-agent
Browse files Browse the repository at this point in the history
  • Loading branch information
MischaPanch authored Aug 24, 2024
2 parents 1eaf276 + 002ffd9 commit 21c21ab
Show file tree
Hide file tree
Showing 97 changed files with 1,976 additions and 652 deletions.
19 changes: 7 additions & 12 deletions docs/02_notebooks/L0_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,6 @@
"Before we get started, we must first install Tianshou's library and Gym environment by running the commands below. This tutorials will always keep up with the latest version of Tianshou since they also serve as a test for the latest version. If you are using an older version of Tianshou, please refer to the [documentation](https://tianshou.readthedocs.io/en/latest/) of your version.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# !pip install tianshou gym"
]
},
{
"cell_type": "markdown",
"metadata": {
Expand Down Expand Up @@ -67,7 +58,7 @@
"import gymnasium as gym\n",
"import torch\n",
"\n",
"from tianshou.data import Collector, VectorReplayBuffer\n",
"from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n",
"from tianshou.env import DummyVectorEnv\n",
"from tianshou.policy import PPOPolicy\n",
"from tianshou.trainer import OnpolicyTrainer\n",
Expand Down Expand Up @@ -114,8 +105,12 @@
")\n",
"\n",
"# collector\n",
"train_collector = Collector(policy, train_envs, VectorReplayBuffer(20000, len(train_envs)))\n",
"test_collector = Collector(policy, test_envs)\n",
"train_collector = Collector[CollectStats](\n",
" policy,\n",
" train_envs,\n",
" VectorReplayBuffer(20000, len(train_envs)),\n",
")\n",
"test_collector = Collector[CollectStats](policy, test_envs)\n",
"\n",
"# trainer\n",
"train_result = OnpolicyTrainer(\n",
Expand Down
6 changes: 3 additions & 3 deletions docs/02_notebooks/L5_Collector.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
"import gymnasium as gym\n",
"import torch\n",
"\n",
"from tianshou.data import Collector, VectorReplayBuffer\n",
"from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n",
"from tianshou.env import DummyVectorEnv\n",
"from tianshou.policy import PGPolicy\n",
"from tianshou.utils.net.common import Net\n",
Expand Down Expand Up @@ -94,7 +94,7 @@
" action_space=env.action_space,\n",
" action_scaling=False,\n",
")\n",
"test_collector = Collector(policy, test_envs)"
"test_collector = Collector[CollectStats](policy, test_envs)"
]
},
{
Expand Down Expand Up @@ -187,7 +187,7 @@
"train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(train_env_num)])\n",
"replayBuffer = VectorReplayBuffer(buffer_size, train_env_num)\n",
"\n",
"train_collector = Collector(policy, train_envs, replayBuffer)"
"train_collector = Collector[CollectStats](policy, train_envs, replayBuffer)"
]
},
{
Expand Down
21 changes: 6 additions & 15 deletions docs/02_notebooks/L6_Trainer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,8 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-06T15:34:02.969675Z",
"start_time": "2024-05-06T15:34:00.747309Z"
},
"editable": true,
"id": "do-xZ-8B7nVH",
"slideshow": {
Expand All @@ -77,7 +73,7 @@
"import gymnasium as gym\n",
"import torch\n",
"\n",
"from tianshou.data import Collector, VectorReplayBuffer\n",
"from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n",
"from tianshou.env import DummyVectorEnv\n",
"from tianshou.policy import PGPolicy\n",
"from tianshou.trainer import OnpolicyTrainer\n",
Expand All @@ -88,13 +84,8 @@
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-06T15:34:07.536452Z",
"start_time": "2024-05-06T15:34:03.636670Z"
}
},
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_env_num = 4\n",
Expand Down Expand Up @@ -131,8 +122,8 @@
"\n",
"# Create the replay buffer and the collector\n",
"replayBuffer = VectorReplayBuffer(buffer_size, train_env_num)\n",
"test_collector = Collector(policy, test_envs)\n",
"train_collector = Collector(policy, train_envs, replayBuffer)"
"test_collector = Collector[CollectStats](policy, test_envs)\n",
"train_collector = Collector[CollectStats](policy, train_envs, replayBuffer)"
]
},
{
Expand Down
6 changes: 3 additions & 3 deletions docs/02_notebooks/L7_Experiment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
"import gymnasium as gym\n",
"import torch\n",
"\n",
"from tianshou.data import Collector, VectorReplayBuffer\n",
"from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n",
"from tianshou.env import DummyVectorEnv\n",
"from tianshou.policy import PPOPolicy\n",
"from tianshou.trainer import OnpolicyTrainer\n",
Expand Down Expand Up @@ -202,12 +202,12 @@
},
"outputs": [],
"source": [
"train_collector = Collector(\n",
"train_collector = Collector[CollectStats](\n",
" policy=policy,\n",
" env=train_envs,\n",
" buffer=VectorReplayBuffer(20000, len(train_envs)),\n",
")\n",
"test_collector = Collector(policy=policy, env=test_envs)"
"test_collector = Collector[CollectStats](policy=policy, env=test_envs)"
]
},
{
Expand Down
6 changes: 6 additions & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,9 @@ autocompletion
codebase
indexable
sliceable
gaussian
logprob
monte
carlo
subclass
subclassing
8 changes: 4 additions & 4 deletions examples/atari/atari_c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from atari_network import C51
from atari_wrapper import make_atari_env

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.data import Collector, CollectStats, VectorReplayBuffer
from tianshou.highlevel.logger import LoggerFactoryDefault
from tianshou.policy import C51Policy
from tianshou.policy.base import BasePolicy
Expand Down Expand Up @@ -112,8 +112,8 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack,
)
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True)

# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
Expand Down Expand Up @@ -173,7 +173,7 @@ def watch() -> None:
save_only_last_obs=True,
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
Expand Down
8 changes: 4 additions & 4 deletions examples/atari/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from atari_network import DQN
from atari_wrapper import make_atari_env

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.data import Collector, CollectStats, VectorReplayBuffer
from tianshou.highlevel.logger import LoggerFactoryDefault
from tianshou.policy import DQNPolicy
from tianshou.policy.base import BasePolicy
Expand Down Expand Up @@ -148,8 +148,8 @@ def main(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack,
)
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True)

# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
Expand Down Expand Up @@ -215,7 +215,7 @@ def watch() -> None:
save_only_last_obs=True,
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
Expand Down
8 changes: 4 additions & 4 deletions examples/atari/atari_fqf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from atari_network import DQN
from atari_wrapper import make_atari_env

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.data import Collector, CollectStats, VectorReplayBuffer
from tianshou.highlevel.logger import LoggerFactoryDefault
from tianshou.policy import FQFPolicy
from tianshou.policy.base import BasePolicy
Expand Down Expand Up @@ -125,8 +125,8 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack,
)
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True)

# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
Expand Down Expand Up @@ -186,7 +186,7 @@ def watch() -> None:
save_only_last_obs=True,
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
Expand Down
8 changes: 4 additions & 4 deletions examples/atari/atari_iqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from atari_network import DQN
from atari_wrapper import make_atari_env

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.data import Collector, CollectStats, VectorReplayBuffer
from tianshou.highlevel.logger import LoggerFactoryDefault
from tianshou.policy import IQNPolicy
from tianshou.policy.base import BasePolicy
Expand Down Expand Up @@ -122,8 +122,8 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack,
)
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True)

# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
Expand Down Expand Up @@ -183,7 +183,7 @@ def watch() -> None:
save_only_last_obs=True,
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
Expand Down
8 changes: 4 additions & 4 deletions examples/atari/atari_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch.distributions import Categorical
from torch.optim.lr_scheduler import LambdaLR

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.data import Collector, CollectStats, VectorReplayBuffer
from tianshou.highlevel.logger import LoggerFactoryDefault
from tianshou.policy import ICMPolicy, PPOPolicy
from tianshou.policy.base import BasePolicy
Expand Down Expand Up @@ -190,8 +190,8 @@ def dist(logits: torch.Tensor) -> Categorical:
stack_num=args.frames_stack,
)
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True)

# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
Expand Down Expand Up @@ -243,7 +243,7 @@ def watch() -> None:
save_only_last_obs=True,
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
Expand Down
8 changes: 4 additions & 4 deletions examples/atari/atari_qrdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from atari_network import QRDQN
from atari_wrapper import make_atari_env

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.data import Collector, CollectStats, VectorReplayBuffer
from tianshou.highlevel.logger import LoggerFactoryDefault
from tianshou.policy import QRDQNPolicy
from tianshou.policy.base import BasePolicy
Expand Down Expand Up @@ -116,8 +116,8 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack,
)
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True)

# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
Expand Down Expand Up @@ -177,7 +177,7 @@ def watch() -> None:
save_only_last_obs=True,
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
Expand Down
13 changes: 9 additions & 4 deletions examples/atari/atari_rainbow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
from atari_network import Rainbow
from atari_wrapper import make_atari_env

from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
from tianshou.data import (
Collector,
CollectStats,
PrioritizedVectorReplayBuffer,
VectorReplayBuffer,
)
from tianshou.highlevel.logger import LoggerFactoryDefault
from tianshou.policy import C51Policy, RainbowPolicy
from tianshou.policy.base import BasePolicy
Expand Down Expand Up @@ -142,8 +147,8 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
weight_norm=not args.no_weight_norm,
)
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True)

# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
Expand Down Expand Up @@ -213,7 +218,7 @@ def watch() -> None:
alpha=args.alpha,
beta=args.beta,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
Expand Down
8 changes: 4 additions & 4 deletions examples/atari/atari_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from atari_network import DQN
from atari_wrapper import make_atari_env

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.data import Collector, CollectStats, VectorReplayBuffer
from tianshou.highlevel.logger import LoggerFactoryDefault
from tianshou.policy import DiscreteSACPolicy, ICMPolicy
from tianshou.policy.base import BasePolicy
Expand Down Expand Up @@ -173,8 +173,8 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack,
)
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True)

# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
Expand Down Expand Up @@ -226,7 +226,7 @@ def watch() -> None:
save_only_last_obs=True,
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
Expand Down
Loading

0 comments on commit 21c21ab

Please sign in to comment.