Skip to content

Commit

Permalink
Fix collate_fn when batch also includes label
Browse files Browse the repository at this point in the history
  • Loading branch information
RalfG committed May 13, 2024
1 parent e890279 commit 782e1e7
Showing 1 changed file with 23 additions and 12 deletions.
35 changes: 23 additions & 12 deletions molexpress/datasets/encoders.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Union

import numpy as np

from molexpress import types
Expand All @@ -23,27 +25,32 @@ def __call__(self, residues: list[types.Molecule | types.SMILES | types.InChI])
for residue in residues:
residue = chem_ops.get_molecule(residue)
residue_graph = {
**self.node_encoder(residue),
**self.node_encoder(residue),
**self.edge_encoder(residue)
}
residue_graphs.append(residue_graph)
residue_sizes.append(residue.GetNumAtoms())
disjoint_peptide_graph = self._merge_molecular_graphs(residue_graphs)
disjoint_peptide_graph["residue_size"] = np.array(residue_sizes)
return disjoint_peptide_graph

disjoint_peptide_graph["peptide_size"] = np.array([len(residues)], dtype="int32")
return disjoint_peptide_graph

@staticmethod
def _collate_fn(
data: list[tuple[types.MolecularGraph, np.ndarray]],
def collate_fn(
data: list[Union[types.MolecularGraph, tuple[types.MolecularGraph, np.ndarray]]],
) -> tuple[types.MolecularGraph, np.ndarray]:
"""TODO: Not sure where to implement this collate function.
Temporarily putting it here.
Procedure:
Merges list of graphs into a single disjoint graph.
"""
Merge list of graphs into a single disjoint graph.
disjoint_peptide_graphs, y = list(zip(*data))
Data can be a list of MolecularGraphs or a list of tuples where the first element is a
MolecularGraph and the second element is a label.
"""
if isinstance(data[0], tuple):
disjoint_peptide_graphs, y = list(zip(*data))
else:
disjoint_peptide_graphs = data
y = None

disjoint_peptide_batch_graph = PeptideGraphEncoder._merge_molecular_graphs(
disjoint_peptide_graphs
Expand All @@ -54,7 +61,11 @@ def _collate_fn(
disjoint_peptide_batch_graph["residue_size"] = np.concatenate([
g["residue_size"] for g in disjoint_peptide_graphs
]).astype("int32")
return disjoint_peptide_batch_graph, np.stack(y)

if y is None:
return disjoint_peptide_batch_graph
else:
return disjoint_peptide_batch_graph, np.stack(y)

@staticmethod
def _merge_molecular_graphs(
Expand Down

0 comments on commit 782e1e7

Please sign in to comment.