Skip to content

Commit

Permalink
support callbacks inside JIT boundary, include grad histograms (#381)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh authored Dec 4, 2024
1 parent d125206 commit f5a6878
Show file tree
Hide file tree
Showing 24 changed files with 636 additions and 237 deletions.
16 changes: 8 additions & 8 deletions docs/dev/Trackers.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,23 @@ Given Levanter's historical dependency on W&B, the interface is designed to look
The methods currently exposed are:

* [levanter.tracker.current_tracker][]: returns the current tracker instance or sets it.
* [levanter.tracker.log_metrics][]: logs a dictionary of metrics for a given step.
* [levanter.tracker.log][]: logs a dictionary of metrics for a given step.
* [levanter.tracker.log_summary][]: logs a dictionary of "summary" information, analogous to W&B's version.
* [levanter.tracker.get_tracker][]: returns a tracker with the given name.
* [levanter.tracker.jit_log_metrics][]: a version of [levanter.tracker.log_metrics][] that works inside JAX jit.
* [levanter.tracker.jit_log][]: a version of [levanter.tracker.log][] that accumulates metrics inside of a `jit`-ted function.

A basic example of using the tracker interface is shown below:

```python
import wandb
from levanter.tracker import current_tracker, log_metrics, log_summary
import levanter.tracker as tracker
from levanter.tracker.wandb import WandbTracker

with current_tracker(WandbTracker(wandb.init())):
with tracker.current_tracker(WandbTracker(wandb.init())):
for step in range(100):
log_metrics({"loss": 100 -0.01 * step}, step=step)
tracker.log({"loss": 100 - 0.01 * step}, step=step)

log_summary({"best_loss": 0.0})
tracker.log_summary({"best_loss": 0.0})
```

A more typical example would be to use it in a config file, as we do with Trainer:
Expand Down Expand Up @@ -73,13 +73,13 @@ TODO: expand this section.
::: levanter.tracker.current_tracker
::: levanter.tracker.log_metrics
::: levanter.tracker.log
::: levanter.tracker.log_summary
::: levanter.tracker.get_tracker
::: levanter.tracker.jit_log_metrics
::: levanter.tracker.jit_log
### Trackers
Expand Down
1 change: 1 addition & 0 deletions infra/run.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
umask 000
LEV_ROOT=$(dirname "$(readlink -f $0)")/..
ulimit -s 65536

# figure out venv, first check if we wrote a path in infra/venv_path
if [ ! -d "$VENV" ] && [ -f "$LEV_ROOT/infra/venv_path.txt" ]; then
Expand Down
Loading

0 comments on commit f5a6878

Please sign in to comment.