diff --git a/requirements.txt b/requirements.txt index 0cf1fbae..b76d4787 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,8 @@ google-cloud-storage absl-py transformers>=4.25.1 datasets -flax +flax>=0.8.1 +aqtp optax torch torchvision diff --git a/src/maxdiffusion/configs/base_xl.yml b/src/maxdiffusion/configs/base_xl.yml index d68016e1..d0c23a52 100644 --- a/src/maxdiffusion/configs/base_xl.yml +++ b/src/maxdiffusion/configs/base_xl.yml @@ -154,4 +154,9 @@ lightning_repo: "" # Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning. lightning_ckpt: "" -enable_mllog: False \ No newline at end of file +enable_mllog: False + +quantization: 'int8' +# Shard the range finding operation for quantization. By default this is set to number of slices. +quantization_local_shard_count: -1 +compile_topology_num_slices: -1 # Number of target slices, set to a positive integer. diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 51f9fe8d..841a3b9b 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -330,6 +330,7 @@ def setup_initial_state(model, tx, config, mesh, model_params, unboxed_abstract_ init_train_state_partial = functools.partial(init_train_state, model=model, tx=tx, training=training) sharding = PositionalSharding(mesh.devices).replicate() + # TODO - Inspect structure of sharding? partial_device_put_replicated = functools.partial(device_put_replicated, sharding=sharding) model_params = jax.tree_util.tree_map(partial_device_put_replicated, model_params) @@ -344,17 +345,22 @@ def setup_initial_state(model, tx, config, mesh, model_params, unboxed_abstract_ state_mesh_shardings = jax.tree_util.tree_map( lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations) + return state, state_mesh_shardings -def get_states(mesh, tx, rng, config, pipeline, unet_params, vae_params, training=True): +def get_states(mesh, tx, rng, config, pipeline, unet_params, vae_params, training=True, q_v=None): # Needed to initialize weights on multi-host with addressable devices. + quant_enabled = config.quantization is not None if config.train_new_unet: unet_variables = jax.jit(pipeline.unet.init_weights, static_argnames=["eval_only"])(rng, eval_only=False) else: - unet_variables = pipeline.unet.init_weights(rng, eval_only=True) - - unboxed_abstract_state, state_mesh_annotations = get_abstract_state(pipeline.unet, tx, config, mesh, unet_variables, training=training) + #unet_variables = jax.jit(pipeline.unet.init_weights, static_argnames=["quantization_enabled"])(rng, quantization_enabled=quant_enabled) + unet_variables = pipeline.unet.init_weights(rng, eval_only=True, quantization_enabled=quant_enabled) + if q_v: + unboxed_abstract_state, state_mesh_annotations = get_abstract_state(pipeline.unet, tx, config, mesh, q_v, training=training) + else: + unboxed_abstract_state, state_mesh_annotations = get_abstract_state(pipeline.unet, tx, config, mesh, unet_variables, training=training) if config.train_new_unet: unet_params = unet_variables else: @@ -364,7 +370,7 @@ def get_states(mesh, tx, rng, config, pipeline, unet_params, vae_params, trainin tx, config, mesh, - unet_params, + q_v, unboxed_abstract_state, state_mesh_annotations, training=training) @@ -381,7 +387,7 @@ def get_states(mesh, tx, rng, config, pipeline, unet_params, vae_params, trainin state_mesh_annotations, training=training ) - + # breakpoint() return unet_state, unet_state_mesh_shardings, vae_state, vae_state_mesh_shardings # Learning Rate Schedule diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index b53d6f9b..4e5f368a 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -23,6 +23,8 @@ from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel from ..import common_types, max_logging +from . import quantizations + Array = common_types.Array Mesh = common_types.Mesh @@ -36,6 +38,11 @@ HEAD = common_types.HEAD D_KV = common_types.D_KV +Quant = quantizations.AqtQuantization + +def _maybe_aqt_einsum(quant: Quant): + return jnp.einsum if quant is None else quant.einsum() + class AttentionOp(nn.Module): mesh: Mesh attention_kernel: str @@ -48,6 +55,7 @@ class AttentionOp(nn.Module): flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV) flash_min_seq_length: int = 4096 flash_block_sizes: BlockSizes = None + quant: Quant = None dtype: DType = jnp.float32 def check_attention_inputs( @@ -385,6 +393,9 @@ class FlaxAttention(nn.Module): jax mesh is required if attention is set to flash. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` + quant (`AqtQuantization`, *optional*, defaults to None) + Configures AQT quantization github.com/google/aqt. + """ query_dim: int @@ -402,6 +413,7 @@ class FlaxAttention(nn.Module): key_axis_names: AxisNames = (BATCH, LENGTH, HEAD) value_axis_names: AxisNames = (BATCH, LENGTH, HEAD) out_axis_names: AxisNames = (BATCH, LENGTH, HEAD) + quant: Quant = None def setup(self): @@ -421,7 +433,8 @@ def setup(self): use_memory_efficient_attention=self.use_memory_efficient_attention, split_head_dim=self.split_head_dim, flash_block_sizes=self.flash_block_sizes, - dtype=self.dtype + dtype=self.dtype, + quant=self.quant ) qkv_init_kernel = nn.with_logical_partitioning( @@ -429,9 +442,14 @@ def setup(self): ("embed","heads") ) + dot_general_cls = None + if self.quant: + dot_general_cls = self.quant.dot_general_cls() + self.query = nn.Dense( inner_dim, kernel_init=qkv_init_kernel, + dot_general_cls=dot_general_cls, use_bias=False, dtype=self.dtype, name="to_q" @@ -440,6 +458,7 @@ def setup(self): self.key = nn.Dense( inner_dim, kernel_init=qkv_init_kernel, + dot_general_cls=dot_general_cls, use_bias=False, dtype=self.dtype, name="to_k" @@ -448,6 +467,7 @@ def setup(self): self.value = nn.Dense( inner_dim, kernel_init=qkv_init_kernel, + dot_general_cls=dot_general_cls, use_bias=False, dtype=self.dtype, name="to_v") @@ -458,6 +478,7 @@ def setup(self): nn.initializers.lecun_normal(), ("heads","embed") ), + dot_general_cls=dot_general_cls, dtype=self.dtype, name="to_out_0") self.dropout_layer = nn.Dropout(rate=self.dropout) @@ -520,6 +541,8 @@ class FlaxBasicTransformerBlock(nn.Module): Overrides default block sizes for flash attention. mesh (`jax.sharding.mesh`, *optional*, defaults to `None`): jax mesh is required if attention is set to flash. + quant (`AqtQuantization`, *optional*, defaults to None) + Configures AQT quantization github.com/google/aqt. """ dim: int n_heads: int @@ -533,6 +556,7 @@ class FlaxBasicTransformerBlock(nn.Module): flash_min_seq_length: int = 4096 flash_block_sizes: BlockSizes = None mesh: jax.sharding.Mesh = None + quant: Quant = None def setup(self): # self attention (or cross_attention if only_cross_attention is True) @@ -548,6 +572,7 @@ def setup(self): flash_block_sizes=self.flash_block_sizes, mesh=self.mesh, dtype=self.dtype, + quant=self.quant, ) # cross attention self.attn2 = FlaxAttention( @@ -562,6 +587,7 @@ def setup(self): flash_block_sizes=self.flash_block_sizes, mesh=self.mesh, dtype=self.dtype, + quant=self.quant, ) self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) @@ -625,6 +651,8 @@ class FlaxTransformer2DModel(nn.Module): Overrides default block sizes for flash attention. mesh (`jax.sharding.mesh`, *optional*, defaults to `None`): jax mesh is required if attention is set to flash. + quant (`AqtQuantization`, *optional*, defaults to None) + Configures AQT quantization github.com/google/aqt. """ in_channels: int n_heads: int @@ -641,6 +669,7 @@ class FlaxTransformer2DModel(nn.Module): flash_block_sizes: BlockSizes = None mesh: jax.sharding.Mesh = None norm_num_groups: int = 32 + quant: Quant = None def setup(self): self.norm = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-5) @@ -679,7 +708,8 @@ def setup(self): attention_kernel=self.attention_kernel, flash_min_seq_length=self.flash_min_seq_length, flash_block_sizes=self.flash_block_sizes, - mesh=self.mesh + mesh=self.mesh, + quant=self.quant ) for _ in range(self.depth) ] diff --git a/src/maxdiffusion/models/quantizations.py b/src/maxdiffusion/models/quantizations.py new file mode 100644 index 00000000..b2fd6c75 --- /dev/null +++ b/src/maxdiffusion/models/quantizations.py @@ -0,0 +1,137 @@ +""" + Copyright 2024 Google LLC + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import functools + +from aqt.jax.v2 import config as aqt_config + +from aqt.jax.v2.flax import aqt_flax +from ..common_types import Config +from dataclasses import dataclass +import jax.numpy as jnp + + +@dataclass +class AqtQuantization: + """ Configures AQT quantization github.com/google/aqt. """ + quant_dg: aqt_config.DotGeneral + lhs_quant_mode: aqt_flax.QuantMode + rhs_quant_mode: aqt_flax.QuantMode + + def dot_general_cls(self): + """ Returns dot_general configured with aqt params. """ + aqt_dg_cls = functools.partial( + aqt_flax.AqtDotGeneral, + self.quant_dg, + lhs_quant_mode=self.lhs_quant_mode, + rhs_quant_mode=self.rhs_quant_mode, + lhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION, + rhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION_AND_VALUE, + ) + return aqt_dg_cls + + def einsum(self): + """ Returns einsum configured with aqt params """ + aqt_einsum = functools.partial(aqt_flax.AqtEinsum( + cfg=self.quant_dg, + lhs_quant_mode=self.lhs_quant_mode, + rhs_quant_mode=self.rhs_quant_mode, + lhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION, + rhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION_AND_VALUE, + ) + ) + return aqt_einsum + +def _get_quant_config(config): + if not config.quantization or config.quantization == '': + return None + elif config.quantization == "int8": + if config.quantization_local_shard_count == 0: + drhs_bits = None + drhs_accumulator_dtype = None + drhs_local_aqt=None + else: + drhs_bits = 8 + drhs_accumulator_dtype = jnp.int32 + print(config.quantization_local_shard_count) # -1 + drhs_local_aqt = aqt_config.LocalAqt(contraction_axis_shard_count=config.quantization_local_shard_count) + return aqt_config.config_v4( + fwd_bits=8, + dlhs_bits=8, + drhs_bits=drhs_bits, + rng_type='jax.uniform', + dlhs_local_aqt=None, + drhs_local_aqt=drhs_local_aqt, + fwd_accumulator_dtype=jnp.int32, + dlhs_accumulator_dtype=jnp.int32, + drhs_accumulator_dtype=drhs_accumulator_dtype, + ) + else: + raise ValueError(f'Invalid value configured for quantization {config.quantization}.') + +def in_convert_mode(quant): + return quant and (quant.quant_mode == aqt_flax.QuantMode.CONVERT) + +def in_serve_mode(quant): + return quant and (quant.quant_mode == aqt_flax.QuantMode.SERVE) + +def get_quant_mode(quant_mode_str: str = 'train'): + """ Set quant mode.""" + if quant_mode_str == 'train': + return aqt_flax.QuantMode.TRAIN + elif quant_mode_str == 'serve': + return aqt_flax.QuantMode.SERVE + elif quant_mode_str == 'convert': + return aqt_flax.QuantMode.CONVERT + else: + raise ValueError(f'Invalid quantization mode {quant_mode_str}.') + return None + +def configure_quantization(config: Config, lhs_quant_mode=aqt_flax.QuantMode.TRAIN, rhs_quant_mode=aqt_flax.QuantMode.TRAIN): + """ Configure quantization based on user config and quant mode.""" + quant_cfg = _get_quant_config(config) + if quant_cfg: + return AqtQuantization(quant_dg=quant_cfg, lhs_quant_mode=lhs_quant_mode, rhs_quant_mode=rhs_quant_mode) + return None + +# @dataclass +# class AqtQuantization: +# """ Configures AQT quantization github.com/google/aqt. """ +# quant_dg: aqt_config.DotGeneral +# quant_mode: aqt_flax.QuantMode = aqt_flax.QuantMode.TRAIN + + + + +# def dot_general_cls_aqt(self, aqt_cfg, lhs_quant_mode, rhs_quant_mode): +# """ Returns dot_general configured with aqt params. """ +# aqt_dg_cls = functools.partial( +# aqt_flax.AqtDotGeneral, +# aqt_cfg, +# lhs_quant_mode=lhs_quant_mode, +# rhs_quant_mode=rhs_quant_mode, +# lhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION, +# rhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION_AND_VALUE, +# ) +# return aqt_dg_cls + +# def einsum_aqt(self, aqt_cfg, lhs_quant_mode, rhs_quant_mode): +# return functools.partial( +# aqt_flax.AqtEinsum, +# aqt_cfg, +# lhs_quant_mode=lhs_quant_mode, +# rhs_quant_mode=rhs_quant_mode, +# lhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION, +# rhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION_AND_VALUE, +# ) + \ No newline at end of file diff --git a/src/maxdiffusion/models/unet_2d_blocks_flax.py b/src/maxdiffusion/models/unet_2d_blocks_flax.py index d9ab86c2..8acd86b7 100644 --- a/src/maxdiffusion/models/unet_2d_blocks_flax.py +++ b/src/maxdiffusion/models/unet_2d_blocks_flax.py @@ -18,6 +18,10 @@ from .attention_flax import FlaxTransformer2DModel from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D +from . import quantizations + +Quant = quantizations.AqtQuantization + from ..common_types import BlockSizes class FlaxCrossAttnDownBlock2D(nn.Module): @@ -53,6 +57,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module): jax mesh is required if attention is set to flash. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` + quant (`AqtQuantization`, *optional*, defaults to None) + Configures AQT quantization github.com/google/aqt. """ in_channels: int out_channels: int @@ -69,6 +75,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module): flash_block_sizes: BlockSizes = None mesh: jax.sharding.Mesh = None dtype: jnp.dtype = jnp.float32 + quant: Quant = None transformer_layers_per_block: int = 1 norm_num_groups: int = 32 @@ -102,7 +109,8 @@ def setup(self): flash_block_sizes=self.flash_block_sizes, mesh=self.mesh, dtype=self.dtype, - norm_num_groups=self.norm_num_groups + norm_num_groups=self.norm_num_groups, + quant=self.quant, ) attentions.append(attn_block) @@ -219,6 +227,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module): jax mesh is required if attention is set to flash. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` + quant (`AqtQuantization`, *optional*, defaults to None) + Configures AQT quantization github.com/google/aqt. """ in_channels: int out_channels: int @@ -238,6 +248,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module): dtype: jnp.dtype = jnp.float32 transformer_layers_per_block: int = 1 norm_num_groups: int = 32 + quant: Quant = None + def setup(self): resnets = [] @@ -270,7 +282,8 @@ def setup(self): flash_block_sizes=self.flash_block_sizes, mesh=self.mesh, dtype=self.dtype, - norm_num_groups=self.norm_num_groups + norm_num_groups=self.norm_num_groups, + quant=self.quant, ) attentions.append(attn_block) @@ -389,6 +402,8 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): jax mesh is required if attention is set to flash. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` + quant (`AqtQuantization`, *optional*, defaults to None) + Configures AQT quantization github.com/google/aqt. """ in_channels: int dropout: float = 0.0 @@ -404,6 +419,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): dtype: jnp.dtype = jnp.float32 transformer_layers_per_block: int = 1 norm_num_groups: int = 32 + quant: Quant = None def setup(self): # there is always at least one resnet @@ -433,7 +449,8 @@ def setup(self): flash_block_sizes=self.flash_block_sizes, mesh=self.mesh, dtype=self.dtype, - norm_num_groups=self.norm_num_groups + norm_num_groups=self.norm_num_groups, + quant=self.quant, ) attentions.append(attn_block) diff --git a/src/maxdiffusion/models/unet_2d_condition_flax.py b/src/maxdiffusion/models/unet_2d_condition_flax.py index 0c6e7f73..7c28732b 100644 --- a/src/maxdiffusion/models/unet_2d_condition_flax.py +++ b/src/maxdiffusion/models/unet_2d_condition_flax.py @@ -31,8 +31,10 @@ FlaxUpBlock2D, ) +from . import quantizations from ..common_types import BlockSizes +Quant = quantizations.AqtQuantization @flax.struct.dataclass class FlaxUNet2DConditionOutput(BaseOutput): @@ -105,6 +107,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): Overrides default block sizes for flash attention. mesh (`jax.sharding.mesh`, *optional*, defaults to `None`): jax mesh is required if attention is set to flash. + quant (`AqtQuantization`, *optional*, defaults to None) + Configures AQT quantization github.com/google/aqt. """ sample_size: int = 32 @@ -140,8 +144,9 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): addition_embed_type_num_heads: int = 64 projection_class_embeddings_input_dim: Optional[int] = None norm_num_groups: int = 32 + quant: Quant = None - def init_weights(self, rng: jax.Array, eval_only: bool = False) -> FrozenDict: + def init_weights(self, rng: jax.Array, eval_only: bool = False, quantization_enabled: bool = False) -> FrozenDict: # init input tensors no_devices = jax.device_count() sample_shape = (no_devices, self.in_channels, self.sample_size, self.sample_size) @@ -151,6 +156,8 @@ def init_weights(self, rng: jax.Array, eval_only: bool = False) -> FrozenDict: params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} + if quantization_enabled: + rngs["aqt"] = params_rng added_cond_kwargs = None if self.addition_embed_type == "text_time": @@ -260,6 +267,7 @@ def setup(self): flash_block_sizes=self.flash_block_sizes, mesh=self.mesh, dtype=self.dtype, + quant=self.quant, ) else: down_block = FlaxDownBlock2D( @@ -288,6 +296,7 @@ def setup(self): flash_block_sizes=self.flash_block_sizes, mesh=self.mesh, dtype=self.dtype, + quant=self.quant, ) # up @@ -323,6 +332,7 @@ def setup(self): flash_block_sizes=self.flash_block_sizes, mesh=self.mesh, dtype=self.dtype, + quant=self.quant, ) else: up_block = FlaxUpBlock2D( diff --git a/src/maxdiffusion/pipelines/pipeline_flax_utils.py b/src/maxdiffusion/pipelines/pipeline_flax_utils.py index b28450fa..5a2fd214 100644 --- a/src/maxdiffusion/pipelines/pipeline_flax_utils.py +++ b/src/maxdiffusion/pipelines/pipeline_flax_utils.py @@ -329,6 +329,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P mesh = kwargs.pop("mesh", None) dtype = kwargs.pop("dtype", None) norm_num_groups = kwargs.pop("norm_num_groups", 32) + quant = kwargs.pop("quant", None) # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained @@ -513,6 +514,7 @@ def load_module(name, value): mesh=mesh, norm_num_groups=norm_num_groups, dtype=dtype, + quant=quant, ) params[name] = loaded_params elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel): diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index eed27f76..8cedb99c 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -111,13 +111,32 @@ def user_init(raw_keys): raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"]) raw_keys["data_sharding"] = _lists_to_tuples(raw_keys["data_sharding"]) - + raw_keys["num_slices"] = get_num_slices(raw_keys) + raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys) + if raw_keys["learning_rate_schedule_steps"]==-1: raw_keys["learning_rate_schedule_steps"] = raw_keys["max_train_steps"] if "gs://" in raw_keys["pretrained_model_name_or_path"]: raw_keys["pretrained_model_name_or_path"] = max_utils.download_blobs(raw_keys["pretrained_model_name_or_path"], "/tmp") +def get_num_slices(raw_keys): + if int(raw_keys['compile_topology_num_slices']) > 0: + return raw_keys['compile_topology_num_slices'] + else: + devices = jax.devices() + try: + return 1 + max([d.slice_index for d in devices]) + except: + return 1 + +def get_quantization_local_shard_count(raw_keys): + if raw_keys['quantization_local_shard_count'] == -1: + return raw_keys['num_slices'] + else: + return raw_keys['quantization_local_shard_count'] + + def get_num_target_devices(raw_keys): return len(jax.devices()) diff --git a/src/maxdiffusion/unet_quantization.py b/src/maxdiffusion/unet_quantization.py new file mode 100644 index 00000000..ff176103 --- /dev/null +++ b/src/maxdiffusion/unet_quantization.py @@ -0,0 +1,479 @@ +""" + Copyright 2024 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import os +import functools +from absl import app +from typing import Sequence +import time +from maxdiffusion.models.unet_2d_condition_flax import FlaxUNet2DConditionModel + +import numpy as np +import jax +import jax.numpy as jnp +from jax.sharding import Mesh +from jax.sharding import PartitionSpec as P +from jax.experimental.compilation_cache import compilation_cache as cc +from flax.linen import partitioning as nn_partitioning +from jax.sharding import PositionalSharding +from aqt.jax.v2.flax import aqt_flax +import optax + +from maxdiffusion import ( + FlaxStableDiffusionXLPipeline, + FlaxEulerDiscreteScheduler, + FlaxDDPMScheduler +) + + +from maxdiffusion import pyconfig +from maxdiffusion.image_processor import VaeImageProcessor +from maxdiffusion.max_utils import ( + InferenceState, + create_device_mesh, + get_dtype, + get_states, + activate_profiler, + deactivate_profiler, + device_put_replicated, + get_flash_block_sizes, + get_abstract_state, + setup_initial_state, + create_learning_rate_schedule +) +from .maxdiffusion_utils import ( + load_sdxllightning_unet, + get_add_time_ids, + rescale_noise_cfg +) +from .models import quantizations +from jax.tree_util import tree_flatten_with_path, tree_unflatten + + +cc.set_cache_dir(os.path.expanduser("~/jax_cache")) + +def _get_aqt_key_paths(aqt_vars): + """Generate a list of paths which have aqt state""" + aqt_tree_flat, _ = jax.tree_util.tree_flatten_with_path(aqt_vars) + aqt_key_paths = [] + for k, _ in aqt_tree_flat: + pruned_keys = [] + for d in list(k): + if "AqtDotGeneral" in d.key: + pruned_keys.append(jax.tree_util.DictKey(key="kernel")) + break + else: + assert "Aqt" not in d.key, f"Unexpected Aqt op {d.key} in {k}." + pruned_keys.append(d) + aqt_key_paths.append(tuple(pruned_keys)) + return aqt_key_paths + +def remove_quantized_params(params, aqt_vars): + """Remove param values with aqt tensors to Null to optimize memory.""" + aqt_paths = _get_aqt_key_paths(aqt_vars) + tree_flat, tree_struct = tree_flatten_with_path(params) + for i, (k, v) in enumerate(tree_flat): + if k in aqt_paths: + v = {} + tree_flat[i] = v + return tree_unflatten(tree_struct, tree_flat) + + +def get_quantized_unet_variables(config): + + # Setup Mesh + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + batch_size = config.per_device_batch_size * jax.device_count() + + weight_dtype = get_dtype(config) + flash_block_sizes = get_flash_block_sizes(config) + + quant = quantizations.configure_quantization(config=config, lhs_quant_mode=aqt_flax.QuantMode.TRAIN, rhs_quant_mode=aqt_flax.QuantMode.CONVERT) + pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( + config.pretrained_model_name_or_path, + revision=config.revision, + dtype=weight_dtype, + split_head_dim=config.split_head_dim, + norm_num_groups=config.norm_num_groups, + attention_kernel=config.attention, + flash_block_sizes=flash_block_sizes, + mesh=mesh, + quant=quant, + ) + + k = jax.random.key(0) + latents = jnp.ones((8, 4,128,128), dtype=jnp.float32) + timesteps = jnp.ones((8,)) + encoder_hidden_states = jnp.ones((8, 77, 2048)) + + added_cond_kwargs = { + "text_embeds": jnp.zeros((8, 1280), dtype=jnp.float32), + "time_ids": jnp.zeros((8, 6), dtype=jnp.float32), + } + noise_pred, quantized_unet_vars = pipeline.unet.apply( + params["unet"] | {"aqt" : {}}, + latents, + timesteps, + encoder_hidden_states=encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, + rngs={"params": jax.random.PRNGKey(0)}, + mutable=True, + ) + del pipeline + del params + + return quantized_unet_vars + +def loop_body(step, args, model, pipeline, added_cond_kwargs, prompt_embeds, guidance_scale, guidance_rescale): + latents, scheduler_state, state = args + latents_input = jnp.concatenate([latents] * 2) + + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + latents_input = pipeline.scheduler.scale_model_input(scheduler_state, latents_input, t) + # breakpoint() + noise_pred = model.apply( + {"params" : state.params, "aqt": state.params["aqt"] }, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=prompt_embeds, + added_cond_kwargs=added_cond_kwargs, + rngs={"params": jax.random.PRNGKey(0)} + ).sample + + noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_prediction_text, guidance_rescale=guidance_rescale) + + + latents, scheduler_state = pipeline.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + + return latents, scheduler_state, state + +def loop_body_for_quantization(latents, scheduler_state, state, rng, model, pipeline, added_cond_kwargs, prompt_embeds, guidance_scale, guidance_rescale): + # latents, scheduler_state, state, rng = args + latents_input = jnp.concatenate([latents] * 2) + + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[0] + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + latents_input = pipeline.scheduler.scale_model_input(scheduler_state, latents_input, t) + noise_pred, quantized_unet_vars = model.apply( + state.params | {"aqt" : {}}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=prompt_embeds, + added_cond_kwargs=added_cond_kwargs, + rngs={"params": rng}, + mutable=True, + ) + return quantized_unet_vars + + +def get_embeddings(prompt_ids, pipeline, params): + te_1_inputs = prompt_ids[:, 0, :] + te_2_inputs = prompt_ids[:, 1, :] + + prompt_embeds = pipeline.text_encoder( + te_1_inputs, params=params["text_encoder"], output_hidden_states=True + ) + prompt_embeds = prompt_embeds["hidden_states"][-2] + prompt_embeds_2_out = pipeline.text_encoder_2( + te_2_inputs, params=params["text_encoder_2"], output_hidden_states=True + ) + prompt_embeds_2 = prompt_embeds_2_out["hidden_states"][-2] + text_embeds = prompt_embeds_2_out["text_embeds"] + prompt_embeds = jnp.concatenate([prompt_embeds, prompt_embeds_2], axis=-1) + return prompt_embeds, text_embeds + +def tokenize(prompt, pipeline): + inputs = [] + for _tokenizer in [pipeline.tokenizer, pipeline.tokenizer_2]: + text_inputs = _tokenizer( + prompt, + padding="max_length", + max_length=_tokenizer.model_max_length, + truncation=True, + return_tensors="np" + ) + inputs.append(text_inputs.input_ids) + inputs = jnp.stack(inputs,axis=1) + return inputs + +def run(config, q_v): + rng = jax.random.PRNGKey(config.seed) + + # Setup Mesh + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + batch_size = config.per_device_batch_size * jax.device_count() + + weight_dtype = get_dtype(config) + flash_block_sizes = get_flash_block_sizes(config) + + quant = quantizations.configure_quantization(config=config, lhs_quant_mode=aqt_flax.QuantMode.TRAIN, rhs_quant_mode=aqt_flax.QuantMode.SERVE) + pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( + config.pretrained_model_name_or_path, + revision=config.revision, + dtype=weight_dtype, + split_head_dim=config.split_head_dim, + norm_num_groups=config.norm_num_groups, + attention_kernel=config.attention, + flash_block_sizes=flash_block_sizes, + mesh=mesh, + quant=quant, + ) + + # if this checkpoint was trained with maxdiffusion + # the training scheduler was saved with it, switch it + # to a Euler scheduler + if isinstance(pipeline.scheduler, FlaxDDPMScheduler): + noise_scheduler, noise_scheduler_state = FlaxEulerDiscreteScheduler.from_pretrained( + config.pretrained_model_name_or_path, + revision=config.revision, subfolder="scheduler", dtype=jnp.float32 + ) + pipeline.scheduler = noise_scheduler + params["scheduler"] = noise_scheduler_state + + if config.lightning_repo: + pipeline, params = load_sdxllightning_unet(config, pipeline, params) + + scheduler_state = params.pop("scheduler") + old_params = params + params = jax.tree_util.tree_map(lambda x: x.astype(weight_dtype), old_params) + params["scheduler"] = scheduler_state + + data_sharding = jax.sharding.NamedSharding(mesh,P(*config.data_sharding)) + + sharding = PositionalSharding(devices_array).replicate() + partial_device_put_replicated = functools.partial(device_put_replicated, sharding=sharding) + params["text_encoder"] = jax.tree_util.tree_map(partial_device_put_replicated, params["text_encoder"]) + params["text_encoder_2"] = jax.tree_util.tree_map(partial_device_put_replicated, params["text_encoder_2"]) + + # p = {} + # p["aqt"] = q_v["aqt"] + # # Remove param values which have corresponding qtensors in aqt to save memory. + # p["params"] = remove_quantized_params(q_v["params"], q_v["aqt"]) + learning_rate_scheduler = create_learning_rate_schedule(config) + tx = optax.adamw( + learning_rate=learning_rate_scheduler, + b1=config.adam_b1, + b2=config.adam_b2, + eps=config.adam_eps, + weight_decay=config.adam_weight_decay, + ) + unet_state, unet_state_mesh_shardings, vae_state, vae_state_mesh_shardings = get_states(mesh, tx, rng, config, pipeline, q_v, params["vae"], training=False, q_v=q_v) + del params["vae"] + del params["unet"] + # unet_state.params = q_v + # params["unet"] = jax.tree_util.tree_map(partial_device_put_replicated, params["unet"]) + # unet_state = InferenceState(pipeline.unet.apply, params=params["unet"]) + + def get_unet_inputs(rng, config, batch_size, pipeline, params): + vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) + prompt_ids = [config.prompt] * batch_size + prompt_ids = tokenize(prompt_ids, pipeline) + negative_prompt_ids = [config.negative_prompt] * batch_size + negative_prompt_ids = tokenize(negative_prompt_ids, pipeline) + guidance_scale = config.guidance_scale + guidance_rescale = config.guidance_rescale + num_inference_steps = config.num_inference_steps + height = config.resolution + width = config.resolution + prompt_embeds, pooled_embeds = get_embeddings(prompt_ids, pipeline, params) + batch_size = prompt_embeds.shape[0] + negative_prompt_embeds, negative_pooled_embeds = get_embeddings(negative_prompt_ids, pipeline, params) + add_time_ids = get_add_time_ids( + (height, width), (0, 0), (height, width), prompt_embeds.shape[0], dtype=prompt_embeds.dtype + ) + + prompt_embeds = jnp.concatenate([negative_prompt_embeds, prompt_embeds], axis=0) + add_text_embeds = jnp.concatenate([negative_pooled_embeds, pooled_embeds], axis=0) + add_time_ids = jnp.concatenate([add_time_ids, add_time_ids], axis=0) + # Ensure model output will be `float32` before going into the scheduler + guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32) + guidance_rescale = jnp.array([guidance_rescale], dtype=jnp.float32) + + latents_shape = ( + batch_size, + pipeline.unet.config.in_channels, + height // vae_scale_factor, + width // vae_scale_factor, + ) + + latents = jax.random.normal(rng, shape=latents_shape, dtype=jnp.float32) + + scheduler_state = pipeline.scheduler.set_timesteps( + params["scheduler"], + num_inference_steps=num_inference_steps, + shape=latents.shape + ) + + latents = latents * scheduler_state.init_noise_sigma + + added_cond_kwargs = {"text_embeds" : add_text_embeds, "time_ids" : add_time_ids} + latents = jax.device_put(latents, data_sharding) + prompt_embeds = jax.device_put(prompt_embeds, data_sharding) + guidance_scale = jax.device_put(guidance_scale, PositionalSharding(devices_array).replicate()) + added_cond_kwargs['text_embeds'] = jax.device_put(added_cond_kwargs['text_embeds'], data_sharding) + added_cond_kwargs['time_ids'] = jax.device_put(added_cond_kwargs['time_ids'], data_sharding) + + return latents, prompt_embeds, added_cond_kwargs, guidance_scale, guidance_rescale, scheduler_state + + def vae_decode(latents, state, pipeline): + latents = 1 / pipeline.vae.config.scaling_factor * latents + image = pipeline.vae.apply( + {"params" : state.params}, + latents, + method=pipeline.vae.decode + ).sample + image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) + return image + + def get_quantized_unet_vars(unet_state, params, rng, config, batch_size, pipeline): + + (latents, + prompt_embeds, + added_cond_kwargs, + guidance_scale, + guidance_rescale, + scheduler_state) = get_unet_inputs(rng, config, batch_size, pipeline, params) + + loop_body_quant_p = jax.jit(functools.partial(loop_body_for_quantization, + model=pipeline.unet, + pipeline=pipeline, + added_cond_kwargs=added_cond_kwargs, + prompt_embeds=prompt_embeds, + guidance_scale=guidance_scale, + guidance_rescale=guidance_rescale)) + # with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + quantized_unet_vars = loop_body_quant_p(latents=latents, scheduler_state=scheduler_state, state=unet_state,rng=rng) + + + return quantized_unet_vars + + def run_inference(unet_state, vae_state, params, rng, config, batch_size, pipeline): + + (latents, + prompt_embeds, + added_cond_kwargs, + guidance_scale, + guidance_rescale, + scheduler_state) = get_unet_inputs(rng, config, batch_size, pipeline, params) + + loop_body_p = functools.partial(loop_body, model=pipeline.unet, + pipeline=pipeline, + added_cond_kwargs=added_cond_kwargs, + prompt_embeds=prompt_embeds, + guidance_scale=guidance_scale, + guidance_rescale=guidance_rescale) + vae_decode_p = functools.partial(vae_decode, pipeline=pipeline) + + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + latents, _, _ = jax.lax.fori_loop(0, config.num_inference_steps, + loop_body_p, (latents, scheduler_state, unet_state)) + image = vae_decode_p(latents, vae_state) + return image + + #quantized_unet_vars = get_quantized_unet_vars(unet_state, params, rng, config, batch_size, pipeline) + + #del params + #del pipeline + #del unet_state + #quant = quantizations.configure_quantization(config=config, lhs_quant_mode=aqt_flax.QuantMode.TRAIN, rhs_quant_mode=aqt_flax.QuantMode.SERVE) + + # pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( + # config.pretrained_model_name_or_path, + # revision=config.revision, + # dtype=weight_dtype, + # split_head_dim=config.split_head_dim, + # norm_num_groups=config.norm_num_groups, + # attention_kernel=config.attention, + # flash_block_sizes=flash_block_sizes, + # mesh=mesh, + # quant=quant, + # ) + + # scheduler_state = params.pop("scheduler") + # old_params = params + # params = jax.tree_util.tree_map(lambda x: x.astype(weight_dtype), old_params) + # params["scheduler"] = scheduler_state + + # data_sharding = jax.sharding.NamedSharding(mesh,P(*config.data_sharding)) + + # sharding = PositionalSharding(devices_array).replicate() + # partial_device_put_replicated = functools.partial(device_put_replicated, sharding=sharding) + # params["text_encoder"] = jax.tree_util.tree_map(partial_device_put_replicated, params["text_encoder"]) + # params["text_encoder_2"] = jax.tree_util.tree_map(partial_device_put_replicated, params["text_encoder_2"]) + + # unet_state = InferenceState(pipeline.unet.apply, params=quantized_unet_vars) + # unet_state, unet_state_mesh_shardings, vae_state, vae_state_mesh_shardings = get_states(mesh, None, rng, config, pipeline, quantized_unet_vars, params["vae"], training=False) + #del params["vae"] + #del params["unet"] + + + p_run_inference = jax.jit( + functools.partial(run_inference, rng=rng, config=config, batch_size=batch_size, pipeline=pipeline), + in_shardings=(unet_state_mesh_shardings, vae_state_mesh_shardings, None), + out_shardings=None + ) + + s = time.time() + p_run_inference(unet_state, vae_state, params).block_until_ready() + print("compile time: ", (time.time() - s)) + s = time.time() + images = p_run_inference(unet_state, vae_state, params).block_until_ready() + images.block_until_ready() + print("inference time: ",(time.time() - s)) + s = time.time() + images = p_run_inference(unet_state, vae_state, params).block_until_ready() #run_inference(unet_state, vae_state, latents, scheduler_state) + images.block_until_ready() + print("inference time: ",(time.time() - s)) + s = time.time() + images = p_run_inference(unet_state, vae_state, params).block_until_ready() # run_inference(unet_state, vae_state, latents, scheduler_state) + images.block_until_ready() + print("inference time: ",(time.time() - s)) + s = time.time() + activate_profiler(config) + images = p_run_inference(unet_state, vae_state, params).block_until_ready() + deactivate_profiler(config) + images.block_until_ready() + print("inference time: ",(time.time() - s)) + images = jax.experimental.multihost_utils.process_allgather(images) + numpy_images = np.array(images) + images = VaeImageProcessor.numpy_to_pil(numpy_images) + for i, image in enumerate(images): + image.save(f"image_sdxl_{i}.png") + + return images + +def main(argv: Sequence[str]) -> None: + pyconfig.initialize(argv) + q_v = get_quantized_unet_variables(pyconfig.config) + # breakpoint() + del q_v['params'] + print(q_v.keys()) + # addedkw_args...., params, aqt + run(pyconfig.config, q_v) + +if __name__ == "__main__": + app.run(main)