Skip to content

Commit

Permalink
Merge pull request #462 from DefTruth/main
Browse files Browse the repository at this point in the history
[Parallel] Avoid OOM while batch size > 1
  • Loading branch information
zRzRzRzRzRzRzR authored Nov 6, 2024
2 parents 3710a61 + bb69713 commit 4aebdb4
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion tools/parallel_inference/parallel_inference_xdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,14 @@ def main():
)
if args.enable_sequential_cpu_offload:
pipe.enable_model_cpu_offload(gpu_id=local_rank)
pipe.vae.enable_tiling()
else:
device = torch.device(f"cuda:{local_rank}")
pipe = pipe.to(device)

# Always enable tiling and slicing to avoid VAE OOM while batch size > 1
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()

torch.cuda.reset_peak_memory_stats()
start_time = time.time()

Expand Down

0 comments on commit 4aebdb4

Please sign in to comment.