Skip to content

Commit

Permalink
Optim config drop stable and add decay (#818)
Browse files Browse the repository at this point in the history
  • Loading branch information
blahBlahhhJ authored Nov 22, 2024
1 parent f44e5c8 commit dcd4872
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 25 deletions.
2 changes: 1 addition & 1 deletion docs/Configuration-Guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ which are common to all optimizers (and most have to do with learning rate sched
| `lr_schedule` | The type of learning rate schedule for decay. See below. | `cosine` |
| `min_lr_ratio` | The minimum learning rate ratio. | `0.1` |
| `warmup` | Warmup fraction or number of steps | `0.01` |
| `stable` | Stable fraction or number of steps | `0.0` |
| `decay` | Decay fraction or number of steps | `None` |
| `cycles` | The number of cycles for the learning rate, or steps where cycles end | `None` |
| `rewarmup` | The learning rate re-warmup, if using cycles. | `0.0` |

Expand Down
12 changes: 8 additions & 4 deletions src/levanter/optim/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
"""The lr scheduler operates on 4 stages: [warmup] - {[stable] - [decay]} x haps - [cooldown]"""
warmup: float = 0.01
"""fraction of training steps to use as warmup, or steps to use. 0.0 means no warmup"""
stable: float = 0.00
"""fraction of training steps to use as cooldown, or steps to use. 0.0 means no cooldown"""
decay: Optional[float] = None
"""fraction of training steps to use as decay, or steps to use. None means full decay"""
rewarmup: float = 0.0
"If using a cycle, how much of the cycle to use as re-warmup. 0.0 means no re-warmup."
cooldown: Optional[float] = None
Expand Down Expand Up @@ -174,8 +174,12 @@ def lr_scheduler(self, num_train_steps):
schedules.append(warmup)
boundaries.append(start + warmup_steps)

stable_steps = _convert_ratio_or_steps(self.stable, cycle_steps)
lr_decay_steps = cycle_steps - stable_steps - warmup_steps
lr_decay_steps = (
_convert_ratio_or_steps(self.decay, cycle_steps)
if self.decay is not None
else cycle_steps - warmup_steps
)
stable_steps = cycle_steps - warmup_steps - lr_decay_steps

if stable_steps != 0:
stable = optax.constant_schedule(self.learning_rate)
Expand Down
35 changes: 15 additions & 20 deletions tests/test_optimizer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ def test_no_stable_weirdness():
learning_rate=2e-6, # 2x10^-6
weight_decay=0.0,
warmup=0.03,
stable=0.0,
min_lr_ratio=0.0,
lr_schedule="linear",
max_grad_norm=None,
haps=None,
cycles=None,
weight_decay_modules=None,
default_weight_decay_mask=None,
)
Expand All @@ -33,10 +32,8 @@ def test_constant_schedule():
learning_rate=1e-3,
weight_decay=0.0,
warmup=0.0,
stable=0.0,
min_lr_ratio=1.0, # No decay
lr_schedule="constant",
haps=None,
cycles=None,
)

Expand All @@ -52,10 +49,8 @@ def test_warmup_and_cosine_decay():
learning_rate=1e-2,
weight_decay=0.0,
warmup=0.1, # 10% of steps
stable=0.0,
min_lr_ratio=0.1,
lr_schedule="cosine",
haps=None,
cycles=None,
)

Expand All @@ -75,7 +70,6 @@ def test_linear_schedule_with_cycles():
learning_rate=5e-4,
weight_decay=0.0,
warmup=50,
stable=0.0,
min_lr_ratio=0.2,
lr_schedule="linear",
cycles=2,
Expand Down Expand Up @@ -105,41 +99,43 @@ def test_linear_schedule_with_cycles():
assert np.isclose(sched_fn(999), 0.2 * 5e-4, atol=1e-5)


def test_haps_schedule():
def test_wsds_schedule():
optimizer = AdamConfig(
learning_rate=1e-3,
weight_decay=0.0,
warmup=0.0,
stable=0.0,
decay=0.1,
min_lr_ratio=0.1,
lr_schedule="cosine",
haps=[300, 700],
cycles=[300, 700],
)

sched_fn = optimizer.lr_scheduler(1000)

# Before first haps
# First cycle
assert np.isclose(sched_fn(0), 1e-3)
assert np.isclose(sched_fn(269), 1e-3)
assert sched_fn(271) < 1e-3

# First haps
# Second cycle
assert np.isclose(sched_fn(300), 1e-3)
assert np.isclose(sched_fn(659), 1e-3)
assert sched_fn(661) < 1e-3

# After first haps
assert sched_fn(301) < 1e-3

# Before second haps
assert sched_fn(699) < sched_fn(301)
# Thrid cycle
assert np.isclose(sched_fn(701), 1e-3)
assert np.isclose(sched_fn(969), 1e-3)
assert sched_fn(971) < 1e-3


def test_inv_sqrt_decay_schedule():
optimizer = AdamConfig(
learning_rate=1e-3,
weight_decay=0.0,
warmup=0.1,
stable=0.0,
min_lr_ratio=0.1,
lr_schedule="inv_sqrt",
haps=None,
cycles=None,
)

sched_fn = optimizer.lr_scheduler(100_000)
Expand All @@ -157,7 +153,6 @@ def test_rewarmup_schedule():
learning_rate=1e-2,
weight_decay=0.0,
warmup=0.2, # 20% of cycle
stable=0.0,
min_lr_ratio=0.2,
lr_schedule="linear",
cycles=2,
Expand Down

0 comments on commit dcd4872

Please sign in to comment.