diff --git a/esm/__init__.py b/esm/__init__.py index 6e3ccdf..64aeb32 100644 --- a/esm/__init__.py +++ b/esm/__init__.py @@ -1,2 +1,2 @@ -__version__ = "3.1.0" +__version__ = "3.1.1" diff --git a/esm/layers/transformer_stack.py b/esm/layers/transformer_stack.py index 37ccc14..0922d06 100644 --- a/esm/layers/transformer_stack.py +++ b/esm/layers/transformer_stack.py @@ -66,7 +66,7 @@ def forward( affine: Affine3D | None = None, affine_mask: torch.Tensor | None = None, chain_id: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Forward pass of the TransformerStack. @@ -85,6 +85,9 @@ def forward( *batch_dims, _ = x.shape if chain_id is None: chain_id = torch.ones(size=batch_dims, dtype=torch.int64, device=x.device) + hiddens = [] for block in self.blocks: x = block(x, sequence_id, affine, affine_mask, chain_id) - return self.norm(x), x + hiddens.append(x) + hiddens = torch.stack(hiddens, dim=0) + return self.norm(x), x, hiddens diff --git a/esm/models/esm3.py b/esm/models/esm3.py index ecc994a..cbe02dd 100644 --- a/esm/models/esm3.py +++ b/esm/models/esm3.py @@ -376,7 +376,9 @@ def forward( function_tokens, residue_annotation_tokens, ) - x, embedding = self.transformer(x, sequence_id, affine, affine_mask, chain_id) + x, embedding, _ = self.transformer( + x, sequence_id, affine, affine_mask, chain_id + ) return self.output_heads(x, embedding) # The following methods are for the ESM3InferenceClient interface diff --git a/esm/models/esmc.py b/esm/models/esmc.py index c153a5c..6329786 100644 --- a/esm/models/esmc.py +++ b/esm/models/esmc.py @@ -21,6 +21,7 @@ from esm.utils import encoding from esm.utils.constants.models import ESMC_600M from esm.utils.decoding import decode_sequence +from esm.utils.misc import stack_variable_length_tensors from esm.utils.sampling import _BatchedESMProteinTensor @@ -28,6 +29,7 @@ class ESMCOutput: sequence_logits: torch.Tensor embeddings: torch.Tensor | None + hidden_states: torch.Tensor | None class ESMC(nn.Module, ESMCInferenceClient): @@ -73,6 +75,23 @@ def device(self): def raw_model(self): return self + def _tokenize(self, sequence: list[str]) -> torch.Tensor: + pad = self.tokenizer.pad_token_id + assert pad is not None + return stack_variable_length_tensors( + [ + encoding.tokenize_sequence(x, self.tokenizer, add_special_tokens=True) + for x in sequence + ], + constant_value=pad, + ).to(next(self.parameters()).device) + + def _detokenize(self, sequence: torch.Tensor) -> list[str]: + pad = self.tokenizer.pad_token_id + assert pad is not None + assert sequence.ndim == 2 + return [decode_sequence(x[x != pad][1:-1], self.tokenizer) for x in sequence] + def forward( self, sequence_tokens: torch.Tensor | None = None, @@ -93,9 +112,11 @@ def forward( sequence_id = sequence_tokens == self.tokenizer.pad_token_id x = self.embed(sequence_tokens) - x, _ = self.transformer(x, sequence_id=sequence_id) + x, _, hiddens = self.transformer(x, sequence_id=sequence_id) sequence_logits = self.sequence_head(x) - output = ESMCOutput(sequence_logits=sequence_logits, embeddings=x) + output = ESMCOutput( + sequence_logits=sequence_logits, embeddings=x, hidden_states=hiddens + ) return output def encode(self, input: ESMProtein) -> ESMProteinTensor: @@ -103,9 +124,7 @@ def encode(self, input: ESMProtein) -> ESMProteinTensor: sequence_tokens = None if input.sequence is not None: - sequence_tokens = encoding.tokenize_sequence( - input.sequence, self.tokenizer, add_special_tokens=True - ) + sequence_tokens = self._tokenize([input.sequence])[0] return ESMProteinTensor(sequence=sequence_tokens).to( next(self.parameters()).device ) @@ -114,7 +133,7 @@ def decode(self, input: ESMProteinTensor) -> ESMProtein: input = attr.evolve(input) # Make a copy assert input.sequence is not None - sequence = decode_sequence(input.sequence[1:-1], self.tokenizer) + sequence = self._detokenize(input.sequence)[0] return ESMProtein(sequence=sequence) diff --git a/esm/models/function_decoder.py b/esm/models/function_decoder.py index 913af17..b918798 100644 --- a/esm/models/function_decoder.py +++ b/esm/models/function_decoder.py @@ -172,7 +172,7 @@ def forward(self, token_ids: torch.Tensor) -> dict[str, torch.Tensor]: inputs = token_ids + vocab_offsets[None, :] embed = self.embedding(inputs) - encoding, _ = self.decoder(embed) + encoding, _, _ = self.decoder(embed) pooled = torch.mean(encoding, dim=1) return {name: head(pooled) for name, head in self.heads.items()} diff --git a/esm/models/vqvae.py b/esm/models/vqvae.py index b7bd149..0f5226a 100644 --- a/esm/models/vqvae.py +++ b/esm/models/vqvae.py @@ -250,7 +250,7 @@ def encode_local_structure( z = self.relative_positional_embedding(res_idxs[:, 0], res_idxs) - z, _ = self.transformer.forward( + z, _, _ = self.transformer.forward( x=z, sequence_id=knn_sequence_id, affine=affine, @@ -397,7 +397,7 @@ def decode( x = self.embed(structure_tokens) # !!! NOTE: Attention mask is actually unused here so watch out - x, _ = self.decoder_stack.forward( + x, _, _ = self.decoder_stack.forward( x, affine=None, affine_mask=None, sequence_id=sequence_id, chain_id=chain_id ) diff --git a/esm/pretrained.py b/esm/pretrained.py index df2686c..10d6e36 100644 --- a/esm/pretrained.py +++ b/esm/pretrained.py @@ -10,7 +10,10 @@ StructureTokenDecoder, StructureTokenEncoder, ) -from esm.tokenization import get_model_tokenizers +from esm.tokenization import ( + get_esm3_model_tokenizers, + get_esmc_model_tokenizers, +) from esm.utils.constants.esm3 import data_root from esm.utils.constants.models import ( ESM3_FUNCTION_DECODER_V0, @@ -62,10 +65,7 @@ def ESM3_function_decoder_v0(device: torch.device | str = "cpu"): def ESMC_300M_202412(device: torch.device | str = "cpu"): with torch.device(device): model = ESMC( - d_model=960, - n_heads=15, - n_layers=30, - tokenizer=get_model_tokenizers(ESM3_OPEN_SMALL).sequence, + d_model=960, n_heads=15, n_layers=30, tokenizer=get_esmc_model_tokenizers() ).eval() state_dict = torch.load( data_root("esmc-300") / "data/weights/esmc_300m_2024_12_v0.pth", @@ -79,10 +79,7 @@ def ESMC_300M_202412(device: torch.device | str = "cpu"): def ESMC_600M_202412(device: torch.device | str = "cpu"): with torch.device(device): model = ESMC( - d_model=1152, - n_heads=18, - n_layers=36, - tokenizer=get_model_tokenizers(ESM3_OPEN_SMALL).sequence, + d_model=1152, n_heads=18, n_layers=36, tokenizer=get_esmc_model_tokenizers() ).eval() state_dict = torch.load( data_root("esmc-600") / "data/weights/esmc_600m_2024_12_v0.pth", @@ -103,7 +100,7 @@ def ESM3_sm_open_v0(device: torch.device | str = "cpu"): structure_encoder_fn=ESM3_structure_encoder_v0, structure_decoder_fn=ESM3_structure_decoder_v0, function_decoder_fn=ESM3_function_decoder_v0, - tokenizers=get_model_tokenizers(ESM3_OPEN_SMALL), + tokenizers=get_esm3_model_tokenizers(ESM3_OPEN_SMALL), ).eval() state_dict = torch.load( data_root("esm3") / "data/weights/esm3_sm_open_v1.pth", map_location=device diff --git a/esm/sdk/api.py b/esm/sdk/api.py index ae96b4c..60619c8 100644 --- a/esm/sdk/api.py +++ b/esm/sdk/api.py @@ -10,7 +10,7 @@ import esm.utils.constants.api as C from esm.tokenization import ( TokenizerCollectionProtocol, - get_model_tokenizers, + get_esm3_model_tokenizers, ) from esm.utils import encoding from esm.utils.constants.models import ESM3_OPEN_SMALL @@ -226,7 +226,7 @@ def empty( device: torch.device | str = "cpu", ) -> ESMProteinTensor: if tokenizers is None: - tokenizers = get_model_tokenizers(ESM3_OPEN_SMALL) + tokenizers = get_esm3_model_tokenizers(ESM3_OPEN_SMALL) return ESMProteinTensor( sequence=encoding.get_default_sequence_tokens( diff --git a/esm/tokenization/__init__.py b/esm/tokenization/__init__.py index 70b9885..ea60922 100644 --- a/esm/tokenization/__init__.py +++ b/esm/tokenization/__init__.py @@ -34,7 +34,7 @@ class TokenizerCollection: residue_annotations: ResidueAnnotationsTokenizer -def get_model_tokenizers(model: str = ESM3_OPEN_SMALL) -> TokenizerCollection: +def get_esm3_model_tokenizers(model: str = ESM3_OPEN_SMALL) -> TokenizerCollection: if normalize_model_name(model) == ESM3_OPEN_SMALL: return TokenizerCollection( sequence=EsmSequenceTokenizer(), @@ -48,6 +48,10 @@ def get_model_tokenizers(model: str = ESM3_OPEN_SMALL) -> TokenizerCollection: raise ValueError(f"Unknown model: {model}") +def get_esmc_model_tokenizers() -> EsmSequenceTokenizer: + return EsmSequenceTokenizer() + + def get_invalid_tokenizer_ids(tokenizer: EsmTokenizerBase) -> list[int]: if isinstance(tokenizer, EsmSequenceTokenizer): return [ diff --git a/esm/utils/generation_test.py b/esm/utils/generation_test.py index fddc7f9..1041b6d 100644 --- a/esm/utils/generation_test.py +++ b/esm/utils/generation_test.py @@ -10,12 +10,12 @@ from evolutionaryscale.utils.remote_inference.api_v1 import ( ESM3RemoteModelInferenceClient, ) -from projects.forge.fastapi.utils.model import _load_esm3 +from projects.forge.fastapi.utils.model import _load_esm_model @pytest.fixture() def esm3_remote_inference_client(): - model = _load_esm3(ModelName.ESM3_TINY_DEV, distributed_model=False) + model = _load_esm_model(ModelName.ESM3_TINY_DEV, distributed_model=False) client = ESM3RemoteModelInferenceClient( model, tokenizers=model.tokenizers, diff --git a/examples/esmc_examples.py b/examples/esmc_examples.py index 2fe423d..cf6a911 100644 --- a/examples/esmc_examples.py +++ b/examples/esmc_examples.py @@ -1,34 +1,48 @@ from esm.models.esmc import ESMC -from examples.local_generate import get_sample_protein -from esm.sdk.api import ( - ESMCInferenceClient, - LogitsConfig, - LogitsOutput, -) +from esm.sdk.api import ESMCInferenceClient, ESMProtein, LogitsConfig, LogitsOutput def main(client: ESMCInferenceClient): # ================================================================ # Example usage: one single protein # ================================================================ - protein = get_sample_protein() - protein.coordinates = None - protein.function_annotations = None - protein.sasa = None + protein = ESMProtein(sequence="AAAAA") # Use logits endpoint. Using bf16 for inference optimization protein_tensor = client.encode(protein) - logits_output = client.logits( + output = client.logits( protein_tensor, LogitsConfig(sequence=True, return_embeddings=True) ) assert isinstance( - logits_output, LogitsOutput - ), f"LogitsOutput was expected but got {logits_output}" - assert ( - logits_output.logits is not None and logits_output.logits.sequence is not None + output, LogitsOutput + ), f"LogitsOutput was expected but got {output}" + assert output.logits is not None and output.logits.sequence is not None + assert output.embeddings is not None and output.embeddings is not None + print( + f"Client returned logits with shape: {output.logits.sequence.shape} and embeddings with shape: {output.embeddings.shape}" + ) + + +def raw_forward(model: ESMC): + protein = ESMProtein(sequence="AAAAA") + sequences = [protein.sequence, protein.sequence] + + # ================================================================ + # Example usage: directly use the model + # ================================================================ + input_ids = model._tokenize(sequences) + output = model(input_ids) + logits, embeddings, hiddens = ( + output.sequence_logits, + output.embeddings, + output.hidden_states, + ) + print( + f"Raw model returned logits with shape: {logits.shape}, embeddings with shape: {embeddings.shape} and hidden states with shape {hiddens.shape}" ) - assert logits_output.embeddings is not None and logits_output.embeddings is not None if __name__ == "__main__": - main(ESMC.from_pretrained("esmc_300m")) + model = ESMC.from_pretrained("esmc_300m") + main(model) + raw_forward(model) diff --git a/examples/raw_forwards.py b/examples/raw_forwards.py index af41dbb..c62f173 100644 --- a/examples/raw_forwards.py +++ b/examples/raw_forwards.py @@ -9,7 +9,7 @@ ESM3_structure_decoder_v0, ESM3_structure_encoder_v0, ) -from esm.tokenization import get_model_tokenizers +from esm.tokenization import get_esm3_model_tokenizers from esm.tokenization.function_tokenizer import ( InterProQuantizedTokenizer as EsmFunctionTokenizer, ) @@ -50,7 +50,7 @@ def inverse_folding_example(): @torch.no_grad() def conditioned_prediction_example(): - tokenizers = get_model_tokenizers() + tokenizers = get_esm3_model_tokenizers() model = ESM3_sm_open_v0("cuda") diff --git a/pyproject.toml b/pyproject.toml index f7ff9db..92212e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "esm" -version = "3.1.0" +version = "3.1.1" description = "EvolutionaryScale open model repository" readme = "README.md" requires-python = ">=3.10" @@ -24,7 +24,7 @@ dependencies = [ "torch>=2.2.0", "torchvision", "torchtext", - "transformers", + "transformers<4.47.0", "ipython", "einops", "biotite==0.41.2",