Skip to content

Commit

Permalink
esmc release
Browse files Browse the repository at this point in the history
  • Loading branch information
tina-z-jia committed Dec 4, 2024
1 parent 1561962 commit a816981
Show file tree
Hide file tree
Showing 14 changed files with 351 additions and 114 deletions.
95 changes: 13 additions & 82 deletions LICENSE.md

Large diffs are not rendered by default.

93 changes: 73 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,72 @@
# ESM3
# Table of Contents

1. [Installation](#installation)
2. [ESM C](#esm-c)
3. [ESM 3](#esm3)
4. [Responsible Development](#responsible-development)
5. [License](#license)

## Installation

To get started with ESM, install the library using pip:

```bash
pip install esm
```

## ESM C
[ESM Cambrian](https://www.evolutionaryscale.ai/blog/esm-cambrian) is a parallel model family to our flagship ESM3 generative models. While ESM3 focuses on controllable generation of proteins for therapeutic and many other applications, ESM C focuses on creating representations of the underlying biology of proteins.

ESM C comes with major performance benefits over ESM2. The 300M parameter ESM C delivers similar performance to ESM2 650M with dramatically reduced memory requirements and faster inference. The 600M parameter ESM C rivals the 3B parameter ESM2 and approaches the capabilities of the 15B model, delivering frontier performance with far greater efficiency. At the leading edge, the 6B parameter ESM C sets a new benchmark, outperforming all prior protein language models by a wide margin.

ESM C models are available immediately for academic and commercial use under a new license structure designed to promote openness and enable scientists and builders. You can find our [open](www.evolutionaryscale.ai/policies/cambrian-open-license-agreement) and [non-commercial](www.evolutionaryscale.ai/policies/cambrian-non-commercial-license-agreement) license agreements here.

You can use the following guides to start using ESM-C models today through [HF](https://huggingface.co/EvolutionaryScale), [the Forge API](https://forge.evolutionaryscale.ai/) and [AWS SageMaker](https://aws.amazon.com/sagemaker/).

### Using ESM C 300M and 600M via GitHub
ESM-C model weights are stored on the HuggingFace hub under https://huggingface.co/EvolutionaryScale/.
```py
from esm.models.esmc import ESMC
from esm.sdk.api import ESMProtein, LogitsConfig

protein = ESMProtein(sequence="AAAAA")
client = ESMC.from_pretrained("esmc_300m").to("cuda") # or "cpu"
protein_tensor = client.encode(protein)
logits_output = client.logits(
protein_tensor, LogitsConfig(sequence=True, return_embeddings=True)
)
print(logits_output.logits, logits_output.embeddings)
```

### Using ESM C 6B via Forge API

ESM-C models, including ESMC 6B, are accessible via EvolutionaryScale Forge. You can request access and utilize these models through forge.evolutionaryscale.ai, as demonstrated in the example below.
```py
from evolutionaryscale.opensource.sdk.forge import ESM3ForgeInferenceClient
from esm.sdk.api import ESMProtein, LogitsConfig

# Apply for forge access and get an access token
forge_client = ESM3ForgeInferenceClient(model="esmc-6b-2024-12", url="https://forge.evolutionaryscale.ai", token="<your forge token>")
protein_tensor = forge_client.encode(protein)
logits_output = forge_client.logits(
protein_tensor, LogitsConfig(sequence=True, return_embeddings=True)
)
print(logits_output.logits, logits_output.embeddings)
```

### Using ESM C 6B via SageMaker

ESM-C models are also available on Amazon SageMaker. They function similarly to the ESM3 model family, and you can refer to the sample notebooks provided in this repository for examples.

After creating the endpoint, you can create a sagemaker client and use it the same way as a forge client. They share the same API.

```py
sagemaker_client = ESM3SageMakerClient(
endpoint_name=SAGE_ENDPOINT_NAME, model=<model_name>
)
```

## ESM 3

[ESM3](https://www.evolutionaryscale.ai/papers/esm3-simulating-500-million-years-of-evolution-with-a-language-model) is a frontier generative model for biology, able to jointly reason across three fundamental biological properties of proteins: sequence, structure, and function. These three data modalities are represented as tracks of discrete tokens at the input and output of ESM3. You can present the model with a combination of partial inputs across the tracks, and ESM3 will provide output predictions for all the tracks.

Expand All @@ -11,10 +79,10 @@ The ESM3 architecture is highly scalable due to its transformer backbone and all
Learn more by reading the [blog post](https://www.evolutionaryscale.ai/blog/esm3-release) and [the pre-print (Hayes et al., 2024)](https://www.evolutionaryscale.ai/papers/esm3-simulating-500-million-years-of-evolution-with-a-language-model).

Here we present `esm3-open-small`. With 1.4B parameters it is the smallest and fastest model in the family.
ESM3-open is available under a [non-commercial license](https://www.evolutionaryscale.ai/policies/community-license-agreement), reproduced under `LICENSE.md`.
ESM3-open is available under the [Cambrian non-commercial license agreement](https://www.evolutionaryscale.ai/policies/cambrian-non-commercial-license-agreement), as outlined in `LICENSE.md` (note: updated with ESM C release).
Visit our [Discussions page](https://github.com/evolutionaryscale/esm/discussions) to get in touch, provide feedback, ask questions or share your experience with ESM3!

## Quickstart for ESM3-open
### Quickstart for ESM3-open

```
pip install esm
Expand Down Expand Up @@ -65,7 +133,7 @@ We also provide example scripts that show common workflows under `examples/`:
- [local_generate.py](./examples/local_generate.py) shows how simple and elegant common tasks are: it shows folding, inverse folding and chain of thought generation, all by calling just `model.generate()` for iterative decoding.
- [seqfun_struct.py](./examples/seqfun_struct.py) shows direct use of the model as a standard pytorch model with a simple model `forward` call.

## Forge: Access to larger ESM3 models
### Forge: Access to larger ESM3 models

You can apply for beta access to the full family of larger and higher capability ESM3 models at [EvolutionaryScale Forge](https://forge.evolutionaryscale.ai).

Expand Down Expand Up @@ -101,19 +169,4 @@ The core tenets of our framework are
With this in mind, we have performed a variety of mitigations for `esm3-sm-open-v1`, detailed in our [paper](https://www.evolutionaryscale.ai/papers/esm3-simulating-500-million-years-of-evolution-with-a-language-model)

## License

**The Big Picture:**

1. The EvolutionaryScale AI Model is **only** available under this Community License Agreement for **non-commercial use** by **individuals** or **non-commercial organizations** (including universities, non-profit organizations and research institutes, educational and government bodies).

2. You **may not** use the EvolutionaryScale AI Model or any derivative works of the EvolutionaryScale AI Model or its outputs:

1. in connection with **any commercial activities**, for example, any activities **by, on behalf of or for a commercial entity** or to develop **any product or service** such as hosting the AI Model behind an API; or

2. without attribution to EvolutionaryScale and this Community License Agreement; or

3. to **train** a AI-powered third party model **similar to EvolutionaryScale’s AI Model**, even for non-commercial usage. You may, however, create **Derivative Works** of ESM3, for example by finetuning or adding model layers.

3. You **can publish, share and adapt** the EvolutionaryScale AI Model and its outputs for **non-commercial purposes** in accordance with the Community License Agreement, including a **non-commercial restriction** on the adapted model.

Please read our non-commercial [Community License Agreement](https://www.evolutionaryscale.ai/policies/community-license-agreement) reproduced under [./LICENSE.md](LICENSE.md) before using ESM3.
The code and model weights of ESM3 and ESM C are available under a mixture of non-commercial and more permissive licenses, fully outlined in [LICENSE.md](LICENSE.md).
2 changes: 1 addition & 1 deletion esm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = "3.0.8"
__version__ = "3.1.0"

5 changes: 4 additions & 1 deletion esm/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor):
def forward(self, x, seq_id):
qkv_BLD3 = self.layernorm_qkv(x)
query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
query_BLD, key_BLD = self.q_ln(query_BLD), self.k_ln(key_BLD)
query_BLD, key_BLD = (
self.q_ln(query_BLD).to(query_BLD.dtype),
self.k_ln(key_BLD).to(query_BLD.dtype),
)
query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)

n_heads = self.n_heads
Expand Down
145 changes: 145 additions & 0 deletions esm/models/esmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from __future__ import annotations

import contextlib

import attr
import torch
import torch.nn as nn
from attr import dataclass

from esm.layers.regression_head import RegressionHead
from esm.layers.transformer_stack import TransformerStack
from esm.sdk.api import (
ESMCInferenceClient,
ESMProtein,
ESMProteinTensor,
ForwardTrackData,
LogitsConfig,
LogitsOutput,
)
from esm.tokenization import EsmSequenceTokenizer
from esm.utils import encoding
from esm.utils.constants.models import ESMC_600M
from esm.utils.decoding import decode_sequence
from esm.utils.sampling import _BatchedESMProteinTensor


@dataclass
class ESMCOutput:
sequence_logits: torch.Tensor
embeddings: torch.Tensor | None


class ESMC(nn.Module, ESMCInferenceClient):
"""
ESMC model implementation.
Args:
d_model (int): The dimensionality of the input and output feature vectors.
n_heads (int): The number of attention heads in the transformer layers.
n_layers (int): The number of transformer layers.
"""

def __init__(
self, d_model: int, n_heads: int, n_layers: int, tokenizer: EsmSequenceTokenizer
):
super().__init__()
self.embed = nn.Embedding(64, d_model)
self.transformer = TransformerStack(
d_model, n_heads, None, n_layers, n_layers_geom=0
)
self.sequence_head = RegressionHead(d_model, 64)
self.tokenizer = tokenizer

@classmethod
def from_pretrained(
cls, model_name: str = ESMC_600M, device: torch.device | None = None
) -> ESMC:
from esm.pretrained import load_local_model

if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_local_model(model_name, device=device)
if device.type != "cpu":
model = model.to(torch.bfloat16)
assert isinstance(model, ESMC)
return model

@property
def device(self):
return next(self.parameters()).device

@property
def raw_model(self):
return self

def forward(
self,
sequence_tokens: torch.Tensor | None = None,
sequence_id: torch.Tensor | None = None,
) -> ESMCOutput:
"""
Performs forward pass through the ESMC model. Check utils to see how to tokenize inputs from raw data.
Args:
sequence_tokens (torch.Tensor, optional): The amino acid tokens.
sequence_id (torch.Tensor, optional): The sequence ID.
Returns:
ESMCOutput: The output of the ESMC model.
"""
if sequence_id is None:
sequence_id = sequence_tokens == self.tokenizer.pad_token_id

x = self.embed(sequence_tokens)
x, _ = self.transformer(x, sequence_id=sequence_id)
sequence_logits = self.sequence_head(x)
output = ESMCOutput(sequence_logits=sequence_logits, embeddings=x)
return output

def encode(self, input: ESMProtein) -> ESMProteinTensor:
input = attr.evolve(input) # Make a copy
sequence_tokens = None

if input.sequence is not None:
sequence_tokens = encoding.tokenize_sequence(
input.sequence, self.tokenizer, add_special_tokens=True
)
return ESMProteinTensor(sequence=sequence_tokens).to(
next(self.parameters()).device
)

def decode(self, input: ESMProteinTensor) -> ESMProtein:
input = attr.evolve(input) # Make a copy

assert input.sequence is not None
sequence = decode_sequence(input.sequence[1:-1], self.tokenizer)

return ESMProtein(sequence=sequence)

def logits(
self,
input: ESMProteinTensor | _BatchedESMProteinTensor,
config: LogitsConfig = LogitsConfig(),
) -> LogitsOutput:
if not isinstance(input, _BatchedESMProteinTensor):
# Create batch dimension if necessary.
input = _BatchedESMProteinTensor.from_protein_tensor(input)

device = torch.device(input.device)

with (
torch.no_grad(),
torch.autocast(enabled=True, device_type=device.type, dtype=torch.bfloat16) # type: ignore
if device.type == "cuda"
else contextlib.nullcontext(),
):
output = self.forward(sequence_tokens=input.sequence)

return LogitsOutput(
logits=ForwardTrackData(
sequence=output.sequence_logits if config.sequence else None
),
embeddings=output.embeddings if config.return_embeddings else None,
)
2 changes: 1 addition & 1 deletion esm/models/function_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class FunctionTokenDecoderConfig:
interpro_entry_list: str = field(default_factory=lambda: str(C.INTERPRO_ENTRY))
# Path to keywords vocabulary.
keyword_vocabulary_path: str = field(
default_factory=lambda: str(C.data_root() / C.KEYWORDS_VOCABULARY)
default_factory=lambda: str(C.data_root("esm3") / C.KEYWORDS_VOCABULARY)
)
# Whether to unpack LSH bits into single-bit tokens.
unpack_lsh_bits: bool = True
Expand Down
50 changes: 46 additions & 4 deletions esm/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch.nn as nn

from esm.models.esm3 import ESM3
from esm.models.esmc import ESMC
from esm.models.function_decoder import FunctionTokenDecoder
from esm.models.vqvae import (
StructureTokenDecoder,
Expand All @@ -16,6 +17,8 @@
ESM3_OPEN_SMALL,
ESM3_STRUCTURE_DECODER_V0,
ESM3_STRUCTURE_ENCODER_V0,
ESMC_300M,
ESMC_600M,
)

ModelBuilder = Callable[[torch.device | str], nn.Module]
Expand All @@ -27,7 +30,8 @@ def ESM3_structure_encoder_v0(device: torch.device | str = "cpu"):
d_model=1024, n_heads=1, v_heads=128, n_layers=2, d_out=128, n_codes=4096
).eval()
state_dict = torch.load(
data_root() / "data/weights/esm3_structure_encoder_v0.pth", map_location=device
data_root("esm3") / "data/weights/esm3_structure_encoder_v0.pth",
map_location=device,
)
model.load_state_dict(state_dict)
return model
Expand All @@ -37,7 +41,8 @@ def ESM3_structure_decoder_v0(device: torch.device | str = "cpu"):
with torch.device(device):
model = StructureTokenDecoder(d_model=1280, n_heads=20, n_layers=30).eval()
state_dict = torch.load(
data_root() / "data/weights/esm3_structure_decoder_v0.pth", map_location=device
data_root("esm3") / "data/weights/esm3_structure_decoder_v0.pth",
map_location=device,
)
model.load_state_dict(state_dict)
return model
Expand All @@ -47,12 +52,47 @@ def ESM3_function_decoder_v0(device: torch.device | str = "cpu"):
with torch.device(device):
model = FunctionTokenDecoder().eval()
state_dict = torch.load(
data_root() / "data/weights/esm3_function_decoder_v0.pth", map_location=device
data_root("esm3") / "data/weights/esm3_function_decoder_v0.pth",
map_location=device,
)
model.load_state_dict(state_dict)
return model


def ESMC_300M_202412(device: torch.device | str = "cpu"):
with torch.device(device):
model = ESMC(
d_model=960,
n_heads=15,
n_layers=30,
tokenizer=get_model_tokenizers(ESM3_OPEN_SMALL).sequence,
).eval()
state_dict = torch.load(
data_root("esmc-300") / "data/weights/esmc_300m_2024_12_v0.pth",
map_location=device,
)
model.load_state_dict(state_dict)

return model


def ESMC_600M_202412(device: torch.device | str = "cpu"):
with torch.device(device):
model = ESMC(
d_model=1152,
n_heads=18,
n_layers=36,
tokenizer=get_model_tokenizers(ESM3_OPEN_SMALL).sequence,
).eval()
state_dict = torch.load(
data_root("esmc-600") / "data/weights/esmc_600m_2024_12_v0.pth",
map_location=device,
)
model.load_state_dict(state_dict)

return model


def ESM3_sm_open_v0(device: torch.device | str = "cpu"):
with torch.device(device):
model = ESM3(
Expand All @@ -66,7 +106,7 @@ def ESM3_sm_open_v0(device: torch.device | str = "cpu"):
tokenizers=get_model_tokenizers(ESM3_OPEN_SMALL),
).eval()
state_dict = torch.load(
data_root() / "data/weights/esm3_sm_open_v1.pth", map_location=device
data_root("esm3") / "data/weights/esm3_sm_open_v1.pth", map_location=device
)
model.load_state_dict(state_dict)
return model
Expand All @@ -77,6 +117,8 @@ def ESM3_sm_open_v0(device: torch.device | str = "cpu"):
ESM3_STRUCTURE_ENCODER_V0: ESM3_structure_encoder_v0,
ESM3_STRUCTURE_DECODER_V0: ESM3_structure_decoder_v0,
ESM3_FUNCTION_DECODER_V0: ESM3_function_decoder_v0,
ESMC_600M: ESMC_600M_202412,
ESMC_300M: ESMC_300M_202412,
}


Expand Down
Loading

0 comments on commit a816981

Please sign in to comment.