diff --git a/gemma/sampler.py b/gemma/sampler.py index 1ec029c..750fed5 100644 --- a/gemma/sampler.py +++ b/gemma/sampler.py @@ -19,6 +19,7 @@ from collections.abc import Sequence import dataclasses +import warnings import chex from gemma import modules @@ -110,6 +111,8 @@ def __init__( transformer: transformer_lib.Transformer, vocab: spm.SentencePieceProcessor, params: params_lib.Params, + *, + cache_length: int | None = None, ): """Initializes a sampler for a Gemma model. @@ -117,11 +120,22 @@ def __init__( transformer: an instance of the Gemma transformer. vocab: vocabulary of the given model. params: weights of the model. + cache_length: Max length of the cache. """ self.transformer = transformer self.vocab = vocab self.params = params self._compiled_sample_fn = jax.jit(self._sample_fn) + if cache_length is None: + warnings.warn( + 'TransformerConfig.max_cache_length is deprecated and will be' + ' REMOVED!!! Instead, set the `cache_length` in the `Sampler` class.', + DeprecationWarning, + stacklevel=2, + ) + self.cache_length = cache_length or transformer.config.max_cache_length + if self.cache_length is None: + raise ValueError('Sampler `cache_length` should be set.') @property def dtype(self) -> jnp.dtype: @@ -136,7 +150,7 @@ def _sample_step( last_token = sampler_state.token_buffer[:, decoding_step] input_mask = sampler_state.token_buffer != self.vocab.pad_id() attention_mask = _compute_attention_masks( - decoding_step, self.transformer.config.max_cache_length, input_mask + decoding_step, self.cache_length, input_mask ) step_positions = jnp.expand_dims( sampler_state.positions[:, decoding_step], -1 @@ -192,7 +206,11 @@ def _sample_step( def init_cache(self, bsz) -> dict[str, modules.LayerCache]: """Initializes the attention cache for each layer.""" - return self.transformer.config.init_cache(bsz, dtype=self.dtype) + return self.transformer.config.init_cache( + bsz, + dtype=self.dtype, + cache_length=self.cache_length, + ) def init_sample_state( self, diff --git a/gemma/transformer.py b/gemma/transformer.py index 630bf82..a03d9c1 100644 --- a/gemma/transformer.py +++ b/gemma/transformer.py @@ -227,13 +227,18 @@ def init_cache( self, batch_size: int, dtype: jnp.dtype = jnp.bfloat16, + *, + cache_length: int | None = None, ) -> Cache: """Initializes a new Transformer cache.""" - if self.max_cache_length is None: - raise ValueError('max_cache_length must be set to initialize cache.') + cache_length = cache_length or self.max_cache_length + if cache_length is None: + raise ValueError( + 'Missing `cache_length=` kwarg when calling `init_cache()`.' + ) cache = { f'layer_{i}': modules.Attention.init_cache( - self.max_cache_length, + cache_length, self.num_kv_heads, self.head_dim, batch_size,