diff --git a/gemma/sampler.py b/gemma/sampler.py index 1ec029c..b2826bd 100644 --- a/gemma/sampler.py +++ b/gemma/sampler.py @@ -110,6 +110,7 @@ def __init__( transformer: transformer_lib.Transformer, vocab: spm.SentencePieceProcessor, params: params_lib.Params, + cache_length: int | None = None, ): """Initializes a sampler for a Gemma model. @@ -122,6 +123,9 @@ def __init__( self.vocab = vocab self.params = params self._compiled_sample_fn = jax.jit(self._sample_fn) + self.cache_length = cache_length or transformer.config.max_cache_length + if self.cache_length is None: + raise ValueError('Sampler `cache_length` should be set.') @property def dtype(self) -> jnp.dtype: @@ -192,7 +196,11 @@ def _sample_step( def init_cache(self, bsz) -> dict[str, modules.LayerCache]: """Initializes the attention cache for each layer.""" - return self.transformer.config.init_cache(bsz, dtype=self.dtype) + return self.transformer.config.init_cache( + bsz, + dtype=self.dtype, + cache_length=self.cache_length, + ) def init_sample_state( self, diff --git a/gemma/transformer.py b/gemma/transformer.py index 630bf82..b3e2d30 100644 --- a/gemma/transformer.py +++ b/gemma/transformer.py @@ -17,6 +17,7 @@ import dataclasses import enum from typing import Iterable +import warnings from flax import linen as nn from gemma import layers @@ -71,6 +72,15 @@ class TransformerConfig: sliding_window_size: int | None = None transpose_gating_einsum: bool = False + def __post_init__(self): + if self.max_cache_length is not None: + warnings.warn( + 'TransformerConfig.max_cache_length is deprecated and will be' + ' REMOVED!!! Instead, set the `cache_length` in the `Sampler` class.', + DeprecationWarning, + stacklevel=2, + ) + def query_pre_attn_scalar(self) -> float: """Returns the scalar to multiply the query by before attention.""" match self.query_pre_attn_norm: @@ -227,10 +237,15 @@ def init_cache( self, batch_size: int, dtype: jnp.dtype = jnp.bfloat16, + *, + cache_length: int | None = None, ) -> Cache: """Initializes a new Transformer cache.""" - if self.max_cache_length is None: - raise ValueError('max_cache_length must be set to initialize cache.') + cache_length = cache_length or self.max_cache_length + if cache_length is None: + raise ValueError( + 'Missing `cache_length=` kwarg when calling `init_cache()`.' + ) cache = { f'layer_{i}': modules.Attention.init_cache( self.max_cache_length,