Skip to content

Commit

Permalink
questionably sucessful dino + lr prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
aayushg55 committed Nov 2, 2024
1 parent a1981ed commit 758eca3
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 116 deletions.
9 changes: 8 additions & 1 deletion src/ppo/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(self,
# print(f"critic: {self.critic}")
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_lr)

def select_action(self, obs):
"""
Given an observation, use the actor to select an action and return
Expand Down Expand Up @@ -135,6 +135,13 @@ def optimizer_step(self, actor_loss, critic_loss):
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()

def set_env(self, train_mode: bool):
"""
Set the env based on train or test mode.
"""
self.actor.set_env(train_mode)
self.critic.set_env(train_mode)

def predict_values(self, obs):
"""
Expand Down
24 changes: 19 additions & 5 deletions src/ppo/lr_predictor/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


class LRCritic(Critic):
def __init__(self, env: Env, input_dim: int = 1024, h_dim: int = 64):
def __init__(self, train_env: Env, test_env: Env, input_dim: int = 1024, h_dim: int = 64):
super().__init__()
# should map obs (image) to value
self.layers = [
Expand All @@ -18,7 +18,14 @@ def __init__(self, env: Env, input_dim: int = 1024, h_dim: int = 64):
nn.Linear(h_dim, 1),
]
self.network = nn.Sequential(*self.layers)
self.env = env
self.train_env = train_env
self.test_env = test_env

self.set_env(train_mode=True)

def set_env(self, train_mode: bool):
assert train_mode or self.test_env is not None, "Test env must be set if train mode is False"
self.env = self.train_env if train_mode else self.test_env

def forward(self, obs: Tensor):
obs = obs.to(torch.int)
Expand All @@ -35,13 +42,13 @@ def forward(self, obs: Tensor):
# return values

class LRActor(Actor):
def __init__(self, env: Env, lrs: list[float] = None, input_dim: int = 1024, h_dim: int = 64):
def __init__(self, train_env: Env, test_env: Env = None, lrs: list[float] = None, input_dim: int = 1024, h_dim: int = 64):
super().__init__()
if lrs:
self.lrs = Tensor(lrs)
else:
self.lrs = Tensor([10**-i for i in range(10)])
self.lrs = self.lrs.to(device=env.device)
self.lrs = self.lrs.to(device=train_env.device)

print(f"Actor using learning rates: {self.lrs}")
self.layers = [
Expand All @@ -51,7 +58,14 @@ def __init__(self, env: Env, lrs: list[float] = None, input_dim: int = 1024, h_d
nn.Softmax(dim=-1)
]
self.network = nn.Sequential(*self.layers)
self.env = env
self.train_env = train_env
self.test_env = test_env

self.set_env(train_mode=True)

def set_env(self, train_mode: bool):
assert train_mode or self.test_env is not None, "Test env must be set if train mode is False"
self.env = self.train_env if train_mode else self.test_env

def forward(self, obs: Tensor):
# convert from torch float to int
Expand Down
79 changes: 11 additions & 68 deletions src/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
from .base_env import Env
from .rollout_buffer import RolloutBuffer
from src.utils import *

class PPO:
def __init__(self, policy: Policy, env: Env,
clip_epsilon=0.2,
gamma=1,
gae_lambda=0.95,
normalize_advantages=True,
entropy_coeff=0.0,
entropy_schedule=lambda t: 0.2,
n_epochs=5,
batch_size=10,
buffer_size=20,
Expand All @@ -30,6 +30,7 @@ def __init__(self, policy: Policy, env: Env,
shuffle=False,
log_callback=None,
plots_path=None,

**kwargs
):
"""
Expand All @@ -40,7 +41,6 @@ def __init__(self, policy: Policy, env: Env,
- gamma: Discount factor for future rewards (1 means full future reward impact).
- gae_lambda: Trade-off between bias and variance for GAE (lambda-return).
- normalize_advantages: Normalize advantages for better numerical stability.
- entropy_coeff: Adds entropy to encourage exploration (higher value -> more exploration).
- n_epochs: Number of passes over the data to update the policy.
- batch_size: Number of samples per batch for policy updates.
- buffer_size: Number of steps to store in the rollout buffer.
Expand All @@ -55,12 +55,12 @@ def __init__(self, policy: Policy, env: Env,
self.gamma = gamma # Discount factor for future rewards.
self.gae_lambda = gae_lambda # Governs the advantage estimation trade-off.
self.normalize_advantages = normalize_advantages
self.entropy_coeff = entropy_coeff # Encourage exploration.
self.n_epochs = n_epochs
self.batch_size = batch_size
self.buffer_size = buffer_size

self.shuffle = shuffle
self.entropy_schedule = entropy_schedule

# Initialize the rollout buffer to store experiences
self.rollout_buffer = RolloutBuffer(
Expand All @@ -82,6 +82,7 @@ def __init__(self, policy: Policy, env: Env,
"avg_critic_values": [],
"avg_advantages": [],
"entropy": [],
"entropy_coeff": [],
"surr_loss": [],
"timesteps": [],
"num_match": []
Expand Down Expand Up @@ -169,7 +170,7 @@ def compute_loss(self, batch_data):
# print(f"actor loss before: {actor_loss}")
# print(f"surr1: {surr1}, surr2: {surr2}")
# print(f"ratios: {ratios}")
entropy_loss = self.entropy_coeff * entropy.mean()
entropy_loss = self.entropy_schedule(self.logger["i_so_far"]) * entropy.mean()
actor_loss -= entropy_loss

# print(f"actions: {actions}")
Expand Down Expand Up @@ -200,6 +201,8 @@ def update(self):
self.logger["actor_losses"].append(actor_loss.item())
self.logger["critic_losses"].append(critic_loss.item())
self.logger["entropy"].append(entropy.item())

self.logger["entropy_coeff"].append(self.entropy_schedule(self.logger["i_so_far"]))

self.logger["avg_rewards"].append(self.rollout_buffer.rewards.mean().item())
self.logger["avg_advantages"].append(self.rollout_buffer.advantages.mean().item())
Expand All @@ -219,15 +222,13 @@ def train(self, total_timesteps):
i_so_far += 1

self.logger["timesteps"].append(t_so_far)
self.logger["i_so_far"] = i_so_far

# Perform policy (actor+critic) updates
self.update()

if i_so_far % self.log_interval == 0:
self._log_summary()

if self.plots_path:
self.plot_training_progress()

def _log_summary(self):
"""
Expand All @@ -237,6 +238,7 @@ def _log_summary(self):
print(f" Timesteps so far: {self.logger['t_so_far']}")
print(f" Actor Loss: {self.logger['actor_losses'][-1]:.6f}")
print(f" Entropy: {self.logger['entropy'][-1]:.4f}")
print(f" Entropy Coefficient: {self.logger['entropy_coeff'][-1]:.4f}")
print(f" Surrogate Loss: {self.logger['surr_loss'][-1]:.4f}")
print(f" Critic Loss: {self.logger['critic_losses'][-1]:.4f}")
print(f" Avg Rewards: {self.logger['avg_rewards'][-1]:.4f}")
Expand All @@ -247,66 +249,7 @@ def _log_summary(self):
num_match = self.log_callback(self.policy)
self.logger["num_match"].append(num_match)
print(f" Num Match: {num_match}")
print("=" * 50)

def plot_training_progress(self):
# Use logging data to plot (actor loss, critic loss, rewards, avg critic values, avg advantages, and entropy)
plt.figure(figsize=(24, 14))

plt.subplot(4, 2, 1)
plt.plot(self.logger['actor_losses'], label='Actor Loss')
plt.xlabel('Timesteps')
plt.ylabel('Loss')
plt.title('Actor Loss')

plt.subplot(4, 2, 2)
plt.plot(self.logger['critic_losses'], label='Critic Loss')
plt.xlabel('Timesteps')
plt.ylabel('Loss')
plt.title('Critic Loss')

plt.subplot(4, 2, 3)
plt.plot(self.logger['avg_rewards'], label='Avg Rewards')
plt.xlabel('Timesteps')
plt.ylabel('Rewards')
plt.title('Average Rewards')

plt.subplot(4, 2, 4)
plt.plot(self.logger['avg_critic_values'], label='Avg Critic Values')
plt.xlabel('Timesteps')
plt.ylabel('Values')
plt.title('Average Critic Values')

plt.subplot(4, 2, 5)
plt.plot(self.logger['avg_advantages'], label='Avg Advantages')
plt.xlabel('Timesteps')
plt.ylabel('Advantages')
plt.title('Average Advantages')

# make y scale integers
plt.subplot(4, 2, 6)
plt.plot(self.logger['num_match'], label='Num Match')
plt.xlabel('Timesteps')
plt.ylabel('Num Match')
plt.title('Number of Matches')
plt.gca().yaxis.set_major_locator(plt.MaxNLocator(integer=True))

plt.subplot(4, 2, 7)
plt.plot(self.logger['entropy'], label='Entropy')
plt.xlabel('Timesteps')
plt.ylabel('Entropy')
plt.title('Entropy')

plt.subplot(4, 2, 8)
plt.plot(self.logger['surr_loss'], label='PG Loss')
plt.xlabel('Timesteps')
plt.ylabel('PG Loss')
plt.title('PG')

# Save the plot to a file using attributes
# plt.tight_layout()
# plt.savefig(self.plots_path)

print("=" * 50)

def save_policy(self, path="policy_checkpoint.pth"):
checkpoint = {
Expand Down
Loading

0 comments on commit 758eca3

Please sign in to comment.