Skip to content

Commit

Permalink
update for new eqx
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Nov 16, 2024
1 parent 2629e77 commit abe0a89
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 9 deletions.
9 changes: 2 additions & 7 deletions src/levanter/compat/hf_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,7 @@
import haliax
from haliax import Axis
from haliax.partitioning import ResourceMapping
from haliax.state_dict import (
ModuleWithStateDictSerialization,
from_torch_compatible_state_dict,
save_state_dict,
to_torch_compatible_state_dict,
)
from haliax.state_dict import from_torch_compatible_state_dict, save_state_dict, to_torch_compatible_state_dict

from levanter.logging import silence_transformer_nag
from levanter.models.asr_model import ASRMixin
Expand Down Expand Up @@ -133,7 +128,7 @@ def hf_checkpoint_converter(cls) -> "HFCheckpointConverter":
MConfig = TypeVar("MConfig", bound=HFCompatConfig)


class ModelWithHfSerializationMixin(Generic[MConfig], ModuleWithStateDictSerialization):
class ModelWithHfSerializationMixin(Generic[MConfig]):
def get_hf_config(self):
return self.config.to_hf_config(self.Vocab.size)

Expand Down
2 changes: 1 addition & 1 deletion src/levanter/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def __call__(self, x: NamedArray, mask: Optional[AttentionMask | NamedArray], la
return x


class Gpt2Transformer(ModuleWithStateDictSerialization, eqx.Module):
class Gpt2Transformer(ModuleWithStateDictSerialization):
config: Gpt2Config = eqx.static_field()
blocks: Stacked[Gpt2Block]
ln_f: hnn.LayerNorm
Expand Down
2 changes: 1 addition & 1 deletion src/levanter/models/lm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def build(self, Vocab: Axis, *, key: PRNGKey) -> "LmT":
return self.model_type.init(Vocab, self, key=key) # type: ignore


class LmHeadModel(Generic[LmConfigT], abc.ABC):
class LmHeadModel(eqx.Module, Generic[LmConfigT]):
"""
Superclass for models with a language modeling head.
"""
Expand Down

0 comments on commit abe0a89

Please sign in to comment.