Skip to content

Commit

Permalink
working distributed sweep for lr pred, loading/saving checkpoints, an…
Browse files Browse the repository at this point in the history
…d eval
  • Loading branch information
aayushg55 committed Oct 28, 2024
1 parent 2cecfd1 commit a1981ed
Show file tree
Hide file tree
Showing 12 changed files with 418 additions and 36 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,5 @@ results

!examples/benchmarks/compression/results/

wandb/
wandb/
temp/
Empty file added create_dataset.ipynb
Empty file.
1 change: 1 addition & 0 deletions src/ppo/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ def __init__(self, observation_shape=None, act_shape=None) -> None:
self.num_steps = 100
self.observation_shape = observation_shape
self.action_shape = act_shape
self.eval_mode = 'train'
super().__init__()

@abstractmethod
Expand Down
7 changes: 4 additions & 3 deletions src/ppo/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init__(self,
device='cuda'
):
super().__init__()
print(f'initializing policy with device: {device}')
self.actor = actor.to(device)
self.critic = critic.to(device)

Expand Down Expand Up @@ -131,9 +132,9 @@ def optimizer_step(self, actor_loss, critic_loss):
# print(f"Did optimizer step with actor loss: {actor_loss}")

# # Critic update
# self.critic_optimizer.zero_grad()
# critic_loss.backward()
# self.critic_optimizer.step()
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()

def predict_values(self, obs):
"""
Expand Down
12 changes: 7 additions & 5 deletions src/ppo/lr_predictor/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,19 @@ def __init__(self, env: Env, input_dim: int = 1024, h_dim: int = 64):
self.network = nn.Sequential(*self.layers)
self.env = env

def forward(self, obs: Tensor):
def forward(self, obs: Tensor):
obs = obs.to(torch.int)
enc_images = self.env.get_encoded_images(obs)

# Ensure obs is properly reshaped for the network
batch_size = obs.shape[0] if len(obs.shape) > 3 else 1
# obs = obs.view(batch_size, -1) # Flatten input to (batch_size, features)
# return self.network(obs)

return self.network(enc_images)

# For debugging, simply return mean of psnr's over diff lr for this img
values = self.env.get_mean_reward(obs)
return values
# values = self.env.get_mean_reward(obs)
# return values

class LRActor(Actor):
def __init__(self, env: Env, lrs: list[float] = None, input_dim: int = 1024, h_dim: int = 64):
Expand Down
27 changes: 16 additions & 11 deletions src/ppo/lr_predictor/env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import torch
import json
import warnings
import os
from collections import defaultdict
from examples.image_fitting import SimpleTrainer
Expand All @@ -16,13 +17,14 @@ class LREnv(Env):
"""
def __init__(
self,
lrs: list[float],
dataset_path: str,
num_points: int,
iterations: int,
num_iterations: int,
# TODO: remove observation_shape?
observation_shape: tuple,
action_shape: tuple,
n_trials: int = 10,
num_trials: int = 1,
device='cuda',
img_encoder: str = 'dino'
):
Expand All @@ -31,18 +33,20 @@ def __init__(
#
self.max_steps = 1
self.num_points = num_points
self.n_trials = n_trials
self.iterations = iterations
# self.lrs = [0.007 + i*0.001 for i in range(10)] #(1, 0.1, 0.01,)
self.lrs = [0.005 + i*0.001 for i in range(40)] #(1, 0.1, 0.01,)
self.num_trials = num_trials
self.num_iterations = num_iterations

self.lrs = lrs
self.lrs = [round(lr, 5) for lr in self.lrs]

self.device = device
# self.observation_shape = img.shape
self.action_shape = action_shape

if img_encoder == 'dino':
self.img_encoder = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self.img_encoder = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
self.encoded_img_shape = (self.img_encoder.embed_dim,)
print("Using DINO large distilled as encoder")
self.observation_shape = (1,) # Just img index
Expand All @@ -65,7 +69,7 @@ def __init__(

self.encoded_images.append(encoded_img)
print("=" * 100)
print(f'num images: {self.num_images}\n num_trials: {self.n_trials}\n num_points: {self.num_points}\n iterations: {self.iterations}, num learning rates: {len(self.lrs)}')
print(f'num images: {self.num_images}\n num_trials: {self.num_trials}\n num_points: {self.num_points}\n num_iterations: {self.num_iterations}, num learning rates: {len(self.lrs)}')
print("=" * 100)
self.encoded_images = torch.stack(self.encoded_images)
self.original_images = torch.stack(self.orig_images)
Expand All @@ -76,7 +80,7 @@ def __init__(
current_dir, f"{dataset_name}_lr_losses_final.json"
)

self.num_images = 8
# self.num_images = 8

# compute losses for each LR
if os.path.exists(losses_json_path):
Expand All @@ -89,12 +93,12 @@ def __init__(
print(f"precomputing image {i+1}/{self.num_images}")
lr_losses = {str(lr): [] for lr in self.lrs}
for lr in self.lrs:
for trial_num in range(self.n_trials):
for trial_num in range(self.num_trials):
print("*" * 50)
print(f'currently training with lr={lr}, trial {trial_num}')
trainer = SimpleTrainer(gt_image=original_img, num_points=num_points)
losses, _ = trainer.train(
iterations=self.iterations,
iterations=self.num_iterations,
lr=lr,
save_imgs=False,
model_type='3dgs',
Expand Down Expand Up @@ -133,6 +137,7 @@ def __init__(
self.lr_losses_tensor = torch.stack(lr_losses_list)

self.psnr = 10 * torch.log10(1 / self.lr_losses_tensor[:, :, -1])
self.psnr = (self.psnr - 14.0) / 9.0

self.psnr_stats = {
"mean": self.psnr.mean(dim=1),
Expand Down

Large diffs are not rendered by default.

34 changes: 29 additions & 5 deletions src/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ def __init__(self, policy: Policy, env: Env,
"""
self.policy = policy.to(device)
self.env = env

assert env.device == device, "Policy and environment devices must match."
assert env.device == device or device == 'cuda', f"Policy device ({device}) and environment device ({env.device}) must match."

# PPO Hyperparameters
self.clip_epsilon = clip_epsilon
Expand Down Expand Up @@ -159,6 +158,7 @@ def compute_loss(self, batch_data):
# Compute the surrogate objectives (clipped vs unclipped)
surr1 = ratios * advantages
surr2 = torch.clamp(ratios, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * advantages
# print(f"ratios: {ratios}")
# print(f"log_probs_new: {log_probs_new}")
# print(f"old_log_probs: {old_log_probs}")
# actor_loss = -(log_probs_new * rewards).mean()
Expand Down Expand Up @@ -233,7 +233,7 @@ def _log_summary(self):
"""
Print a summary of the current training progress.
"""
print(f"Iteration {len(self.logger['timesteps'])}:")
print(f"Update {len(self.logger['timesteps'])}:")
print(f" Timesteps so far: {self.logger['t_so_far']}")
print(f" Actor Loss: {self.logger['actor_losses'][-1]:.6f}")
print(f" Entropy: {self.logger['entropy'][-1]:.4f}")
Expand All @@ -242,6 +242,7 @@ def _log_summary(self):
print(f" Avg Rewards: {self.logger['avg_rewards'][-1]:.4f}")
print(f" Avg Advantages: {self.logger['avg_advantages'][-1]:.4f}")
print(f" Avg Critic Values: {self.logger['avg_critic_values'][-1]:.4f}")

if self.log_callback:
num_match = self.log_callback(self.policy)
self.logger["num_match"].append(num_match)
Expand Down Expand Up @@ -303,6 +304,29 @@ def plot_training_progress(self):
plt.title('PG')

# Save the plot to a file using attributes
plt.tight_layout()
plt.savefig(self.plots_path)
# plt.tight_layout()
# plt.savefig(self.plots_path)


def save_policy(self, path="policy_checkpoint.pth"):
checkpoint = {
'actor_state_dict': self.policy.actor.state_dict(),
'critic_state_dict': self.policy.critic.state_dict(),
'actor_optimizer_state_dict': self.policy.actor_optimizer.state_dict(),
'critic_optimizer_state_dict': self.policy.critic_optimizer.state_dict(),
}
torch.save(checkpoint, path)
print(f"Policy saved to {path}")

def load_policy(self, path="policy_checkpoint.pth", device='cuda'):
checkpoint = torch.load(path, map_location=device)

# Load actor and critic state dictionaries
self.policy.actor.load_state_dict(checkpoint['actor_state_dict'])
self.policy.critic.load_state_dict(checkpoint['critic_state_dict'])

# Load optimizers' state dictionaries
self.policy.actor_optimizer.load_state_dict(checkpoint['actor_optimizer_state_dict'])
self.policy.critic_optimizer.load_state_dict(checkpoint['critic_optimizer_state_dict'])

print(f"Policy loaded from {path} onto {device}")
Loading

0 comments on commit a1981ed

Please sign in to comment.