diff --git a/.ci-cd/build.sh b/.ci-cd/build.sh index 54858a58c..232d1f43d 100755 --- a/.ci-cd/build.sh +++ b/.ci-cd/build.sh @@ -116,6 +116,7 @@ function test() { alf.networks.q_networks_test \ alf.networks.relu_mlp_test \ alf.networks.value_networks_test \ + alf.optimizers.dadapt_optimizers_test \ alf.optimizers.nero_plus_test \ alf.optimizers.optimizers_test \ alf.optimizers.trusted_updater_test \ diff --git a/alf/optimizers/__init__.py b/alf/optimizers/__init__.py index 08a954bd4..f70de6b5b 100644 --- a/alf/optimizers/__init__.py +++ b/alf/optimizers/__init__.py @@ -17,6 +17,8 @@ from .optimizers import AdamW from .optimizers import SGD from .optimizers import NeroPlus +from .optimizers import DAdaptSGD +from .optimizers import DAdaptAdam from typing import Any Optimizer = Any diff --git a/alf/optimizers/dadapt_adam.py b/alf/optimizers/dadapt_adam.py new file mode 100644 index 000000000..a47c98009 --- /dev/null +++ b/alf/optimizers/dadapt_adam.py @@ -0,0 +1,263 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree: +# https://github.com/facebookresearch/dadaptation/blob/main/LICENSE + +import math +from typing import TYPE_CHECKING, Any, Callable, Optional + +import torch +import pdb +import logging +import os +import torch.distributed as dist +from torch.optim import Optimizer + +if TYPE_CHECKING: + from torch.optim.optimizer import _params_t +else: + _params_t = Any + + +class DAdaptAdam(Optimizer): + r""" + Implements Adam with D-Adaptation automatic step-sizes. + Leave LR set to 1 unless you encounter instability. + + Arguments: + params (iterable): + Iterable of parameters to optimize or dicts defining parameter groups. + lr (float): + Learning rate adjustment parameter. Increases or decreases the D-adapted learning rate. + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float): + Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-8). + weight_decay (float): + Weight decay, i.e. a L2 penalty (default: 0). + log_every (int): + Log using print every k steps, default 0 (no logging). + decouple (boolean): + Use AdamW style decoupled weight decay + use_bias_correction (boolean): + Turn on Adam's bias correction. Off by default. + d0 (float): + Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing. + growth_rate (float): + prevent the D estimate from growing faster than this multiplicative rate. + Default is inf, for unrestricted. Values like 1.02 give a kind of learning + rate warmup effect. + fsdp_in_use (bool): + If you're using sharded parameters, this should be set to True. The optimizer + will attempt to auto-detect this, but if you're using an implementation other + than PyTorch's builtin version, the auto-detection won't work. + """ + + def __init__(self, + params, + lr=1.0, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + log_every=0, + decouple=False, + use_bias_correction=False, + d0=1e-6, + growth_rate=float('inf'), + fsdp_in_use=False): + if not 0.0 < d0: + raise ValueError("Invalid d0 value: {}".format(d0)) + if not 0.0 < lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 < eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format( + betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format( + betas[1])) + + if decouple: + print(f"Using decoupled weight decay") + + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + d=d0, + k=0, + numerator_weighted=0.0, + log_every=log_every, + growth_rate=growth_rate, + use_bias_correction=use_bias_correction, + decouple=decouple, + fsdp_in_use=fsdp_in_use) + self.d0 = d0 + super().__init__(params, defaults) + + @property + def supports_memory_efficient_fp16(self): + return False + + @property + def supports_flat_params(self): + return True + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + sk_l1 = 0.0 + + group = self.param_groups[0] + use_bias_correction = group['use_bias_correction'] + numerator_weighted = group['numerator_weighted'] + beta1, beta2 = group['betas'] + k = group['k'] + + d = group['d'] + lr = max(group['lr'] for group in self.param_groups) + + if use_bias_correction: + bias_correction = ((1 - beta2**(k + 1))**0.5) / (1 - beta1** + (k + 1)) + else: + bias_correction = 1 + + dlr = d * lr * bias_correction + + growth_rate = group['growth_rate'] + decouple = group['decouple'] + log_every = group['log_every'] + fsdp_in_use = group['fsdp_in_use'] + + sqrt_beta2 = beta2**(0.5) + + numerator_acum = 0.0 + + for group in self.param_groups: + decay = group['weight_decay'] + k = group['k'] + eps = group['eps'] + group_lr = group['lr'] + + if group_lr not in [lr, 0.0]: + raise RuntimeError( + f"Setting different lr values in different parameter groups is only supported for values of 0" + ) + + for p in group['params']: + if p.grad is None: + continue + if hasattr(p, "_fsdp_flattened"): + fsdp_in_use = True + + grad = p.grad.data + + # Apply weight decay (coupled variant) + if decay != 0 and not decouple: + grad.add_(p.data, alpha=decay) + + state = self.state[p] + + # State initialization + if 'step' not in state: + state['step'] = 0 + state['s'] = torch.zeros_like(p.data).detach() + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data).detach() + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data).detach() + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + + s = state['s'] + + if group_lr > 0.0: + denom = exp_avg_sq.sqrt().add_(eps) + numerator_acum += dlr * torch.dot( + grad.flatten(), + s.div(denom).flatten()).item() + + # Adam EMA updates + exp_avg.mul_(beta1).add_(grad, alpha=dlr * (1 - beta1)) + exp_avg_sq.mul_(beta2).addcmul_( + grad, grad, value=1 - beta2) + + s.mul_(sqrt_beta2).add_(grad, alpha=dlr * (1 - sqrt_beta2)) + sk_l1 += s.abs().sum().item() + + ###### + + numerator_weighted = sqrt_beta2 * numerator_weighted + ( + 1 - sqrt_beta2) * numerator_acum + d_hat = d + + # if we have not done any progres, return + # if we have any gradients available, will have sk_l1 > 0 (unless \|g\|=0) + if sk_l1 == 0: + return loss + + if lr > 0.0: + if fsdp_in_use: + dist_tensor = torch.zeros(2).cuda() + dist_tensor[0] = numerator_weighted + dist_tensor[1] = sk_l1 + dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM) + global_numerator_weighted = dist_tensor[0] + global_sk_l1 = dist_tensor[1] + else: + global_numerator_weighted = numerator_weighted + global_sk_l1 = sk_l1 + + d_hat = global_numerator_weighted / ( + (1 - sqrt_beta2) * global_sk_l1) + d = max(d, min(d_hat, d * growth_rate)) + + if log_every > 0 and k % log_every == 0: + logging.info( + f"lr: {lr} dlr: {dlr} d_hat: {d_hat}, d: {d}. sk_l1={global_sk_l1:1.1e} numerator_weighted={global_numerator_weighted:1.1e}" + ) + + for group in self.param_groups: + group['numerator_weighted'] = numerator_weighted + group['d'] = d + + decay = group['weight_decay'] + k = group['k'] + eps = group['eps'] + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + + state = self.state[p] + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + + state['step'] += 1 + + denom = exp_avg_sq.sqrt().add_(eps) + + # Apply weight decay (decoupled variant) + if decay != 0 and decouple: + p.data.add_(p.data, alpha=-decay * dlr) + + ### Take step + p.data.addcdiv_(exp_avg, denom, value=-1) + + group['k'] = k + 1 + + return loss diff --git a/alf/optimizers/dadapt_optimizers_test.py b/alf/optimizers/dadapt_optimizers_test.py new file mode 100644 index 000000000..3d9749dd2 --- /dev/null +++ b/alf/optimizers/dadapt_optimizers_test.py @@ -0,0 +1,102 @@ +# Copyright (c) 2023 Horizon Robotics and ALF Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import parameterized +from absl import logging +import torch +import torch.nn.functional as F + +import alf + +from alf.optimizers import DAdaptSGD, DAdaptAdam +from alf.utils.datagen import load_mnist + + +class DadaptOptimizersTest(parameterized.TestCase, alf.test.TestCase): + def test_dadapt_sgd(self): + train_set, test_set = load_mnist(train_bs=256, test_bs=256) + num_classes = len(train_set.dataset.classes) + model = alf.layers.Sequential( + alf.layers.Conv2D(1, 32, 3, strides=2, padding=1), + alf.layers.Conv2D(32, 32, 3, strides=2, padding=1), + alf.layers.Conv2D(32, 32, 3, strides=2, padding=1), + alf.layers.Reshape(-1), + alf.layers.FC( + 4 * 4 * 32, + num_classes, + weight_opt_args=dict( + fixed_norm=False, + l2_regularization=1e-3, + zero_mean=True, + max_norm=float('inf')))) + opt = DAdaptSGD() + opt.add_param_group(dict(params=list(model.parameters()))) + + for epoch in range(5): + for data, target in train_set: + logits = model(data) + loss = F.cross_entropy(logits, target) + opt.zero_grad() + loss.backward() + opt.step() + correct = 0 + total = 0 + for data, target in test_set: + logits = model(data) + correct += (logits.argmax(dim=1) == target).sum() + total += target.numel() + logging.info("epoch=%s loss=%s acc=%s" % (epoch, loss.item(), + correct.item())) + self.assertGreater(correct / total, 0.97) + + @parameterized.parameters((True), (False)) + def test_dadapt_adam(self, decouple=False): + train_set, test_set = load_mnist(train_bs=256, test_bs=256) + num_classes = len(train_set.dataset.classes) + model = alf.layers.Sequential( + alf.layers.Conv2D(1, 32, 3, strides=2, padding=1), + alf.layers.Conv2D(32, 32, 3, strides=2, padding=1), + alf.layers.Conv2D(32, 32, 3, strides=2, padding=1), + alf.layers.Reshape(-1), + alf.layers.FC( + 4 * 4 * 32, + num_classes, + weight_opt_args=dict( + fixed_norm=False, + l2_regularization=1e-3, + zero_mean=True, + max_norm=float('inf')))) + opt = DAdaptAdam(decouple=decouple) + opt.add_param_group(dict(params=list(model.parameters()))) + + for epoch in range(5): + for data, target in train_set: + logits = model(data) + loss = F.cross_entropy(logits, target) + opt.zero_grad() + loss.backward() + opt.step() + correct = 0 + total = 0 + for data, target in test_set: + logits = model(data) + correct += (logits.argmax(dim=1) == target).sum() + total += target.numel() + logging.info("epoch=%s loss=%s acc=%s" % (epoch, loss.item(), + correct.item())) + self.assertGreater(correct / total, 0.97) + + +if __name__ == '__main__': + alf.test.main() diff --git a/alf/optimizers/dadapt_sgd.py b/alf/optimizers/dadapt_sgd.py new file mode 100644 index 000000000..7db7bfb77 --- /dev/null +++ b/alf/optimizers/dadapt_sgd.py @@ -0,0 +1,216 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree: +# https://github.com/facebookresearch/dadaptation/blob/main/LICENSE + +import torch +import pdb +import math +import logging +import torch.distributed as dist +from torch.optim import Optimizer + + +class DAdaptSGD(Optimizer): + r""" + Implements SGD with D-Adaptation automatic step-sizes. Leave LR set to 1 unless you encounter instability. + + Arguments: + params (iterable): + Iterable of parameters to optimize or dicts defining parameter groups. + lr (float): + Learning rate adjustment parameter. Increases or decreases the D-adapted learning rate. + momentum (float): + Momentum value in the range [0,1) (default: 0). + weight_decay (float): + Weight decay, i.e. a L2 penalty (default: 0). + log_every (int): + Log using print every k steps, default 0 (no logging). + d0 (float): + Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing. + growth_rate (float): + prevent the D estimate from growing faster than this multiplicative rate. + Default is inf, for unrestricted. More conservative values like 1.02 may + help if training is unstable. + fsdp_in_use (bool): + If you're using sharded parameters, this should be set to True. + The optimizer will attempt to auto-detect this, but if you're + using an implementation other than PyTorch's builtin version, + the auto-detection won't work. + """ + + def __init__(self, + params, + lr=1.0, + momentum=0.0, + weight_decay=0, + log_every=0, + d0=1e-6, + growth_rate=float('inf'), + fsdp_in_use=False): + + if not 0.0 < d0: + raise ValueError("Invalid d0 value: {}".format(d0)) + if not 0.0 < lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + + defaults = dict( + lr=lr, + momentum=momentum, + weight_decay=weight_decay, + k=0, + log_every=log_every, + numerator_weighted=0.0, + d=d0, + growth_rate=growth_rate, + fsdp_in_use=fsdp_in_use) + self.loggables = {} + + try: + self.rank = torch.distributed.get_rank() + except: + self.rank = 0 + + super().__init__(params, defaults) + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + group = self.param_groups[0] + lr = max(group['lr'] for group in self.param_groups) + + decay = group['weight_decay'] + momentum = group['momentum'] + log_every = group['log_every'] + ck = 1 - momentum + k = group['k'] + + numerator_weighted = group['numerator_weighted'] + growth_rate = group['growth_rate'] + d = group['d'] + fsdp_in_use = group['fsdp_in_use'] + + group = self.param_groups[0] + + sk_sq = 0.0 + + if k == 0: + g_sq = 0.0 + for group in self.param_groups: + group_lr = group['lr'] + for p in group['params']: + if p.grad is None: + continue + if hasattr(p, "_fsdp_flattened"): + fsdp_in_use = True + grad = p.grad.data + + # Apply weight decay + if decay != 0: + grad.add(p.data, alpha=decay) + + state = self.state[p] + + if group_lr > 0.0: + g_sq += (grad * grad).sum().item() + + if fsdp_in_use: + dist_tensor = torch.zeros(1).cuda() + dist_tensor[0] = g_sq + dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM) + global_gsq = dist_tensor[0] + else: + global_gsq = g_sq + group['g0_norm'] = g0_norm = math.sqrt(global_gsq) + + g0_norm = group['g0_norm'] + + dlr = d * lr / g0_norm + + for group in self.param_groups: + group_lr = group['lr'] + if group_lr not in [lr, 0.0]: + raise RuntimeError( + f"Setting different lr values in different parameter groups is only supported for values of 0" + ) + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + state = self.state[p] + + if 'z' not in state: + z = state['z'] = torch.clone(p.data).detach() + s = state['s'] = torch.zeros_like(p.data).detach() + x0 = state['x0'] = torch.clone(p.data).detach() + + # Apply weight decay + if decay != 0: + grad.add_(p.data, alpha=decay) + + s = state['s'] + + if group_lr > 0.0: + numerator_weighted += dlr * torch.dot( + grad.flatten(), s.flatten()).item() + + s.data.add_(grad, alpha=dlr) + sk_sq += (s * s).sum().item() + ###### + + d_hat = d + + if lr > 0.0: + if fsdp_in_use: + dist_tensor = torch.zeros(2).cuda() + dist_tensor[0] = sk_sq + dist_tensor[1] = numerator_weighted + dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM) + global_sk_sq = dist_tensor[0] + global_numerator_weighted = dist_tensor[1] + else: + global_sk_sq = sk_sq + global_numerator_weighted = numerator_weighted + + d_hat = 2 * global_numerator_weighted / math.sqrt(global_sk_sq) + d = max(d, min(d_hat, d * growth_rate)) + + # if we have not done any updates + # if we have any gradients available, will have sk_sq > 0 (unless \|g\|=0) + if global_sk_sq == 0: + return loss + + if log_every > 0 and k % log_every == 0: + logging.info( + f"(r={self.rank},k={k}) dlr: {dlr} d_hat: {d_hat}, d: {d}. sk_norm={math.sqrt(global_sk_sq)} numerator_weighted={global_numerator_weighted} g0_norm={g0_norm}" + ) + + for group in self.param_groups: + group['numerator_weighted'] = numerator_weighted + group['d'] = d + group['g0_norm'] = g0_norm + ###################################### + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + state = self.state[p] + + s = state['s'] + x0 = state['x0'] + z = state['z'] + + # z step + z.data.copy_(x0 - s) + + # x step + p.data.mul_(1 - ck).add_(z, alpha=ck) + + group['k'] = k + 1 + + return loss diff --git a/alf/optimizers/optimizers.py b/alf/optimizers/optimizers.py index 71ad4a240..35f2af459 100644 --- a/alf/optimizers/optimizers.py +++ b/alf/optimizers/optimizers.py @@ -20,7 +20,7 @@ import alf from alf.utils import common from alf.utils import tensor_utils -from . import adam_tf, adamw, nero_plus +from . import adam_tf, adamw, nero_plus, dadapt_sgd, dadapt_adam from .utils import get_opt_arg @@ -327,3 +327,8 @@ def add_param_group(self, param_group): AdamTF = alf.configurable('AdamTF')(wrap_optimizer(adam_tf.AdamTF)) NeroPlus = alf.configurable('NeroPlus')(wrap_optimizer(nero_plus.NeroPlus)) + +DAdaptSGD = alf.configurable('DAdaptSGD')(wrap_optimizer(dadapt_sgd.DAdaptSGD)) + +DAdaptAdam = alf.configurable('DAdaptAdam')(wrap_optimizer( + dadapt_adam.DAdaptAdam))