diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 900c1a9ae..ad297bcce 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -45,6 +45,12 @@ def main(): type=lambda arg: [int(bs) for bs in arg.split(",")], default="4", ) + parser.add_argument( + "--block-seq-stride", + help="Block sequence stride for paged KV cache, must divide evenly into the context length", + type=int, + default="16", + ) parser.add_argument( "--verbose", help="Include verbose logging", @@ -76,6 +82,7 @@ def main(): static_tables=False, # Rely on the compiler for hoisting tables. kv_cache_type="direct" if args.bs == [1] else "paged", attention_kernel=args.attention_kernel, + block_seq_stride=args.block_seq_stride, ) llama_config.fake_quant = args.fake_quant