From 758eca3aac9e8d27366a4b20873ba6df92e4a708 Mon Sep 17 00:00:00 2001 From: Aayush Date: Sat, 2 Nov 2024 13:23:10 -0700 Subject: [PATCH] questionably sucessful dino + lr prediction --- src/ppo/base_policy.py | 9 +- src/ppo/lr_predictor/actor_critic.py | 24 ++++-- src/ppo/ppo.py | 79 +++-------------- src/scripts/distr_wb_ppo_lr_predictor.py | 105 ++++++++++++++--------- src/scripts/distributed_sweep.sh | 2 +- src/scripts/test.py | 45 ++++++++++ src/utils.py | 12 ++- 7 files changed, 160 insertions(+), 116 deletions(-) create mode 100644 src/scripts/test.py diff --git a/src/ppo/base_policy.py b/src/ppo/base_policy.py index 338ac086..abb69974 100644 --- a/src/ppo/base_policy.py +++ b/src/ppo/base_policy.py @@ -90,7 +90,7 @@ def __init__(self, # print(f"critic: {self.critic}") self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr) self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_lr) - + def select_action(self, obs): """ Given an observation, use the actor to select an action and return @@ -135,6 +135,13 @@ def optimizer_step(self, actor_loss, critic_loss): self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() + + def set_env(self, train_mode: bool): + """ + Set the env based on train or test mode. + """ + self.actor.set_env(train_mode) + self.critic.set_env(train_mode) def predict_values(self, obs): """ diff --git a/src/ppo/lr_predictor/actor_critic.py b/src/ppo/lr_predictor/actor_critic.py index f1432eaa..270bd9ba 100644 --- a/src/ppo/lr_predictor/actor_critic.py +++ b/src/ppo/lr_predictor/actor_critic.py @@ -9,7 +9,7 @@ class LRCritic(Critic): - def __init__(self, env: Env, input_dim: int = 1024, h_dim: int = 64): + def __init__(self, train_env: Env, test_env: Env, input_dim: int = 1024, h_dim: int = 64): super().__init__() # should map obs (image) to value self.layers = [ @@ -18,7 +18,14 @@ def __init__(self, env: Env, input_dim: int = 1024, h_dim: int = 64): nn.Linear(h_dim, 1), ] self.network = nn.Sequential(*self.layers) - self.env = env + self.train_env = train_env + self.test_env = test_env + + self.set_env(train_mode=True) + + def set_env(self, train_mode: bool): + assert train_mode or self.test_env is not None, "Test env must be set if train mode is False" + self.env = self.train_env if train_mode else self.test_env def forward(self, obs: Tensor): obs = obs.to(torch.int) @@ -35,13 +42,13 @@ def forward(self, obs: Tensor): # return values class LRActor(Actor): - def __init__(self, env: Env, lrs: list[float] = None, input_dim: int = 1024, h_dim: int = 64): + def __init__(self, train_env: Env, test_env: Env = None, lrs: list[float] = None, input_dim: int = 1024, h_dim: int = 64): super().__init__() if lrs: self.lrs = Tensor(lrs) else: self.lrs = Tensor([10**-i for i in range(10)]) - self.lrs = self.lrs.to(device=env.device) + self.lrs = self.lrs.to(device=train_env.device) print(f"Actor using learning rates: {self.lrs}") self.layers = [ @@ -51,7 +58,14 @@ def __init__(self, env: Env, lrs: list[float] = None, input_dim: int = 1024, h_d nn.Softmax(dim=-1) ] self.network = nn.Sequential(*self.layers) - self.env = env + self.train_env = train_env + self.test_env = test_env + + self.set_env(train_mode=True) + + def set_env(self, train_mode: bool): + assert train_mode or self.test_env is not None, "Test env must be set if train mode is False" + self.env = self.train_env if train_mode else self.test_env def forward(self, obs: Tensor): # convert from torch float to int diff --git a/src/ppo/ppo.py b/src/ppo/ppo.py index f985bfbf..779bcdcc 100644 --- a/src/ppo/ppo.py +++ b/src/ppo/ppo.py @@ -14,14 +14,14 @@ from .base_env import Env from .rollout_buffer import RolloutBuffer from src.utils import * - + class PPO: def __init__(self, policy: Policy, env: Env, clip_epsilon=0.2, gamma=1, gae_lambda=0.95, normalize_advantages=True, - entropy_coeff=0.0, + entropy_schedule=lambda t: 0.2, n_epochs=5, batch_size=10, buffer_size=20, @@ -30,6 +30,7 @@ def __init__(self, policy: Policy, env: Env, shuffle=False, log_callback=None, plots_path=None, + **kwargs ): """ @@ -40,7 +41,6 @@ def __init__(self, policy: Policy, env: Env, - gamma: Discount factor for future rewards (1 means full future reward impact). - gae_lambda: Trade-off between bias and variance for GAE (lambda-return). - normalize_advantages: Normalize advantages for better numerical stability. - - entropy_coeff: Adds entropy to encourage exploration (higher value -> more exploration). - n_epochs: Number of passes over the data to update the policy. - batch_size: Number of samples per batch for policy updates. - buffer_size: Number of steps to store in the rollout buffer. @@ -55,12 +55,12 @@ def __init__(self, policy: Policy, env: Env, self.gamma = gamma # Discount factor for future rewards. self.gae_lambda = gae_lambda # Governs the advantage estimation trade-off. self.normalize_advantages = normalize_advantages - self.entropy_coeff = entropy_coeff # Encourage exploration. self.n_epochs = n_epochs self.batch_size = batch_size self.buffer_size = buffer_size self.shuffle = shuffle + self.entropy_schedule = entropy_schedule # Initialize the rollout buffer to store experiences self.rollout_buffer = RolloutBuffer( @@ -82,6 +82,7 @@ def __init__(self, policy: Policy, env: Env, "avg_critic_values": [], "avg_advantages": [], "entropy": [], + "entropy_coeff": [], "surr_loss": [], "timesteps": [], "num_match": [] @@ -169,7 +170,7 @@ def compute_loss(self, batch_data): # print(f"actor loss before: {actor_loss}") # print(f"surr1: {surr1}, surr2: {surr2}") # print(f"ratios: {ratios}") - entropy_loss = self.entropy_coeff * entropy.mean() + entropy_loss = self.entropy_schedule(self.logger["i_so_far"]) * entropy.mean() actor_loss -= entropy_loss # print(f"actions: {actions}") @@ -200,6 +201,8 @@ def update(self): self.logger["actor_losses"].append(actor_loss.item()) self.logger["critic_losses"].append(critic_loss.item()) self.logger["entropy"].append(entropy.item()) + + self.logger["entropy_coeff"].append(self.entropy_schedule(self.logger["i_so_far"])) self.logger["avg_rewards"].append(self.rollout_buffer.rewards.mean().item()) self.logger["avg_advantages"].append(self.rollout_buffer.advantages.mean().item()) @@ -219,15 +222,13 @@ def train(self, total_timesteps): i_so_far += 1 self.logger["timesteps"].append(t_so_far) + self.logger["i_so_far"] = i_so_far # Perform policy (actor+critic) updates self.update() if i_so_far % self.log_interval == 0: self._log_summary() - - if self.plots_path: - self.plot_training_progress() def _log_summary(self): """ @@ -237,6 +238,7 @@ def _log_summary(self): print(f" Timesteps so far: {self.logger['t_so_far']}") print(f" Actor Loss: {self.logger['actor_losses'][-1]:.6f}") print(f" Entropy: {self.logger['entropy'][-1]:.4f}") + print(f" Entropy Coefficient: {self.logger['entropy_coeff'][-1]:.4f}") print(f" Surrogate Loss: {self.logger['surr_loss'][-1]:.4f}") print(f" Critic Loss: {self.logger['critic_losses'][-1]:.4f}") print(f" Avg Rewards: {self.logger['avg_rewards'][-1]:.4f}") @@ -247,66 +249,7 @@ def _log_summary(self): num_match = self.log_callback(self.policy) self.logger["num_match"].append(num_match) print(f" Num Match: {num_match}") - print("=" * 50) - - def plot_training_progress(self): - # Use logging data to plot (actor loss, critic loss, rewards, avg critic values, avg advantages, and entropy) - plt.figure(figsize=(24, 14)) - - plt.subplot(4, 2, 1) - plt.plot(self.logger['actor_losses'], label='Actor Loss') - plt.xlabel('Timesteps') - plt.ylabel('Loss') - plt.title('Actor Loss') - - plt.subplot(4, 2, 2) - plt.plot(self.logger['critic_losses'], label='Critic Loss') - plt.xlabel('Timesteps') - plt.ylabel('Loss') - plt.title('Critic Loss') - - plt.subplot(4, 2, 3) - plt.plot(self.logger['avg_rewards'], label='Avg Rewards') - plt.xlabel('Timesteps') - plt.ylabel('Rewards') - plt.title('Average Rewards') - - plt.subplot(4, 2, 4) - plt.plot(self.logger['avg_critic_values'], label='Avg Critic Values') - plt.xlabel('Timesteps') - plt.ylabel('Values') - plt.title('Average Critic Values') - - plt.subplot(4, 2, 5) - plt.plot(self.logger['avg_advantages'], label='Avg Advantages') - plt.xlabel('Timesteps') - plt.ylabel('Advantages') - plt.title('Average Advantages') - - # make y scale integers - plt.subplot(4, 2, 6) - plt.plot(self.logger['num_match'], label='Num Match') - plt.xlabel('Timesteps') - plt.ylabel('Num Match') - plt.title('Number of Matches') - plt.gca().yaxis.set_major_locator(plt.MaxNLocator(integer=True)) - - plt.subplot(4, 2, 7) - plt.plot(self.logger['entropy'], label='Entropy') - plt.xlabel('Timesteps') - plt.ylabel('Entropy') - plt.title('Entropy') - - plt.subplot(4, 2, 8) - plt.plot(self.logger['surr_loss'], label='PG Loss') - plt.xlabel('Timesteps') - plt.ylabel('PG Loss') - plt.title('PG') - - # Save the plot to a file using attributes - # plt.tight_layout() - # plt.savefig(self.plots_path) - + print("=" * 50) def save_policy(self, path="policy_checkpoint.pth"): checkpoint = { diff --git a/src/scripts/distr_wb_ppo_lr_predictor.py b/src/scripts/distr_wb_ppo_lr_predictor.py index 4f187e94..21c99c92 100644 --- a/src/scripts/distr_wb_ppo_lr_predictor.py +++ b/src/scripts/distr_wb_ppo_lr_predictor.py @@ -5,11 +5,11 @@ import os from dataclasses import dataclass import tyro -from src.ppo.ppo import PPO +from src.ppo.ppo import PPO, Env from src.ppo.base_policy import Policy from src.ppo.lr_predictor.actor_critic import LRCritic, LRActor from src.ppo.lr_predictor.env import LREnv -from src.utils import image_path_to_tensor, seed_everything +from src.utils import image_path_to_tensor, seed_everything, LinearSchedule device = 'cuda' seed_everything(42) @@ -18,7 +18,7 @@ width: int = 256 # im_path = 'images/adam.jpg' im_path = None -if not im_path: +if not im_path: gt_image = torch.ones((height, width, 3)) * 1.0 # make top left and bottom right red, blue gt_image[: height // 2, : width // 2, :] = torch.tensor([1.0, 0.0, 0.0]) @@ -55,7 +55,7 @@ class PPOConfig: batch_size: int = 32 buffer_size: int = 128 num_updates: int = 300 - entropy_coeff: float = 0.2 + entropy_schedule: LinearSchedule = LinearSchedule(300, 0.5, 0.1, 0.1) log_interval: int = 1 actor_lr: float = 3e-4 critic_lr: float = 3e-4 @@ -80,15 +80,23 @@ class EnvConfig: img_encoder: str = 'dino' class WandbCallback: - def __init__(self, policy, env, ppo, is_sweep: bool = False): + def __init__(self, policy, train_env : Env, ppo, test_env: Env = None, is_sweep: bool = False): self.policy = policy - self.env = env + self.train_env = train_env + self.test_env = test_env + self.ppo = ppo self.is_sweep = is_sweep self.iteration = 0 def __call__(self, policy): - num_matches = eval_policy(policy, self.env) + num_matches = eval_policy(policy, self.train_env) + + num_test_matches = -1 + if self.test_env is not None: + policy.set_env(train_mode=False) + num_test_matches = eval_policy(policy, self.test_env, verbose=False) + policy.set_env(train_mode=True) # Base metrics logged in both modes metrics = { @@ -97,10 +105,13 @@ def __call__(self, policy): "critic_loss": self.ppo.logger["critic_losses"][-1] if self.ppo.logger["critic_losses"] else None, "avg_reward": self.ppo.logger["avg_rewards"][-1] if self.ppo.logger["avg_rewards"] else None, + "entropy_coeff": self.ppo.logger["entropy_coeff"][-1], "entropy": self.ppo.logger["entropy"][-1], "surrogate_loss": self.ppo.logger["surr_loss"][-1], "avg_advantage": self.ppo.logger["avg_advantages"][-1], "avg_critic_value": self.ppo.logger["avg_critic_values"][-1], + + "num_matches_test": num_test_matches, } wandb.log(metrics, step=self.iteration) @@ -113,14 +124,17 @@ def setup_wandb_sweep(method: str='grid'): 'command': ['/secondary/home/aayushg/miniconda3/envs/gsplat/bin/python3', '${program}', '--run_sweep_agent'], # Updated this line 'method': method, 'parameters': { - 'n_epochs': {'values': [3, 5, 7]}, - 'batch_size': {'values': [256, 512, 1024]}, + 'n_epochs': {'values': [5]}, + 'batch_size': {'values': [1024]}, 'buffer_multiplier': {'values': [1]}, 'num_updates': {'values': [300, 400, 500]}, - 'entropy_coeff': {'values': [0.05, 0.1, 0.15, 0.20]}, + 'entropy_coeff_schedule': {'values': ['linear']}, + 'entropy_coeff_start': {'values': [.8, .6, .4, .2]}, + 'entropy_coeff_stride': {'values': [.05, .1, .25, .5]}, + 'entropy_coeff_freq': {'values': [0.01, 0.05, 0.1, 0.2]}, 'actor_lr': {'values': [3e-4, 1e-3]}, 'critic_lr': {'values': [1e-4]}, - 'clip_epsilon': {'values': [0.15, 0.18, 0.2]} + 'clip_epsilon': {'values': [.18]} # 'critic_lr': {'values': [1e-4, 3e-4, 1e-3]} }, 'metric': { @@ -139,6 +153,12 @@ def setup_wandb_sweep(method: str='grid'): def train(config: PPOConfig, is_sweep: bool = False, save_ckpt: bool = False, load_ckpt_id: str = None) -> None: """Main training function used for both sweep and single runs""" if is_sweep: + entropy_schedule = LinearSchedule( + num_updates=wandb.config.num_updates, + start=wandb.config.entropy_coeff_start, + stride=wandb.config.entropy_coeff_stride, + frequency=wandb.config.entropy_coeff_freq + ) # For sweep: compute buffer_size from multiplier buffer_size = wandb.config.batch_size * wandb.config.buffer_multiplier config = PPOConfig( @@ -146,7 +166,7 @@ def train(config: PPOConfig, is_sweep: bool = False, save_ckpt: bool = False, lo batch_size=wandb.config.batch_size, buffer_size=buffer_size, num_updates=wandb.config.num_updates, - entropy_coeff=wandb.config.entropy_coeff, + entropy_schedule=entropy_schedule, actor_lr=wandb.config.actor_lr, critic_lr=wandb.config.critic_lr, clip_epsilon=wandb.config.clip_epsilon, @@ -156,7 +176,6 @@ def train(config: PPOConfig, is_sweep: bool = False, save_ckpt: bool = False, lo # Initialize environment and models train_env_config = EnvConfig(device=config.device) - train_env = LREnv( dataset_path=train_env_config.dataset_path, num_points=train_env_config.num_points, @@ -168,14 +187,30 @@ def train(config: PPOConfig, is_sweep: bool = False, save_ckpt: bool = False, lo lrs=train_env_config.lrs, num_trials=train_env_config.num_trials, ) - + + # Eval + test_env_config = EnvConfig(device=config.device) + test_env = LREnv( + dataset_path=test_env_config.dataset_path+"_test", + num_points=test_env_config.num_points, + num_iterations=test_env_config.iterations, + observation_shape=test_env_config.observation_shape, + action_shape=test_env_config.action_shape, + device=test_env_config.device, + img_encoder=test_env_config.img_encoder, + lrs = test_env_config.lrs, + num_trials = test_env_config.num_trials, + ) + actor = LRActor( - env=train_env, + train_env=train_env, + test_env=test_env, lrs=train_env.lrs, input_dim=train_env.encoded_img_shape[0] ) critic = LRCritic( - env=train_env, + train_env=train_env, + test_env=test_env, input_dim=train_env.encoded_img_shape[0] ) policy = Policy( @@ -195,7 +230,7 @@ def train(config: PPOConfig, is_sweep: bool = False, save_ckpt: bool = False, lo buffer_size=config.buffer_size, log_interval=config.log_interval, device=config.device, - entropy_coeff=config.entropy_coeff, + entropy_schedule=config.entropy_schedule, clip_epsilon=config.clip_epsilon, shuffle=True, normalize_advantages=True, @@ -203,7 +238,7 @@ def train(config: PPOConfig, is_sweep: bool = False, save_ckpt: bool = False, lo ) # Then create and set the callback - wandb_callback = WandbCallback(policy, train_env, ppo, is_sweep=is_sweep) + wandb_callback = WandbCallback(policy, train_env, ppo, test_env=test_env, is_sweep=is_sweep) ppo.log_callback = wandb_callback # Train @@ -218,25 +253,9 @@ def train(config: PPOConfig, is_sweep: bool = False, save_ckpt: bool = False, lo ppo.save_policy(ckpt_path) wandb.save(ckpt_path) - # Eval - test_env_config = EnvConfig(device=config.device) - test_env = LREnv( - dataset_path=test_env_config.dataset_path+"_test", - num_points=test_env_config.num_points, - num_iterations=test_env_config.iterations, - observation_shape=test_env_config.observation_shape, - action_shape=test_env_config.action_shape, - device=test_env_config.device, - img_encoder=test_env_config.img_encoder, - lrs = test_env_config.lrs, - num_trials = test_env_config.num_trials, - ) - - actor.env = test_env - critic.env = test_env - num_match = eval_policy(policy, test_env, verbose=True) - wandb.log({"num_matches_test": num_match}) - print(f"Number of matches on test set: {num_match}") + ppo.policy.set_env(train_mode=False) + num_test_matches = eval_policy(policy, test_env, verbose=True) + print(f"Number of matches on test set: {num_test_matches}") def train_with_wandb( @@ -297,16 +316,22 @@ def main( ) return - + else: # train without sweep - # our + entropy_schedule = LinearSchedule( + num_updates=300, + start=0.5, + stride=0.1, + frequency=0.1 + ) + config = PPOConfig( batch_size=64, buffer_size=512, clip_epsilon=.18, n_epochs=7, num_updates=300, - entropy_coeff=0.15, + entropy_schedule=entropy_schedule, device=device ) diff --git a/src/scripts/distributed_sweep.sh b/src/scripts/distributed_sweep.sh index c484bc15..f39625b2 100755 --- a/src/scripts/distributed_sweep.sh +++ b/src/scripts/distributed_sweep.sh @@ -1,5 +1,5 @@ SCRIPT="src/scripts/distr_wb_ppo_lr_predictor.py" -PROJECT_NAME="ppo_lr_predictor_distr_no_multiplier_final" +PROJECT_NAME="ppo_lr_predictor_distr_entropy_sched_debug" ENTITY="rl_gsplat" COMMAND="python $SCRIPT --sweep --project_name $PROJECT_NAME --entity $ENTITY" $COMMAND diff --git a/src/scripts/test.py b/src/scripts/test.py new file mode 100644 index 00000000..62e96f7f --- /dev/null +++ b/src/scripts/test.py @@ -0,0 +1,45 @@ +import ray +import torch +import torch.nn as nn + +import os + +cuda_devices = os.getenv("CUDA_VISIBLE_DEVICES") +if cuda_devices is not None: + print(f"CUDA_VISIBLE_DEVICES={cuda_devices}") +else: + print("CUDA_VISIBLE_DEVICES is not set.") + + +# Initialize Ray +ray.init() + +# Define your model training function +@ray.remote(num_gpus=1) +def train_model_on_gpu(gpu_id, config): + device = torch.device(f"cuda:{gpu_id}") + print(device) + + # Define a simple model and move it to the specific GPU + model = nn.Linear(10, 10).to(device) + optimizer = torch.optim.SGD(model.parameters(), lr=config["lr"]) + + # Simulate a training loop + for epoch in range(config["epochs"]): + inputs = torch.randn(32, 10).to(device) + outputs = model(inputs) + loss = outputs.sum() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + return f"Finished training on GPU {gpu_id} with config {config}" + +# Launch asynchronous training tasks on multiple GPUs +configs = [{"lr": 0.01, "epochs": 5}, {"lr": 0.001, "epochs": 5}] +futures = [train_model_on_gpu.remote(i, config) for i, config in enumerate(configs)] + +# Collect results (non-blocking) +results = ray.get(futures) +print(results) diff --git a/src/utils.py b/src/utils.py index 19ba4dd0..f6918da6 100644 --- a/src/utils.py +++ b/src/utils.py @@ -32,4 +32,14 @@ def dino_preprocess(img_path: str) -> torch.tensor: ]) input_tensor = preprocess(input_image) - return input_tensor \ No newline at end of file + return input_tensor + +class LinearSchedule: + def __init__(self, num_updates: int, start: float, stride: float, frequency: float): + self.num_updates = num_updates + self.start = start + self.stride = stride + self.frequency = frequency + + def __call__(self, t: int) -> float: + return max(0, self.start - self.stride * (t // (self.num_updates * self.frequency))) \ No newline at end of file