Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Do not merge ]PTQ - AQT Integration #72

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ google-cloud-storage
absl-py
transformers>=4.25.1
datasets
flax
flax>=0.8.1
aqtp
optax
torch
torchvision
Expand Down
7 changes: 6 additions & 1 deletion src/maxdiffusion/configs/base_xl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,9 @@ lightning_repo: ""
# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning.
lightning_ckpt: ""

enable_mllog: False
enable_mllog: False

quantization: 'int8'
# Shard the range finding operation for quantization. By default this is set to number of slices.
quantization_local_shard_count: -1
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
18 changes: 12 additions & 6 deletions src/maxdiffusion/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def setup_initial_state(model, tx, config, mesh, model_params, unboxed_abstract_
init_train_state_partial = functools.partial(init_train_state, model=model, tx=tx, training=training)

sharding = PositionalSharding(mesh.devices).replicate()
# TODO - Inspect structure of sharding?
partial_device_put_replicated = functools.partial(device_put_replicated, sharding=sharding)
model_params = jax.tree_util.tree_map(partial_device_put_replicated, model_params)

Expand All @@ -344,17 +345,22 @@ def setup_initial_state(model, tx, config, mesh, model_params, unboxed_abstract_

state_mesh_shardings = jax.tree_util.tree_map(
lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations)

return state, state_mesh_shardings

def get_states(mesh, tx, rng, config, pipeline, unet_params, vae_params, training=True):
def get_states(mesh, tx, rng, config, pipeline, unet_params, vae_params, training=True, q_v=None):

# Needed to initialize weights on multi-host with addressable devices.
quant_enabled = config.quantization is not None
if config.train_new_unet:
unet_variables = jax.jit(pipeline.unet.init_weights, static_argnames=["eval_only"])(rng, eval_only=False)
else:
unet_variables = pipeline.unet.init_weights(rng, eval_only=True)

unboxed_abstract_state, state_mesh_annotations = get_abstract_state(pipeline.unet, tx, config, mesh, unet_variables, training=training)
#unet_variables = jax.jit(pipeline.unet.init_weights, static_argnames=["quantization_enabled"])(rng, quantization_enabled=quant_enabled)
unet_variables = pipeline.unet.init_weights(rng, eval_only=True, quantization_enabled=quant_enabled)
if q_v:
unboxed_abstract_state, state_mesh_annotations = get_abstract_state(pipeline.unet, tx, config, mesh, q_v, training=training)
else:
unboxed_abstract_state, state_mesh_annotations = get_abstract_state(pipeline.unet, tx, config, mesh, unet_variables, training=training)
if config.train_new_unet:
unet_params = unet_variables
else:
Expand All @@ -364,7 +370,7 @@ def get_states(mesh, tx, rng, config, pipeline, unet_params, vae_params, trainin
tx,
config,
mesh,
unet_params,
q_v,
unboxed_abstract_state,
state_mesh_annotations,
training=training)
Expand All @@ -381,7 +387,7 @@ def get_states(mesh, tx, rng, config, pipeline, unet_params, vae_params, trainin
state_mesh_annotations,
training=training
)

# breakpoint()
return unet_state, unet_state_mesh_shardings, vae_state, vae_state_mesh_shardings

# Learning Rate Schedule
Expand Down
34 changes: 32 additions & 2 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel

from ..import common_types, max_logging
from . import quantizations


Array = common_types.Array
Mesh = common_types.Mesh
Expand All @@ -36,6 +38,11 @@
HEAD = common_types.HEAD
D_KV = common_types.D_KV

Quant = quantizations.AqtQuantization

def _maybe_aqt_einsum(quant: Quant):
return jnp.einsum if quant is None else quant.einsum()

class AttentionOp(nn.Module):
mesh: Mesh
attention_kernel: str
Expand All @@ -48,6 +55,7 @@ class AttentionOp(nn.Module):
flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV)
flash_min_seq_length: int = 4096
flash_block_sizes: BlockSizes = None
quant: Quant = None
dtype: DType = jnp.float32

def check_attention_inputs(
Expand Down Expand Up @@ -385,6 +393,9 @@ class FlaxAttention(nn.Module):
jax mesh is required if attention is set to flash.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
quant (`AqtQuantization`, *optional*, defaults to None)
Configures AQT quantization github.com/google/aqt.


"""
query_dim: int
Expand All @@ -402,6 +413,7 @@ class FlaxAttention(nn.Module):
key_axis_names: AxisNames = (BATCH, LENGTH, HEAD)
value_axis_names: AxisNames = (BATCH, LENGTH, HEAD)
out_axis_names: AxisNames = (BATCH, LENGTH, HEAD)
quant: Quant = None

def setup(self):

Expand All @@ -421,17 +433,23 @@ def setup(self):
use_memory_efficient_attention=self.use_memory_efficient_attention,
split_head_dim=self.split_head_dim,
flash_block_sizes=self.flash_block_sizes,
dtype=self.dtype
dtype=self.dtype,
quant=self.quant
)

qkv_init_kernel = nn.with_logical_partitioning(
nn.initializers.lecun_normal(),
("embed","heads")
)

dot_general_cls = None
if self.quant:
dot_general_cls = self.quant.dot_general_cls()

self.query = nn.Dense(
inner_dim,
kernel_init=qkv_init_kernel,
dot_general_cls=dot_general_cls,
use_bias=False,
dtype=self.dtype,
name="to_q"
Expand All @@ -440,6 +458,7 @@ def setup(self):
self.key = nn.Dense(
inner_dim,
kernel_init=qkv_init_kernel,
dot_general_cls=dot_general_cls,
use_bias=False,
dtype=self.dtype,
name="to_k"
Expand All @@ -448,6 +467,7 @@ def setup(self):
self.value = nn.Dense(
inner_dim,
kernel_init=qkv_init_kernel,
dot_general_cls=dot_general_cls,
use_bias=False,
dtype=self.dtype,
name="to_v")
Expand All @@ -458,6 +478,7 @@ def setup(self):
nn.initializers.lecun_normal(),
("heads","embed")
),
dot_general_cls=dot_general_cls,
dtype=self.dtype,
name="to_out_0")
self.dropout_layer = nn.Dropout(rate=self.dropout)
Expand Down Expand Up @@ -520,6 +541,8 @@ class FlaxBasicTransformerBlock(nn.Module):
Overrides default block sizes for flash attention.
mesh (`jax.sharding.mesh`, *optional*, defaults to `None`):
jax mesh is required if attention is set to flash.
quant (`AqtQuantization`, *optional*, defaults to None)
Configures AQT quantization github.com/google/aqt.
"""
dim: int
n_heads: int
Expand All @@ -533,6 +556,7 @@ class FlaxBasicTransformerBlock(nn.Module):
flash_min_seq_length: int = 4096
flash_block_sizes: BlockSizes = None
mesh: jax.sharding.Mesh = None
quant: Quant = None

def setup(self):
# self attention (or cross_attention if only_cross_attention is True)
Expand All @@ -548,6 +572,7 @@ def setup(self):
flash_block_sizes=self.flash_block_sizes,
mesh=self.mesh,
dtype=self.dtype,
quant=self.quant,
)
# cross attention
self.attn2 = FlaxAttention(
Expand All @@ -562,6 +587,7 @@ def setup(self):
flash_block_sizes=self.flash_block_sizes,
mesh=self.mesh,
dtype=self.dtype,
quant=self.quant,
)
self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
Expand Down Expand Up @@ -625,6 +651,8 @@ class FlaxTransformer2DModel(nn.Module):
Overrides default block sizes for flash attention.
mesh (`jax.sharding.mesh`, *optional*, defaults to `None`):
jax mesh is required if attention is set to flash.
quant (`AqtQuantization`, *optional*, defaults to None)
Configures AQT quantization github.com/google/aqt.
"""
in_channels: int
n_heads: int
Expand All @@ -641,6 +669,7 @@ class FlaxTransformer2DModel(nn.Module):
flash_block_sizes: BlockSizes = None
mesh: jax.sharding.Mesh = None
norm_num_groups: int = 32
quant: Quant = None

def setup(self):
self.norm = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-5)
Expand Down Expand Up @@ -679,7 +708,8 @@ def setup(self):
attention_kernel=self.attention_kernel,
flash_min_seq_length=self.flash_min_seq_length,
flash_block_sizes=self.flash_block_sizes,
mesh=self.mesh
mesh=self.mesh,
quant=self.quant
)
for _ in range(self.depth)
]
Expand Down
137 changes: 137 additions & 0 deletions src/maxdiffusion/models/quantizations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""
Copyright 2024 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import functools

from aqt.jax.v2 import config as aqt_config

from aqt.jax.v2.flax import aqt_flax
from ..common_types import Config
from dataclasses import dataclass
import jax.numpy as jnp


@dataclass
class AqtQuantization:
""" Configures AQT quantization github.com/google/aqt. """
quant_dg: aqt_config.DotGeneral
lhs_quant_mode: aqt_flax.QuantMode
rhs_quant_mode: aqt_flax.QuantMode

def dot_general_cls(self):
""" Returns dot_general configured with aqt params. """
aqt_dg_cls = functools.partial(
aqt_flax.AqtDotGeneral,
self.quant_dg,
lhs_quant_mode=self.lhs_quant_mode,
rhs_quant_mode=self.rhs_quant_mode,
lhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION,
rhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION_AND_VALUE,
)
return aqt_dg_cls

def einsum(self):
""" Returns einsum configured with aqt params """
aqt_einsum = functools.partial(aqt_flax.AqtEinsum(
cfg=self.quant_dg,
lhs_quant_mode=self.lhs_quant_mode,
rhs_quant_mode=self.rhs_quant_mode,
lhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION,
rhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION_AND_VALUE,
)
)
return aqt_einsum

def _get_quant_config(config):
if not config.quantization or config.quantization == '':
return None
elif config.quantization == "int8":
if config.quantization_local_shard_count == 0:
drhs_bits = None
drhs_accumulator_dtype = None
drhs_local_aqt=None
else:
drhs_bits = 8
drhs_accumulator_dtype = jnp.int32
print(config.quantization_local_shard_count) # -1
drhs_local_aqt = aqt_config.LocalAqt(contraction_axis_shard_count=config.quantization_local_shard_count)
return aqt_config.config_v4(
fwd_bits=8,
dlhs_bits=8,
drhs_bits=drhs_bits,
rng_type='jax.uniform',
dlhs_local_aqt=None,
drhs_local_aqt=drhs_local_aqt,
fwd_accumulator_dtype=jnp.int32,
dlhs_accumulator_dtype=jnp.int32,
drhs_accumulator_dtype=drhs_accumulator_dtype,
)
else:
raise ValueError(f'Invalid value configured for quantization {config.quantization}.')

def in_convert_mode(quant):
return quant and (quant.quant_mode == aqt_flax.QuantMode.CONVERT)

def in_serve_mode(quant):
return quant and (quant.quant_mode == aqt_flax.QuantMode.SERVE)

def get_quant_mode(quant_mode_str: str = 'train'):
""" Set quant mode."""
if quant_mode_str == 'train':
return aqt_flax.QuantMode.TRAIN
elif quant_mode_str == 'serve':
return aqt_flax.QuantMode.SERVE
elif quant_mode_str == 'convert':
return aqt_flax.QuantMode.CONVERT
else:
raise ValueError(f'Invalid quantization mode {quant_mode_str}.')
return None

def configure_quantization(config: Config, lhs_quant_mode=aqt_flax.QuantMode.TRAIN, rhs_quant_mode=aqt_flax.QuantMode.TRAIN):
""" Configure quantization based on user config and quant mode."""
quant_cfg = _get_quant_config(config)
if quant_cfg:
return AqtQuantization(quant_dg=quant_cfg, lhs_quant_mode=lhs_quant_mode, rhs_quant_mode=rhs_quant_mode)
return None

# @dataclass
# class AqtQuantization:
# """ Configures AQT quantization github.com/google/aqt. """
# quant_dg: aqt_config.DotGeneral
# quant_mode: aqt_flax.QuantMode = aqt_flax.QuantMode.TRAIN




# def dot_general_cls_aqt(self, aqt_cfg, lhs_quant_mode, rhs_quant_mode):
# """ Returns dot_general configured with aqt params. """
# aqt_dg_cls = functools.partial(
# aqt_flax.AqtDotGeneral,
# aqt_cfg,
# lhs_quant_mode=lhs_quant_mode,
# rhs_quant_mode=rhs_quant_mode,
# lhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION,
# rhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION_AND_VALUE,
# )
# return aqt_dg_cls

# def einsum_aqt(self, aqt_cfg, lhs_quant_mode, rhs_quant_mode):
# return functools.partial(
# aqt_flax.AqtEinsum,
# aqt_cfg,
# lhs_quant_mode=lhs_quant_mode,
# rhs_quant_mode=rhs_quant_mode,
# lhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION,
# rhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION_AND_VALUE,
# )

Loading