diff --git a/molexpress/datasets/encoders.py b/molexpress/datasets/encoders.py index 442cc44..d40d5b0 100644 --- a/molexpress/datasets/encoders.py +++ b/molexpress/datasets/encoders.py @@ -17,20 +17,20 @@ def __init__( self.node_encoder = MolecularNodeEncoder(atom_featurizers) self.edge_encoder = MolecularEdgeEncoder(bond_featurizers, self_loops=self_loops) - def __call__(self, molecules: list[types.Molecule | types.SMILES | types.InChI]) -> np.ndarray: - molecular_graphs = [] + def __call__(self, residues: list[types.Molecule | types.SMILES | types.InChI]) -> np.ndarray: + residue_graphs = [] residue_sizes = [] - for molecule in molecules: - molecule = chem_ops.get_molecule(molecule) - molecular_graph = { - **self.node_encoder(molecule), - **self.edge_encoder(molecule) + for residue in residues: + residue = chem_ops.get_molecule(residue) + residue_graph = { + **self.node_encoder(residue), + **self.edge_encoder(residue) } - molecular_graphs.append(molecular_graph) - residue_sizes.append(molecule.GetNumAtoms()) - graph = self._merge_molecular_graphs(molecular_graphs) - graph["residue_size"] = np.array(residue_sizes) - return graph + residue_graphs.append(residue_graph) + residue_sizes.append(residue.GetNumAtoms()) + disjoint_peptide_graph = self._merge_molecular_graphs(residue_graphs) + disjoint_peptide_graph["residue_size"] = np.array(residue_sizes) + return disjoint_peptide_graph @staticmethod def _collate_fn( @@ -43,16 +43,18 @@ def _collate_fn( Merges list of graphs into a single disjoint graph. """ - x, y = list(zip(*data)) + disjoint_peptide_graphs, y = list(zip(*data)) - disjoint_graph = PeptideGraphEncoder._merge_molecular_graphs(x) - disjoint_graph["peptide_size"] = np.concatenate([ - graph["residue_size"].shape[:1] for graph in x + disjoint_peptide_batch_graph = PeptideGraphEncoder._merge_molecular_graphs( + disjoint_peptide_graphs + ) + disjoint_peptide_batch_graph["peptide_size"] = np.concatenate([ + g["residue_size"].shape[:1] for g in disjoint_peptide_graphs ]).astype("int32") - disjoint_graph["residue_size"] = np.concatenate([ - graph["residue_size"] for graph in x + disjoint_peptide_batch_graph["residue_size"] = np.concatenate([ + g["residue_size"] for g in disjoint_peptide_graphs ]).astype("int32") - return disjoint_graph, np.stack(y) + return disjoint_peptide_batch_graph, np.stack(y) @staticmethod def _merge_molecular_graphs(