diff --git a/src/maxdiffusion/generate.py b/src/maxdiffusion/generate.py index 96bb6738..c5a1bbff 100644 --- a/src/maxdiffusion/generate.py +++ b/src/maxdiffusion/generate.py @@ -168,8 +168,10 @@ def run_inference(unet_state, vae_state, params, rng, config, batch_size, pipeli vae_decode_p = functools.partial(vae_decode, pipeline=pipeline) with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - latents, _, _ = jax.lax.fori_loop(0, config.num_inference_steps, - loop_body_p, (latents, scheduler_state, unet_state)) + val = (latents, scheduler_state, unet_state) + for i in range(0, config.num_inference_steps): + val = loop_body_p(i, val) + latents, _, _ = val image = vae_decode_p(latents, vae_state) return image diff --git a/src/maxdiffusion/generate_sdxl.py b/src/maxdiffusion/generate_sdxl.py index c08123c1..b52214cf 100644 --- a/src/maxdiffusion/generate_sdxl.py +++ b/src/maxdiffusion/generate_sdxl.py @@ -226,8 +226,10 @@ def run_inference(unet_state, vae_state, params, rng, config, batch_size, pipeli vae_decode_p = functools.partial(vae_decode, pipeline=pipeline) with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - latents, _, _ = jax.lax.fori_loop(0, config.num_inference_steps, - loop_body_p, (latents, scheduler_state, unet_state)) + val = (latents, scheduler_state, unet_state) + for i in range(0, config.num_inference_steps): + val = loop_body_p(i, val) + latents, _, _ = val image = vae_decode_p(latents, vae_state) return image diff --git a/src/maxdiffusion/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/maxdiffusion/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index bf70e038..9294e481 100644 --- a/src/maxdiffusion/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/maxdiffusion/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -299,7 +299,10 @@ def loop_body(step, args): for i in range(num_inference_steps): latents, scheduler_state = loop_body(i, (latents, scheduler_state)) else: - latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) + val = (latents, scheduler_state) + for i in range(0, num_inference_steps): + val = loop_body(i, val) + latents, _ = val # scale and decode the image latents with vae latents = 1 / self.vae.config.scaling_factor * latents diff --git a/src/maxdiffusion/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py b/src/maxdiffusion/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py index 82600ba3..3865b1d6 100644 --- a/src/maxdiffusion/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +++ b/src/maxdiffusion/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py @@ -261,7 +261,10 @@ def loop_body(step, args): for i in range(num_inference_steps): latents, scheduler_state = loop_body(i, (latents, scheduler_state)) else: - latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) + val = (latents, scheduler_state) + for i in range(0, num_inference_steps): + val = loop_body(i, val) + latents, _ = val if return_latents: return latents