From f75e9aab2701d9715ed4703f058a1209a605d355 Mon Sep 17 00:00:00 2001 From: RalfG Date: Thu, 16 May 2024 16:21:09 +0200 Subject: [PATCH] Add caching for encoding individual residues in PeptideGraphEncoder --- .gitignore | 1 + molexpress/datasets/encoders.py | 56 +++++++++++++++++++-------------- 2 files changed, 33 insertions(+), 24 deletions(-) diff --git a/.gitignore b/.gitignore index eb1ab5d..e8060b3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +lightning_logs/ notebooks/_*.ipynb # vscode diff --git a/molexpress/datasets/encoders.py b/molexpress/datasets/encoders.py index a3253a6..779389d 100644 --- a/molexpress/datasets/encoders.py +++ b/molexpress/datasets/encoders.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Union +from functools import lru_cache +from typing import Dict, Tuple, Union import numpy as np @@ -23,18 +24,28 @@ def __call__(self, residues: list[types.Molecule | types.SMILES | types.InChI]) residue_graphs = [] residue_sizes = [] for residue in residues: - residue = chem_ops.get_molecule(residue) - residue_graph = { - **self.node_encoder(residue), - **self.edge_encoder(residue) - } + residue_graph, residue_size = self._encode_residue( + residue, self.node_encoder, self.edge_encoder + ) residue_graphs.append(residue_graph) - residue_sizes.append(residue.GetNumAtoms()) + residue_sizes.append(residue_size) + disjoint_peptide_graph = self._merge_molecular_graphs(residue_graphs) disjoint_peptide_graph["residue_size"] = np.array(residue_sizes) disjoint_peptide_graph["peptide_size"] = np.array([len(residues)], dtype="int32") return disjoint_peptide_graph + @staticmethod + @lru_cache(maxsize=None) + def _encode_residue( + residue: types.Molecule | types.SMILES | types.InChI, + node_encoder: MolecularNodeEncoder, + edge_encoder: MolecularEdgeEncoder, + ) -> Tuple[Dict[str, np.ndarray], int]: + residue = chem_ops.get_molecule(residue) + residue_graph = {**node_encoder(residue), **edge_encoder(residue)} + return residue_graph, residue.GetNumAtoms() + @staticmethod def collate_fn( data: list[Union[types.MolecularGraph, tuple[types.MolecularGraph, np.ndarray]]], @@ -55,12 +66,12 @@ def collate_fn( disjoint_peptide_batch_graph = PeptideGraphEncoder._merge_molecular_graphs( disjoint_peptide_graphs ) - disjoint_peptide_batch_graph["peptide_size"] = np.concatenate([ - g["residue_size"].shape[:1] for g in disjoint_peptide_graphs - ]).astype("int32") - disjoint_peptide_batch_graph["residue_size"] = np.concatenate([ - g["residue_size"] for g in disjoint_peptide_graphs - ]).astype("int32") + disjoint_peptide_batch_graph["peptide_size"] = np.concatenate( + [g["residue_size"].shape[:1] for g in disjoint_peptide_graphs] + ).astype("int32") + disjoint_peptide_batch_graph["residue_size"] = np.concatenate( + [g["residue_size"] for g in disjoint_peptide_graphs] + ).astype("int32") if y is None: return disjoint_peptide_batch_graph @@ -71,21 +82,18 @@ def collate_fn( def _merge_molecular_graphs( molecular_graphs: list[types.MolecularGraph], ) -> types.MolecularGraph: - - num_nodes = np.array([ - g["node_state"].shape[0] for g in molecular_graphs - ]) + num_nodes = np.array([g["node_state"].shape[0] for g in molecular_graphs]) disjoint_molecular_graph = {} - disjoint_molecular_graph["node_state"] = np.concatenate([ - g["node_state"] for g in molecular_graphs - ]) + disjoint_molecular_graph["node_state"] = np.concatenate( + [g["node_state"] for g in molecular_graphs] + ) if "edge_state" in molecular_graphs[0]: - disjoint_molecular_graph["edge_state"] = np.concatenate([ - g["edge_state"] for g in molecular_graphs - ]) + disjoint_molecular_graph["edge_state"] = np.concatenate( + [g["edge_state"] for g in molecular_graphs] + ) edge_src = np.concatenate([graph["edge_src"] for graph in molecular_graphs]) edge_dst = np.concatenate([graph["edge_dst"] for graph in molecular_graphs]) @@ -147,7 +155,7 @@ def __call__(self, molecule: types.Molecule) -> np.ndarray: if molecule.GetNumBonds() == 0: edge_state = np.zeros( shape=(int(self.self_loops), self.output_dim + int(self.self_loops)), - dtype=self.output_dtype + dtype=self.output_dtype, ) return { "edge_src": edge_src,