Skip to content

Commit

Permalink
Merge pull request #3 from compomics/fixes
Browse files Browse the repository at this point in the history
Various fixes
  • Loading branch information
akensert authored May 13, 2024
2 parents bafbdf2 + 782e1e7 commit 950bc83
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 29 deletions.
35 changes: 23 additions & 12 deletions molexpress/datasets/encoders.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Union

import numpy as np

from molexpress import types
Expand All @@ -23,27 +25,32 @@ def __call__(self, residues: list[types.Molecule | types.SMILES | types.InChI])
for residue in residues:
residue = chem_ops.get_molecule(residue)
residue_graph = {
**self.node_encoder(residue),
**self.node_encoder(residue),
**self.edge_encoder(residue)
}
residue_graphs.append(residue_graph)
residue_sizes.append(residue.GetNumAtoms())
disjoint_peptide_graph = self._merge_molecular_graphs(residue_graphs)
disjoint_peptide_graph["residue_size"] = np.array(residue_sizes)
return disjoint_peptide_graph

disjoint_peptide_graph["peptide_size"] = np.array([len(residues)], dtype="int32")
return disjoint_peptide_graph

@staticmethod
def _collate_fn(
data: list[tuple[types.MolecularGraph, np.ndarray]],
def collate_fn(
data: list[Union[types.MolecularGraph, tuple[types.MolecularGraph, np.ndarray]]],
) -> tuple[types.MolecularGraph, np.ndarray]:
"""TODO: Not sure where to implement this collate function.
Temporarily putting it here.
Procedure:
Merges list of graphs into a single disjoint graph.
"""
Merge list of graphs into a single disjoint graph.
disjoint_peptide_graphs, y = list(zip(*data))
Data can be a list of MolecularGraphs or a list of tuples where the first element is a
MolecularGraph and the second element is a label.
"""
if isinstance(data[0], tuple):
disjoint_peptide_graphs, y = list(zip(*data))
else:
disjoint_peptide_graphs = data
y = None

disjoint_peptide_batch_graph = PeptideGraphEncoder._merge_molecular_graphs(
disjoint_peptide_graphs
Expand All @@ -54,7 +61,11 @@ def _collate_fn(
disjoint_peptide_batch_graph["residue_size"] = np.concatenate([
g["residue_size"] for g in disjoint_peptide_graphs
]).astype("int32")
return disjoint_peptide_batch_graph, np.stack(y)

if y is None:
return disjoint_peptide_batch_graph
else:
return disjoint_peptide_batch_graph, np.stack(y)

@staticmethod
def _merge_molecular_graphs(
Expand Down
30 changes: 13 additions & 17 deletions molexpress/layers/residue_readout.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def build(self, input_shape: dict[str, tuple[int, ...]]) -> None:
raise ValueError("Cannot perform readout: 'residue_size' not found.")

def call(self, inputs: types.MolecularGraph) -> types.Array:
peptide_size = keras.ops.cast(inputs['peptide_size'], 'int32')
residue_size = keras.ops.cast(inputs['residue_size'], 'int32')
peptide_size = keras.ops.cast(inputs["peptide_size"], "int32")
residue_size = keras.ops.cast(inputs["residue_size"], "int32")
n_residues = keras.ops.shape(residue_size)[0]
segment_ids = keras.ops.repeat(range(n_residues), residue_size)
residue_state = self._readout_fn(
Expand All @@ -34,25 +34,21 @@ def call(self, inputs: types.MolecularGraph) -> types.Array:
)
# Make shape known
residue_state = keras.ops.reshape(
residue_state,
(
keras.ops.shape(residue_size)[0],
keras.ops.shape(inputs['node_state'])[-1]
)
residue_state,
(keras.ops.shape(residue_size)[0], keras.ops.shape(inputs["node_state"])[-1]),
)

if keras.ops.shape(peptide_size)[0] == 1:
# Single peptide in batch
return residue_state[None]

# Split and stack (with padding in the second dim)
# Resulting shape: (n_peptides, n_residues, n_features)
residues = keras.ops.split(residue_state, peptide_size[:-1])
residues = keras.ops.split(residue_state, keras.ops.cumsum(peptide_size)[:-1])
max_residue_size = keras.ops.max([len(r) for r in residues])
return keras.ops.stack([
keras.ops.pad(r, [(0, max_residue_size-keras.ops.shape(r)[0]), (0, 0)])
for r in residues
])



return keras.ops.stack(
[
keras.ops.pad(r, [(0, max_residue_size - keras.ops.shape(r)[0]), (0, 0)])
for r in residues
]
)

0 comments on commit 950bc83

Please sign in to comment.