Skip to content

Commit

Permalink
add request id propagation to InferenceExecProcess and simplify debug…
Browse files Browse the repository at this point in the history
… interface a little; remove catch-all try statement because if we're not debugging we won't need it and if we're debugging we don't want the debugging code to fail silently when loglevel doesn't output info messages
  • Loading branch information
renxida committed Dec 10, 2024
1 parent e7e6f01 commit 6c9d802
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 49 deletions.
44 changes: 23 additions & 21 deletions shortfin/python/shortfin_apps/llm/components/debug_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import json
import numpy as np
import pandas as pd
from typing import List
from typing import List, Dict, Any
from pprint import pformat

logger = logging.getLogger(__name__)
Expand All @@ -30,9 +30,9 @@
logger.info("DEBUG_LLM_SERVICE=True")
dump_id = 0
boot_timestamp = datetime.now().isoformat()
DEBUG_DATA_DIR = Path.home() / "sfdebug"
DEBUG_DATA_DIR = Path.home() / ".shortfin/debug/"
DUMP_DIR_THIS_SESSION = (
DEBUG_DATA_DIR / f"llm_service_invocation_dumps_from_{boot_timestamp}"
DEBUG_DATA_DIR / "llm_service_invocation_dumps" / "{boot_timestamp}"
)
DUMP_DIR_THIS_SESSION.mkdir(parents=True, exist_ok=False)
logger.info(
Expand All @@ -41,26 +41,28 @@


async def pre_invocation_debug_dump(
phase,
is_decode,
device0,
fn,
req_bs,
bsl,
seq_stride,
block_count,
req_count,
exec_requests,
tokens,
start_positions,
seq_lens,
seq_block_ids,
model_params,
args,
executor: "InferenceExecutorProcess", local_vars: Dict[str, Any]
):
"""Comprehensive debug dump before inference invocation."""
if not SHORTFIN_DEBUG_LLM_SERVICE:
return

# Extract variables from locals
is_decode = local_vars["is_decode"]
device0 = local_vars["device0"]
fn = local_vars["fn"]
req_bs = local_vars["req_bs"]
bsl = local_vars["bsl"]
seq_stride = local_vars["seq_stride"]
block_count = local_vars["block_count"]
req_count = local_vars["req_count"]
tokens = local_vars["tokens"]
start_positions = local_vars.get("start_positions")
seq_lens = local_vars["seq_lens"]
seq_block_ids = local_vars["seq_block_ids"]
args = local_vars["args"]

phase = executor.phase
exec_requests = executor.exec_requests
model_params = executor.service.model_params

global dump_id
dump_path = DUMP_DIR_THIS_SESSION / f"{dump_id}"
Expand Down
6 changes: 5 additions & 1 deletion shortfin/python/shortfin_apps/llm/components/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ def __init__(
self.eos_token_id = eos_token_id

async def run(self):
exec = InferenceExecRequest(InferencePhase.PREFILL, self.input_token_ids)
exec = InferenceExecRequest(
phase=InferencePhase.PREFILL,
input_token_ids=self.input_token_ids,
rid=self.gen_req.rid,
)
try:
self.client.batcher.submit(exec)
await exec.done
Expand Down
3 changes: 2 additions & 1 deletion shortfin/python/shortfin_apps/llm/components/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ class InferencePhase(Enum):
class InferenceExecRequest(sf.Message):
"""Performs a prefill operation."""

def __init__(self, phase: InferencePhase, input_token_ids: list[int]):
def __init__(self, phase: InferencePhase, input_token_ids: list[int], rid=None):
super().__init__()
self.phase = phase
self.start_position: int = 0
self.input_token_ids = input_token_ids
self.done = sf.VoidFuture()
self.rid = rid

# Response control.
# If True, return all sequence position logits. If False, return only
Expand Down
36 changes: 10 additions & 26 deletions shortfin/python/shortfin_apps/llm/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@
isolation.name.lower(): isolation for isolation in sf.ProgramIsolation
}

import os

SHORTFIN_DEBUG_LLM_SERVICE = os.getenv(
"SHORTFIN_DEBUG_LLM_SERVICE", "False"
).lower() in ("true", "yes", "1", "y")
if SHORTFIN_DEBUG_LLM_SERVICE:
from .debug_service import pre_invocation_debug_dump


class GenerateService:
"""Top level service interface for generating text against a model."""
Expand Down Expand Up @@ -435,32 +443,8 @@ async def run(self):
)

# pre-invocation args dump
try:
from .debug_service import pre_invocation_debug_dump

await pre_invocation_debug_dump(
phase=self.phase,
is_decode=is_decode,
device0=device0,
fn=fn,
req_bs=req_bs,
bsl=bsl,
seq_stride=seq_stride,
block_count=block_count,
req_count=req_count,
exec_requests=self.exec_requests,
tokens=tokens,
start_positions=start_positions if is_decode else None,
seq_lens=seq_lens,
seq_block_ids=seq_block_ids,
model_params=self.service.model_params,
args=args,
)
except Exception as e:
err_msg = (
f"Error Type: {type(e).__name__}\n" f"Error Message: {str(e)}\n"
)
logger.info(f"Non-critical failure: debug logging failed due to {e}")
if SHORTFIN_DEBUG_LLM_SERVICE:
await pre_invocation_debug_dump(executor=self, local_vars=locals())

# invoke VMFB
(logits,) = await fn(*args, fiber=self.fiber)
Expand Down

0 comments on commit 6c9d802

Please sign in to comment.