Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replicated SDXL Quantization #76

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ google-cloud-storage
absl-py
transformers>=4.25.1
datasets
flax
flax>=0.8.1
aqtp
optax
torch
torchvision
Expand Down
7 changes: 6 additions & 1 deletion src/maxdiffusion/configs/base_xl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,9 @@ lightning_repo: ""
# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning.
lightning_ckpt: ""

enable_mllog: False
enable_mllog: False

quantization: 'int8'
# Shard the range finding operation for quantization. By default this is set to number of slices.
quantization_local_shard_count: -1
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
1 change: 1 addition & 0 deletions src/maxdiffusion/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,7 @@ def to_json_saveable(value):
config_dict.pop("_ignore_files", None)
config_dict.pop("_use_default_values", None)
config_dict.pop("mesh", None)
config_dict.pop("quant", None)

return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"

Expand Down
54 changes: 52 additions & 2 deletions src/maxdiffusion/generate_sdxl_replicated.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,65 @@
# 1. Let's start by downloading the model and loading it into our pipeline class
# Adhering to JAX's functional approach, the model's parameters are returned seperatetely and
# will have to be passed to the pipeline during inference
from aqt.jax.v2.flax import aqt_flax
from maxdiffusion.models import quantizations
def get_quantized_unet_variables():
quant = quantizations.configure_quantization(config=None, lhs_quant_mode=aqt_flax.QuantMode.TRAIN, rhs_quant_mode=aqt_flax.QuantMode.CONVERT)
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
revision="refs/pr/95",
dtype=jnp.bfloat16,
split_head_dim=False,
quant=quant,
)
latents = jnp.ones((4, 4,128,128), dtype=jnp.float32)
timesteps = jnp.ones((4,))
encoder_hidden_states = jnp.ones((4, 77, 2048))

added_cond_kwargs = {
"text_embeds": jnp.ones((4, 1280), dtype=jnp.float32),
"time_ids": jnp.ones((4, 6), dtype=jnp.float32),
}
_, quantized_unet_vars = pipeline.unet.apply(
# params["unet"],
params["unet"] | {"aqt" : {}},
latents,
timesteps,
encoder_hidden_states=encoder_hidden_states,
added_cond_kwargs=added_cond_kwargs,
rngs={"params": jax.random.PRNGKey(0)},
mutable=True,
)
breakpoint()
del pipeline
del params
del quantized_unet_vars["params"]
return quantized_unet_vars


quantized_unet_vars = get_quantized_unet_variables()

quant = quantizations.configure_quantization(config=None, lhs_quant_mode=aqt_flax.QuantMode.TRAIN, rhs_quant_mode=aqt_flax.QuantMode.SERVE)
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True
"stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=False, quant=quant,
)

print("params loaded keys ", params.keys())
breakpoint()
# 2. We cast all parameters to bfloat16 EXCEPT the scheduler which we leave in
# float32 to keep maximal precision
scheduler_state = params.pop("scheduler")
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
params["scheduler"] = scheduler_state
# del params["unet"]
# p ={}
# p['aqt'] = quantized_unet_vars['aqt']
# del quantized_unet_vars['aqt']
# p['params'] = quantized_unet_vars['params']
# params["unet"] = p
# del quantized_unet_vars
params["unet"] = quantized_unet_vars

# p[]['aqt'] = quantized_unet_vars['aqt']

# 3. Next, we define the different inputs to the pipeline
default_prompt = "a colorful photo of a castle in the middle of a forest with trees and bushes, by Ismail Inceoglu, shadows, high contrast, dynamic shading, hdr, detailed vegetation, digital painting, digital drawing, detailed painting, a detailed digital painting, gothic art, featured on deviantart"
Expand Down
18 changes: 12 additions & 6 deletions src/maxdiffusion/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def setup_initial_state(model, tx, config, mesh, model_params, unboxed_abstract_
init_train_state_partial = functools.partial(init_train_state, model=model, tx=tx, training=training)

sharding = PositionalSharding(mesh.devices).replicate()
# TODO - Inspect structure of sharding?
partial_device_put_replicated = functools.partial(device_put_replicated, sharding=sharding)
model_params = jax.tree_util.tree_map(partial_device_put_replicated, model_params)

Expand All @@ -344,17 +345,22 @@ def setup_initial_state(model, tx, config, mesh, model_params, unboxed_abstract_

state_mesh_shardings = jax.tree_util.tree_map(
lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations)

return state, state_mesh_shardings

def get_states(mesh, tx, rng, config, pipeline, unet_params, vae_params, training=True):
def get_states(mesh, tx, rng, config, pipeline, unet_params, vae_params, training=True, q_v=None):

# Needed to initialize weights on multi-host with addressable devices.
quant_enabled = config.quantization is not None
if config.train_new_unet:
unet_variables = jax.jit(pipeline.unet.init_weights, static_argnames=["eval_only"])(rng, eval_only=False)
else:
unet_variables = pipeline.unet.init_weights(rng, eval_only=True)

unboxed_abstract_state, state_mesh_annotations = get_abstract_state(pipeline.unet, tx, config, mesh, unet_variables, training=training)
#unet_variables = jax.jit(pipeline.unet.init_weights, static_argnames=["quantization_enabled"])(rng, quantization_enabled=quant_enabled)
unet_variables = pipeline.unet.init_weights(rng, eval_only=True, quantization_enabled=quant_enabled)
if q_v:
unboxed_abstract_state, state_mesh_annotations = get_abstract_state(pipeline.unet, tx, config, mesh, q_v, training=training)
else:
unboxed_abstract_state, state_mesh_annotations = get_abstract_state(pipeline.unet, tx, config, mesh, unet_variables, training=training)
if config.train_new_unet:
unet_params = unet_variables
else:
Expand All @@ -364,7 +370,7 @@ def get_states(mesh, tx, rng, config, pipeline, unet_params, vae_params, trainin
tx,
config,
mesh,
unet_params,
q_v,
unboxed_abstract_state,
state_mesh_annotations,
training=training)
Expand All @@ -381,7 +387,7 @@ def get_states(mesh, tx, rng, config, pipeline, unet_params, vae_params, trainin
state_mesh_annotations,
training=training
)

# breakpoint()
return unet_state, unet_state_mesh_shardings, vae_state, vae_state_mesh_shardings

# Learning Rate Schedule
Expand Down
37 changes: 35 additions & 2 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel

from ..import common_types, max_logging
from . import quantizations


Array = common_types.Array
Mesh = common_types.Mesh
Expand All @@ -36,6 +38,11 @@
HEAD = common_types.HEAD
D_KV = common_types.D_KV

Quant = quantizations.AqtQuantization

def _maybe_aqt_einsum(quant: Quant):
return jnp.einsum if quant is None else quant.einsum()

class AttentionOp(nn.Module):
mesh: Mesh
attention_kernel: str
Expand All @@ -48,6 +55,7 @@ class AttentionOp(nn.Module):
flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV)
flash_min_seq_length: int = 4096
flash_block_sizes: BlockSizes = None
quant: Quant = None
dtype: DType = jnp.float32

def check_attention_inputs(
Expand Down Expand Up @@ -385,6 +393,9 @@ class FlaxAttention(nn.Module):
jax mesh is required if attention is set to flash.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
quant (`AqtQuantization`, *optional*, defaults to None)
Configures AQT quantization github.com/google/aqt.


"""
query_dim: int
Expand All @@ -402,6 +413,7 @@ class FlaxAttention(nn.Module):
key_axis_names: AxisNames = (BATCH, LENGTH, HEAD)
value_axis_names: AxisNames = (BATCH, LENGTH, HEAD)
out_axis_names: AxisNames = (BATCH, LENGTH, HEAD)
quant: Quant = None

def setup(self):

Expand All @@ -421,17 +433,26 @@ def setup(self):
use_memory_efficient_attention=self.use_memory_efficient_attention,
split_head_dim=self.split_head_dim,
flash_block_sizes=self.flash_block_sizes,
dtype=self.dtype
dtype=self.dtype,
quant=self.quant
)

qkv_init_kernel = nn.with_logical_partitioning(
nn.initializers.lecun_normal(),
("embed","heads")
)

dot_general_cls = None
if self.quant:
dot_general_cls = self.quant.dot_general_cls()
# breakpoint()
else:
print("Quant is NONE ***************")

self.query = nn.Dense(
inner_dim,
kernel_init=qkv_init_kernel,
dot_general_cls=dot_general_cls,
use_bias=False,
dtype=self.dtype,
name="to_q"
Expand All @@ -440,6 +461,7 @@ def setup(self):
self.key = nn.Dense(
inner_dim,
kernel_init=qkv_init_kernel,
dot_general_cls=dot_general_cls,
use_bias=False,
dtype=self.dtype,
name="to_k"
Expand All @@ -448,6 +470,7 @@ def setup(self):
self.value = nn.Dense(
inner_dim,
kernel_init=qkv_init_kernel,
dot_general_cls=dot_general_cls,
use_bias=False,
dtype=self.dtype,
name="to_v")
Expand All @@ -458,6 +481,7 @@ def setup(self):
nn.initializers.lecun_normal(),
("heads","embed")
),
dot_general_cls=dot_general_cls,
dtype=self.dtype,
name="to_out_0")
self.dropout_layer = nn.Dropout(rate=self.dropout)
Expand Down Expand Up @@ -520,6 +544,8 @@ class FlaxBasicTransformerBlock(nn.Module):
Overrides default block sizes for flash attention.
mesh (`jax.sharding.mesh`, *optional*, defaults to `None`):
jax mesh is required if attention is set to flash.
quant (`AqtQuantization`, *optional*, defaults to None)
Configures AQT quantization github.com/google/aqt.
"""
dim: int
n_heads: int
Expand All @@ -533,6 +559,7 @@ class FlaxBasicTransformerBlock(nn.Module):
flash_min_seq_length: int = 4096
flash_block_sizes: BlockSizes = None
mesh: jax.sharding.Mesh = None
quant: Quant = None

def setup(self):
# self attention (or cross_attention if only_cross_attention is True)
Expand All @@ -548,6 +575,7 @@ def setup(self):
flash_block_sizes=self.flash_block_sizes,
mesh=self.mesh,
dtype=self.dtype,
quant=self.quant,
)
# cross attention
self.attn2 = FlaxAttention(
Expand All @@ -562,6 +590,7 @@ def setup(self):
flash_block_sizes=self.flash_block_sizes,
mesh=self.mesh,
dtype=self.dtype,
quant=self.quant,
)
self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
Expand Down Expand Up @@ -625,6 +654,8 @@ class FlaxTransformer2DModel(nn.Module):
Overrides default block sizes for flash attention.
mesh (`jax.sharding.mesh`, *optional*, defaults to `None`):
jax mesh is required if attention is set to flash.
quant (`AqtQuantization`, *optional*, defaults to None)
Configures AQT quantization github.com/google/aqt.
"""
in_channels: int
n_heads: int
Expand All @@ -641,6 +672,7 @@ class FlaxTransformer2DModel(nn.Module):
flash_block_sizes: BlockSizes = None
mesh: jax.sharding.Mesh = None
norm_num_groups: int = 32
quant: Quant = None

def setup(self):
self.norm = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-5)
Expand Down Expand Up @@ -679,7 +711,8 @@ def setup(self):
attention_kernel=self.attention_kernel,
flash_min_seq_length=self.flash_min_seq_length,
flash_block_sizes=self.flash_block_sizes,
mesh=self.mesh
mesh=self.mesh,
quant=self.quant
)
for _ in range(self.depth)
]
Expand Down
Loading