From 6b1559b7642fb8aadf8f9bbc268f03d177223f64 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Thu, 21 Mar 2024 12:57:09 -0700 Subject: [PATCH 1/4] disable fori_loop in generate_sdxl.py --- src/maxdiffusion/generate_sdxl.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/maxdiffusion/generate_sdxl.py b/src/maxdiffusion/generate_sdxl.py index c08123c1..b52214cf 100644 --- a/src/maxdiffusion/generate_sdxl.py +++ b/src/maxdiffusion/generate_sdxl.py @@ -226,8 +226,10 @@ def run_inference(unet_state, vae_state, params, rng, config, batch_size, pipeli vae_decode_p = functools.partial(vae_decode, pipeline=pipeline) with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - latents, _, _ = jax.lax.fori_loop(0, config.num_inference_steps, - loop_body_p, (latents, scheduler_state, unet_state)) + val = (latents, scheduler_state, unet_state) + for i in range(0, config.num_inference_steps): + val = loop_body_p(i, val) + latents, _, _ = val image = vae_decode_p(latents, vae_state) return image From f26688379740d3348a941c9b182c47bf2f5c123e Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Thu, 21 Mar 2024 13:00:48 -0700 Subject: [PATCH 2/4] Update generate.py --- src/maxdiffusion/generate.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/maxdiffusion/generate.py b/src/maxdiffusion/generate.py index 96bb6738..c5a1bbff 100644 --- a/src/maxdiffusion/generate.py +++ b/src/maxdiffusion/generate.py @@ -168,8 +168,10 @@ def run_inference(unet_state, vae_state, params, rng, config, batch_size, pipeli vae_decode_p = functools.partial(vae_decode, pipeline=pipeline) with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - latents, _, _ = jax.lax.fori_loop(0, config.num_inference_steps, - loop_body_p, (latents, scheduler_state, unet_state)) + val = (latents, scheduler_state, unet_state) + for i in range(0, config.num_inference_steps): + val = loop_body_p(i, val) + latents, _, _ = val image = vae_decode_p(latents, vae_state) return image From 0d4805ea9443b17cfa0e8b64b7ebb56b5c95ae12 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Thu, 21 Mar 2024 13:02:38 -0700 Subject: [PATCH 3/4] Update pipeline_flax_stable_diffusion_xl.py --- .../stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/maxdiffusion/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py b/src/maxdiffusion/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py index 82600ba3..3865b1d6 100644 --- a/src/maxdiffusion/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +++ b/src/maxdiffusion/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py @@ -261,7 +261,10 @@ def loop_body(step, args): for i in range(num_inference_steps): latents, scheduler_state = loop_body(i, (latents, scheduler_state)) else: - latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) + val = (latents, scheduler_state) + for i in range(0, num_inference_steps): + val = loop_body(i, val) + latents, _ = val if return_latents: return latents From 58c7783b03e7fa92b61e470d01abcb9f50579c6b Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Thu, 21 Mar 2024 13:04:05 -0700 Subject: [PATCH 4/4] Update pipeline_flax_stable_diffusion.py --- .../stable_diffusion/pipeline_flax_stable_diffusion.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/maxdiffusion/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/maxdiffusion/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index bf70e038..9294e481 100644 --- a/src/maxdiffusion/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/maxdiffusion/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -299,7 +299,10 @@ def loop_body(step, args): for i in range(num_inference_steps): latents, scheduler_state = loop_body(i, (latents, scheduler_state)) else: - latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) + val = (latents, scheduler_state) + for i in range(0, num_inference_steps): + val = loop_body(i, val) + latents, _ = val # scale and decode the image latents with vae latents = 1 / self.vae.config.scaling_factor * latents