Skip to content

Commit

Permalink
(Studio2) Refactors SD pipeline to rely on turbine-models pipeline, f…
Browse files Browse the repository at this point in the history
…ixes to LLM, gitignore (#2129)

* Shark Studio SDXL support, HIP driver support, simpler device info, small fixes

* Fixups to llm API/UI and ignore user config files.

* Small fixes for unifying pipelines.

* Update requirements.txt for iree-turbine (#2130)

* Fix Llama2 on CPU (#2133)

* Filesystem cleanup and custom model fixes (#2127)

* Fix some formatting issues

* Remove IREE pin (fixes exe issue) (#2126)

* Update find links for IREE packages (#2136)

* Shark Studio SDXL support, HIP driver support, simpler device info, small fixes

* Abstract out SD pipelines from Studio Webui (WIP)

* Switch from pin to minimum torch version and fix index url

* Fix device parsing.

* Fix linux setup

* Fix custom weights.

---------

Co-authored-by: saienduri <[email protected]>
Co-authored-by: gpetters-amd <[email protected]>
Co-authored-by: gpetters94 <[email protected]>
  • Loading branch information
4 people authored May 28, 2024
1 parent fd07cae commit 68e9281
Show file tree
Hide file tree
Showing 19 changed files with 335 additions and 454 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/test-studio.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,4 @@ jobs:
source shark.venv/bin/activate
pip install -r requirements.txt --no-cache-dir
pip install -e .
pip uninstall -y torch
pip install torch==2.1.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
python apps/shark_studio/tests/api_test.py
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,15 @@ cython_debug/
# vscode related
.vscode

# Shark related artefacts
# Shark related artifacts
*venv/
shark_tmp/
*.vmfb
.use-iree
tank/dict_configs.py
*.csv
reproducers/
apps/shark_studio/web/configs

# ORT related artefacts
cache_models/
Expand All @@ -188,6 +189,11 @@ variants.json
# models folder
apps/stable_diffusion/web/models/

# model artifacts (SHARK)
*.tempfile
*.mlir
*.vmfb

# Stencil annotators.
stencil_annotator/

Expand Down
16 changes: 7 additions & 9 deletions apps/shark_studio/api/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ def __init__(
use_auth_token=hf_auth_token,
)
elif not os.path.exists(self.tempfile_name):
self.torch_ir, self.tokenizer = llm_model_map[model_name]["initializer"](
self.torch_ir, self.tokenizer = llm_model_map[self.hf_model_name][
"initializer"
](
self.hf_model_name,
hf_auth_token,
compile_to="torch",
Expand Down Expand Up @@ -258,8 +260,7 @@ def format_out(results):

history.append(format_out(token))
while (
format_out(token)
!= llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]
format_out(token) != llm_model_map[self.hf_model_name]["stop_token"]
and len(history) < self.max_tokens
):
dec_time = time.time()
Expand All @@ -273,10 +274,7 @@ def format_out(results):

self.prev_token_len = token_len + len(history)

if (
format_out(token)
== llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]
):
if format_out(token) == llm_model_map[self.hf_model_name]["stop_token"]:
break

for i in range(len(history)):
Expand Down Expand Up @@ -310,7 +308,7 @@ def chat_hf(self, prompt):
self.first_input = False

history.append(int(token))
while token != llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]:
while token != llm_model_map[self.hf_model_name]["stop_token"]:
dec_time = time.time()
result = self.hf_mod(token.reshape([1, 1]), past_key_values=pkv)
history.append(int(token))
Expand All @@ -321,7 +319,7 @@ def chat_hf(self, prompt):

self.prev_token_len = token_len + len(history)

if token == llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]:
if token == llm_model_map[self.hf_model_name]["stop_token"]:
break
for i in range(len(history)):
if type(history[i]) != int:
Expand Down
Loading

0 comments on commit 68e9281

Please sign in to comment.