diff --git a/molexpress/datasets/encoders.py b/molexpress/datasets/encoders.py index 7743c52..442cc44 100644 --- a/molexpress/datasets/encoders.py +++ b/molexpress/datasets/encoders.py @@ -7,7 +7,7 @@ from molexpress.ops import chem_ops -class MolecularGraphEncoder: +class PeptideGraphEncoder: def __init__( self, atom_featurizers: list[featurizers.Featurizer], @@ -17,10 +17,21 @@ def __init__( self.node_encoder = MolecularNodeEncoder(atom_featurizers) self.edge_encoder = MolecularEdgeEncoder(bond_featurizers, self_loops=self_loops) - def __call__(self, molecule: types.Molecule | types.SMILES | types.InChI) -> np.ndarray: - molecule = chem_ops.get_molecule(molecule) - return {**self.node_encoder(molecule), **self.edge_encoder(molecule)} - + def __call__(self, molecules: list[types.Molecule | types.SMILES | types.InChI]) -> np.ndarray: + molecular_graphs = [] + residue_sizes = [] + for molecule in molecules: + molecule = chem_ops.get_molecule(molecule) + molecular_graph = { + **self.node_encoder(molecule), + **self.edge_encoder(molecule) + } + molecular_graphs.append(molecular_graph) + residue_sizes.append(molecule.GetNumAtoms()) + graph = self._merge_molecular_graphs(molecular_graphs) + graph["residue_size"] = np.array(residue_sizes) + return graph + @staticmethod def _collate_fn( data: list[tuple[types.MolecularGraph, np.ndarray]], @@ -34,27 +45,46 @@ def _collate_fn( x, y = list(zip(*data)) - num_nodes = np.array([graph["node_state"].shape[0] for graph in x]) + disjoint_graph = PeptideGraphEncoder._merge_molecular_graphs(x) + disjoint_graph["peptide_size"] = np.concatenate([ + graph["residue_size"].shape[:1] for graph in x + ]).astype("int32") + disjoint_graph["residue_size"] = np.concatenate([ + graph["residue_size"] for graph in x + ]).astype("int32") + return disjoint_graph, np.stack(y) + + @staticmethod + def _merge_molecular_graphs( + molecular_graphs: list[types.MolecularGraph], + ) -> types.MolecularGraph: - disjoint_graph = {} + num_nodes = np.array([ + g["node_state"].shape[0] for g in molecular_graphs + ]) - disjoint_graph["node_state"] = np.concatenate([graph["node_state"] for graph in x]) + disjoint_molecular_graph = {} - if "edge_state" in x[0]: - disjoint_graph["edge_state"] = np.concatenate([graph["edge_state"] for graph in x]) + disjoint_molecular_graph["node_state"] = np.concatenate([ + g["node_state"] for g in molecular_graphs + ]) - edge_src = np.concatenate([graph["edge_src"] for graph in x]) - edge_dst = np.concatenate([graph["edge_dst"] for graph in x]) - num_edges = np.array([graph["edge_src"].shape[0] for graph in x]) - indices = np.repeat(range(len(x)), num_edges) + if "edge_state" in molecular_graphs[0]: + disjoint_molecular_graph["edge_state"] = np.concatenate([ + g["edge_state"] for g in molecular_graphs + ]) + + edge_src = np.concatenate([graph["edge_src"] for graph in molecular_graphs]) + edge_dst = np.concatenate([graph["edge_dst"] for graph in molecular_graphs]) + num_edges = np.array([graph["edge_src"].shape[0] for graph in molecular_graphs]) + indices = np.repeat(range(len(molecular_graphs)), num_edges) edge_incr = np.concatenate([[0], num_nodes[:-1]]) edge_incr = np.take_along_axis(edge_incr, indices, axis=0) - disjoint_graph["edge_src"] = edge_src + edge_incr - disjoint_graph["edge_dst"] = edge_dst + edge_incr - disjoint_graph["graph_indicator"] = np.repeat(range(len(x)), num_nodes) + disjoint_molecular_graph["edge_src"] = edge_src + edge_incr + disjoint_molecular_graph["edge_dst"] = edge_dst + edge_incr - return disjoint_graph, np.stack(y) + return disjoint_molecular_graph class Composer: @@ -103,7 +133,7 @@ def __call__(self, molecule: types.Molecule) -> np.ndarray: if molecule.GetNumBonds() == 0: edge_state = np.zeros( - shape=(0, self.output_dim + int(self.self_loops)), + shape=(int(self.self_loops), self.output_dim + int(self.self_loops)), dtype=self.output_dtype ) return { @@ -144,4 +174,6 @@ def __init__( def __call__(self, molecule: types.Molecule) -> np.ndarray: node_encodings = np.stack([self.featurizer(atom) for atom in molecule.GetAtoms()], axis=0) - return {"node_state": np.stack(node_encodings)} + return { + "node_state": np.stack(node_encodings), + } diff --git a/molexpress/layers/__init__.py b/molexpress/layers/__init__.py index d48b332..141294c 100644 --- a/molexpress/layers/__init__.py +++ b/molexpress/layers/__init__.py @@ -1,4 +1,5 @@ from molexpress.layers.base_layer import BaseLayer as BaseLayer from molexpress.layers.gcn_conv import GCNConv as GCNConv from molexpress.layers.gin_conv import GINConv as GINConv -from molexpress.layers.readout import Readout as Readout +from molexpress.layers.peptide_readout import PeptideReadout as PeptideReadout +from molexpress.layers.residue_readout import ResidueReadout as ResidueReadout \ No newline at end of file diff --git a/molexpress/layers/readout.py b/molexpress/layers/peptide_readout.py similarity index 55% rename from molexpress/layers/readout.py rename to molexpress/layers/peptide_readout.py index 59fa7d5..51c8418 100644 --- a/molexpress/layers/readout.py +++ b/molexpress/layers/peptide_readout.py @@ -6,7 +6,7 @@ from molexpress.ops import gnn_ops -class Readout(keras.layers.Layer): +class PeptideReadout(keras.layers.Layer): def __init__(self, mode: str = "mean", **kwargs) -> None: super().__init__(**kwargs) self.mode = mode @@ -18,14 +18,21 @@ def __init__(self, mode: str = "mean", **kwargs) -> None: self._readout_fn = gnn_ops.segment_mean def build(self, input_shape: dict[str, tuple[int, ...]]) -> None: - if "graph_indicator" not in input_shape: - raise ValueError("Cannot perform readout: 'graph_indicator' not found.") + if "peptide_size" not in input_shape: + raise ValueError("Cannot perform readout: 'peptide_size' not found.") def call(self, inputs: types.MolecularGraph) -> types.Array: - graph_indicator = keras.ops.cast(inputs["graph_indicator"], "int32") + peptide_size = keras.ops.cast(inputs['peptide_size'], 'int32') + residue_size = keras.ops.cast(inputs['residue_size'], 'int32') + n_peptides = keras.ops.shape(peptide_size)[0] + repeats = keras.ops.segment_sum( + residue_size, + keras.ops.repeat(range(n_peptides), peptide_size) + ) + segment_ids = keras.ops.repeat(range(n_peptides), repeats) return self._readout_fn( data=inputs["node_state"], - segment_ids=graph_indicator, + segment_ids=segment_ids, num_segments=None, sorted=False, ) diff --git a/molexpress/layers/residue_readout.py b/molexpress/layers/residue_readout.py new file mode 100644 index 0000000..c0cad78 --- /dev/null +++ b/molexpress/layers/residue_readout.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import keras + +from molexpress import types +from molexpress.ops import gnn_ops + + +class ResidueReadout(keras.layers.Layer): + def __init__(self, mode: str = "mean", **kwargs) -> None: + super().__init__(**kwargs) + self.mode = mode + if self.mode == "max": + self._readout_fn = keras.ops.segment_max + elif self.mode == "sum": + self._readout_fn = keras.ops.segment_sum + else: + self._readout_fn = gnn_ops.segment_mean + + def build(self, input_shape: dict[str, tuple[int, ...]]) -> None: + if "residue_size" not in input_shape: + raise ValueError("Cannot perform readout: 'residue_size' not found.") + + def call(self, inputs: types.MolecularGraph) -> types.Array: + peptide_size = keras.ops.cast(inputs['peptide_size'], 'int32') + residue_size = keras.ops.cast(inputs['residue_size'], 'int32') + n_residues = keras.ops.shape(residue_size)[0] + segment_ids = keras.ops.repeat(range(n_residues), residue_size) + residue_state = self._readout_fn( + data=inputs["node_state"], + segment_ids=segment_ids, + num_segments=None, + sorted=False, + ) + # Make shape known + residue_state = keras.ops.reshape( + residue_state, + ( + keras.ops.shape(residue_size)[0], + keras.ops.shape(inputs['node_state'])[-1] + ) + ) + + if keras.ops.shape(peptide_size)[0] == 1: + # Single peptide in batch + return residue_state[None] + + # Split and stack (with padding in the second dim) + # Resulting shape: (n_peptides, n_residues, n_features) + residues = keras.ops.split(residue_state, peptide_size[:-1]) + max_residue_size = keras.ops.max([len(r) for r in residues]) + return keras.ops.stack([ + keras.ops.pad(r, [(0, max_residue_size-keras.ops.shape(r)[0]), (0, 0)]) + for r in residues + ]) + + + diff --git a/molexpress/ops/gnn_ops.py b/molexpress/ops/gnn_ops.py index c26c654..dc38eab 100644 --- a/molexpress/ops/gnn_ops.py +++ b/molexpress/ops/gnn_ops.py @@ -59,6 +59,10 @@ def aggregate( """ num_nodes = keras.ops.shape(node_state)[0] + # Instead of casting to int, throw an error if not int? + edge_src = keras.ops.cast(edge_src, "int32") + edge_dst = keras.ops.cast(edge_dst, "int32") + expected_rank = 2 current_rank = len(keras.ops.shape(edge_src)) for _ in range(expected_rank - current_rank): diff --git a/notebooks/examples.ipynb b/notebooks/examples.ipynb index 83fc2e1..913fd69 100644 --- a/notebooks/examples.ipynb +++ b/notebooks/examples.ipynb @@ -13,8 +13,7 @@ "from molexpress import layers\n", "from molexpress.datasets import featurizers\n", "from molexpress.datasets import encoders\n", - "\n", - "from rdkit import Chem\n", + "from molexpress.ops.chem_ops import get_molecule\n", "\n", "import torch" ] @@ -34,7 +33,7 @@ "metadata": {}, "outputs": [], "source": [ - "mol = Chem.MolFromSmiles('CCO')\n", + "mol = get_molecule('C(C(=O)O)N')\n", "\n", "print(featurizers.AtomType(vocab={'O'}, oov=False)(mol.GetAtoms()[0]))\n", "print(featurizers.AtomType(vocab={'O'}, oov=True)(mol.GetAtoms()[0]))\n", @@ -67,13 +66,15 @@ " featurizers.BondType()\n", "]\n", "\n", - "encoder = encoders.MolecularGraphEncoder(\n", + "peptide_graph_encoder = encoders.PeptideGraphEncoder(\n", " atom_featurizers=atom_featurizers, \n", " bond_featurizers=bond_featurizers,\n", " self_loops=True # adds one dim to edge state\n", ")\n", "\n", - "encoder(mol)" + "mol2 = get_molecule('CC(C(=O)O)N')\n", + "\n", + "peptide_graph_encoder([mol, mol2])" ] }, { @@ -91,8 +92,12 @@ "metadata": {}, "outputs": [], "source": [ - "x_dummy = ['CC', 'CC', 'CCO', 'CCCN']\n", - "y_dummy = [1., 2., 3., 4.]\n", + "x_dummy = [\n", + " ['CC(C)C(C(=O)O)N', 'C(C(=O)O)N'], \n", + " ['C(C(=O)O)N', 'CC(C(=O)O)N', 'C(C(=O)O)N'], \n", + " ['CC(C(=O)O)N']\n", + "]\n", + "y_dummy = [1., 2., 3.]\n", "\n", "\n", "class TinyDataset(torch.utils.data.Dataset):\n", @@ -107,13 +112,13 @@ " def __getitem__(self, index):\n", " x = self.x[index]\n", " y = self.y[index]\n", - " x = encoder(x)\n", - " return x, y\n", + " x = peptide_graph_encoder(x)\n", + " return x, [y]\n", "\n", "torch_dataset = TinyDataset(x_dummy, y_dummy)\n", "\n", "dataset = torch.utils.data.DataLoader(\n", - " torch_dataset, batch_size=2, collate_fn=encoder._collate_fn)\n", + " torch_dataset, batch_size=2, collate_fn=peptide_graph_encoder._collate_fn)\n", "\n", "for x, y in dataset:\n", " print(f'x = {x}\\ny = {y}', end='\\n' + '---' * 30 + '\\n')" @@ -141,14 +146,16 @@ "\n", " self.gcn1 = layers.GINConv(32)\n", " self.gcn2 = layers.GINConv(32)\n", - " self.readout = layers.Readout()\n", + " self.readout = layers.ResidueReadout()\n", + " self.lstm = torch.nn.LSTM(32, 32, 1, batch_first=True)\n", " self.linear = torch.nn.Linear(32, 1)\n", "\n", " def forward(self, x):\n", " x = self.gcn1(x)\n", " x = self.gcn2(x)\n", " x = self.readout(x)\n", - " x = self.linear(x)\n", + " x, (_, _) = self.lstm(x)\n", + " x = self.linear(x[:, -1, :])\n", " return x\n", "\n", "model = TinyGCNModel().to('cuda')" @@ -169,18 +176,16 @@ "metadata": {}, "outputs": [], "source": [ - "optimizer = torch.optim.SGD(model.parameters(), lr=0.00001, momentum=0.9)\n", + "optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)\n", "loss_fn = torch.nn.MSELoss()\n", "\n", "for _ in range(30):\n", " loss_sum = 0.\n", " for x, y in dataset:\n", " optimizer.zero_grad()\n", - " \n", " outputs = model(x)\n", - " \n", " y = torch.tensor(y, dtype=torch.float32).to('cuda')\n", - " loss = loss_fn(outputs, y[:, None])\n", + " loss = loss_fn(outputs, y)\n", " loss.backward()\n", " optimizer.step()\n", "\n", @@ -192,7 +197,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ad047588-9926-4838-a264-193476897b4b", + "id": "9fe0fe29-34d1-445a-9ea7-81e2e3aa0046", "metadata": {}, "outputs": [], "source": []