Skip to content

Commit

Permalink
Fix difference of LLM export for the direct vs paged cache (#347)
Browse files Browse the repository at this point in the history
Before the work on unifying the cache interfaces lands there are some
differences between sharded, direct and paged caches.
The direct cache uses a list of tensors for each transformer block while
paged cache has one slab and paged sharded expects a list of shards.
  • Loading branch information
sogartar authored Oct 28, 2024
1 parent fe56b30 commit 98392d0
Showing 1 changed file with 29 additions and 20 deletions.
49 changes: 29 additions & 20 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,35 +116,38 @@ def setup_cache(model, shard_count):
page_count=hp.context_length // llama_config.block_seq_stride
)
page_dim = torch.export.Dim("page")

dynamic_shapes = [{0: page_dim}]
unpacked = cache_state
arg_affinities = {}
shard_dim = None

# Need to unpacke that state when sharded
if llama_config.tensor_parallelism_size > 1:
shard_dim = cache_state[0].shard_dim

unpacked = [[shard._data for shard in cs.shards] for cs in cache_state]
dynamic_shapes = [
[ds] * llama_config.tensor_parallelism_size for ds in dynamic_shapes
]

for i in range(llama_config.tensor_parallelism_size):
arg_affinities[i] = DeviceAffinity(str(i))

return unpacked, shard_dim, dynamic_shapes, arg_affinities

elif model.config.kv_cache_type == "direct":
cache_state = model.cache.allocate(bs=1)
# Direct cache dimensions:
# 2 * transformer_block_count of...
# [bs, seq_length, attn_head_count, attn_head_dim]
dynamic_shapes = [None]
arg_affinities = {}
shard_dim = None
return torch.stack(cache_state), shard_dim, dynamic_shapes, arg_affinities
else:
raise NotImplementedError(f"Unsupported KV cache type: {type(model.cache)}")

unpacked = cache_state
dynamic_shapes = dynamic_shapes
arg_affinities = {}
shard_dim = None

# Need to unpacke that state when sharded
if llama_config.tensor_parallelism_size > 1:
shard_dim = cache_state[0].shard_dim

unpacked = [[shard._data for shard in cs.shards] for cs in cache_state]
dynamic_shapes = [
[ds] * llama_config.tensor_parallelism_size for ds in dynamic_shapes
]

for i in range(llama_config.tensor_parallelism_size):
arg_affinities[i] = DeviceAffinity(str(i))

return torch.stack(unpacked), shard_dim, dynamic_shapes, arg_affinities

def repack_cache(cache, shard_dim):
return [SplitPrimitiveTensor(ts=c, shard_dim=shard_dim) for c in cache]

Expand Down Expand Up @@ -184,7 +187,13 @@ def generate_batch_prefill(bs: int):
arg_device=arg_affinities,
)
def _(model, tokens, seq_lens, seq_block_ids, cs):
cache_tensors = torch.unbind(cs)
if (
model.config.tensor_parallelism_size == 1
and model.config.kv_cache_type == "direct"
):
cache_tensors = torch.unbind(cs)
else:
cache_tensors = cs

sl = tokens.shape[1]
input_mask = model.input_mask(seq_lens, sl)
Expand Down

0 comments on commit 98392d0

Please sign in to comment.