Skip to content

Commit

Permalink
Make cache_length an argument from the Sampler
Browse files Browse the repository at this point in the history
`max_cache_length` is currently defined in the `TransformerConfig` but this really should be a sampler argument.

* This would allow to use the same transformer instance for both training and inference.
* From an API perspective, creating the Gemma configs for the 2b,... models should not require arguments.

PiperOrigin-RevId: 708307065
  • Loading branch information
Conchylicultor authored and The gemma Authors committed Jan 6, 2025
1 parent 0d6ae85 commit 399500e
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 5 deletions.
22 changes: 20 additions & 2 deletions gemma/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from collections.abc import Sequence
import dataclasses
import warnings

import chex
from gemma import modules
Expand Down Expand Up @@ -110,18 +111,31 @@ def __init__(
transformer: transformer_lib.Transformer,
vocab: spm.SentencePieceProcessor,
params: params_lib.Params,
*,
cache_length: int | None = None,
):
"""Initializes a sampler for a Gemma model.
Args:
transformer: an instance of the Gemma transformer.
vocab: vocabulary of the given model.
params: weights of the model.
cache_length: Max length of the cache.
"""
self.transformer = transformer
self.vocab = vocab
self.params = params
self._compiled_sample_fn = jax.jit(self._sample_fn)
if cache_length is None:
warnings.warn(
'TransformerConfig.max_cache_length is deprecated and will be'
' REMOVED!!! Instead, set the `cache_length` in the `Sampler` class.',
DeprecationWarning,
stacklevel=2,
)
self.cache_length = cache_length or transformer.config.max_cache_length
if self.cache_length is None:
raise ValueError('Sampler `cache_length` should be set.')

@property
def dtype(self) -> jnp.dtype:
Expand All @@ -136,7 +150,7 @@ def _sample_step(
last_token = sampler_state.token_buffer[:, decoding_step]
input_mask = sampler_state.token_buffer != self.vocab.pad_id()
attention_mask = _compute_attention_masks(
decoding_step, self.transformer.config.max_cache_length, input_mask
decoding_step, self.cache_length, input_mask
)
step_positions = jnp.expand_dims(
sampler_state.positions[:, decoding_step], -1
Expand Down Expand Up @@ -192,7 +206,11 @@ def _sample_step(

def init_cache(self, bsz) -> dict[str, modules.LayerCache]:
"""Initializes the attention cache for each layer."""
return self.transformer.config.init_cache(bsz, dtype=self.dtype)
return self.transformer.config.init_cache(
bsz,
dtype=self.dtype,
cache_length=self.cache_length,
)

def init_sample_state(
self,
Expand Down
11 changes: 8 additions & 3 deletions gemma/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,18 @@ def init_cache(
self,
batch_size: int,
dtype: jnp.dtype = jnp.bfloat16,
*,
cache_length: int | None = None,
) -> Cache:
"""Initializes a new Transformer cache."""
if self.max_cache_length is None:
raise ValueError('max_cache_length must be set to initialize cache.')
cache_length = cache_length or self.max_cache_length
if cache_length is None:
raise ValueError(
'Missing `cache_length=` kwarg when calling `init_cache()`.'
)
cache = {
f'layer_{i}': modules.Attention.init_cache(
self.max_cache_length,
cache_length,
self.num_kv_heads,
self.head_dim,
batch_size,
Expand Down

0 comments on commit 399500e

Please sign in to comment.