From 151d3009bce0b2951903fe702692b885582855c6 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 30 May 2024 15:03:58 -0400 Subject: [PATCH] Add compiled pipeline option --- apps/shark_studio/api/sd.py | 12 ++++++++---- apps/shark_studio/web/ui/sd.py | 14 +++++++------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index aa0a66f907..d70b60c7e7 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -181,12 +181,17 @@ def __init__( print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.") gc.collect() - def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img): + def prepare_pipe( + self, custom_weights, adapters, embeddings, is_img2img, compiled_pipeline + ): print(f"\n[LOG] Preparing pipeline...") self.is_img2img = False mlirs = copy.deepcopy(self.model_map) vmfbs = copy.deepcopy(self.model_map) weights = copy.deepcopy(self.model_map) + if not self.is_sdxl: + compiled_pipeline = False + self.compiled_pipeline = compiled_pipeline if custom_weights: custom_weights = os.path.join( @@ -253,7 +258,6 @@ def generate_images( guidance_scale, seed, ondemand, - repeatable_seeds, resample_type, control_mode, hints, @@ -306,7 +310,7 @@ def shark_sd_fn( device: str, target_triple: str, ondemand: bool, - repeatable_seeds: bool, + compiled_pipeline: bool, resample_type: str, controlnets: dict, embeddings: dict, @@ -369,6 +373,7 @@ def shark_sd_fn( "adapters": adapters, "embeddings": embeddings, "is_img2img": is_img2img, + "compiled_pipeline": compiled_pipeline, } submit_run_kwargs = { "prompt": prompt, @@ -378,7 +383,6 @@ def shark_sd_fn( "guidance_scale": guidance_scale, "seed": seed, "ondemand": ondemand, - "repeatable_seeds": repeatable_seeds, "resample_type": resample_type, "control_mode": control_mode, "hints": hints, diff --git a/apps/shark_studio/web/ui/sd.py b/apps/shark_studio/web/ui/sd.py index 20330bcf75..13daa83aa8 100644 --- a/apps/shark_studio/web/ui/sd.py +++ b/apps/shark_studio/web/ui/sd.py @@ -120,7 +120,7 @@ def pull_sd_configs( device, target_triple, ondemand, - repeatable_seeds, + compiled_pipeline, resample_type, controlnets, embeddings, @@ -179,7 +179,7 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str): sd_json["device"], sd_json["target_triple"], sd_json["ondemand"], - sd_json["repeatable_seeds"], + sd_json["compiled_pipeline"], sd_json["resample_type"], sd_json["controlnets"], sd_json["embeddings"], @@ -606,9 +606,9 @@ def base_model_changed(base_model_id): interactive=True, visible=True, ) - repeatable_seeds = gr.Checkbox( - cmd_opts.repeatable_seeds, - label="Use Repeatable Seeds for Batches", + compiled_pipeline = gr.Checkbox( + False, + label="Faster txt2img (SDXL only)", ) with gr.Row(): stable_diffusion = gr.Button("Start") @@ -685,7 +685,7 @@ def base_model_changed(base_model_id): device, target_triple, ondemand, - repeatable_seeds, + compiled_pipeline, resample_type, cnet_config, embeddings_config, @@ -741,7 +741,7 @@ def base_model_changed(base_model_id): device, target_triple, ondemand, - repeatable_seeds, + compiled_pipeline, resample_type, cnet_config, embeddings_config,