Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixed Precision and Resource Envs #65

Open
wants to merge 30 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e2f71bc
wip compute_context.py
dlwh Jan 23, 2024
1b7eb81
wip
dlwh Jan 25, 2024
8f3e2c8
ok, current functionality still works
dlwh Jan 26, 2024
b7798d9
towards DTypeish
dlwh Jan 27, 2024
da76dd5
use semantic dtypes in linear
dlwh Jan 27, 2024
d39c6e0
use semantic dtypes in nn layers where it makes sense
dlwh Jan 27, 2024
e5e20c7
rename compute_context->resource_env, mv mesh to end
dlwh Jan 27, 2024
260627d
typo
dlwh Jan 27, 2024
2ea2878
fix imports
dlwh Jan 27, 2024
a652e9c
fix ordering of mesh again
dlwh Jan 27, 2024
195b587
add a short mixed precision doc for new mixed precision stuff
dlwh Jan 28, 2024
2972e31
fix some docs
dlwh Jan 28, 2024
acdece2
wat
dlwh Jan 29, 2024
8e7d4f3
cleanup the new ResourceEnv stuff, fix a few bugs related to threading
dlwh Jan 29, 2024
3bcd67b
wip
dlwh Jan 31, 2024
3a203ba
rename to mixed-precision.md
dlwh Jan 31, 2024
3c886ea
modernizing the partitioning module
dlwh Jan 31, 2024
cd6592a
switch to not having a default jmp.policy
dlwh Jan 31, 2024
637428d
wip
dlwh Jan 31, 2024
21aa128
fix named_jit without a mesh
dlwh Jan 31, 2024
e4dcf3a
fixes for levanter migration
dlwh Jan 31, 2024
a0dc5a8
use our own dtypelike when we need to
dlwh Jan 31, 2024
3a69dcf
str for ResourceEnv
dlwh Jan 31, 2024
9c86449
add strenum
dlwh Jan 31, 2024
db00171
missed a function
dlwh Feb 1, 2024
3b6deb0
wip docs for ResourceEnv
dlwh Feb 1, 2024
e89b983
fix tests with multiple devices
dlwh Feb 1, 2024
cd83402
use a mesh when we have it in shard
dlwh Feb 1, 2024
5cf4d71
factor out some loss functions
dlwh Feb 1, 2024
bcc2ce0
Merge remote-tracking branch 'origin/dev' into jamp
dlwh Feb 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 208 additions & 0 deletions docs/mixed-precision.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# Mixed Precision in Haliax

## JMP

Haliax's mixed precision is currently built on [JMP], which is a simple library for mixed precision training.
JMP's principal class is [jmp.Policy][], which holds three dtypes:

* `param_dtype`: The dtype of the model parameters.
* `compute_dtype`: The dtype that computations are performed in.
* `output_dtype`: The dtype of the model outputs, typically used for loss functions or other more "numerically-unstable" operations.

Policies are typically represented as a string like `"p=f32,c=bf16,o=f32"`, which means that the parameters are stored in `f32`, computations are performed in `bf16`, and outputs are in `f32`.""

Once you have a policy, you can convert an arbitrary PyTree between dtypes using [jmp.Policy.cast_to_param] etc.

```python
import jmp
import haliax as hax
import jax.numpy as jnp

policy = jmp.get_policy("p=f32,c=bf16,o=f32")

D = hax.Axis("D", 16)
x = hax.arange(D, dtype=float)

assert policy.cast_to_compute(x).dtype == jnp.bfloat16
assert policy.cast_to_output(x).dtype == jnp.float32
assert policy.cast_to_param(x).dtype == jnp.float32
```

### Scaling

JMP also has support for scaling the loss when using FP16, but we don't typically use that ourselves, preferring
to stick with `bfloat16`.

## `SemanticDType` and `DTypeish`

Haliax extends this core idea from jmp.Policy to add an explicit [haliax.SemanticDType][] enum,
with three entries: `"compute"`, `"param"`, and `"output"`. Instances of this enum can be
resolved to a specific `dtype` using the `to_dtype` method and a [jmp.Policy][].
In addition, you can convert all floating point arrays in a PyTree using [haliax.mixed_precision.cast_floating][]:

```python
import haliax.mixed_precision as hmp

assert hmp.cast_floating(x, "compute", policy).dtype == jnp.bfloat16
assert hmp.cast_floating(x, "param", policy).dtype == jnp.float32
```

`cast_floating` actually accepts a [haliax.DTypeish][] for the second parameter,
which is a union of "real" dtypes (like `jnp.bfloat16`), SemanticDtypes, and strings like `"compute"`.
This is useful for writing generic code that can accept either a dtype or a SemanticDType.

The `policy` argument is optional, and if not provided, Haliax will use the global [haliax.ResourceEnv][] to
determine the current policy. See the next section for more details.


## haliax.ResourceEnv

See also [ResourceEnvs](resource-env.md).

Haliax uses a (technically, thread-local) context manager called [haliax.ResourceEnv][] to manage
both [partitioning](partitioning.md) and mixed precision. For mixed precision, the `ResourceEnv` holds a
`jmp.Policy` that can be accessed via the `policy` attribute or via the [haliax.current_mp_policy][] function:

```python
import haliax as hax
import jax.numpy as jnp

with hax.resource_env(mp="p=f32,c=bf16,o=f32"):
assert hmp.cast_floating(x, "compute").dtype == jnp.bfloat16
assert hmp.cast_floating(x, "param").dtype == jnp.float32

# The default env is fp32
assert hmp.cast_floating(x, "compute").dtype == jnp.float32
assert hmp.cast_floating(x, "param").dtype == jnp.float32
```

There is no "default" policy. If you don't set one, casting to a `SemanticDType` will be a no-op,
meaning that mixed precision is entirely opt-in.

## NN Modules

Many Haliax modules, including [haliax.nn.Linear][], [haliax.nn.LayerNorm][], and [haliax.nn.Conv][] accept
an optional `compute_dtype` argument in their `init` and their `__init__` methods.
This argument defaults to `"compute"`, but can be set to `"param"` or
`"output"` or a specific dtype to override the global policy. (And again, if there is no global policy, a
semantic dtype will be a no-op.)


```python
import haliax as hax
import jax.numpy as jnp
import jax.random as jrandom

In = hax.Axis("In", 16)
Out = hax.Axis("Out", 32)

linear = hax.nn.Linear.init(In, Out, key=jrandom.PRNGKey(0))
assert linear.weight.dtype == jnp.float32
assert linear.bias.dtype == jnp.float32
input = hax.arange(In, dtype=jnp.bfloat16)

out = linear(input)
assert out.dtype == jnp.float32

with hax.resource_env(mp="p=f32,c=bf16,o=f32"):
out = linear(input)
assert out.dtype == jnp.bfloat16



```

## Loss Functions

XXX TODO










### API Reference

::: haliax.DTypeish
::: haliax.SemanticDType
::: haliax.current_mp_policy

::: haliax.mixed_precision.cast_floating


::: haliax.ResourceEnv
::: haliax.resource_env
::: haliax.current_resource_env



## Future: Quantization and FP8

WIP

This is not at all implemented yet, but the plan is to add support for quantization and FP8 in the future.

This section is not going to talk about how specific quantization schemes work, but rather how
structurally they are implemented in the JAX ecosystem.

For purposes of this discussion, I'm going to treat quantization and FP8 as the same thing, since they
end up requiring basically the same infrastructure.

### Quantized Training Overview

Most of this section is put together by my digging through [this blog post on AQT from Google](https://cloud.google.com/blog/products/compute/accurate-quantized-training-aqt-for-tpu-v5e/),
[the docs for NVIDIA's Transformer Engine](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html),
as well as the libraries themselves. There's no paper for AQT, but I'm sure there will be one soon.
I don't pretend to understand all of this, but I think I have a decent grasp on the basics.

Let's talk just a bit about how quantization works. The basic idea is that you have
an array of floating point values, and you want to convert them to some low-precision
representation, such as INT8 or FP8.

Typically, you don't just project to the nearest representable value, at least not when doing training. Instead,
you want to scale the entire array so that you get as much high-resolution coverage as possible.

So


### Quantization in JAX

There are two relevant libraries I'm basing my understanding of quantization on:

* [TransformerEngine](https://github.com/NVIDIA/TransformerEngine), which is NVIDIA's library for
accelerated training of Transformers, including FP8 support.
* [AQT](https://github.com/google/aqt/), which is Google's library for quantization-aware training that
focuses on integer-based quantization (like int8 and int4).


The way quantization is shaping up to work in JAX is a combination of two mechanisms: "dot injection" and
what I'm going to call "grad hijacking."

#### Dot Injection

The first piece is dot injection, which is a mechanism for injecting alternative versions
of the [jax.lax.dot_general][] primitive into higher level calls like [jax.numpy.einsum][].
(`einsum` is actually the only function in JAX that takes this argument, but it's seen in
[libraries like FLAX](https://github.com/google/flax/blob/61ece402d1b805e5ce797caf74b69ed8a7ae21ce/flax/linen/linear.py#L116-L117).)

This part is fairly intuitive: you are likely going to want custom logic for how to do
matrix multiplication in a quantized setting, and dot injection lets you do that.

#### Grad Hijacking

By itself, XXX storing scale in model, update in forward and backward




### Quantization in Haliax

Again, this is not implemented.

My current thinking is to support dot injection similar to what you see in FLAX for FP8 and INT8,
but to extend the ResourceEnv to also support an automatic dot injection policy. Not entirely
sure how this looks yet.
83 changes: 83 additions & 0 deletions docs/resource-env.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# ResourceEnvs

Haliax currently uses three bits of (optional!) global state:

1. the [jax.sharding.Mesh][] of devices;
2. a [haliax.partitioning.ResourceMapping][] that maps [haliax.Axis][]s to [jax.sharding.Axis][]s (See [Partitioning](partitioning.md));
3. and a `jmp.Policy` that controls [mixed precision](mixed-precision.md) behavior.

(1) and (2) are discussed in [Partitioning](partitioning.md) and (3) is discussed in [Mixed Precision](mixed-precision.md).

Haliax stores these three pieces of state in a [haliax.ResourceEnv][] object, which can be used either as a context
manager (like `Mesh`) or can be passed explicitly to functions that need it.

## Using `ResourceEnv`s

### As a context manager

You can use a `ResourceEnv` as a context manager to temporarily set the global state.

```python
import haliax as hax
import jax
import jmp

from jax.sharding import Mesh

mesh = Mesh(jax.devices(), ("dp"))
resource_mapping = {"embed": "dp"}
mp = jmp.get_policy("p=f32,c=bf16")

with hax.resource_env(resource_mapping, mp, mesh):
# code that uses the resource env
```

### Explicitly passing a `ResourceEnv`

You can also pass a `ResourceEnv` explicitly to many functions that use one.
This is useful if you want to avoid using global state.

```python

import haliax as hax
import jax

from jax.sharding import Mesh

mesh = Mesh(jax.devices(), ("dp"))
resource_mapping = {"embed": "dp"}
mp = jmp.get_policy("p=f32,c=bf16")

env = hax.ResourceEnv(resource_mapping, mp, mesh)

# code that uses the resource env

H = hax.Axis("H", 128)
Embed = hax.Axis("embed", 128)


x = hax.shard(hax.zeros((H, Embed)), env)
```

#### Functions that can take a `ResourceEnv`

This is not an exhaustive list, but here are some functions that can take a `ResourceEnv`
as an explicit argument. Most of these will use the context `ResourceEnv` if one is not provided.
These are all sharding functions.

- [haliax.shard][]
- [haliax.named_jit][]
- [haliax.partitioning.physical_axis_name][]
- [haliax.partitioning.physical_axis_size][]
- [haliax.partitioning.sharding_for_axis][]




## Reference

::: haliax.ResourceEnv

::: haliax.resource_env

::: haliax.current_resource_env
5 changes: 4 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ nav:
- "Introduction to Haliax": https://colab.research.google.com/drive/1TiTcQQ4V5mopbgCu1SVl-oqJtXn7rFnC
- "Distributed Training and FSDP": https://colab.research.google.com/drive/1QX4yH3zRFF3Xiibf1aahETcSQ5nbcUMz
- "Tensor Parallelism": https://colab.research.google.com/drive/18_BrtDpe1lu89M4T6fKzda8DdSLtFJhi
- "Mixed Precision with `jmp`": https://colab.research.google.com/drive/1_4cikwt-UhSH7yRzNRK8ze9msM9r2mEl?usp=sharing
# - "Mixed Precision with `jmp`": https://colab.research.google.com/drive/1_4cikwt-UhSH7yRzNRK8ze9msM9r2mEl?usp=sharing
- Cheatsheet: 'cheatsheet.md'
- Named Arrays:
- Broadcasting: 'broadcasting.md'
Expand All @@ -86,6 +86,9 @@ nav:
- Matrix Multiplication: 'matmul.md'
- Neural Networks: 'nn.md'
- Partitioning: 'partitioning.md'
- Advanced Topics:
- Mixed Precision: 'mixed-precision.md'
- "`ResourceEnv`" : 'resource-env.md'
- Higher Order Functions: 'hof.md'
- API Reference: 'api.md'
- FAQ: 'faq.md'
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ dependencies = [
"equinox>=0.10.6",
"jaxtyping>=0.2.20",
"safetensors[numpy]",
"jmp"
"jmp",
"strenum"
]

[project.optional-dependencies]
Expand Down
11 changes: 11 additions & 0 deletions src/haliax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from jax._src.typing import DTypeLike

import haliax.debug as debug
import haliax.mixed_precision as mixed_precision
import haliax.nn as nn
import haliax.random as random
import haliax.tree_util as tree_util
Expand All @@ -19,6 +20,7 @@
from ._src.dot import dot
from ._src.einsum import einsum
from ._src.rearrange import rearrange
from ._src.resource_env import ResourceEnv, current_resource_env, resource_env
from .axis import (
Axis,
AxisSelection,
Expand Down Expand Up @@ -55,6 +57,7 @@
updated_slice,
)
from .hof import fold, map, scan, vmap
from .mixed_precision import DTypeish, SemanticDType, current_mp_policy
from .ops import clip, isclose, pad_left, trace, tril, triu, where
from .partitioning import auto_sharded, axis_mapping, fsdp, named_jit, shard, shard_with_axis_mapping
from .specialized_fns import top_k
Expand Down Expand Up @@ -845,6 +848,7 @@ def true_divide(x1: NamedOrNumeric, x2: NamedOrNumeric, /) -> NamedOrNumeric:
"random",
"tree_util",
"nn",
"mixed_precision",
"Axis",
"AxisSpec",
"AxisSelection",
Expand Down Expand Up @@ -957,6 +961,13 @@ def true_divide(x1: NamedOrNumeric, x2: NamedOrNumeric, /) -> NamedOrNumeric:
"fold",
"map",
"vmap",
"current_resource_env",
"ResourceEnv",
"resource_env",
"current_mp_policy",
"SemanticDType",
"DTypeish",
"DTypeLike",
"trace",
"where",
"clip",
Expand Down
Loading
Loading