diff --git a/molexpress/datasets/encoders.py b/molexpress/datasets/encoders.py index d40d5b0..a3253a6 100644 --- a/molexpress/datasets/encoders.py +++ b/molexpress/datasets/encoders.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Union + import numpy as np from molexpress import types @@ -23,27 +25,32 @@ def __call__(self, residues: list[types.Molecule | types.SMILES | types.InChI]) for residue in residues: residue = chem_ops.get_molecule(residue) residue_graph = { - **self.node_encoder(residue), + **self.node_encoder(residue), **self.edge_encoder(residue) } residue_graphs.append(residue_graph) residue_sizes.append(residue.GetNumAtoms()) disjoint_peptide_graph = self._merge_molecular_graphs(residue_graphs) disjoint_peptide_graph["residue_size"] = np.array(residue_sizes) - return disjoint_peptide_graph - + disjoint_peptide_graph["peptide_size"] = np.array([len(residues)], dtype="int32") + return disjoint_peptide_graph + @staticmethod - def _collate_fn( - data: list[tuple[types.MolecularGraph, np.ndarray]], + def collate_fn( + data: list[Union[types.MolecularGraph, tuple[types.MolecularGraph, np.ndarray]]], ) -> tuple[types.MolecularGraph, np.ndarray]: - """TODO: Not sure where to implement this collate function. - Temporarily putting it here. - - Procedure: - Merges list of graphs into a single disjoint graph. """ + Merge list of graphs into a single disjoint graph. - disjoint_peptide_graphs, y = list(zip(*data)) + Data can be a list of MolecularGraphs or a list of tuples where the first element is a + MolecularGraph and the second element is a label. + + """ + if isinstance(data[0], tuple): + disjoint_peptide_graphs, y = list(zip(*data)) + else: + disjoint_peptide_graphs = data + y = None disjoint_peptide_batch_graph = PeptideGraphEncoder._merge_molecular_graphs( disjoint_peptide_graphs @@ -54,7 +61,11 @@ def _collate_fn( disjoint_peptide_batch_graph["residue_size"] = np.concatenate([ g["residue_size"] for g in disjoint_peptide_graphs ]).astype("int32") - return disjoint_peptide_batch_graph, np.stack(y) + + if y is None: + return disjoint_peptide_batch_graph + else: + return disjoint_peptide_batch_graph, np.stack(y) @staticmethod def _merge_molecular_graphs(