diff --git a/shortfin/python/shortfin_apps/llm/components/debug_service.py b/shortfin/python/shortfin_apps/llm/components/debug_service.py index 2eef3b6d8..7e801f4ce 100644 --- a/shortfin/python/shortfin_apps/llm/components/debug_service.py +++ b/shortfin/python/shortfin_apps/llm/components/debug_service.py @@ -15,7 +15,7 @@ import json import numpy as np import pandas as pd -from typing import List +from typing import List, Dict, Any from pprint import pformat logger = logging.getLogger(__name__) @@ -30,9 +30,9 @@ logger.info("DEBUG_LLM_SERVICE=True") dump_id = 0 boot_timestamp = datetime.now().isoformat() - DEBUG_DATA_DIR = Path.home() / "sfdebug" + DEBUG_DATA_DIR = Path.home() / ".shortfin/debug/" DUMP_DIR_THIS_SESSION = ( - DEBUG_DATA_DIR / f"llm_service_invocation_dumps_from_{boot_timestamp}" + DEBUG_DATA_DIR / "llm_service_invocation_dumps" / "{boot_timestamp}" ) DUMP_DIR_THIS_SESSION.mkdir(parents=True, exist_ok=False) logger.info( @@ -41,26 +41,28 @@ async def pre_invocation_debug_dump( - phase, - is_decode, - device0, - fn, - req_bs, - bsl, - seq_stride, - block_count, - req_count, - exec_requests, - tokens, - start_positions, - seq_lens, - seq_block_ids, - model_params, - args, + executor: "InferenceExecutorProcess", local_vars: Dict[str, Any] ): """Comprehensive debug dump before inference invocation.""" - if not SHORTFIN_DEBUG_LLM_SERVICE: - return + + # Extract variables from locals + is_decode = local_vars["is_decode"] + device0 = local_vars["device0"] + fn = local_vars["fn"] + req_bs = local_vars["req_bs"] + bsl = local_vars["bsl"] + seq_stride = local_vars["seq_stride"] + block_count = local_vars["block_count"] + req_count = local_vars["req_count"] + tokens = local_vars["tokens"] + start_positions = local_vars.get("start_positions") + seq_lens = local_vars["seq_lens"] + seq_block_ids = local_vars["seq_block_ids"] + args = local_vars["args"] + + phase = executor.phase + exec_requests = executor.exec_requests + model_params = executor.service.model_params global dump_id dump_path = DUMP_DIR_THIS_SESSION / f"{dump_id}" diff --git a/shortfin/python/shortfin_apps/llm/components/generate.py b/shortfin/python/shortfin_apps/llm/components/generate.py index 698f779fb..9e9fea692 100644 --- a/shortfin/python/shortfin_apps/llm/components/generate.py +++ b/shortfin/python/shortfin_apps/llm/components/generate.py @@ -49,7 +49,11 @@ def __init__( self.eos_token_id = eos_token_id async def run(self): - exec = InferenceExecRequest(InferencePhase.PREFILL, self.input_token_ids) + exec = InferenceExecRequest( + phase=InferencePhase.PREFILL, + input_token_ids=self.input_token_ids, + rid=self.gen_req.rid, + ) try: self.client.batcher.submit(exec) await exec.done diff --git a/shortfin/python/shortfin_apps/llm/components/messages.py b/shortfin/python/shortfin_apps/llm/components/messages.py index 9e2ab7179..724f71569 100644 --- a/shortfin/python/shortfin_apps/llm/components/messages.py +++ b/shortfin/python/shortfin_apps/llm/components/messages.py @@ -21,12 +21,13 @@ class InferencePhase(Enum): class InferenceExecRequest(sf.Message): """Performs a prefill operation.""" - def __init__(self, phase: InferencePhase, input_token_ids: list[int]): + def __init__(self, phase: InferencePhase, input_token_ids: list[int], rid=None): super().__init__() self.phase = phase self.start_position: int = 0 self.input_token_ids = input_token_ids self.done = sf.VoidFuture() + self.rid = rid # Response control. # If True, return all sequence position logits. If False, return only diff --git a/shortfin/python/shortfin_apps/llm/components/service.py b/shortfin/python/shortfin_apps/llm/components/service.py index 0caa21194..31edc7f51 100644 --- a/shortfin/python/shortfin_apps/llm/components/service.py +++ b/shortfin/python/shortfin_apps/llm/components/service.py @@ -29,6 +29,14 @@ isolation.name.lower(): isolation for isolation in sf.ProgramIsolation } +import os + +SHORTFIN_DEBUG_LLM_SERVICE = os.getenv( + "SHORTFIN_DEBUG_LLM_SERVICE", "False" +).lower() in ("true", "yes", "1", "y") +if SHORTFIN_DEBUG_LLM_SERVICE: + from .debug_service import pre_invocation_debug_dump + class GenerateService: """Top level service interface for generating text against a model.""" @@ -435,32 +443,8 @@ async def run(self): ) # pre-invocation args dump - try: - from .debug_service import pre_invocation_debug_dump - - await pre_invocation_debug_dump( - phase=self.phase, - is_decode=is_decode, - device0=device0, - fn=fn, - req_bs=req_bs, - bsl=bsl, - seq_stride=seq_stride, - block_count=block_count, - req_count=req_count, - exec_requests=self.exec_requests, - tokens=tokens, - start_positions=start_positions if is_decode else None, - seq_lens=seq_lens, - seq_block_ids=seq_block_ids, - model_params=self.service.model_params, - args=args, - ) - except Exception as e: - err_msg = ( - f"Error Type: {type(e).__name__}\n" f"Error Message: {str(e)}\n" - ) - logger.info(f"Non-critical failure: debug logging failed due to {e}") + if SHORTFIN_DEBUG_LLM_SERVICE: + await pre_invocation_debug_dump(executor=self, local_vars=locals()) # invoke VMFB (logits,) = await fn(*args, fiber=self.fiber)