diff --git a/shortfin/python/shortfin_apps/llm/components/service.py b/shortfin/python/shortfin_apps/llm/components/service.py index 5b43c1310..71d54c234 100644 --- a/shortfin/python/shortfin_apps/llm/components/service.py +++ b/shortfin/python/shortfin_apps/llm/components/service.py @@ -340,8 +340,9 @@ async def run(self): for r in self.exec_requests: assert r.start_position == 0 + extra_token_slots = 1 if is_decode else 0 bsl = max( - (r.start_position + len(r.input_token_ids)) for r in self.exec_requests + (extra_token_slots + len(r.input_token_ids)) for r in self.exec_requests ) bsl = int(math.ceil(bsl / seq_stride) * seq_stride) block_count = bsl // seq_stride @@ -389,13 +390,17 @@ async def run(self): if self.phase == InferencePhase.DECODE: start_positions_host = start_positions.for_transfer() with start_positions_host.map(discard=True) as m: - m.fill(0) + m.fill( + 1 + ) # Pad unused requests. Must pad with nonzero value because division by 0 floods clobber page (page 0) in cache with NaN values. m.items = [req.start_position for req in self.exec_requests] start_positions_host.copy_to(start_positions) seq_lens_host = seq_lens.for_transfer() with seq_lens_host.map(discard=True) as m: - m.fill(0) + m.fill( + 1 + ) # Pad unused requests. Must pad with nonzero value because division by 0 floods clobber page (page 0) in cache with NaN values. m.items = [ req.start_position + len(req.input_token_ids) for req in self.exec_requests