From 14bbec9dc0e4724b7d56450a3a2d6c531af70bd9 Mon Sep 17 00:00:00 2001 From: Otto Brinkhaus Date: Wed, 31 May 2023 15:28:37 +0200 Subject: [PATCH 1/8] restructuring of CDK depiction code --- RanDepict/randepict.py | 293 +++++++++++++----- Tests/test_functions.py | 12 +- ...h_structure_dataset_from_smiles_dataset.py | 4 +- ..._dataset_with_and_without_augmentations.py | 4 +- 4 files changed, 220 insertions(+), 93 deletions(-) diff --git a/RanDepict/randepict.py b/RanDepict/randepict.py index d7df0f3..71414a4 100644 --- a/RanDepict/randepict.py +++ b/RanDepict/randepict.py @@ -1096,7 +1096,7 @@ def has_r_group(self, smiles: str) -> bool: if re.search("\[.*[RXYZ].*\]", smiles): return True - def get_random_cdk_rendering_settings(self, rendererModel, molecule, smiles: str): + def _cdk_get_random_rendering_settings(self, rendererModel, molecule, smiles: str): """ This function defines random rendering options for the structure depictions created using CDK. @@ -1194,6 +1194,7 @@ def get_random_cdk_rendering_settings(self, rendererModel, molecule, smiles: str [True, False, False, False], log_attribute="cdk_add_atom_indices" ): if not self.has_r_group(smiles): + # Avoid confusion with R group indices and atom numbering labels = True for atom in molecule.atoms(): label = JClass("java.lang.Integer")( @@ -1235,35 +1236,20 @@ def get_random_cdk_rendering_settings(self, rendererModel, molecule, smiles: str cdk_superatom_abrv.apply(molecule) return rendererModel, molecule - def depict_and_resize_cdk( - self, smiles: str, shape: Tuple[int, int] = (299, 299) - ) -> np.array: + def _cdk_add_explicite_hydrogens_to_iatomcontainer( + self, molecule, + ): """ - This function takes a smiles str and an image shape. - It renders the chemical structures using CDK with random - rendering/depiction settings and returns an RGB image (np.array) - with the given image shape. - The general workflow here is a JPype adaptation of code published - by Egon Willighagen in 'Groovy Cheminformatics with the Chemistry - Development Kit': - https://egonw.github.io/cdkbook/ctr.html#depict-a-compound-as-an-image - with additional adaptations to create all the different depiction - types from - https://github.com/cdk/cdk/wiki/Standard-Generator + This function takes an IAtomContainer and returns an IAtomContainer with added + explicite hydrogen atoms. Args: - smiles (str): SMILES representation of molecule - shape (Tuple[int, int], optional): im shape. Defaults to (299, 299) + molecule: IAtomContainer (JClass object) Returns: - np.array: Chemical structure depiction + molecule: IAtomContainer (JClass object) """ cdk_base = "org.openscience.cdk" - # Read molecule from SMILES str - molecule = self.cdk_smiles_to_IAtomContainer(smiles) - - # Add hydrogens for coordinate generation (to make it look nicer/ - # avoid overlaps) matcher = JClass(cdk_base + ".atomtype.CDKAtomTypeMatcher").getInstance( molecule.getBuilder() ) @@ -1280,29 +1266,35 @@ def depict_and_resize_cdk( cdk_base + ".tools.manipulator.AtomContainerManipulator" ) AtomContainerManipulator.convertImplicitToExplicitHydrogens(molecule) + return molecule - # Instantiate StructureDiagramGenerator, determine coordinates - sdg = JClass(cdk_base + ".layout.StructureDiagramGenerator")() - sdg.setMolecule(molecule) - sdg.generateCoordinates(molecule) - molecule = sdg.getMolecule() + def _cdk_remove_explicite_hydrogens_from_iatomcontainer( + self, molecule, + ): + """ + This function takes an IAtomContainer and returns an IAtomContainer with added + explicite hydrogen atoms. - # Remove explicit hydrogens again - AtomContainerManipulator.suppressHydrogens(molecule) + Args: + molecule: IAtomContainer (JClass object) - # Rotate molecule randomly - point = JClass(cdk_base + ".geometry.GeometryTools").get2DCenter(molecule) - rot_degrees = self.random_choice(range(360)) - JClass(cdk_base + ".geometry.GeometryTools").rotate( - molecule, point, rot_degrees + Returns: + molecule: IAtomContainer (JClass object) + """ + cdk_base = "org.openscience.cdk" + AtomContainerManipulator = JClass( + cdk_base + ".tools.manipulator.AtomContainerManipulator" ) + AtomContainerManipulator.suppressHydrogens(molecule) + return molecule - # Get Generators - generators = JClass("java.util.ArrayList")() - BasicSceneGenerator = JClass( - "org.openscience.cdk.renderer.generators.BasicSceneGenerator" - )() - generators.add(BasicSceneGenerator) + def _cdk_get_random_java_font(self): + """ + This function returns a random java.awt.Font (JClass) object + + Returns: + font: java.awt.Font (JClass object) + """ font_size = self.random_choice( range(10, 20), log_attribute="cdk_atom_label_font_size" ) @@ -1325,39 +1317,157 @@ def depict_and_resize_cdk( # log_attribute='cdk_atom_label_font_style' ) font = Font(font_name, font_style, font_size) + return font + + def _cdk_rotate_coordinates(self, molecule): + """ + Given an IAtomContainer (JClass object), this function rotates the molecule + and adapts the coordinates of accordingly. The IAtomContainer is then returned.# + + Args: + molecule: IAtomContainer (JClass object) + + Returns: + molecule: IAtomContainer (JClass object) + """ + cdk_base = "org.openscience.cdk" + point = JClass(cdk_base + ".geometry.GeometryTools").get2DCenter(molecule) + rot_degrees = self.random_choice(range(360)) + JClass(cdk_base + ".geometry.GeometryTools").rotate( + molecule, point, rot_degrees + ) + return molecule + + def _cdk_generate_2d_coordinates(self, molecule): + """ + Given an IAtomContainer (JClass object), this function adds 2D coordinate to + the molecule. The modified IAtomContainer is then returned. + + Args: + molecule: IAtomContainer (JClass object) + + Returns: + molecule: IAtomContainer (JClass object) + """ + cdk_base = "org.openscience.cdk" + sdg = JClass(cdk_base + ".layout.StructureDiagramGenerator")() + sdg.setMolecule(molecule) + sdg.generateCoordinates(molecule) + molecule = sdg.getMolecule() + return molecule + + def _cdk_create_generators(self,): + """ + This function returns a java.util.ArrayList (JClass object) that contains the + BasicSceneGenerator and the StandardGenerator. A random font is for the + instantiation of the StandardGenerator. + + Returns: + generators: java.util.ArrayList (JClass object) + """ + generators = JClass("java.util.ArrayList")() + BasicSceneGenerator = JClass( + "org.openscience.cdk.renderer.generators.BasicSceneGenerator" + )() + generators.add(BasicSceneGenerator) + StandardGenerator = JClass( - cdk_base + ".renderer.generators.standard.StandardGenerator" - )(font) + "org.openscience.cdk.renderer.generators.standard.StandardGenerator" + )(self._cdk_get_random_java_font()) generators.add(StandardGenerator) + return generators + + def _cdk_get_atomcontainer_renderer( + self, + molecule, + shape: Tuple[int, int] + ): + """pytest + This function takes an IAtomContainer (JClass object) and returns an + AtomContainerRenderer (JClass object) that CDK uses for the depiction of the + molecule. + + Args: + molecule (IAtomContainer (JClass object)): molecule + shape (Tuple[int, int]): y, x - # Instantiate renderer - AWTFontManager = JClass(cdk_base + ".renderer.font.AWTFontManager") - renderer = JClass(cdk_base + ".renderer.AtomContainerRenderer")( + Returns: + AtomContainerRenderer (JClass object) + """ + generators = self._cdk_create_generators() + AWTFontManager = JClass("org.openscience.cdk.renderer.font.AWTFontManager") + renderer = JClass("org.openscience.cdk.renderer.AtomContainerRenderer")( generators, AWTFontManager() ) - - # Create an empty image of the right size - y, x = self.random_image_size(shape) + y, x = shape # Workaround for structures that are cut off at edged of images: # Make image twice as big, reduce Zoom factor, then remove white # areas at borders and resize to originally desired shape # TODO: Find out why the structures are cut off in the first place y = y * 4 x = x * 4 - drawArea = JClass("java.awt.Rectangle")(x, y) - BufferedImage = JClass("java.awt.image.BufferedImage") - image = BufferedImage(x, y, BufferedImage.TYPE_INT_RGB) - # Draw the molecule renderer.setup(molecule, drawArea) + return renderer + + def _cdk_bufferedimage_to_numpyarray( + self, + image + ) -> np.ndarray: + """ + This function converts a BufferedImage (JClass object) into a numpy array. + + Args: + image (BufferedImage (JClass object)) + + Returns: + image (np.ndarray) + """ + # Write the image into a format that can be read by skimage + ImageIO = JClass("javax.imageio.ImageIO") + os = JClass("java.io.ByteArrayOutputStream")() + Base64 = JClass("java.util.Base64") + ImageIO.write( + image, JClass("java.lang.String")("PNG"), Base64.getEncoder().wrap(os) + ) + image = bytes(os.toString("UTF-8")) + image = base64.b64decode(image) + + # Read image in skimage + image = sk_io.imread(image, plugin="imageio") + image = img_as_ubyte(image) + return image + + def _cdk_render_molecule( + self, + molecule, + smiles: str, + shape: Tuple[int, int] + ): + """ + This function takes an IAtomContainer (JClass object), the corresponding SMILES + string and an image shape and returns a BufferedImage (JClass object) with the + rendered molecule. + + Args: + molecule (IAtomContainer (JClass object)): molecule + smiles (str): SMILES string + shape (Tuple[int, int]): y, x + Returns: + depiction (np.ndarray): chemical structure depiction + """ + cdk_base = "org.openscience.cdk" + renderer = self._cdk_get_atomcontainer_renderer(molecule, shape) model = renderer.getRenderer2DModel() + BufferedImage = JClass("java.awt.image.BufferedImage") + y, x = shape + image = BufferedImage(x, y, BufferedImage.TYPE_INT_RGB) # Get random rendering settings - model, molecule = self.get_random_cdk_rendering_settings( + model, molecule = self._cdk_get_random_rendering_settings( model, molecule, smiles ) - double = JClass("java.lang.Double") model.set( JClass(cdk_base + ".renderer.generators.BasicSceneGenerator.ZoomFactor"), @@ -1367,29 +1477,48 @@ def depict_and_resize_cdk( g2.setColor(JClass("java.awt.Color").WHITE) g2.fillRect(0, 0, x, y) AWTDrawVisitor = JClass("org.openscience.cdk.renderer.visitor.AWTDrawVisitor") - renderer.paint(molecule, AWTDrawVisitor(g2)) - - # Write the image into a format that can be read by skimage - ImageIO = JClass("javax.imageio.ImageIO") - os = JClass("java.io.ByteArrayOutputStream")() - Base64 = JClass("java.util.Base64") - ImageIO.write( - image, JClass("java.lang.String")("PNG"), Base64.getEncoder().wrap(os) - ) - depiction = bytes(os.toString("UTF-8")) - depiction = base64.b64decode(depiction) - - # Read image in skimage - depiction = sk_io.imread(depiction, plugin="imageio") + depiction = self._cdk_bufferedimage_to_numpyarray(image) # Normalise padding and get non-distorted image of right size - depiction = self.normalise_padding(depiction) - depiction = self.central_square_image(depiction) + # TODO: Get rid of this nonsense with adding padding and then removing it + # depiction = self.normalise_padding(depiction) + # depiction = self.central_square_image(depiction) depiction = self.resize(depiction, shape, HQ=True) - depiction = img_as_ubyte(depiction) return depiction - def cdk_smiles_to_IAtomContainer(self, smiles: str): + def cdk_depict( + self, smiles: str, shape: Tuple[int, int] = (299, 299) + ) -> np.array: + """ + This function takes a smiles str and an image shape. + It renders the chemical structures using CDK with random + rendering/depiction settings and returns an RGB image (np.array) + with the given image shape. + The general workflow here is a JPype adaptation of code published + by Egon Willighagen in 'Groovy Cheminformatics with the Chemistry + Development Kit': + https://egonw.github.io/cdkbook/ctr.html#depict-a-compound-as-an-image + with additional adaptations to create all the different depiction + types from + https://github.com/cdk/cdk/wiki/Standard-Generator + + Args: + smiles (str): SMILES representation of molecule + shape (Tuple[int, int], optional): im shape. Defaults to (299, 299) + + Returns: + np.array: Chemical structure depiction + """ + # TODO: Find out if adding and removing hydrogens is really necessary + molecule = self._cdk_smiles_to_IAtomContainer(smiles) + molecule = self._cdk_add_explicite_hydrogens_to_iatomcontainer(molecule) + molecule = self._cdk_generate_2d_coordinates(molecule) + molecule = self._cdk_remove_explicite_hydrogens_from_iatomcontainer(molecule) + molecule = self._cdk_rotate_coordinates(molecule) + depiction = self._cdk_render_molecule(molecule, smiles, shape) + return depiction + + def _cdk_smiles_to_IAtomContainer(self, smiles: str): """ This function takes a SMILES representation of a molecule and returns the corresponding IAtomContainer object. @@ -1423,7 +1552,7 @@ def smiles_to_mol_str(self, smiles: str) -> str: Returns: str: content of SD file of input molecule """ - i_atom_container = self.cdk_smiles_to_IAtomContainer(smiles) + i_atom_container = self._cdk_smiles_to_IAtomContainer(smiles) mol_str = self.cdk_IAtomContainer_to_mol_str(i_atom_container) return mol_str @@ -1503,10 +1632,9 @@ def random_depiction( self, smiles: str, shape: Tuple[int, int] = (299, 299), - # path_bkg="./backgrounds/", ) -> np.array: """ - This function takes a SMILES and depicts it using Rdkit, Indigo or CDK. + This function takes a SMILES and depicts it using Rdkit, Indigo, CDK or PIKACHU. The depiction method and the specific parameters for the depiction are chosen completely randomly. The purpose of this function is to enable depicting a diverse variety of chemical structure depictions. @@ -1519,8 +1647,7 @@ def random_depiction( np.array: Chemical structure depiction """ depiction_functions = self.get_depiction_functions(smiles) - # If nothing is returned, try different function - # FIXME: depictions_functions could be an empty list + for _ in range(3): if len(depiction_functions) != 0: # Pick random depiction function and call it @@ -1579,7 +1706,7 @@ def get_depiction_functions(self, smiles: str) -> List[Callable]: depiction_functions_registry = { 'rdkit': self.depict_and_resize_rdkit, 'indigo': self.depict_and_resize_indigo, - 'cdk': self.depict_and_resize_cdk, + 'cdk': self.cdk_depict, 'pikachu': self.depict_and_resize_pikachu, } depiction_functions = [depiction_functions_registry[k] @@ -2443,7 +2570,7 @@ def depict_from_fingerprint( elif "pikachu" in list(schemes[0].keys())[0]: depiction = depictor.depict_and_resize_pikachu(smiles, shape) elif "cdk" in list(schemes[0].keys())[0]: - depiction = depictor.depict_and_resize_cdk(smiles, shape) + depiction = depictor.cdk_depict(smiles, shape) # Add augmentations if len(fingerprints) == 2: @@ -2673,7 +2800,7 @@ def __init__(self): # Call every depiction function depiction = self(smiles) - depiction = self.depict_and_resize_cdk(smiles) + depiction = self.cdk_depict(smiles) depiction = self.depict_and_resize_rdkit(smiles) depiction = self.depict_and_resize_indigo(smiles) depiction = self.depict_and_resize_pikachu(smiles) @@ -3361,7 +3488,7 @@ def add_explicite_hydrogen_to_smiles(self, smiles: str) -> str: Returns: smiles (str): SMILES representation of a molecule with explicite H """ - i_atom_container = self.depictor.cdk_smiles_to_IAtomContainer(smiles) + i_atom_container = self.depictor._cdk_smiles_to_IAtomContainer(smiles) # Add explicite hydrogen atoms cdk_base = "org.openscience.cdk." @@ -3387,7 +3514,7 @@ def remove_explicite_hydrogen_from_smiles(self, smiles: str) -> str: Returns: smiles (str): SMILES representation of a molecule with explicite H """ - i_atom_container = self.depictor.cdk_smiles_to_IAtomContainer(smiles) + i_atom_container = self.depictor._cdk_smiles_to_IAtomContainer(smiles) # Remove explicite hydrogen atoms cdk_base = "org.openscience.cdk." manipulator = JClass(cdk_base + "tools.manipulator.AtomContainerManipulator") diff --git a/Tests/test_functions.py b/Tests/test_functions.py index ba4bb2a..cf274c3 100644 --- a/Tests/test_functions.py +++ b/Tests/test_functions.py @@ -367,13 +367,13 @@ def test_depict_and_resize_rdkit(self): im = self.depictor.depict_and_resize_rdkit(smiles) assert type(im) == np.ndarray - def test_depict_and_resize_cdk(self): + def test_cdk_depict(self): # Assert that an image is returned with different types # of input SMILES str test_smiles = ['c1ccccc1', '[Otto]C1=C([XYZ123])C([R1])=C([Y])C([X])=C1[R]'] for smiles in test_smiles: - im = self.depictor.depict_and_resize_cdk(smiles) + im = self.depictor.cdk_depict(smiles) assert type(im) == np.ndarray def test_depict_and_resize_pikachu(self): @@ -391,7 +391,7 @@ def test_get_depiction_functions_normal(self): expected = [ self.depictor.depict_and_resize_rdkit, self.depictor.depict_and_resize_indigo, - self.depictor.depict_and_resize_cdk, + self.depictor.cdk_depict, self.depictor.depict_and_resize_pikachu, ] # symmetric_difference @@ -404,7 +404,7 @@ def test_get_depiction_functions_isotopes(self): expected = [ self.depictor.depict_and_resize_rdkit, self.depictor.depict_and_resize_indigo, - self.depictor.depict_and_resize_cdk, + self.depictor.cdk_depict, ] difference = set(observed) ^ set(expected) assert not difference @@ -414,7 +414,7 @@ def test_get_depiction_functions_R(self): observed = self.depictor.get_depiction_functions("[R]N1C=NC2=C1C(=O)N(C(=O)N2C)C") expected = [ self.depictor.depict_and_resize_indigo, - self.depictor.depict_and_resize_cdk, + self.depictor.cdk_depict, self.depictor.depict_and_resize_pikachu, ] difference = set(observed) ^ set(expected) @@ -424,7 +424,7 @@ def test_get_depiction_functions_X(self): # RDKit and Indigo don't depict "X" observed = self.depictor.get_depiction_functions("[X]N1C=NC2=C1C(=O)N(C(=O)N2C)C") expected = [ - self.depictor.depict_and_resize_cdk, + self.depictor.cdk_depict, self.depictor.depict_and_resize_pikachu, ] difference = set(observed) ^ set(expected) diff --git a/examples/create_markush_structure_dataset_from_smiles_dataset.py b/examples/create_markush_structure_dataset_from_smiles_dataset.py index 096d9a1..3f58a70 100644 --- a/examples/create_markush_structure_dataset_from_smiles_dataset.py +++ b/examples/create_markush_structure_dataset_from_smiles_dataset.py @@ -41,7 +41,7 @@ def split_id_from_smiles(lines: List[str]) -> Tuple[List[str], List[str]]: smiles_list = [] for line in lines: line = line[:-1] - id, smiles = line.split(",") + id, smiles = line.split("\t") id_list.append(id) smiles_list.append(smiles) return id_list, smiles_list @@ -74,7 +74,7 @@ def main(): smiles_lists = split_smiles_list(input_file.readlines(), 1000) starmap_tuples = [(smiles_lists[index], index * 1000000) for index in range(len(smiles_lists))] - with Pool(15) as pool: + with Pool(40) as pool: _ = pool.starmap(helper, starmap_tuples) diff --git a/examples/generate_fingerprint_based_dataset_with_and_without_augmentations.py b/examples/generate_fingerprint_based_dataset_with_and_without_augmentations.py index baaeb17..bbd7ba3 100644 --- a/examples/generate_fingerprint_based_dataset_with_and_without_augmentations.py +++ b/examples/generate_fingerprint_based_dataset_with_and_without_augmentations.py @@ -183,7 +183,7 @@ def depict_from_fingerprint( elif "rdkit" in list(schemes[0].keys())[0]: depiction = depictor.depict_and_resize_rdkit(smiles, shape) elif "cdk" in list(schemes[0].keys())[0]: - depiction = depictor.depict_and_resize_cdk(smiles, shape) + depiction = depictor.cdk_depict(smiles, shape) elif "pikachu" in list(schemes[0].keys())[0]: depiction = depictor.depict_and_resize_pikachu(smiles, shape) except IndexError: @@ -196,7 +196,7 @@ def depict_from_fingerprint( False, False, ) - depiction = depictor.depict_and_resize_cdk(smiles, shape) + depiction = depictor.cdk_depict(smiles, shape) with open(os.path.join(self.output_dir, 'error_log.txt'), 'a') as error_log: error_log.write(f'Failed depicting SMILES: {smiles}\n') error_log.write('It was depicted using CDK WITHOUT fingerprints.\n') From 965e57ce5f6f7145eeb50208dd9b1066af943274 Mon Sep 17 00:00:00 2001 From: Otto Brinkhaus Date: Thu, 1 Jun 2023 16:42:36 +0200 Subject: [PATCH 2/8] rewrite outdated CDK depiction functions --- RanDepict/randepict.py | 194 +++++++---------------------------------- 1 file changed, 33 insertions(+), 161 deletions(-) diff --git a/RanDepict/randepict.py b/RanDepict/randepict.py index 71414a4..141a359 100644 --- a/RanDepict/randepict.py +++ b/RanDepict/randepict.py @@ -1096,26 +1096,27 @@ def has_r_group(self, smiles: str) -> bool: if re.search("\[.*[RXYZ].*\]", smiles): return True - def _cdk_get_random_rendering_settings(self, rendererModel, molecule, smiles: str): + def _cdk_get_depiction_generator(self, molecule, smiles: str): """ This function defines random rendering options for the structure depictions created using CDK. - It takes a cdk.renderer.AtomContainerRenderer.2DModel - and a cdk.AtomContainer and returns the 2DModel object with random - rendering settings and the AtomContainer. - I followed https://github.com/cdk/cdk/wiki/Standard-Generator while - creating this. + It takes an iAtomContainer and a SMILES string and returns the iAtomContainer + and the DepictionGenerator + with random rendering settings and the AtomContainer. + I followed https://github.com/cdk/cdk/wiki/Standard-Generator to adjust the + depiction parameters. Args: - rendererModel (cdk.renderer.AtomContainerRenderer.2DModel) molecule (cdk.AtomContainer): Atom container smiles (str): smiles representation of molecule Returns: - rendererModel, molecule: Objects that hold depiction parameters + DepictionGenerator, molecule: Objects that hold depiction parameters """ cdk_base = "org.openscience.cdk" - + dep_gen = JClass("org.openscience.cdk.depict.DepictionGenerator")( + self._cdk_get_random_java_font() + ) StandardGenerator = JClass( cdk_base + ".renderer.generators.standard.StandardGenerator" ) @@ -1127,18 +1128,18 @@ def _cdk_get_random_rendering_settings(self, rendererModel, molecule, smiles: st ) SymbolVisibility = JClass("org.openscience.cdk.renderer.SymbolVisibility") if symbol_visibility == "iupac_recommendation": - rendererModel.set( + dep_gen = dep_gen.withParam( StandardGenerator.Visibility.class_, SymbolVisibility.iupacRecommendations(), ) elif symbol_visibility == "no_terminal_methyl": # only hetero atoms, no terminal alkyl groups - rendererModel.set( + dep_gen = dep_gen.withParam( StandardGenerator.Visibility.class_, SymbolVisibility.iupacRecommendationsWithoutTerminalCarbon(), ) elif symbol_visibility == "show_all_atom_labels": - rendererModel.set( + dep_gen = dep_gen.withParam( StandardGenerator.Visibility.class_, SymbolVisibility.all() ) # show all atom labels @@ -1146,12 +1147,13 @@ def _cdk_get_random_rendering_settings(self, rendererModel, molecule, smiles: st stroke_width = self.random_choice( np.arange(0.8, 2.0, 0.1), log_attribute="cdk_stroke_width" ) - rendererModel.set(StandardGenerator.StrokeRatio.class_, stroke_width) + dep_gen = dep_gen.withParam(StandardGenerator.StrokeRatio.class_, + stroke_width) # Define symbol margin ratio margin_ratio = self.random_choice( [0, 1, 2, 2, 2, 3, 4], log_attribute="cdk_margin_ratio" ) - rendererModel.set( + dep_gen = dep_gen.withParam( StandardGenerator.SymbolMarginRatio.class_, JClass("java.lang.Double")(margin_ratio), ) @@ -1159,21 +1161,23 @@ def _cdk_get_random_rendering_settings(self, rendererModel, molecule, smiles: st double_bond_dist = self.random_choice( np.arange(0.11, 0.25, 0.01), log_attribute="cdk_double_bond_dist" ) - rendererModel.set(StandardGenerator.BondSeparation.class_, double_bond_dist) + dep_gen = dep_gen.withParam(StandardGenerator.BondSeparation.class_, + double_bond_dist) wedge_ratio = self.random_choice( np.arange(4.5, 7.5, 0.1), log_attribute="cdk_wedge_ratio" ) - rendererModel.set( + dep_gen = dep_gen.withParam( StandardGenerator.WedgeRatio.class_, JClass("java.lang.Double")(wedge_ratio) ) if self.random_choice([True, False], log_attribute="cdk_fancy_bold_wedges"): - rendererModel.set(StandardGenerator.FancyBoldWedges.class_, True) + dep_gen = dep_gen.withParam(StandardGenerator.FancyBoldWedges.class_, True) if self.random_choice([True, False], log_attribute="cdk_fancy_hashed_wedges"): - rendererModel.set(StandardGenerator.FancyHashedWedges.class_, True) + dep_gen = dep_gen.withParam(StandardGenerator.FancyHashedWedges.class_, + True) hash_spacing = self.random_choice( np.arange(4.0, 6.0, 0.2), log_attribute="cdk_hash_spacing" ) - rendererModel.set(StandardGenerator.HashSpacing.class_, hash_spacing) + dep_gen = dep_gen.withParam(StandardGenerator.HashSpacing.class_, hash_spacing) # Add CIP labels labels = False if self.random_choice([True, False], log_attribute="cdk_add_CIP_labels"): @@ -1203,7 +1207,7 @@ def _cdk_get_random_rendering_settings(self, rendererModel, molecule, smiles: st atom.setProperty(StandardGenerator.ANNOTATION_LABEL, label) if labels: # We only need black - rendererModel.set( + dep_gen = dep_gen.withParam( StandardGenerator.AnnotationColor.class_, JClass("java.awt.Color")(0x000000), ) @@ -1211,12 +1215,14 @@ def _cdk_get_random_rendering_settings(self, rendererModel, molecule, smiles: st font_scale = self.random_choice( np.arange(0.5, 0.8, 0.1), log_attribute="cdk_label_font_scale" ) - rendererModel.set(StandardGenerator.AnnotationFontScale.class_, font_scale) + dep_gen = dep_gen.withParam( + StandardGenerator.AnnotationFontScale.class_, + font_scale) # Distance between atom numbering and depiction annotation_distance = self.random_choice( np.arange(0.15, 0.30, 0.05), log_attribute="cdk_annotation_distance" ) - rendererModel.set( + dep_gen = dep_gen.withParam( StandardGenerator.AnnotationDistance.class_, annotation_distance ) # Abbreviate superatom labels in half of the cases @@ -1234,59 +1240,7 @@ def _cdk_get_random_rendering_settings(self, rendererModel, molecule, smiles: st abbreviation_path = JClass("java.lang.String")(abbreviation_path) cdk_superatom_abrv.loadFromFile(abbreviation_path) cdk_superatom_abrv.apply(molecule) - return rendererModel, molecule - - def _cdk_add_explicite_hydrogens_to_iatomcontainer( - self, molecule, - ): - """ - This function takes an IAtomContainer and returns an IAtomContainer with added - explicite hydrogen atoms. - - Args: - molecule: IAtomContainer (JClass object) - - Returns: - molecule: IAtomContainer (JClass object) - """ - cdk_base = "org.openscience.cdk" - matcher = JClass(cdk_base + ".atomtype.CDKAtomTypeMatcher").getInstance( - molecule.getBuilder() - ) - for atom in molecule.atoms(): - atom_type = matcher.findMatchingAtomType(molecule, atom) - JClass(cdk_base + ".tools.manipulator.AtomTypeManipulator").configure( - atom, atom_type - ) - adder = JClass(cdk_base + ".tools.CDKHydrogenAdder").getInstance( - molecule.getBuilder() - ) - adder.addImplicitHydrogens(molecule) - AtomContainerManipulator = JClass( - cdk_base + ".tools.manipulator.AtomContainerManipulator" - ) - AtomContainerManipulator.convertImplicitToExplicitHydrogens(molecule) - return molecule - - def _cdk_remove_explicite_hydrogens_from_iatomcontainer( - self, molecule, - ): - """ - This function takes an IAtomContainer and returns an IAtomContainer with added - explicite hydrogen atoms. - - Args: - molecule: IAtomContainer (JClass object) - - Returns: - molecule: IAtomContainer (JClass object) - """ - cdk_base = "org.openscience.cdk" - AtomContainerManipulator = JClass( - cdk_base + ".tools.manipulator.AtomContainerManipulator" - ) - AtomContainerManipulator.suppressHydrogens(molecule) - return molecule + return dep_gen, molecule def _cdk_get_random_java_font(self): """ @@ -1356,61 +1310,6 @@ def _cdk_generate_2d_coordinates(self, molecule): molecule = sdg.getMolecule() return molecule - def _cdk_create_generators(self,): - """ - This function returns a java.util.ArrayList (JClass object) that contains the - BasicSceneGenerator and the StandardGenerator. A random font is for the - instantiation of the StandardGenerator. - - Returns: - generators: java.util.ArrayList (JClass object) - """ - generators = JClass("java.util.ArrayList")() - BasicSceneGenerator = JClass( - "org.openscience.cdk.renderer.generators.BasicSceneGenerator" - )() - generators.add(BasicSceneGenerator) - - StandardGenerator = JClass( - "org.openscience.cdk.renderer.generators.standard.StandardGenerator" - )(self._cdk_get_random_java_font()) - generators.add(StandardGenerator) - return generators - - def _cdk_get_atomcontainer_renderer( - self, - molecule, - shape: Tuple[int, int] - ): - """pytest - This function takes an IAtomContainer (JClass object) and returns an - AtomContainerRenderer (JClass object) that CDK uses for the depiction of the - molecule. - - Args: - molecule (IAtomContainer (JClass object)): molecule - shape (Tuple[int, int]): y, x - - Returns: - AtomContainerRenderer (JClass object) - """ - generators = self._cdk_create_generators() - AWTFontManager = JClass("org.openscience.cdk.renderer.font.AWTFontManager") - renderer = JClass("org.openscience.cdk.renderer.AtomContainerRenderer")( - generators, AWTFontManager() - ) - y, x = shape - # Workaround for structures that are cut off at edged of images: - # Make image twice as big, reduce Zoom factor, then remove white - # areas at borders and resize to originally desired shape - # TODO: Find out why the structures are cut off in the first place - y = y * 4 - x = x * 4 - drawArea = JClass("java.awt.Rectangle")(x, y) - # Draw the molecule - renderer.setup(molecule, drawArea) - return renderer - def _cdk_bufferedimage_to_numpyarray( self, image @@ -1433,8 +1332,6 @@ def _cdk_bufferedimage_to_numpyarray( ) image = bytes(os.toString("UTF-8")) image = base64.b64decode(image) - - # Read image in skimage image = sk_io.imread(image, plugin="imageio") image = img_as_ubyte(image) return image @@ -1457,33 +1354,10 @@ def _cdk_render_molecule( Returns: depiction (np.ndarray): chemical structure depiction """ - cdk_base = "org.openscience.cdk" - renderer = self._cdk_get_atomcontainer_renderer(molecule, shape) - model = renderer.getRenderer2DModel() - BufferedImage = JClass("java.awt.image.BufferedImage") - y, x = shape - image = BufferedImage(x, y, BufferedImage.TYPE_INT_RGB) - - # Get random rendering settings - model, molecule = self._cdk_get_random_rendering_settings( - model, molecule, smiles - ) - double = JClass("java.lang.Double") - model.set( - JClass(cdk_base + ".renderer.generators.BasicSceneGenerator.ZoomFactor"), - double(1.0), - ) - g2 = image.getGraphics() - g2.setColor(JClass("java.awt.Color").WHITE) - g2.fillRect(0, 0, x, y) - AWTDrawVisitor = JClass("org.openscience.cdk.renderer.visitor.AWTDrawVisitor") - renderer.paint(molecule, AWTDrawVisitor(g2)) - depiction = self._cdk_bufferedimage_to_numpyarray(image) - # Normalise padding and get non-distorted image of right size - # TODO: Get rid of this nonsense with adding padding and then removing it - # depiction = self.normalise_padding(depiction) - # depiction = self.central_square_image(depiction) - depiction = self.resize(depiction, shape, HQ=True) + dep_gen, molecule = self._cdk_get_depiction_generator(molecule, smiles) + dep_gen = dep_gen.withSize(shape[1], shape[0]) + depiction = dep_gen.depict(molecule).toImg() + depiction = self._cdk_bufferedimage_to_numpyarray(depiction) return depiction def cdk_depict( @@ -1511,9 +1385,7 @@ def cdk_depict( """ # TODO: Find out if adding and removing hydrogens is really necessary molecule = self._cdk_smiles_to_IAtomContainer(smiles) - molecule = self._cdk_add_explicite_hydrogens_to_iatomcontainer(molecule) molecule = self._cdk_generate_2d_coordinates(molecule) - molecule = self._cdk_remove_explicite_hydrogens_from_iatomcontainer(molecule) molecule = self._cdk_rotate_coordinates(molecule) depiction = self._cdk_render_molecule(molecule, smiles, shape) return depiction From e1095c2304e1e595a29664e296dbd746670a2290 Mon Sep 17 00:00:00 2001 From: Otto Brinkhaus Date: Thu, 1 Jun 2023 16:59:41 +0200 Subject: [PATCH 3/8] add coordinate generation to cdk smi->molblock method --- RanDepict/randepict.py | 47 +++++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/RanDepict/randepict.py b/RanDepict/randepict.py index 141a359..80cc8f1 100644 --- a/RanDepict/randepict.py +++ b/RanDepict/randepict.py @@ -941,7 +941,7 @@ def depict_and_resize_indigo( if not self.has_r_group(smiles): molecule = indigo.loadMolecule(smiles) else: - mol_str = self.smiles_to_mol_str(smiles) + mol_str = self._cdk_smiles_to_mol_str(smiles) molecule = indigo.loadMolecule(mol_str) except IndigoException: return None @@ -1034,8 +1034,8 @@ def get_random_rdkit_rendering_settings( depiction_settings.drawOptions().useBWAtomPalette() return depiction_settings - def depict_and_resize_rdkit( - self, smiles: str, shape: Tuple[int, int] = (299, 299) + def rdkit_depict( + self, smiles: str, shape: Tuple[int, int] = (512, 512) ) -> np.array: """ This function takes a smiles str and an image shape. @@ -1054,7 +1054,7 @@ def depict_and_resize_rdkit( if not self.has_r_group(smiles): mol = Chem.MolFromSmiles(smiles) if self.has_r_group(smiles) or not mol: - mol_str = self.smiles_to_mol_str(smiles) + mol_str = self._cdk_smiles_to_mol_str(smiles) mol = Chem.MolFromMolBlock(mol_str) if mol: AllChem.Compute2DCoords(mol) @@ -1066,13 +1066,6 @@ def depict_and_resize_rdkit( mol = CondenseMolAbbreviations(mol, abbrevs) # Get random depiction settings depiction_settings = self.get_random_rdkit_rendering_settings(smiles=smiles) - # Create depiction - # TODO: Figure out how to depict without kekulization here - # The following line does not prevent the molecule from being - # depicted kekulized: - # mol = rdMolDraw2D.PrepareMolForDrawing(mol, kekulize = False) - # The molecule must get kekulized somewhere "by accident" - rdMolDraw2D.PrepareAndDrawMolecule(depiction_settings, mol) depiction = depiction_settings.GetDrawingText() depiction = sk_io.imread(io.BytesIO(depiction)) @@ -1358,7 +1351,7 @@ def _cdk_render_molecule( dep_gen = dep_gen.withSize(shape[1], shape[0]) depiction = dep_gen.depict(molecule).toImg() depiction = self._cdk_bufferedimage_to_numpyarray(depiction) - return depiction + return depiction def cdk_depict( self, smiles: str, shape: Tuple[int, int] = (299, 299) @@ -1409,7 +1402,11 @@ def _cdk_smiles_to_IAtomContainer(self, smiles: str): molecule = SmilesParser.parseSmiles(smiles) return molecule - def smiles_to_mol_str(self, smiles: str) -> str: + def _cdk_smiles_to_mol_str( + self, + smiles: str, + generate_2d: bool = False, + ) -> str: """ This function takes a SMILES representation of a molecule and returns the content of the corresponding SD file using the CDK. @@ -1420,15 +1417,19 @@ def smiles_to_mol_str(self, smiles: str) -> str: Args: smiles (str): SMILES representation of a molecule + generate_2d (bool, optional): Whether to generate 2D coordinates. Returns: - str: content of SD file of input molecule + mol_block (str): content of SD file of input molecule """ - i_atom_container = self._cdk_smiles_to_IAtomContainer(smiles) - mol_str = self.cdk_IAtomContainer_to_mol_str(i_atom_container) - return mol_str + molecule = self._cdk_smiles_to_IAtomContainer(smiles) + if generate_2d: + molecule = self._cdk_generate_2d_coordinates(molecule) + molecule = self._cdk_rotate_coordinates(molecule) + mol_block = self._cdk_IAtomContainer_to_mol_str(molecule) + return mol_block - def cdk_IAtomContainer_to_mol_str(self, i_atom_container) -> str: + def _cdk_IAtomContainer_to_mol_str(self, i_atom_container) -> str: """ This function takes an IAtomContainer object and returns the content of the corresponding MDL MOL file as a string. @@ -1576,7 +1577,7 @@ def get_depiction_functions(self, smiles: str) -> List[Callable]: """ depiction_functions_registry = { - 'rdkit': self.depict_and_resize_rdkit, + 'rdkit': self.rdkit_depict, 'indigo': self.depict_and_resize_indigo, 'cdk': self.cdk_depict, 'pikachu': self.depict_and_resize_pikachu, @@ -1596,9 +1597,9 @@ def get_depiction_functions(self, smiles: str) -> List[Callable]: depiction_functions.remove(self.depict_and_resize_pikachu) # "R", "X", "Z" are not depicted by RDKit # The same is valid for X,Y,Z and a number - if self.depict_and_resize_rdkit in depiction_functions: + if self.rdkit_depict in depiction_functions: if re.search("\[[RXZ]\]|\[[XYZ]\d+", smiles): - depiction_functions.remove(self.depict_and_resize_rdkit) + depiction_functions.remove(self.rdkit_depict) # "X", "R0", [RXYZ]\d+[a-f] and indices above 32 are not depicted by Indigo if self.depict_and_resize_indigo in depiction_functions: if re.search("\[R0\]|\[X\]|[4-9][0-9]+|3[3-9]|[XYZR]\d+[a-f]", smiles): @@ -2438,7 +2439,7 @@ def depict_from_fingerprint( if "indigo" in list(schemes[0].keys())[0]: depiction = depictor.depict_and_resize_indigo(smiles, shape) elif "rdkit" in list(schemes[0].keys())[0]: - depiction = depictor.depict_and_resize_rdkit(smiles, shape) + depiction = depictor.rdkit_depict(smiles, shape) elif "pikachu" in list(schemes[0].keys())[0]: depiction = depictor.depict_and_resize_pikachu(smiles, shape) elif "cdk" in list(schemes[0].keys())[0]: @@ -2673,7 +2674,7 @@ def __init__(self): # Call every depiction function depiction = self(smiles) depiction = self.cdk_depict(smiles) - depiction = self.depict_and_resize_rdkit(smiles) + depiction = self.rdkit_depict(smiles) depiction = self.depict_and_resize_indigo(smiles) depiction = self.depict_and_resize_pikachu(smiles) # Call augmentation function From efdb29402ddadbd5c303cb1965abda43d8a30744 Mon Sep 17 00:00:00 2001 From: Otto Brinkhaus Date: Fri, 2 Jun 2023 15:24:51 +0200 Subject: [PATCH 4/8] - work in progress - ongoing work on depicitions with cxSMILES --- RanDepict/randepict.py | 174 ++++++++++++++++++++++++++-------------- Tests/test_functions.py | 24 +++--- 2 files changed, 128 insertions(+), 70 deletions(-) diff --git a/RanDepict/randepict.py b/RanDepict/randepict.py index 80cc8f1..f7c4dec 100644 --- a/RanDepict/randepict.py +++ b/RanDepict/randepict.py @@ -33,7 +33,8 @@ from jpype import startJVM, getDefaultJVMPath from jpype import JClass, JVMNotFoundException, isJVMStarted from pikachu.drawing import drawing -from pikachu.smiles.smiles import read_smiles +from pikachu.chem.molfile.read_molfile import MolFileReader + import base64 import cv2 @@ -828,28 +829,30 @@ def get_random_pikachu_rendering_settings( # options.font_size_small = 3 return options - def depict_and_resize_pikachu( - self, smiles: str, shape: Tuple[int, int] = (299, 299) + def pikachu_depict( + self, mol_block: str, shape: Tuple[int, int] = (299, 299) ) -> np.array: """ - This function takes a smiles str and an image shape. + This function takes a mol block str and an image shape. It renders the chemical structures using PIKAChU with random rendering/depiction settings and returns an RGB image (np.array) with the given image shape. Args: - smiles (str): SMILES representation of molecule + mol_block (str): mol block representation of molecule shape (Tuple[int, int], optional): im shape. Defaults to (299, 299) Returns: np.array: Chemical structure depiction """ - structure = read_smiles(smiles) + reader = MolFileReader(mol_block) + structure = reader.molfile_to_structure() + # structure = read_smiles(smiles) depiction_settings = self.get_random_pikachu_rendering_settings() - if "." in smiles: - drawer = drawing.draw_multiple(structure, options=depiction_settings) - else: - drawer = drawing.Drawer(structure, options=depiction_settings) + # if "." in smiles: + # drawer = drawing.draw_multiple(structure, options=depiction_settings) + # else: + drawer = drawing.Drawer(structure, options=depiction_settings) depiction = drawer.get_image_as_array() depiction = self.central_square_image(depiction) depiction = self.resize(depiction, (shape[0], shape[1])) @@ -918,17 +921,17 @@ def get_random_indigo_rendering_settings( indigo.setOption("render-superatom-mode", "collapse") return indigo, renderer - def depict_and_resize_indigo( - self, smiles: str, shape: Tuple[int, int] = (299, 299) + def indigo_depict( + self, mol_block: str, shape: Tuple[int, int] = (299, 299) ) -> np.array: """ - This function takes a smiles str and an image shape. + This function takes a mol block str and an image shape. It renders the chemical structures using Indigo with random rendering/depiction settings and returns an RGB image (np.array) with the given image shape. Args: - smiles (str): SMILES representation of molecule + mol_block (str): mol block representation of molecule shape (Tuple[int, int], optional): im shape. Defaults to (299, 299) Returns: @@ -937,12 +940,14 @@ def depict_and_resize_indigo( # Instantiate Indigo with random settings and IndigoRenderer indigo, renderer = self.get_random_indigo_rendering_settings() # Load molecule + # try: + # if not self.has_r_group(smiles): + # molecule = indigo.loadMolecule(smiles) + # else: + # mol_str = self._smiles_to_mol_block(smiles) + # molecule = indigo.loadMolecule(mol_str) try: - if not self.has_r_group(smiles): - molecule = indigo.loadMolecule(smiles) - else: - mol_str = self._cdk_smiles_to_mol_str(smiles) - molecule = indigo.loadMolecule(mol_str) + molecule = indigo.loadMolecule(mol_block) except IndigoException: return None # Kekulize in 67% of cases @@ -1035,29 +1040,30 @@ def get_random_rdkit_rendering_settings( return depiction_settings def rdkit_depict( - self, smiles: str, shape: Tuple[int, int] = (512, 512) + self, mol_block: str, shape: Tuple[int, int] = (512, 512) ) -> np.array: """ - This function takes a smiles str and an image shape. - It renders the chemical structuresusing Rdkit with random + This function takes a mol_block str and an image shape. + It renders the chemical structures using Rdkit with random rendering/depiction settings and returns an RGB image (np.array) with the given image shape. Args: - smiles (str): SMILES representation of molecule_ + mol block (str): mol block representation of molecule_ shape (Tuple[int, int], optional): im shape. Defaults to (299, 299) Returns: np.array: Chemical structure depiction """ # Load molecule - if not self.has_r_group(smiles): - mol = Chem.MolFromSmiles(smiles) - if self.has_r_group(smiles) or not mol: - mol_str = self._cdk_smiles_to_mol_str(smiles) - mol = Chem.MolFromMolBlock(mol_str) + # if not self.has_r_group(smiles): + # mol = Chem.MolFromSmiles(smiles) + # if self.has_r_group(smiles) or not mol: + # mol_str = self._smiles_to_mol_block(smiles) + # mol = Chem.MolFromMolBlock(mol_str) + mol = Chem.MolFromMolBlock(mol_block, sanitize=False) if mol: - AllChem.Compute2DCoords(mol) + # AllChem.Compute2DCoords(mol) # Abbreviate superatoms if self.random_choice( [True, False], log_attribute="rdkit_collapse_superatoms" @@ -1065,6 +1071,8 @@ def rdkit_depict( abbrevs = self.random_choice(self.get_all_rdkit_abbreviations()) mol = CondenseMolAbbreviations(mol, abbrevs) # Get random depiction settings + # TODO: Fix this provisory nonsense + smiles = "CCCCC" depiction_settings = self.get_random_rdkit_rendering_settings(smiles=smiles) rdMolDraw2D.PrepareAndDrawMolecule(depiction_settings, mol) depiction = depiction_settings.GetDrawingText() @@ -1074,7 +1082,8 @@ def rdkit_depict( depiction = img_as_ubyte(depiction) return np.asarray(depiction) else: - print("RDKit was unable to read input SMILES: {}".format(smiles)) + pass + # print("RDKit was unable to read input SMILES: {}".format(smiles)) def has_r_group(self, smiles: str) -> bool: """ @@ -1351,13 +1360,13 @@ def _cdk_render_molecule( dep_gen = dep_gen.withSize(shape[1], shape[0]) depiction = dep_gen.depict(molecule).toImg() depiction = self._cdk_bufferedimage_to_numpyarray(depiction) - return depiction + return depiction def cdk_depict( - self, smiles: str, shape: Tuple[int, int] = (299, 299) + self, mol_block: str, shape: Tuple[int, int] = (299, 299) ) -> np.array: """ - This function takes a smiles str and an image shape. + This function takes a mol block str and an image shape. It renders the chemical structures using CDK with random rendering/depiction settings and returns an RGB image (np.array) with the given image shape. @@ -1370,16 +1379,17 @@ def cdk_depict( https://github.com/cdk/cdk/wiki/Standard-Generator Args: - smiles (str): SMILES representation of molecule + mol_block (str): SMILES representation of molecule shape (Tuple[int, int], optional): im shape. Defaults to (299, 299) Returns: np.array: Chemical structure depiction """ - # TODO: Find out if adding and removing hydrogens is really necessary - molecule = self._cdk_smiles_to_IAtomContainer(smiles) - molecule = self._cdk_generate_2d_coordinates(molecule) - molecule = self._cdk_rotate_coordinates(molecule) + molecule = self._cdk_mol_block_to_iatomcontainer(mol_block) + # molecule = self._cdk_smiles_to_IAtomContainer(smiles) + # molecule = self._cdk_generate_2d_coordinates(molecule) + # molecule = self._cdk_rotate_coordinates(molecule) + smiles = "C1=CC=CC=C1" depiction = self._cdk_render_molecule(molecule, smiles, shape) return depiction @@ -1402,7 +1412,7 @@ def _cdk_smiles_to_IAtomContainer(self, smiles: str): molecule = SmilesParser.parseSmiles(smiles) return molecule - def _cdk_smiles_to_mol_str( + def _smiles_to_mol_block( self, smiles: str, generate_2d: bool = False, @@ -1417,25 +1427,73 @@ def _cdk_smiles_to_mol_str( Args: smiles (str): SMILES representation of a molecule - generate_2d (bool, optional): Whether to generate 2D coordinates. + generate_2d (bool or str, optional): False if no coordinates are created + Otherwise pick tool for coordinate + generation: + "rdkit", "cdk", "indigo" or "pikachu". Returns: mol_block (str): content of SD file of input molecule """ - molecule = self._cdk_smiles_to_IAtomContainer(smiles) - if generate_2d: + if not generate_2d: + molecule = self._cdk_smiles_to_IAtomContainer(smiles) + mol_block = self._cdk_iatomcontainer_to_mol_block(molecule) + elif generate_2d == "cdk": + molecule = self._cdk_smiles_to_IAtomContainer(smiles) molecule = self._cdk_generate_2d_coordinates(molecule) molecule = self._cdk_rotate_coordinates(molecule) - mol_block = self._cdk_IAtomContainer_to_mol_str(molecule) + mol_block = self._cdk_iatomcontainer_to_mol_block(molecule) + elif generate_2d == "rdkit": + mol_block = self._smiles_to_mol_block(smiles) + molecule = Chem.MolFromMolBlock(mol_block, sanitize=False) + if molecule: + AllChem.Compute2DCoords(molecule) + mol_block = Chem.MolToMolBlock(molecule) + atom_container = self._cdk_mol_block_to_iatomcontainer(mol_block) + atom_container = self._cdk_rotate_coordinates(atom_container) + mol_block = self._cdk_iatomcontainer_to_mol_block(atom_container) + else: + raise ValueError(f"RDKit could not read molecule: {smiles}") + elif generate_2d == "indigo": + indigo = Indigo() + mol_block = self._smiles_to_mol_block(smiles) + molecule = indigo.loadMolecule(mol_block) + molecule.layout() + buf = indigo.writeBuffer() + buf.sdfAppend(molecule) + mol_block = buf.toString() + atom_container = self._cdk_mol_block_to_iatomcontainer(mol_block) + atom_container = self._cdk_rotate_coordinates(atom_container) + mol_block = self._cdk_iatomcontainer_to_mol_block(atom_container) + elif generate_2d == "pikachu": + pass return mol_block - def _cdk_IAtomContainer_to_mol_str(self, i_atom_container) -> str: + def _cdk_mol_block_to_iatomcontainer(self, mol_block: str): + """ + Given a mol block, this function returns an IAtomContainer (JClass) object. + + Args: + mol_block (str): content of MDL MOL file + + Returns: + IAtomContainer: CDK IAtomContainer object that represents the molecule + """ + xyz_reader = JClass("org.openscience.cdk.io.XYZReader") + string_reader = JClass("java.io.StringReader")(mol_block) + reader = xyz_reader(string_reader) + chemfile = reader.read(JClass("org.openscience.cdk.ChemFile")()) + manip = JClass("org.openscience.cdk.tools.manipulator.ChemFileManipulator") + iatomcontainer = manip.getAllAtomContainers(chemfile).get(0) + return iatomcontainer + + def _cdk_iatomcontainer_to_mol_block(self, i_atom_container) -> str: """ This function takes an IAtomContainer object and returns the content of the corresponding MDL MOL file as a string. Args: - i_atom_container (CDK IAtomContainer) + i_atom_container (CDK IAtomContainer (JClass object)) Returns: str: string content of MDL MOL file @@ -1578,38 +1636,38 @@ def get_depiction_functions(self, smiles: str) -> List[Callable]: depiction_functions_registry = { 'rdkit': self.rdkit_depict, - 'indigo': self.depict_and_resize_indigo, + 'indigo': self.indigo_depict, 'cdk': self.cdk_depict, - 'pikachu': self.depict_and_resize_pikachu, + 'pikachu': self.pikachu_depict, } depiction_functions = [depiction_functions_registry[k] for k in self._config.styles] # Remove PIKAChU if there is an isotope if re.search("(\[\d\d\d?[A-Z])|(\[2H\])|(\[3H\])|(D)|(T)", smiles): - depiction_functions.remove(self.depict_and_resize_pikachu) + depiction_functions.remove(self.pikachu_depict) if self.has_r_group(smiles): # PIKAChU only accepts \[[RXZ]\d*\] squared_bracket_content = re.findall("\[.+?\]", smiles) for r_group in squared_bracket_content: if not re.search("\[[RXZ]\d*\]", r_group): - if self.depict_and_resize_pikachu in depiction_functions: - depiction_functions.remove(self.depict_and_resize_pikachu) + if self.pikachu_depict in depiction_functions: + depiction_functions.remove(self.pikachu_depict) # "R", "X", "Z" are not depicted by RDKit # The same is valid for X,Y,Z and a number if self.rdkit_depict in depiction_functions: if re.search("\[[RXZ]\]|\[[XYZ]\d+", smiles): depiction_functions.remove(self.rdkit_depict) # "X", "R0", [RXYZ]\d+[a-f] and indices above 32 are not depicted by Indigo - if self.depict_and_resize_indigo in depiction_functions: + if self.indigo_depict in depiction_functions: if re.search("\[R0\]|\[X\]|[4-9][0-9]+|3[3-9]|[XYZR]\d+[a-f]", smiles): - depiction_functions.remove(self.depict_and_resize_indigo) + depiction_functions.remove(self.indigo_depict) # Workaround because PIKAChU fails to depict large structures # TODO: Delete workaround when problem is fixed in PIKAChU # https://github.com/BTheDragonMaster/pikachu/issues/11 if len(smiles) > 100: - if self.depict_and_resize_pikachu in depiction_functions: - depiction_functions.remove(self.depict_and_resize_pikachu) + if self.pikachu_depict in depiction_functions: + depiction_functions.remove(self.pikachu_depict) return depiction_functions def resize(self, image: np.array, shape: Tuple[int], HQ: bool = False) -> np.array: @@ -2437,11 +2495,11 @@ def depict_from_fingerprint( self.active_scheme = schemes[0] # Depict molecule if "indigo" in list(schemes[0].keys())[0]: - depiction = depictor.depict_and_resize_indigo(smiles, shape) + depiction = depictor.indigo_depict(smiles, shape) elif "rdkit" in list(schemes[0].keys())[0]: depiction = depictor.rdkit_depict(smiles, shape) elif "pikachu" in list(schemes[0].keys())[0]: - depiction = depictor.depict_and_resize_pikachu(smiles, shape) + depiction = depictor.pikachu_depict(smiles, shape) elif "cdk" in list(schemes[0].keys())[0]: depiction = depictor.cdk_depict(smiles, shape) @@ -2675,8 +2733,8 @@ def __init__(self): depiction = self(smiles) depiction = self.cdk_depict(smiles) depiction = self.rdkit_depict(smiles) - depiction = self.depict_and_resize_indigo(smiles) - depiction = self.depict_and_resize_pikachu(smiles) + depiction = self.indigo_depict(smiles) + depiction = self.pikachu_depict(smiles) # Call augmentation function depiction = self.add_augmentations(depiction) # Generate schemes for Fingerprint creation diff --git a/Tests/test_functions.py b/Tests/test_functions.py index cf274c3..2fe40f4 100644 --- a/Tests/test_functions.py +++ b/Tests/test_functions.py @@ -355,7 +355,7 @@ def test_depict_and_resize_indigo(self): '[Otto]C1=C([XYZ123])C([R1])=C([Y])C([X1])=C1[R]', 'c1ccccc1[R1]'] for smiles in test_smiles: - im = self.depictor.depict_and_resize_indigo(smiles) + im = self.depictor.indigo_depict(smiles) assert type(im) == np.ndarray def test_depict_and_resize_rdkit(self): @@ -364,7 +364,7 @@ def test_depict_and_resize_rdkit(self): test_smiles = ['c1ccccc1', '[Otto]C1=C([XYZ123])C([R1])=C([Y])C([X])=C1[R]'] for smiles in test_smiles: - im = self.depictor.depict_and_resize_rdkit(smiles) + im = self.depictor.rdkit_depict(smiles) assert type(im) == np.ndarray def test_cdk_depict(self): @@ -382,17 +382,17 @@ def test_depict_and_resize_pikachu(self): test_smiles = ['c1ccccc1', '[R1]C1=C([X23])C([R])=C([Z])C([X])=C1[R]'] for smiles in test_smiles: - im = self.depictor.depict_and_resize_pikachu(smiles) + im = self.depictor.pikachu_depict(smiles) assert type(im) == np.ndarray def test_get_depiction_functions_normal(self): # For a molecule without isotopes or R groups, all toolkits can be used observed = self.depictor.get_depiction_functions('c1ccccc1C(O)=O') expected = [ - self.depictor.depict_and_resize_rdkit, - self.depictor.depict_and_resize_indigo, + self.depictor.rdkit_depict, + self.depictor.indigo_depict, self.depictor.cdk_depict, - self.depictor.depict_and_resize_pikachu, + self.depictor.pikachu_depict, ] # symmetric_difference difference = set(observed) ^ set(expected) @@ -402,8 +402,8 @@ def test_get_depiction_functions_isotopes(self): # PIKAChU can't handle isotopes observed = self.depictor.get_depiction_functions("[13CH3]N1C=NC2=C1C(=O)N(C(=O)N2C)C") expected = [ - self.depictor.depict_and_resize_rdkit, - self.depictor.depict_and_resize_indigo, + self.depictor.rdkit_depict, + self.depictor.indigo_depict, self.depictor.cdk_depict, ] difference = set(observed) ^ set(expected) @@ -413,9 +413,9 @@ def test_get_depiction_functions_R(self): # RDKit depicts "R" without indices as '*' (which is not desired) observed = self.depictor.get_depiction_functions("[R]N1C=NC2=C1C(=O)N(C(=O)N2C)C") expected = [ - self.depictor.depict_and_resize_indigo, + self.depictor.indigo_depict, self.depictor.cdk_depict, - self.depictor.depict_and_resize_pikachu, + self.depictor.pikachu_depict, ] difference = set(observed) ^ set(expected) assert not difference @@ -425,14 +425,14 @@ def test_get_depiction_functions_X(self): observed = self.depictor.get_depiction_functions("[X]N1C=NC2=C1C(=O)N(C(=O)N2C)C") expected = [ self.depictor.cdk_depict, - self.depictor.depict_and_resize_pikachu, + self.depictor.pikachu_depict, ] difference = set(observed) ^ set(expected) assert not difference def test_smiles_to_mol_str(self): # Compare generated mol file str with reference string - mol_str = self.depictor.smiles_to_mol_str("CC") + mol_str = self.depictor._cdk_smiles_to_mol_block("CC") mol_str_lines = mol_str.split('\n') with open('Tests/test.mol', 'r') as ref_mol_file: ref_lines = ref_mol_file.readlines() From fb555ad1f24935c860d9eb52053bc04782709802 Mon Sep 17 00:00:00 2001 From: Otto Brinkhaus Date: Mon, 5 Jun 2023 16:48:41 +0200 Subject: [PATCH 5/8] depiction from mol block with coordinates (does not work for PIKAChU) --- RanDepict/randepict.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/RanDepict/randepict.py b/RanDepict/randepict.py index f7c4dec..07b6029 100644 --- a/RanDepict/randepict.py +++ b/RanDepict/randepict.py @@ -845,7 +845,7 @@ def pikachu_depict( Returns: np.array: Chemical structure depiction """ - reader = MolFileReader(mol_block) + reader = MolFileReader(molfile_str=mol_block) structure = reader.molfile_to_structure() # structure = read_smiles(smiles) depiction_settings = self.get_random_pikachu_rendering_settings() @@ -955,7 +955,7 @@ def indigo_depict( [True, True, False], log_attribute="indigo_kekulized" ): molecule.aromatize() - molecule.layout() + # molecule.layout() # Write to buffer temp = renderer.renderToBuffer(molecule) temp = io.BytesIO(temp) @@ -1024,7 +1024,7 @@ def get_random_rdkit_rendering_settings( depiction_settings.drawOptions().minFontSize = min_font_size depiction_settings.drawOptions().maxFontSize = 30 # Rotate the molecule - depiction_settings.drawOptions().rotate = self.random_choice(range(360)) + # depiction_settings.drawOptions().rotate = self.random_choice(range(360)) # Fixed bond length fixed_bond_length = self.random_choice( range(30, 45), log_attribute="rdkit_fixed_bond_length" @@ -1061,7 +1061,7 @@ def rdkit_depict( # if self.has_r_group(smiles) or not mol: # mol_str = self._smiles_to_mol_block(smiles) # mol = Chem.MolFromMolBlock(mol_str) - mol = Chem.MolFromMolBlock(mol_block, sanitize=False) + mol = Chem.MolFromMolBlock(mol_block, sanitize=True) if mol: # AllChem.Compute2DCoords(mol) # Abbreviate superatoms @@ -1358,6 +1358,7 @@ def _cdk_render_molecule( """ dep_gen, molecule = self._cdk_get_depiction_generator(molecule, smiles) dep_gen = dep_gen.withSize(shape[1], shape[0]) + dep_gen = dep_gen.withFillToFit() depiction = dep_gen.depict(molecule).toImg() depiction = self._cdk_bufferedimage_to_numpyarray(depiction) return depiction @@ -1479,12 +1480,18 @@ def _cdk_mol_block_to_iatomcontainer(self, mol_block: str): Returns: IAtomContainer: CDK IAtomContainer object that represents the molecule """ - xyz_reader = JClass("org.openscience.cdk.io.XYZReader") + # xyz_reader = JClass("org.openscience.cdk.io.XYZReader") + scob = JClass("org.openscience.cdk.silent.SilentChemObjectBuilder") + bldr = scob.getInstance() + iac_class = JClass("org.openscience.cdk.interfaces.IAtomContainer").class_ string_reader = JClass("java.io.StringReader")(mol_block) - reader = xyz_reader(string_reader) - chemfile = reader.read(JClass("org.openscience.cdk.ChemFile")()) - manip = JClass("org.openscience.cdk.tools.manipulator.ChemFileManipulator") - iatomcontainer = manip.getAllAtomContainers(chemfile).get(0) + mdlr = JClass("org.openscience.cdk.io.MDLV2000Reader")(string_reader) + iatomcontainer = mdlr.read(bldr.newInstance(iac_class)) + mdlr.close() + # reader = xyz_reader(string_reader) + # chemfile = reader.read(JClass("org.openscience.cdk.ChemFile")()) + # manip = JClass("org.openscience.cdk.tools.manipulator.ChemFileManipulator") + # iatomcontainer = manip.getAllAtomContainers(chemfile).get(0) return iatomcontainer def _cdk_iatomcontainer_to_mol_block(self, i_atom_container) -> str: From f7bca573ce57812c3bc463b5d74485a3cacbfe81 Mon Sep 17 00:00:00 2001 From: Otto Brinkhaus Date: Thu, 15 Jun 2023 14:43:21 +0200 Subject: [PATCH 6/8] major restructuring + cxsmiles coordinate generation --- RanDepict/__init__.py | 6 +- RanDepict/augmentations.py | 1152 +++ RanDepict/cdk_functionalities.py | 441 ++ RanDepict/config.py | 38 + RanDepict/depiction_feature_ranges.py | 565 ++ RanDepict/import augmentations.py | 1152 +++ RanDepict/indigo_functionalities.py | 126 + RanDepict/pikachu_functionalities.py | 75 + RanDepict/randepict.py | 3108 +------- .../random_markush_structure_generator.py | 188 + RanDepict/rdkit_functionalities.py | 242 + Tests/test_functions.py | 18 +- examples/RanDepictNotebook.ipynb | 6633 ++++++++++------- ..._dataset_with_and_without_augmentations.py | 6 +- .../randepict_batch_run_tfrecord_output.py | 40 +- 15 files changed, 8189 insertions(+), 5601 deletions(-) create mode 100644 RanDepict/augmentations.py create mode 100644 RanDepict/cdk_functionalities.py create mode 100644 RanDepict/config.py create mode 100644 RanDepict/depiction_feature_ranges.py create mode 100644 RanDepict/import augmentations.py create mode 100644 RanDepict/indigo_functionalities.py create mode 100644 RanDepict/pikachu_functionalities.py create mode 100644 RanDepict/random_markush_structure_generator.py create mode 100644 RanDepict/rdkit_functionalities.py diff --git a/RanDepict/__init__.py b/RanDepict/__init__.py index a4d03cc..d65b6e7 100644 --- a/RanDepict/__init__.py +++ b/RanDepict/__init__.py @@ -27,5 +27,7 @@ "RanDepict", ] - -from .randepict import RandomDepictor, RandomDepictorConfig, DepictionFeatureRanges, RandomMarkushStructureCreator +from .config import RandomDepictorConfig +from .depiction_feature_ranges import DepictionFeatureRanges +from .randepict import RandomDepictor +from .random_markush_structure_generator import RandomMarkushStructureCreator diff --git a/RanDepict/augmentations.py b/RanDepict/augmentations.py new file mode 100644 index 0000000..d3bf7eb --- /dev/null +++ b/RanDepict/augmentations.py @@ -0,0 +1,1152 @@ +from copy import deepcopy +import cv2 +import imgaug.augmenters as iaa +import numpy as np +import os +from PIL import Image, ImageEnhance, ImageFont, ImageDraw, ImageStat +from scipy.ndimage import gaussian_filter +from scipy.ndimage import map_coordinates +from skimage.color import rgb2gray +from skimage.util import img_as_float +from typing import Tuple + + +class Augmentations: + def resize(self, image: np.array, shape: Tuple[int], HQ: bool = False) -> np.array: + """ + This function takes an image (np.array) and a shape and returns + the resized image (np.array). It uses Pillow to do this, as it + seems to have a bigger variety of scaling methods than skimage. + The up/downscaling method is chosen randomly. + + Args: + image (np.array): the input image + shape (Tuple[int, int], optional): im shape. Defaults to (299, 299) + HQ (bool): if true, only choose from Image.BICUBIC, Image.LANCZOS + ___ + Returns: + np.array: the resized image + + """ + image = Image.fromarray(image) + shape = (shape[0], shape[1]) + if not HQ: + image = image.resize( + shape, resample=self.random_choice(self.PIL_resize_methods) + ) + else: + image = image = image.resize( + shape, resample=self.random_choice(self.PIL_HQ_resize_methods) + ) + + return np.asarray(image) + + def imgaug_augment( + self, + image: np.array, + ) -> np.array: + """ + This function applies a random amount of augmentations to + a given image (np.array) using and returns the augmented image + (np.array). + + Args: + image (np.array): input image + + Returns: + np.array: output image (augmented) + """ + original_shape = image.shape + + # Choose number of augmentations to apply (0-2); + # return image if nothing needs to be done. + aug_number = self.random_choice(range(0, 3)) + if not aug_number: + return image + + # Add some padding to avoid weird artifacts after rotation + image = np.pad( + image, ((1, 1), (1, 1), (0, 0)), mode="constant", constant_values=255 + ) + + def imgaug_rotation(): + # Rotation between -10 and 10 degrees + if not self.random_choice( + [True, True, False], log_attribute="has_imgaug_rotation" + ): + return False + rot_angle = self.random_choice(np.arange(-10, 10, 1)) + aug = iaa.Affine(rotate=rot_angle, mode="edge", fit_output=True) + return aug + + def imgaug_black_and_white_noise(): + # Black and white noise + if not self.random_choice( + [True, True, False], log_attribute="has_imgaug_salt_pepper" + ): + return False + coarse_dropout_p = self.random_choice(np.arange(0.0002, 0.0015, 0.0001)) + coarse_dropout_size_percent = self.random_choice(np.arange(1.0, 1.1, 0.01)) + replace_elementwise_p = self.random_choice(np.arange(0.01, 0.3, 0.01)) + aug = iaa.Sequential( + [ + iaa.CoarseDropout( + coarse_dropout_p, size_percent=coarse_dropout_size_percent + ), + iaa.ReplaceElementwise(replace_elementwise_p, 255), + ] + ) + return aug + + '''def imgaug_shearing(): + # Shearing + if not self.random_choice( + [True, True, False], log_attribute="has_imgaug_shearing" + ): + return False + shear_param = self.random_choice(np.arange(-5, 5, 1)) + aug = self.random_choice( + [ + iaa.geometric.ShearX(shear_param, mode="edge", fit_output=True), + iaa.geometric.ShearY(shear_param, mode="edge", fit_output=True), + ] + ) + return aug''' + + def imgaug_imgcorruption(): + # Jpeg compression or pixelation + if not self.random_choice( + [True, True, False], log_attribute="has_imgaug_corruption" + ): + return False + imgcorrupt_severity = self.random_choice(np.arange(1, 2, 1)) + aug = self.random_choice( + [ + iaa.imgcorruptlike.JpegCompression(severity=imgcorrupt_severity), + iaa.imgcorruptlike.Pixelate(severity=imgcorrupt_severity), + ] + ) + return aug + + def imgaug_brightness_adjustment(): + # Brightness adjustment + if not self.random_choice( + [True, True, False], log_attribute="has_imgaug_brightness_adj" + ): + return False + brightness_adj_param = self.random_choice(np.arange(-50, 50, 1)) + aug = iaa.WithBrightnessChannels(iaa.Add(brightness_adj_param)) + return aug + + def imgaug_colour_temp_adjustment(): + # Colour temperature adjustment + if not self.random_choice( + [True, True, False], log_attribute="has_imgaug_col_adj" + ): + return False + colour_temp = self.random_choice(np.arange(1100, 10000, 1)) + aug = iaa.ChangeColorTemperature(colour_temp) + return aug + + # Define list of available augmentations + aug_list = [ + imgaug_rotation, + imgaug_black_and_white_noise, + # imgaug_shearing, # Disabled as is shifts the coordinates in the image + imgaug_imgcorruption, + imgaug_brightness_adjustment, + imgaug_colour_temp_adjustment, + ] + + # Every one of them has a 1/3 chance of returning False + aug_list = [fun() for fun in aug_list] + aug_list = [fun for fun in aug_list if fun] + aug = iaa.Sequential(aug_list) + augmented_image = aug.augment_images([image])[0] + augmented_image = self.resize(augmented_image, original_shape) + augmented_image = augmented_image.astype(np.uint8) + return augmented_image + + def add_augmentations(self, depiction: np.array) -> np.array: + """ + This function takes a chemical structure depiction (np.array) + and returns the same image with added augmentation elements + + Args: + depiction (np.array): chemical structure depiction + + Returns: + np.array: chemical structure depiction with added augmentations + """ + if self.random_choice( + [True, False, False, False, False, False], log_attribute="has_curved_arrows" + ): + depiction = self.add_curved_arrows_to_structure(depiction) + if self.random_choice( + [True, False, False], log_attribute="has_straight_arrows" + ): + depiction = self.add_straight_arrows_to_structure(depiction) + if self.random_choice( + [True, False, False, False, False, False], log_attribute="has_id_label" + ): + depiction = self.add_chemical_label(depiction, "ID") + if self.random_choice( + [True, False, False, False, False, False], log_attribute="has_R_group_label" + ): + depiction = self.add_chemical_label(depiction, "R_GROUP") + if self.random_choice( + [True, False, False, False, False, False], + log_attribute="has_reaction_label", + ): + depiction = self.add_chemical_label(depiction, "REACTION") + depiction = self.imgaug_augment(depiction) + return depiction + + def get_random_label_position(self, width: int, height: int) -> Tuple[int, int]: + """ + Given the width and height of an image (int), this function + determines a random position in the outer 15% of the image and + returns a tuple that contain the coordinates (y,x) of that position. + + Args: + width (int): image width + height (int): image height + + Returns: + Tuple[int, int]: Random label position + """ + if self.random_choice([True, False]): + y_range = range(0, height) + x_range = list(range(0, int(0.15 * width))) + list( + range(int(0.85 * width), width) + ) + else: + y_range = list(range(0, int(0.15 * height))) + list( + range(int(0.85 * height), height) + ) + x_range = range(0, width) + return self.random_choice(y_range), self.random_choice(x_range) + + def ID_label_text(self) -> str: + """ + This function returns a string that resembles a typical + chemical ID label + + Returns: + str: Label text + """ + label_num = range(1, 50) + label_letters = [ + "a", + "b", + "c", + "d", + "e", + "f", + "g", + "i", + "j", + "k", + "l", + "m", + "n", + "o", + ] + options = [ + "only_number", + "num_letter_combination", + "numtonum", + "numcombtonumcomb", + ] + option = self.random_choice(options) + if option == "only_number": + return str(self.random_choice(label_num)) + if option == "num_letter_combination": + return str(self.random_choice(label_num)) + self.random_choice( + label_letters + ) + if option == "numtonum": + return ( + str(self.random_choice(label_num)) + + "-" + + str(self.random_choice(label_num)) + ) + if option == "numcombtonumcomb": + return ( + str(self.random_choice(label_num)) + + self.random_choice(label_letters) + + "-" + + self.random_choice(label_letters) + ) + + def new_reaction_condition_elements(self) -> Tuple[str, str, str]: + """ + Randomly redefine reaction_time, solvent and other_reactand. + + Returns: + Tuple[str, str, str]: Reaction time, solvent, reactand + """ + reaction_time = self.random_choice( + [str(num) for num in range(30)] + ) + self.random_choice([" h", " min"]) + solvent = self.random_choice( + [ + "MeOH", + "EtOH", + "CHCl3", + "DCM", + "iPrOH", + "MeCN", + "DMSO", + "pentane", + "hexane", + "benzene", + "Et2O", + "THF", + "DMF", + ] + ) + other_reactand = self.random_choice( + [ + "HF", + "HCl", + "HBr", + "NaOH", + "Et3N", + "TEA", + "Ac2O", + "DIBAL", + "DIBAL-H", + "DIPEA", + "DMAP", + "EDTA", + "HOBT", + "HOAt", + "TMEDA", + "p-TsOH", + "Tf2O", + ] + ) + return reaction_time, solvent, other_reactand + + def reaction_condition_label_text(self) -> str: + """ + This function returns a random string that looks like a + reaction condition label. + + Returns: + str: Reaction condition label text + """ + reaction_condition_label = "" + label_type = self.random_choice(["A", "B", "C", "D"]) + if label_type in ["A", "B"]: + for n in range(self.random_choice(range(1, 5))): + ( + reaction_time, + solvent, + other_reactand, + ) = self.new_reaction_condition_elements() + if label_type == "A": + reaction_condition_label += ( + str(n + 1) + + " " + + other_reactand + + ", " + + solvent + + ", " + + reaction_time + + "\n" + ) + elif label_type == "B": + reaction_condition_label += ( + str(n + 1) + + " " + + other_reactand + + ", " + + solvent + + " (" + + reaction_time + + ")\n" + ) + elif label_type == "C": + ( + reaction_time, + solvent, + other_reactand, + ) = self.new_reaction_condition_elements() + reaction_condition_label += ( + other_reactand + "\n" + solvent + "\n" + reaction_time + ) + elif label_type == "D": + reaction_condition_label += self.random_choice( + self.new_reaction_condition_elements() + ) + return reaction_condition_label + + def make_R_group_str(self) -> str: + """ + This function returns a random string that looks like an R group label. + It generates them by inserting randomly chosen elements into one of + five templates. + + Returns: + str: R group label text + """ + rest_variables = [ + "X", + "Y", + "Z", + "R", + "R1", + "R2", + "R3", + "R4", + "R5", + "R6", + "R7", + "R8", + "R9", + "R10", + "Y2", + "D", + ] + # Load list of superatoms (from OSRA) + superatoms = self.superatoms + label_type = self.random_choice(["A", "B", "C", "D", "E"]) + R_group_label = "" + if label_type == "A": + for _ in range(1, self.random_choice(range(2, 6))): + R_group_label += ( + self.random_choice(rest_variables) + + " = " + + self.random_choice(superatoms) + + "\n" + ) + elif label_type == "B": + R_group_label += " " + self.random_choice(rest_variables) + "\n" + for n in range(1, self.random_choice(range(2, 6))): + R_group_label += str(n) + " " + self.random_choice(superatoms) + "\n" + elif label_type == "C": + R_group_label += ( + " " + + self.random_choice(rest_variables) + + " " + + self.random_choice(rest_variables) + + "\n" + ) + for n in range(1, self.random_choice(range(2, 6))): + R_group_label += ( + str(n) + + " " + + self.random_choice(superatoms) + + " " + + self.random_choice(superatoms) + + "\n" + ) + elif label_type == "D": + R_group_label += ( + " " + + self.random_choice(rest_variables) + + " " + + self.random_choice(rest_variables) + + " " + + self.random_choice(rest_variables) + + "\n" + ) + for n in range(1, self.random_choice(range(2, 6))): + R_group_label += ( + str(n) + + " " + + self.random_choice(superatoms) + + " " + + self.random_choice(superatoms) + + " " + + self.random_choice(superatoms) + + "\n" + ) + if label_type == "E": + for n in range(1, self.random_choice(range(2, 6))): + R_group_label += ( + str(n) + + " " + + self.random_choice(rest_variables) + + " = " + + self.random_choice(superatoms) + + "\n" + ) + return R_group_label + + def add_chemical_label( + self, image: np.array, label_type: str, foreign_fonts: bool = True + ) -> np.array: + """ + This function takes an image (np.array) and adds random text that + looks like a chemical ID label, an R group label or a reaction + condition label around the structure. It returns the modified image. + The label type is determined by the parameter label_type (str), + which needs to be 'ID', 'R_GROUP' or 'REACTION' + + Args: + image (np.array): Chemical structure depiction + label_type (str): 'ID', 'R_GROUP' or 'REACTION' + foreign_fonts (bool, optional): Defaults to True. + + Returns: + np.array: Chemical structure depiction with label + """ + im = Image.fromarray(image) + orig_image = deepcopy(im) + width, height = im.size + # Choose random font + if self.random_choice([True, False]) or not foreign_fonts: + font_dir = self.HERE.joinpath("fonts/") + # In half of the cases: Use foreign-looking font to generate + # bigger noise variety + else: + font_dir = self.HERE.joinpath("foreign_fonts/") + + fonts = os.listdir(str(font_dir)) + # Choose random font size + font_sizes = range(10, 20) + size = self.random_choice(font_sizes) + # Generate random string that resembles the desired type of label + if label_type == "ID": + label_text = self.ID_label_text() + if label_type == "R_GROUP": + label_text = self.make_R_group_str() + if label_type == "REACTION": + label_text = self.reaction_condition_label_text() + + try: + font = ImageFont.truetype( + str(os.path.join(str(font_dir), self.random_choice(fonts))), size=size + ) + except OSError: + font = ImageFont.load_default() + + draw = ImageDraw.Draw(im, "RGBA") + + # Try different positions with the condition that the label´does not + # overlap with non-white pixels (the structure) + for _ in range(50): + y_pos, x_pos = self.get_random_label_position(width, height) + bounding_box = draw.textbbox( + (x_pos, y_pos), label_text, font=font + ) # left, up, right, low + paste_region = orig_image.crop(bounding_box) + try: + mean = ImageStat.Stat(paste_region).mean + except ZeroDivisionError: + return np.asarray(im) + if sum(mean) / len(mean) == 255: + draw.text((x_pos, y_pos), label_text, font=font, fill=(0, 0, 0, 255)) + break + return np.asarray(im) + + def add_curved_arrows_to_structure(self, image: np.array) -> np.array: + """ + This function takes an image of a chemical structure (np.array) + and adds between 2 and 4 curved arrows in random positions in the + central part of the image. + + Args: + image (np.array): Chemical structure depiction + + Returns: + np.array: Chemical structure depiction with curved arrows + """ + height, width, _ = image.shape + image = Image.fromarray(image) + orig_image = deepcopy(image) + # Determine area where arrows are pasted. + x_min, x_max = (int(0.1 * width), int(0.9 * width)) + y_min, y_max = (int(0.1 * height), int(0.9 * height)) + + arrow_dir = os.path.normpath( + str(self.HERE.joinpath("arrow_images/curved_arrows/")) + ) + + for _ in range(self.random_choice(range(2, 4))): + # Load random curved arrow image, resize and rotate it randomly. + arrow_image = Image.open( + os.path.join( + str(arrow_dir), self.random_choice(os.listdir(str(arrow_dir))) + ) + ) + new_arrow_image_shape = int( + (x_max - x_min) / self.random_choice(range(3, 6)) + ), int((y_max - y_min) / self.random_choice(range(3, 6))) + arrow_image = self.resize(np.asarray(arrow_image), new_arrow_image_shape) + arrow_image = Image.fromarray(arrow_image) + arrow_image = arrow_image.rotate( + self.random_choice(range(360)), + resample=self.random_choice( + [Image.BICUBIC, Image.NEAREST, Image.BILINEAR] + ), + expand=True, + ) + # Try different positions with the condition that the arrows are + # overlapping with non-white pixels (the structure) + for _ in range(50): + x_position = self.random_choice( + range(x_min, x_max - new_arrow_image_shape[0]) + ) + y_position = self.random_choice( + range(y_min, y_max - new_arrow_image_shape[1]) + ) + paste_region = orig_image.crop( + ( + x_position, + y_position, + x_position + new_arrow_image_shape[0], + y_position + new_arrow_image_shape[1], + ) + ) + mean = ImageStat.Stat(paste_region).mean + if sum(mean) / len(mean) < 252: + image.paste(arrow_image, (x_position, y_position), arrow_image) + + break + return np.asarray(image) + + def get_random_arrow_position(self, width: int, height: int) -> Tuple[int, int]: + """ + Given the width and height of an image (int), this function determines + a random position to paste a reaction arrow in the outer 15% frame of + the image + + Args: + width (_type_): image width + height (_type_): image height + + Returns: + Tuple[int, int]: Random arrow position + """ + if self.random_choice([True, False]): + y_range = range(0, height) + x_range = list(range(0, int(0.15 * width))) + list( + range(int(0.85 * width), width) + ) + else: + y_range = list(range(0, int(0.15 * height))) + list( + range(int(0.85 * height), height) + ) + x_range = range(0, int(0.5 * width)) + return self.random_choice(y_range), self.random_choice(x_range) + + def add_straight_arrows_to_structure(self, image: np.array) -> np.array: + """ + This function takes an image of a chemical structure (np.array) + and adds between 1 and 2 straight arrows in random positions in the + image (no overlap with other elements) + + Args: + image (np.array): Chemical structure depiction + + Returns: + np.array: Chemical structure depiction with straight arrow + """ + height, width, _ = image.shape + image = Image.fromarray(image) + + arrow_dir = os.path.normpath( + str(self.HERE.joinpath("arrow_images/straight_arrows/")) + ) + + for _ in range(self.random_choice(range(1, 3))): + # Load random curved arrow image, resize and rotate it randomly. + arrow_image = Image.open( + os.path.join( + str(arrow_dir), self.random_choice(os.listdir(str(arrow_dir))) + ) + ) + # new_arrow_image_shape = (int(width * + # self.random_choice(np.arange(0.9, 1.5, 0.1))), + # int(height/10 * self.random_choice(np.arange(0.7, 1.2, 0.1)))) + + # arrow_image = arrow_image.resize(new_arrow_image_shape, + # resample=Image.BICUBIC) + # Rotate completely randomly in half of the cases and in 180° steps + # in the other cases (higher probability that pasting works) + if self.random_choice([True, False]): + arrow_image = arrow_image.rotate( + self.random_choice(range(360)), + resample=self.random_choice( + [Image.Resampling.BICUBIC, Image.Resampling.NEAREST, Image.Resampling.BILINEAR] + ), + expand=True, + ) + else: + arrow_image = arrow_image.rotate(self.random_choice([180, 360])) + new_arrow_image_shape = arrow_image.size + # Try different positions with the condition that the arrows are + # overlapping with non-white pixels (the structure) + for _ in range(50): + y_position, x_position = self.get_random_arrow_position(width, height) + x2_position = x_position + new_arrow_image_shape[0] + y2_position = y_position + new_arrow_image_shape[1] + # Make sure we only check a region inside of the image + if x2_position > width: + x2_position = width - 1 + if y2_position > height: + y2_position = height - 1 + paste_region = image.crop( + (x_position, y_position, x2_position, y2_position) + ) + try: + mean = ImageStat.Stat(paste_region).mean + if sum(mean) / len(mean) == 255: + image.paste(arrow_image, (x_position, y_position), arrow_image) + break + except ZeroDivisionError: + pass + return np.asarray(image) + + def to_grayscale_float_img(self, image: np.array) -> np.array: + """ + This function takes an image (np.array), converts it to grayscale + and returns it. + + Args: + image (np.array): image + + Returns: + np.array: grayscale float image + """ + return img_as_float(rgb2gray(image)) + + def hand_drawn_augment(self, img) -> np.array: + """ + This function randomly applies different image augmentations with + different probabilities to the input image. + + It has been modified from the original augment.py present on + https://github.com/mtzgroup/ChemPixCH + + From the publication: + https://pubs.rsc.org/en/content/articlelanding/2021/SC/D1SC02957F + + Args: + img: the image to modify in array format. + Returns: + img: the augmented image. + """ + # resize + if self.random_choice(np.arange(0, 1, 0.01)) < 0.5: + img = self.resize_hd(img) + # blur + if self.random_choice(np.arange(0, 1, 0.01)) < 0.4: + img = self.blur(img) + # erode + if self.random_choice(np.arange(0, 1, 0.01)) < 0.4: + img = self.erode(img) + # dilate + if self.random_choice(np.arange(0, 1, 0.01)) < 0.4: + img = self.dilate(img) + # aspect_ratio + if self.random_choice(np.arange(0, 1, 0.01)) < 0.7: + img = self.aspect_ratio(img, "mol") + # affine + if self.random_choice(np.arange(0, 1, 0.01)) < 0.7: + img = self.affine(img, "mol") + # distort + if self.random_choice(np.arange(0, 1, 0.01)) < 0.8: + img = self.distort(img) + if img.shape != (255, 255, 3): + img = cv2.resize(img, (256, 256)) + return img + + def augment_bkg(self, img) -> np.array: + """ + This function randomly applies different image augmentations with + different probabilities to the input image. + Args: + img: the image to modify in array format. + Returns: + img: the augmented image. + """ + # rotate + rows, cols, _ = img.shape + angle = self.random_choice(np.arange(0, 360)) + M = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, 1) + img = cv2.warpAffine(img, M, (cols, rows), borderMode=cv2.BORDER_REFLECT) + # resize + if self.random_choice(np.arange(0, 1, 0.01)) < 0.5: + img = self.resize_hd(img) + # blur + if self.random_choice(np.arange(0, 1, 0.01)) < 0.4: + img = self.blur(img) + # erode + if self.random_choice(np.arange(0, 1, 0.01)) < 0.2: + img = self.erode(img) + # dilate + if self.random_choice(np.arange(0, 1, 0.01)) < 0.2: + img = self.dilate(img) + # aspect_ratio + if self.random_choice(np.arange(0, 1, 0.01)) < 0.3: + img = self.aspect_ratio(img, "bkg") + # affine + if self.random_choice(np.arange(0, 1, 0.01)) < 0.3: + img = self.affine(img, "bkg") + # distort + if self.random_choice(np.arange(0, 1, 0.01)) < 0.8: + img = self.distort(img) + if img.shape != (255, 255, 3): + img = cv2.resize(img, (256, 256)) + return img + + def resize_hd(self, img) -> np.array: + """ + This function resizes the image randomly from between (200-300, 200-300) + and then resizes it back to 256x256. + Args: + img: the image to modify in array format. + Returns: + img: the resized image. + """ + interpolations = [ + cv2.INTER_NEAREST, + cv2.INTER_AREA, + cv2.INTER_LINEAR, + cv2.INTER_CUBIC, + cv2.INTER_LANCZOS4, + ] + + img = cv2.resize( + img, + ( + self.random_choice(np.arange(200, 300)), + self.random_choice(np.arange(200, 300)), + ), + interpolation=self.random_choice(interpolations), + ) + img = cv2.resize( + img, (256, 256), interpolation=self.random_choice(interpolations) + ) + + return img + + def blur(self, img) -> np.array: + """ + This function blurs the image randomly between 1-3. + Args: + img: the image to modify in array format. + Returns: + img: the blurred image. + """ + n = self.random_choice(np.arange(1, 4)) + kernel = np.ones((n, n), np.float32) / n**2 + img = cv2.filter2D(img, -1, kernel) + return img + + def erode(self, img) -> np.array: + """ + This function bolds the image randomly between 1-2. + Args: + img: the image to modify in array format. + Returns: + img: the bold image. + """ + n = self.random_choice(np.arange(1, 3)) + kernel = np.ones((n, n), np.float32) / n**2 + img = cv2.erode(img, kernel, iterations=1) + return img + + def dilate(self, img) -> np.array: + """ + This function dilates the image with a factor of 2. + Args: + img: the image to modify in array format. + Returns: + img: the dilated image. + """ + n = 2 + kernel = np.ones((n, n), np.float32) / n**2 + img = cv2.dilate(img, kernel, iterations=1) + return img + + def aspect_ratio(self, img, obj=None) -> np.array: + """ + This function irregularly changes the size of the image + and converts it back to (256,256). + Args: + img: the image to modify in array format. + obj: "mol" or "bkg" to modify a chemical structure image or + a background image. + Returns: + image: the resized image. + """ + n1 = self.random_choice(np.arange(0, 50)) + n2 = self.random_choice(np.arange(0, 50)) + n3 = self.random_choice(np.arange(0, 50)) + n4 = self.random_choice(np.arange(0, 50)) + if obj == "mol": + image = cv2.copyMakeBorder( + img, n1, n2, n3, n4, cv2.BORDER_CONSTANT, value=[255, 255, 255] + ) + elif obj == "bkg": + image = cv2.copyMakeBorder(img, n1, n2, n3, n4, cv2.BORDER_REFLECT) + + image = cv2.resize(image, (256, 256)) + return image + + def affine(self, img, obj=None) -> np.array: + """ + This function randomly applies affine transformation which consists + of matrix rotations, translations and scale operations and converts + it back to (256,256). + Args: + img: the image to modify in array format. + obj: "mol" or "bkg" to modify a chemical structure image or + a background image. + Returns: + skewed: the transformed image. + """ + rows, cols, _ = img.shape + n = 20 + pts1 = np.float32([[5, 50], [200, 50], [50, 200]]) + pts2 = np.float32( + [ + [ + 5 + self.random_choice(np.arange(-n, n)), + 50 + self.random_choice(np.arange(-n, n)), + ], + [ + 200 + self.random_choice(np.arange(-n, n)), + 50 + self.random_choice(np.arange(-n, n)), + ], + [ + 50 + self.random_choice(np.arange(-n, n)), + 200 + self.random_choice(np.arange(-n, n)), + ], + ] + ) + + M = cv2.getAffineTransform(pts1, pts2) + + if obj == "mol": + skewed = cv2.warpAffine(img, M, (cols, rows), borderValue=[255, 255, 255]) + elif obj == "bkg": + skewed = cv2.warpAffine(img, M, (cols, rows), borderMode=cv2.BORDER_REFLECT) + + skewed = cv2.resize(skewed, (256, 256)) + return skewed + + def elastic_transform(self, image, alpha_sigma) -> np.array: + """ + Elastic deformation of images as described in [Simard2003]_. + .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for + Convolutional Neural Networks applied to Visual Document Analysis", in + Proc. of the International Conference on Document Analysis and + Recognition, 2003. + https://gist.github.com/erniejunior/601cdf56d2b424757de5 + This function distords an image randomly changing the alpha and gamma + values. + Args: + image: the image to modify in array format. + alpha_sigma: alpha and sigma values randomly selected as a list. + Returns: + distored_image: the image after the transformation with the same size + as it had originally. + """ + alpha = alpha_sigma[0] + sigma = alpha_sigma[1] + random_state = np.random.RandomState(self.random_choice(np.arange(1, 1000))) + + shape = image.shape + dx = ( + gaussian_filter( + (random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0 + ) + * alpha + ) + random_state = np.random.RandomState(self.random_choice(np.arange(1, 1000))) + dy = ( + gaussian_filter( + (random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0 + ) + * alpha + ) + + x, y, z = np.meshgrid( + np.arange(shape[0]), np.arange(shape[1]), np.arange(shape[2]) + ) + indices = ( + np.reshape(y + dy, (-1, 1)), + np.reshape(x + dx, (-1, 1)), + np.reshape(z, (-1, 1)), + ) + + distored_image = map_coordinates( + image, indices, order=self.random_choice(np.arange(1, 5)), mode="reflect" + ) + return distored_image.reshape(image.shape) + + def distort(self, img) -> np.array: + """ + This function randomly selects a list with the shape [a, g] where + a=alpha and g=gamma and passes them along with the input image + to the elastic_transform function that will do the image distorsion. + Args: + img: the image to modify in array format. + Returns: + the output from elastic_transform function which is the image + after the transformation with the same size as it had originally. + """ + sigma_alpha = [ + (self.random_choice(np.arange(9, 11)), self.random_choice(np.arange(2, 4))), + (self.random_choice(np.arange(80, 100)), 4), + (self.random_choice(np.arange(150, 300)), 5), + ( + self.random_choice(np.arange(800, 1200)), + self.random_choice(np.arange(8, 10)), + ), + ( + self.random_choice(np.arange(1500, 2000)), + self.random_choice(np.arange(10, 15)), + ), + ( + self.random_choice(np.arange(5000, 8000)), + self.random_choice(np.arange(15, 25)), + ), + ( + self.random_choice(np.arange(10000, 15000)), + self.random_choice(np.arange(20, 25)), + ), + ( + self.random_choice(np.arange(45000, 55000)), + self.random_choice(np.arange(30, 35)), + ), + ] + choice = self.random_choice(range(len(sigma_alpha))) + sigma_alpha_chosen = sigma_alpha[choice] + return self.elastic_transform(img, sigma_alpha_chosen) + + def degrade_img(self, img) -> np.array: + """ + This function randomly degrades the input image by applying different + degradation steps with different robabilities. + Args: + img: the image to modify in array format. + Returns: + img: the degraded image. + """ + # s+p + if self.random_choice(np.arange(0, 1, 0.01)) < 0.1: + img = self.s_and_p(img) + + # scale + if self.random_choice(np.arange(0, 1, 0.01)) < 0.5: + img = self.scale(img) + + # brightness + if self.random_choice(np.arange(0, 1, 0.01)) < 0.7: + img = self.brightness(img) + + # contrast + if self.random_choice(np.arange(0, 1, 0.01)) < 0.7: + img = self.contrast(img) + + # sharpness + if self.random_choice(np.arange(0, 1, 0.01)) < 0.5: + img = self.sharpness(img) + + return img + + def contrast(self, img) -> np.array: + """ + This function randomly changes the input image contrast. + Args: + img: the image to modify in array format. + Returns: + img: the image with the contrast changes. + """ + if self.random_choice(np.arange(0, 1, 0.01)) < 0.8: # increase contrast + f = self.random_choice(np.arange(1, 2, 0.01)) + else: # decrease contrast + f = self.random_choice(np.arange(0.5, 1, 0.01)) + im_pil = Image.fromarray(img) + enhancer = ImageEnhance.Contrast(im_pil) + im = enhancer.enhance(f) + img = np.asarray(im) + return np.asarray(im) + + def brightness(self, img) -> np.array: + """ + This function randomly changes the input image brightness. + Args: + img: the image to modify in array format. + Returns: + img: the image with the brightness changes. + """ + f = self.random_choice(np.arange(0.4, 1.1, 0.01)) + im_pil = Image.fromarray(img) + enhancer = ImageEnhance.Brightness(im_pil) + im = enhancer.enhance(f) + img = np.asarray(im) + return np.asarray(im) + + def sharpness(self, img) -> np.array: + """ + This function randomly changes the input image sharpness. + Args: + img: the image to modify in array format. + Returns: + img: the image with the sharpness changes. + """ + if self.random_choice(np.arange(0, 1, 0.01)) < 0.5: # increase sharpness + f = self.random_choice(np.arange(0.1, 1, 0.01)) + else: # decrease sharpness + f = self.random_choice(np.arange(1, 10)) + im_pil = Image.fromarray(img) + enhancer = ImageEnhance.Sharpness(im_pil) + im = enhancer.enhance(f) + img = np.asarray(im) + return np.asarray(im) + + def s_and_p(self, img) -> np.array: + """ + This function randomly adds salt and pepper to the input image. + Args: + img: the image to modify in array format. + Returns: + out: the image with the s&p changes. + """ + amount = self.random_choice(np.arange(0.001, 0.01)) + # add some s&p + s_vs_p = 0.5 + out = np.copy(img) + # Salt mode + num_salt = int(np.ceil(amount * img.size * s_vs_p)) + coords = [] + for i in img.shape: + coordinates = [] + for _ in range(num_salt): + coordinates.append(self.random_choice(np.arange(0, i - 1))) + coords.append(np.array(coordinates)) + out[tuple(coords)] = 1 + # pepper + num_pepper = int(np.ceil(amount * img.size * (1.0 - s_vs_p))) + coords = [] + for i in img.shape: + coordinates = [] + for _ in range(num_pepper): + coordinates.append(self.random_choice(np.arange(0, i - 1))) + coords.append(np.array(coordinates)) + out[tuple(coords)] = 0 + return out + + def scale(self, img) -> np.array: + """ + This function randomly scales the input image. + Args: + img: the image to modify in array format. + Returns: + res: the scaled image. + """ + f = self.random_choice(np.arange(0.5, 1.5, 0.01)) + res = cv2.resize(img, None, fx=f, fy=f, interpolation=cv2.INTER_CUBIC) + res = cv2.resize( + res, None, fx=1.0 / f, fy=1.0 / f, interpolation=cv2.INTER_CUBIC + ) + return res diff --git a/RanDepict/cdk_functionalities.py b/RanDepict/cdk_functionalities.py new file mode 100644 index 0000000..0bf17bc --- /dev/null +++ b/RanDepict/cdk_functionalities.py @@ -0,0 +1,441 @@ +from __future__ import annotations +import base64 +from jpype import JClass +import numpy as np +from skimage import io as sk_io +from skimage.util import img_as_ubyte +from typing import Tuple + + +class CDKFunctionalities: + """ + Child class of RandomDepictor that contains all CDK-related functions. + ___ + This class does not work on its own. It is meant to be used as a child class. + """ + def cdk_depict( + self, + smiles: str = None, + mol_block: str = None, + has_R_group: bool = False, + shape: Tuple[int, int] = (512, 512) + ) -> np.array: + """ + This function takes a mol block str and an image shape. + It renders the chemical structures using CDK with random + rendering/depiction settings and returns an RGB image (np.array) + with the given image shape. + The general workflow here is a JPype adaptation of code published + by Egon Willighagen in 'Groovy Cheminformatics with the Chemistry + Development Kit': + https://egonw.github.io/cdkbook/ctr.html#depict-a-compound-as-an-image + with additional adaptations to create all the different depiction + types from + https://github.com/cdk/cdk/wiki/Standard-Generator + + Args: + smiles (str, Optional): SMILES representation of molecule + mol block (str, Optional): mol block representation of molecule + has_R_group (bool): Whether the molecule has R groups (used to determine + whether or not to use atom numbering as it can be + confusing with R groups indices) for SMILES, this is + checked using a simple regex. This argument only has + an effect if the mol_block is provided. + # TODO: check this in mol_block + shape (Tuple[int, int], optional): im shape. Defaults to (512, 512) + + Returns: + np.array: Chemical structure depiction + """ + if not smiles and not mol_block: + raise ValueError("Either smiles or mol_block must be provided") + if smiles: + has_R_group = self.has_r_group(smiles) + mol_block = self._smiles_to_mol_block(smiles, + generate_2d=self.random_choice( + ["rdkit", "cdk", "indigo"] + )) + molecule = self._cdk_mol_block_to_iatomcontainer(mol_block) + depiction = self._cdk_render_molecule(molecule, has_R_group, shape) + return depiction + + def _cdk_mol_block_to_cxsmiles(self, mol_block: str) -> str: + """ + This function takes a mol block str and returns the corresponding CXSMILES + with coordinates using the CDK. + + Args: + mol_block (str): mol block str + + Returns: + str: CXSMILES + """ + atom_container = self._cdk_mol_block_to_iatomcontainer(mol_block) + smi_gen = JClass("org.openscience.cdk.smiles.SmilesGenerator") + flavor = JClass("org.openscience.cdk.smiles.SmiFlavor") + smi_gen = smi_gen(flavor.CxSmilesWithCoords) + cxsmiles = smi_gen.create(atom_container) + return cxsmiles + + def _cdk_smiles_to_IAtomContainer(self, smiles: str): + """ + This function takes a SMILES representation of a molecule and + returns the corresponding IAtomContainer object. + + Args: + smiles (str): SMILES representation of the molecule + + Returns: + IAtomContainer: CDK IAtomContainer object that represents the molecule + """ + cdk_base = "org.openscience.cdk" + SCOB = JClass(cdk_base + ".silent.SilentChemObjectBuilder") + SmilesParser = JClass(cdk_base + ".smiles.SmilesParser")(SCOB.getInstance()) + if self.random_choice([True, False, False], log_attribute="cdk_kekulized"): + SmilesParser.kekulise(False) + molecule = SmilesParser.parseSmiles(smiles) + return molecule + + def _cdk_mol_block_to_iatomcontainer(self, mol_block: str): + """ + Given a mol block, this function returns an IAtomContainer (JClass) object. + + Args: + mol_block (str): content of MDL MOL file + + Returns: + IAtomContainer: CDK IAtomContainer object that represents the molecule + """ + scob = JClass("org.openscience.cdk.silent.SilentChemObjectBuilder") + bldr = scob.getInstance() + iac_class = JClass("org.openscience.cdk.interfaces.IAtomContainer").class_ + string_reader = JClass("java.io.StringReader")(mol_block) + mdlr = JClass("org.openscience.cdk.io.MDLV2000Reader")(string_reader) + iatomcontainer = mdlr.read(bldr.newInstance(iac_class)) + mdlr.close() + return iatomcontainer + + def _cdk_iatomcontainer_to_mol_block(self, i_atom_container) -> str: + """ + This function takes an IAtomContainer object and returns the content + of the corresponding MDL MOL file as a string. + + Args: + i_atom_container (CDK IAtomContainer (JClass object)) + + Returns: + str: string content of MDL MOL file + """ + string_writer = JClass("java.io.StringWriter")() + mol_writer = JClass("org.openscience.cdk.io.MDLV2000Writer")(string_writer) + mol_writer.setWriteAromaticBondTypes(True) + mol_writer.write(i_atom_container) + mol_writer.close() + mol_str = string_writer.toString() + return str(mol_str) + + def _cdk_get_depiction_generator(self, molecule, has_R_group: bool = False): + """ + This function defines random rendering options for the structure + depictions created using CDK. + It takes an iAtomContainer and a SMILES string and returns the iAtomContainer + and the DepictionGenerator + with random rendering settings and the AtomContainer. + I followed https://github.com/cdk/cdk/wiki/Standard-Generator to adjust the + depiction parameters. + + Args: + molecule (cdk.AtomContainer): Atom container + smiles (str): smiles representation of molecule + has_R_group (bool): Whether the molecule has R groups (used to determine + whether or not to use atom numbering as it can be + confusing with R groups indices) + # TODO: check this in atomcontainer + + Returns: + DepictionGenerator, molecule: Objects that hold depiction parameters + """ + cdk_base = "org.openscience.cdk" + dep_gen = JClass("org.openscience.cdk.depict.DepictionGenerator")( + self._cdk_get_random_java_font() + ) + StandardGenerator = JClass( + cdk_base + ".renderer.generators.standard.StandardGenerator" + ) + + # Define visibility of atom/superatom labels + symbol_visibility = self.random_choice( + ["iupac_recommendation", "no_terminal_methyl", "show_all_atom_labels"], + log_attribute="cdk_symbol_visibility", + ) + SymbolVisibility = JClass("org.openscience.cdk.renderer.SymbolVisibility") + if symbol_visibility == "iupac_recommendation": + dep_gen = dep_gen.withParam( + StandardGenerator.Visibility.class_, + SymbolVisibility.iupacRecommendations(), + ) + elif symbol_visibility == "no_terminal_methyl": + # only hetero atoms, no terminal alkyl groups + dep_gen = dep_gen.withParam( + StandardGenerator.Visibility.class_, + SymbolVisibility.iupacRecommendationsWithoutTerminalCarbon(), + ) + elif symbol_visibility == "show_all_atom_labels": + dep_gen = dep_gen.withParam( + StandardGenerator.Visibility.class_, SymbolVisibility.all() + ) # show all atom labels + + # Define bond line stroke width + stroke_width = self.random_choice( + np.arange(0.8, 2.0, 0.1), log_attribute="cdk_stroke_width" + ) + dep_gen = dep_gen.withParam(StandardGenerator.StrokeRatio.class_, + stroke_width) + # Define symbol margin ratio + margin_ratio = self.random_choice( + [0, 1, 2, 2, 2, 3, 4], log_attribute="cdk_margin_ratio" + ) + dep_gen = dep_gen.withParam( + StandardGenerator.SymbolMarginRatio.class_, + JClass("java.lang.Double")(margin_ratio), + ) + # Define bond properties + double_bond_dist = self.random_choice( + np.arange(0.11, 0.25, 0.01), log_attribute="cdk_double_bond_dist" + ) + dep_gen = dep_gen.withParam(StandardGenerator.BondSeparation.class_, + double_bond_dist) + wedge_ratio = self.random_choice( + np.arange(4.5, 7.5, 0.1), log_attribute="cdk_wedge_ratio" + ) + dep_gen = dep_gen.withParam( + StandardGenerator.WedgeRatio.class_, JClass("java.lang.Double")(wedge_ratio) + ) + if self.random_choice([True, False], log_attribute="cdk_fancy_bold_wedges"): + dep_gen = dep_gen.withParam(StandardGenerator.FancyBoldWedges.class_, True) + if self.random_choice([True, False], log_attribute="cdk_fancy_hashed_wedges"): + dep_gen = dep_gen.withParam(StandardGenerator.FancyHashedWedges.class_, + True) + hash_spacing = self.random_choice( + np.arange(4.0, 6.0, 0.2), log_attribute="cdk_hash_spacing" + ) + dep_gen = dep_gen.withParam(StandardGenerator.HashSpacing.class_, hash_spacing) + # Add CIP labels + labels = False + if self.random_choice([True, False], log_attribute="cdk_add_CIP_labels"): + labels = True + JClass("org.openscience.cdk.geometry.cip.CIPTool").label(molecule) + for atom in molecule.atoms(): + label = atom.getProperty( + JClass("org.openscience.cdk.CDKConstants").CIP_DESCRIPTOR + ) + atom.setProperty(StandardGenerator.ANNOTATION_LABEL, label) + for bond in molecule.bonds(): + label = bond.getProperty( + JClass("org.openscience.cdk.CDKConstants").CIP_DESCRIPTOR + ) + bond.setProperty(StandardGenerator.ANNOTATION_LABEL, label) + # Add atom indices to the depictions + if self.random_choice( + [True, False, False, False], log_attribute="cdk_add_atom_indices" + ): + if not has_R_group: + # Avoid confusion with R group indices and atom numbering + labels = True + for atom in molecule.atoms(): + label = JClass("java.lang.Integer")( + 1 + molecule.getAtomNumber(atom) + ) + atom.setProperty(StandardGenerator.ANNOTATION_LABEL, label) + if labels: + # We only need black + dep_gen = dep_gen.withParam( + StandardGenerator.AnnotationColor.class_, + JClass("java.awt.Color")(0x000000), + ) + # Font size of labels + font_scale = self.random_choice( + np.arange(0.5, 0.8, 0.1), log_attribute="cdk_label_font_scale" + ) + dep_gen = dep_gen.withParam( + StandardGenerator.AnnotationFontScale.class_, + font_scale) + # Distance between atom numbering and depiction + annotation_distance = self.random_choice( + np.arange(0.15, 0.30, 0.05), log_attribute="cdk_annotation_distance" + ) + dep_gen = dep_gen.withParam( + StandardGenerator.AnnotationDistance.class_, annotation_distance + ) + # Abbreviate superatom labels in half of the cases + # TODO: Find a way to define Abbreviations object as a class attribute. + # Problem: can't be pickled. + # Right now, this is loaded every time when a structure is depicted. + # That seems inefficient. + if self.random_choice([True, False], log_attribute="cdk_collapse_superatoms"): + cdk_superatom_abrv = JClass("org.openscience.cdk.depict.Abbreviations")() + abbr_filename = self.random_choice([ + "cdk_superatom_abbreviations.smi", + "cdk_alt_superatom_abbreviations.smi"]) + abbreviation_path = str(self.HERE.joinpath(abbr_filename)) + abbreviation_path = abbreviation_path.replace("\\", "/") + abbreviation_path = JClass("java.lang.String")(abbreviation_path) + cdk_superatom_abrv.loadFromFile(abbreviation_path) + cdk_superatom_abrv.apply(molecule) + return dep_gen, molecule + + def _cdk_get_random_java_font(self): + """ + This function returns a random java.awt.Font (JClass) object + + Returns: + font: java.awt.Font (JClass object) + """ + font_size = self.random_choice( + range(10, 20), log_attribute="cdk_atom_label_font_size" + ) + Font = JClass("java.awt.Font") + font_name = self.random_choice( + ["Verdana", + "Times New Roman", + "Arial", + "Gulliver Regular", + "Helvetica", + "Courier", + "architectural", + "Geneva", + "Lucida Sans", + "Teletype"], + # log_attribute='cdk_atom_label_font' + ) + font_style = self.random_choice( + [Font.PLAIN, Font.BOLD], + # log_attribute='cdk_atom_label_font_style' + ) + font = Font(font_name, font_style, font_size) + return font + + def _cdk_rotate_coordinates(self, molecule): + """ + Given an IAtomContainer (JClass object), this function rotates the molecule + and adapts the coordinates of accordingly. The IAtomContainer is then returned.# + + Args: + molecule: IAtomContainer (JClass object) + + Returns: + molecule: IAtomContainer (JClass object) + """ + cdk_base = "org.openscience.cdk" + point = JClass(cdk_base + ".geometry.GeometryTools").get2DCenter(molecule) + rot_degrees = self.random_choice(range(360)) + JClass(cdk_base + ".geometry.GeometryTools").rotate( + molecule, point, rot_degrees + ) + return molecule + + def _cdk_generate_2d_coordinates(self, molecule): + """ + Given an IAtomContainer (JClass object), this function adds 2D coordinate to + the molecule. The modified IAtomContainer is then returned. + + Args: + molecule: IAtomContainer (JClass object) + + Returns: + molecule: IAtomContainer (JClass object) + """ + cdk_base = "org.openscience.cdk" + sdg = JClass(cdk_base + ".layout.StructureDiagramGenerator")() + sdg.setMolecule(molecule) + sdg.generateCoordinates(molecule) + molecule = sdg.getMolecule() + return molecule + + def _convert_rgba2rgb(self, rgba: np.array, background=(255, 255, 255)): + """ + Convert an RGBA image (np.array) to an RGB image (np.array). + https://stackoverflow.com/questions/50331463/convert-rgba-to-rgb-in-python + + Args: + rgba (np.array): RGBA image + background (tuple, optional): . Defaults to (255, 255, 255). + + Returns: + np.array: RGB image + """ + row, col, ch = rgba.shape + + if ch == 3: + return rgba + + assert ch == 4, 'RGBA image has 4 channels.' + + rgb = np.zeros((row, col, 3), dtype='float32') + r, g, b, a = rgba[:, :, 0], rgba[:, :, 1], rgba[:, :, 2], rgba[:, :, 3] + + a = np.asarray(a, dtype='float32') / 255.0 + + R, G, B = background + + rgb[:, :, 0] = r * a + (1.0 - a) * R + rgb[:, :, 1] = g * a + (1.0 - a) * G + rgb[:, :, 2] = b * a + (1.0 - a) * B + + return np.asarray(rgb, dtype='uint8') + + def _cdk_bufferedimage_to_numpyarray( + self, + image + ) -> np.ndarray: + """ + This function converts a BufferedImage (JClass object) into a numpy array. + + Args: + image (BufferedImage (JClass object)) + + Returns: + image (np.ndarray) + """ + # Write the image into a format that can be read by skimage + ImageIO = JClass("javax.imageio.ImageIO") + os = JClass("java.io.ByteArrayOutputStream")() + Base64 = JClass("java.util.Base64") + ImageIO.write( + image, JClass("java.lang.String")("PNG"), Base64.getEncoder().wrap(os) + ) + image = bytes(os.toString("UTF-8")) + image = base64.b64decode(image) + image = sk_io.imread(image, plugin="imageio") + image = img_as_ubyte(image) + image = self._convert_rgba2rgb(image) + return image + + def _cdk_render_molecule( + self, + molecule, + has_R_group: bool = False, + shape: Tuple[int, int] = (512, 512) + ): + """ + This function takes an IAtomContainer (JClass object), the corresponding SMILES + string and an image shape and returns a BufferedImage (JClass object) with the + rendered molecule. + + Args: + molecule (IAtomContainer (JClass object)): molecule + has_R_group (bool): Whether the molecule has R groups (used to determine + whether or not to use atom numbering as it can be + confusing with R groups indices) + # TODO: check this in atomcontainer + smiles (str): SMILES string + shape (Tuple[int, int]): y, x + Returns: + depiction (np.ndarray): chemical structure depiction + """ + dep_gen, molecule = self._cdk_get_depiction_generator(molecule, has_R_group) + dep_gen = dep_gen.withSize(shape[1], shape[0]) + dep_gen = dep_gen.withFillToFit() + depiction = dep_gen.depict(molecule).toImg() + depiction = self._cdk_bufferedimage_to_numpyarray(depiction) + return depiction diff --git a/RanDepict/config.py b/RanDepict/config.py new file mode 100644 index 0000000..cb257d6 --- /dev/null +++ b/RanDepict/config.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass, field +from omegaconf import OmegaConf, DictConfig +from typing import List, Optional + +@dataclass +class RandomDepictorConfig: + """ + Examples + -------- + >>> c1 = RandomDepictorConfig(seed=24, styles=["cdk", "indigo"]) + >>> c1 + RandomDepictorConfig(seed=24, hand_drawn=False, augment=True, styles=['cdk', 'indigo']) + >>> c2 = RandomDepictorConfig(styles=["cdk", "indigo", "pikachu", "rdkit"]) + >>> c2 + RandomDepictorConfig(seed=42, hand_drawn=False, augment=True, styles=['cdk', 'indigo', 'pikachu', 'rdkit']) + """ + seed: int = 42 + hand_drawn: bool = False + augment: bool = True + # unions of containers are not supported yet + # https://github.com/omry/omegaconf/issues/144 + # styles: Union[str, List[str]] = field(default_factory=lambda: ["cdk", "indigo", "pikachu", "rdkit"]) + styles: List[str] = field(default_factory=lambda: ["cdk", "indigo", "pikachu", "rdkit"]) + + @classmethod + def from_config(cls, dict_config: Optional[DictConfig] = None) -> 'RandomDepictorConfig': + return OmegaConf.structured(cls(**dict_config)) + + def __post_init__(self): + # Ensure styles are always List[str] when "cdk, indigo" is passed + if isinstance(self.styles, str): + self.styles = [v.strip() for v in self.styles.split(",")] + if len(self.styles) == 0: + raise ValueError("Empty list of styles was supplied.") + # Not sure if this is the best way in order to not repeat the list of styles + ss = set(self.__dataclass_fields__['styles'].default_factory()) + if any([s not in ss for s in self.styles]): + raise ValueError(f"Use only {', '.join(ss)}") diff --git a/RanDepict/depiction_feature_ranges.py b/RanDepict/depiction_feature_ranges.py new file mode 100644 index 0000000..9c3b314 --- /dev/null +++ b/RanDepict/depiction_feature_ranges.py @@ -0,0 +1,565 @@ +from itertools import product +import numpy as np +import os +import random +from rdkit import DataStructs +from rdkit.SimDivFilters.rdSimDivPickers import MaxMinPicker +from typing import Any, Dict, List, Tuple + +from .randepict import RandomDepictor + +class DepictionFeatureRanges(RandomDepictor): + """Class for depiction feature fingerprint generation""" + + def __init__(self): + super().__init__() + # Fill ranges. By simply using all the depiction and augmentation + # functions, the available features are saved by the overwritten + # random_choice function. We just have to make sure to run through + # every available decision once to get all the information about the + # feature space that we need. + smiles = "CN1C=NC2=C1C(=O)N(C(=O)N2C)C" + + # Call every depiction function + depiction = self(smiles) + depiction = self.cdk_depict(smiles) + depiction = self.rdkit_depict(smiles) + depiction = self.indigo_depict(smiles) + depiction = self.pikachu_depict(smiles) + # Call augmentation function + depiction = self.add_augmentations(depiction) + # Generate schemes for Fingerprint creation + self.schemes = self.generate_fingerprint_schemes() + ( + self.CDK_scheme, + self.RDKit_scheme, + self.Indigo_scheme, + self.PIKAChU_scheme, + self.augmentation_scheme, + ) = self.schemes + # Generate the pool of all valid fingerprint combinations + + self.generate_all_possible_fingerprints() + self.FP_length_scheme_dict = { + len(self.CDK_fingerprints[0]): self.CDK_scheme, + len(self.RDKit_fingerprints[0]): self.RDKit_scheme, + len(self.Indigo_fingerprints[0]): self.Indigo_scheme, + len(self.PIKAChU_fingerprints[0]): self.PIKAChU_scheme, + len(self.augmentation_fingerprints[0]): self.augmentation_scheme, + } + + def random_choice(self, iterable: List, log_attribute: str = False) -> Any: + """ + In RandomDepictor, this function would take an iterable, call + random_choice() on it, increase the seed attribute by 1 and return + the result. + ___ + Here, this function is overwritten, so that it also sets the class + attribute $log_attribute_range to contain the iterable. + This way, a DepictionFeatureRanges object can easily be filled with + all the iterables that define the complete depiction feature space + (for fingerprint generation). + ___ + Args: + iterable (List): iterable to pick from + log_attribute (str, optional): ID for fingerprint. + Defaults to False. + + Returns: + Any: "Randomly" picked element + """ + # Save iterables as class attributes (for fingerprint generation) + if log_attribute: + setattr(self, f"{log_attribute}_range", iterable) + # Pseudo-randomly pick element from iterable + self.seed += 1 + random.seed(self.seed) + result = random.choice(iterable) + return result + + def generate_fingerprint_schemes(self) -> List[Dict]: + """ + Generates fingerprint schemes (see generate_fingerprint_scheme()) + for the depictions with CDK, RDKit and Indigo as well as the + augmentations. + ___ + Returns: + List[Dict]: [cdk_scheme: Dict, rdkit_scheme: Dict, + indigo_scheme: Dict, augmentation_scheme: Dict] + """ + fingerprint_schemes = [] + range_IDs = [att for att in dir(self) if "range" in att] + # Generate fingerprint scheme for our cdk, indigo and rdkit depictions + depiction_toolkits = ["cdk", "rdkit", "indigo", "pikachu", ""] + for toolkit in depiction_toolkits: + toolkit_range_IDs = [att for att in range_IDs if toolkit in att] + # Delete toolkit-specific ranges + # (The last time this loop runs, only augmentation-related ranges + # are left) + for ID in toolkit_range_IDs: + range_IDs.remove(ID) + # [:-6] --> remove "_range" at the end + toolkit_range_dict = { + attr[:-6]: list(set(getattr(self, attr))) for attr in toolkit_range_IDs + } + fingerprint_scheme = self.generate_fingerprint_scheme(toolkit_range_dict) + fingerprint_schemes.append(fingerprint_scheme) + return fingerprint_schemes + + def generate_fingerprint_scheme(self, ID_range_map: Dict) -> Dict: + """ + This function takes the ID_range_map and returns a dictionary that + defines where each feature is represented in the depiction feature + fingerprint. + ___ + Example: + >> example_ID_range_map = {'thickness': [0, 1, 2, 3], + 'kekulized': [True, False]} + >> generate_fingerprint_scheme(example_ID_range_map) + >>>> {'thickness': [{'position': 0, 'one_if': 0}, + {'position': 1, 'one_if': 1}, + {'position': 2, 'one_if': 2}, + {'position': 3, 'one_if': 3}], + 'kekulized': [{'position': 4, 'one_if': True}]} + Args: + ID_range_map (Dict): dict that maps an ID (str) of a feature range + to the feature range itself (iterable) + + Returns: + Dict: Map of feature ID (str) and dictionaries that define the + fingerprint position and a condition + """ + fingerprint_scheme = {} + position = 0 + for feature_ID in ID_range_map.keys(): + feature_range = ID_range_map[feature_ID] + # Make sure numeric ranges don't take up more than 5 positions + # in the fingerprint + if ( + type(feature_range[0]) in [int, float, np.float64, np.float32] + and len(feature_range) > 5 + ): + subranges = self.split_into_n_sublists(feature_range, n=3) + position_dicts = [] + for subrange in subranges: + subrange_minmax = (min(subrange), max(subrange)) + position_dict = {"position": position, "one_if": subrange_minmax} + position_dicts.append(position_dict) + position += 1 + fingerprint_scheme[feature_ID] = position_dicts + # Bools take up only one position in the fingerprint + elif isinstance(feature_range[0], bool): + assert len(feature_range) == 2 + position_dicts = [{"position": position, "one_if": True}] + position += 1 + fingerprint_scheme[feature_ID] = position_dicts + else: + # For other types of categorical data: Each category gets one + # position in the FP + position_dicts = [] + for feature in feature_range: + position_dict = {"position": position, "one_if": feature} + position_dicts.append(position_dict) + position += 1 + fingerprint_scheme[feature_ID] = position_dicts + return fingerprint_scheme + + def split_into_n_sublists(self, iterable, n: int) -> List[List]: + """ + Takes an iterable, sorts it, splits it evenly into n lists + and returns the split lists. + + Args: + iterable ([type]): Iterable that is supposed to be split + n (int): Amount of sublists to return + Returns: + List[List]: Split list + """ + iterable = sorted(iterable) + iter_len = len(iterable) + sublists = [] + for i in range(0, iter_len, int(np.ceil(iter_len / n))): + sublists.append(iterable[i: i + int(np.ceil(iter_len / n))]) + return sublists + + def get_number_of_possible_fingerprints(self, scheme: Dict) -> int: + """ + This function takes a fingerprint scheme (Dict) as returned by + generate_fingerprint_scheme() + and returns the number of possible fingerprints for that scheme. + + Args: + scheme (Dict): Output of generate_fingerprint_scheme() + + Returns: + int: Number of possible fingerprints + """ + comb_count = 1 + for feature_key in scheme.keys(): + if len(scheme[feature_key]) != 1: + # n fingerprint positions -> n options + # (because only one position can be [1]) + # n = 3 --> [1][0][0] or [0][1][0] or [0][0][1] + comb_count *= len(scheme[feature_key]) + else: + # One fingerprint position -> two options: [0] or [1] + comb_count *= 2 + return comb_count + + def get_FP_building_blocks(self, scheme: Dict) -> List[List[List]]: + """ + This function takes a fingerprint scheme (Dict) as returned by + generate_fingerprint_scheme() + and returns a list of possible building blocks. + Example: + scheme = {'thickness': [{'position': 0, 'one_if': 0}, + {'position': 1, 'one_if': 1}, + {'position': 2, 'one_if': 2}, + {'position': 3, 'one_if': 3}], + 'kekulized': [{'position': 4, 'one_if': True}]} + + --> Output: [[[1,0,0,0], [0,1,0,0], [0,0,1,0], [0,0,0,1]], + [[1], [0]]] + + Args: + scheme (Dict): Output of generate_fingerprint_scheme() + + Returns: + List that contains the valid fingerprint parts that represent the + different features + + """ + FP_building_blocks = [] + for feature_key in scheme.keys(): + position_condition_dicts = scheme[feature_key] + FP_building_blocks.append([]) + # Add every single valid option to the building block + for position_index in range(len(position_condition_dicts)): + # Add list of zeros + FP_building_blocks[-1].append([0] * len(position_condition_dicts)) + # Replace one zero with a one + FP_building_blocks[-1][-1][position_index] = 1 + # If a feature is described by only one position in the FP, + # make sure that 0 and 1 are listed options + if FP_building_blocks[-1] == [[1]]: + FP_building_blocks[-1].append([0]) + return FP_building_blocks + + def flatten_fingerprint( + self, + unflattened_list: List[List], + ) -> List: + """ + This function takes a list of lists and returns a list. + ___ + Looks like this could be one line elsewhere but this function used for + parallelisation of FP generation and consequently needs to be wrapped + up in a separate function. + + Args: + unflattened_list (List[List[X,Y,Z]]) + + Returns: + flattened_list (List[X,Y,Z]): + """ + flattened_list = [ + element for sublist in unflattened_list for element in sublist + ] + return flattened_list + + def generate_all_possible_fingerprints_per_scheme( + self, + scheme: Dict, + ) -> List[List[int]]: + """ + This function takes a fingerprint scheme (Dict) as returned by + generate_fingerprint_scheme() + and returns a List of all possible fingerprints for that scheme. + + Args: + scheme (Dict): Output of generate_fingerprint_scheme() + name (str): name that is used for filename of saved FPs + + Returns: + List[List[int]]: List of fingerprints + """ + # Determine valid building blocks for fingerprints + FP_building_blocks = self.get_FP_building_blocks(scheme) + # Determine cartesian product of valid building blocks to get all + # valid fingerprints + FP_generator = product(*FP_building_blocks) + flattened_fingerprints = list(map(self.flatten_fingerprint, FP_generator)) + return flattened_fingerprints + + def generate_all_possible_fingerprints(self) -> None: + """ + This function generates all possible valid fingerprint combinations + for the four available fingerprint schemes if they have not been + created already. Otherwise, they are loaded from files. + This function returns None but saves the fingerprint pools as a + class attribute $ID_fingerprints + """ + FP_names = ["CDK", "RDKit", "Indigo", "PIKAChU", "augmentation"] + for scheme_index in range(len(self.schemes)): + exists_already = False + n_FP = self.get_number_of_possible_fingerprints(self.schemes[scheme_index]) + # Load fingerprint pool from file (if it exists) + FP_filename = "{}_fingerprints.npz".format(FP_names[scheme_index]) + FP_file_path = self.HERE.joinpath(FP_filename) + if os.path.exists(FP_file_path): + fps = np.load(FP_file_path)["arr_0"] + if len(fps) == n_FP: + exists_already = True + # Otherwise, generate the fingerprint pool + if not exists_already: + print("No saved fingerprints found. This may take a minute.") + fps = self.generate_all_possible_fingerprints_per_scheme( + self.schemes[scheme_index] + ) + np.savez_compressed(FP_file_path, fps) + print( + "{} fingerprints were saved in {}.".format( + FP_names[scheme_index], FP_file_path + ) + ) + setattr(self, "{}_fingerprints".format(FP_names[scheme_index]), fps) + return + + def convert_to_int_arr( + self, fingerprints: List[List[int]] + ) -> List[DataStructs.cDataStructs.ExplicitBitVect]: + """ + Takes a list of fingerprints (List[int]) and returns them as a list of + rdkit.DataStructs.cDataStructs.ExplicitBitVect so that they can be + processed by RDKit's MaxMinPicker. + + Args: + fingerprints (List[List[int]]): List of fingerprints + + Returns: + List[DataStructs.cDataStructs.ExplicitBitVect]: Converted arrays + """ + converted_fingerprints = [] + for fp in fingerprints: + bitstring = "".join(np.array(fp).astype(str)) + fp_converted = DataStructs.cDataStructs.CreateFromBitString(bitstring) + converted_fingerprints.append(fp_converted) + return converted_fingerprints + + def pick_fingerprints( + self, + fingerprints: List[List[int]], + n: int, + ) -> np.array: + """ + Given a list of fingerprints and a number n of fingerprints to pick, + this function uses RDKit's MaxMin Picker to pick n fingerprints and + returns them. + + Args: + fingerprints (List[List[int]]): List of fingerprints + n (int): Number of fingerprints to pick + + Returns: + np.array: Picked fingerprints + """ + + converted_fingerprints = self.convert_to_int_arr(fingerprints) + + """TODO: I don't like this function definition in the function but + according to the RDKit Documentation, the fingerprints need to be + given in the distance function as the default value.""" + + def dice_dist( + fp_index_1: int, + fp_index_2: int, + fingerprints: List[ + DataStructs.cDataStructs.ExplicitBitVect + ] = converted_fingerprints, + ) -> float: + """ + Returns the dice similarity between two fingerprints. + Args: + fp_index_1 (int): index of first fingerprint in fingerprints + fp_index_2 (int): index of second fingerprint in fingerprints + fingerprints (List[cDataStructs.ExplicitBitVect]): fingerprints + + Returns: + float: Dice similarity between the two fingerprints + """ + return 1 - DataStructs.DiceSimilarity( + fingerprints[fp_index_1], fingerprints[fp_index_2] + ) + + # If we want to pick more fingerprints than there are in the pool, + # simply distribute the complete pool as often as possible and pick + # the amount that is not dividable by the size of the pool + picked_fingerprints, n = self.correct_amount_of_FP_to_pick(fingerprints, n) + + picker = MaxMinPicker() + pick_indices = picker.LazyPick(dice_dist, len(fingerprints), n, seed=42) + if isinstance(picked_fingerprints, bool): + picked_fingerprints = np.array([fingerprints[i] for i in pick_indices]) + else: + picked_fingerprints = np.concatenate( + (np.array(picked_fingerprints), np.array(([fingerprints[i] for i in pick_indices]))) + ) + return picked_fingerprints + + def correct_amount_of_FP_to_pick(self, fingerprints: List, n: int) -> Tuple[List, int]: + """ + When picking n elements from a list of fingerprints, if the amount of fingerprints is + bigger than n, there is no need to pick n fingerprints. Instead, the complete fingerprint + list is added to the picked fingerprints as often as possible while only the amount + that is not dividable by the fingerprint pool size is picked. + ___ + Given a list of fingerprints and the amount of fingerprints to pick n, this function + returns a list of "picked" fingerprints and (in the ideal case) a corrected lower number + of fingerprints to be picked + + Args: + fingerprints (List): _description_ + n (int): _description_ + + Returns: + Tuple[List, int]: _description_ + """ + if n > len(fingerprints): + oversize_factor = int(n / len(fingerprints)) + picked_fingerprints = np.concatenate([fingerprints for _ + in range(oversize_factor)]) + n = n - len(fingerprints) * oversize_factor + else: + picked_fingerprints = False + return picked_fingerprints, n + + def generate_fingerprints_for_dataset( + self, + size: int, + indigo_proportion: float = 0.15, + rdkit_proportion: float = 0.25, + pikachu_proportion: float = 0.25, + cdk_proportion: float = 0.35, + aug_proportion: float = 0.5, + ) -> List[List[int]]: + """ + Given a dataset size (int) and (optional) proportions for the + different types of fingerprints, this function returns + + Args: + size (int): Desired dataset size, number of returned fingerprints + indigo_proportion (float): Indigo proportion. Defaults to 0.15. + rdkit_proportion (float): RDKit proportion. Defaults to 0.25. + pikachu_proportion (float): PIKAChU proportion. Defaults to 0.25. + cdk_proportion (float): CDK proportion. Defaults to 0.35. + aug_proportion (float): Augmentation proportion. Defaults to 0.5. + + Raises: + ValueError: + - If the sum of Indigo, RDKit, PIKAChU and CDK proportions is not 1 + - If the augmentation proportion is > 1 + + Returns: + List[List[int]]: List of lists containing the fingerprints. + ___ + Depending on augmentation_proportion, the depiction fingerprints + are paired with augmentation fingerprints or not. + + Example output: + [[$some_depiction_fingerprint, $some augmentation_fingerprint], + [$another_depiction_fingerprint] + [$yet_another_depiction_fingerprint]] + + """ + # Make sure that the given proportion arguments make sense + if sum((indigo_proportion, rdkit_proportion, pikachu_proportion, cdk_proportion)) != 1: + raise ValueError( + "Sum of Indigo, CDK, PIKAChU and RDKit proportions arguments has to be 1" + ) + if aug_proportion > 1: + raise ValueError( + "The proportion of augmentation fingerprints can't be > 1." + ) + # Pick and return diverse fingerprints + picked_Indigo_fingerprints = self.pick_fingerprints( + self.Indigo_fingerprints, int(size * indigo_proportion) + ) + picked_RDKit_fingerprints = self.pick_fingerprints( + self.RDKit_fingerprints, int(size * rdkit_proportion) + ) + picked_PIKAChU_fingerprints = self.pick_fingerprints( + self.PIKAChU_fingerprints, int(size * pikachu_proportion) + ) + picked_CDK_fingerprints = self.pick_fingerprints( + self.CDK_fingerprints, int(size * cdk_proportion) + ) + picked_augmentation_fingerprints = self.pick_fingerprints( + self.augmentation_fingerprints, int(size * aug_proportion) + ) + # Distribute augmentation_fingerprints over depiction fingerprints + fingerprint_tuples = self.distribute_elements_evenly( + picked_augmentation_fingerprints, + picked_Indigo_fingerprints, + picked_RDKit_fingerprints, + picked_PIKAChU_fingerprints, + picked_CDK_fingerprints, + ) + # Shuffle fingerprint tuples randomly to avoid the same smiles + # always being depicted with the same cheminformatics toolkit + random.seed(self.seed) + random.shuffle(fingerprint_tuples) + return fingerprint_tuples + + def distribute_elements_evenly( + self, elements_to_be_distributed: List[Any], *args: List[Any] + ) -> List[List[Any]]: + """ + This function distributes the elements from elements_to_be_distributed + evenly over the lists of elements given in args. It can be used to link + augmentation fingerprints to given lists of depiction fingerprints. + + Example: + distribute_elements_evenly(["A", "B", "C", "D"], [1, 2, 3], [4, 5, 6]) + Output: [[1, "A"], [2, "B"], [3], [4, "C"], [5, "D"], [6]] + --> see test_distribute_elements_evenly() in ../Tests/test_functions.py + + Args: + elements_to_be_distributed (List[Any]): elements to be distributed + args: Every arg is a list of elements (B) + + Returns: + List[List[Any]]: List of Lists (B, A) where the elements A are + distributed evenly over the elements B according + to the length of the list of elements B + """ + # Make sure that the input is valid + args_total_len = len([element for sublist in args for element in sublist]) + if len(elements_to_be_distributed) > args_total_len: + raise ValueError("Can't take more elements to be distributed than in args.") + + output = [] + start_index = 0 + for element_list in args: + # Define part of elements_to_be_distributed that belongs to this + # element_sublist + sublist_len = len(element_list) + end_index = start_index + int( + sublist_len / args_total_len * len(elements_to_be_distributed) + ) + select_elements_to_be_distributed = elements_to_be_distributed[ + start_index:end_index + ] + for element_index in range(len(element_list)): + if element_index < len(select_elements_to_be_distributed): + output.append( + [ + element_list[element_index], + select_elements_to_be_distributed[element_index], + ] + ) + else: + output.append([element_list[element_index]]) + start_index = start_index + int( + sublist_len / args_total_len * len(elements_to_be_distributed) + ) + return output diff --git a/RanDepict/import augmentations.py b/RanDepict/import augmentations.py new file mode 100644 index 0000000..2dd9974 --- /dev/null +++ b/RanDepict/import augmentations.py @@ -0,0 +1,1152 @@ +from copy import deepcopy +import cv2 +import imgaug.augmenters as iaa +import numpy as np +import os +from PIL import Image, ImageEnhance, ImageFont, ImageDraw, ImageStat +from scipy.ndimage import gaussian_filter +from scipy.ndimage import map_coordinates +from skimage.color import rgb2gray +from skimage.util import img_as_float +from typing import Tuple + + +class Augmentations: + def resize(self, image: np.array, shape: Tuple[int], HQ: bool = False) -> np.array: + """ + This function takes an image (np.array) and a shape and returns + the resized image (np.array). It uses Pillow to do this, as it + seems to have a bigger variety of scaling methods than skimage. + The up/downscaling method is chosen randomly. + + Args: + image (np.array): the input image + shape (Tuple[int, int], optional): im shape. Defaults to (299, 299) + HQ (bool): if true, only choose from Image.BICUBIC, Image.LANCZOS + ___ + Returns: + np.array: the resized image + + """ + image = Image.fromarray(image) + shape = (shape[0], shape[1]) + if not HQ: + image = image.resize( + shape, resample=self.random_choice(self.PIL_resize_methods) + ) + else: + image = image = image.resize( + shape, resample=self.random_choice(self.PIL_HQ_resize_methods) + ) + + return np.asarray(image) + + def imgaug_augment( + self, + image: np.array, + ) -> np.array: + """ + This function applies a random amount of augmentations to + a given image (np.array) using and returns the augmented image + (np.array). + + Args: + image (np.array): input image + + Returns: + np.array: output image (augmented) + """ + original_shape = image.shape + + # Choose number of augmentations to apply (0-2); + # return image if nothing needs to be done. + aug_number = self.random_choice(range(0, 3)) + if not aug_number: + return image + + # Add some padding to avoid weird artifacts after rotation + image = np.pad( + image, ((1, 1), (1, 1), (0, 0)), mode="constant", constant_values=255 + ) + + def imgaug_rotation(): + # Rotation between -10 and 10 degrees + if not self.random_choice( + [True, True, False], log_attribute="has_imgaug_rotation" + ): + return False + rot_angle = self.random_choice(np.arange(-10, 10, 1)) + aug = iaa.Affine(rotate=rot_angle, mode="edge", fit_output=True) + return aug + + def imgaug_black_and_white_noise(): + # Black and white noise + if not self.random_choice( + [True, True, False], log_attribute="has_imgaug_salt_pepper" + ): + return False + coarse_dropout_p = self.random_choice(np.arange(0.0002, 0.0015, 0.0001)) + coarse_dropout_size_percent = self.random_choice(np.arange(1.0, 1.1, 0.01)) + replace_elementwise_p = self.random_choice(np.arange(0.01, 0.3, 0.01)) + aug = iaa.Sequential( + [ + iaa.CoarseDropout( + coarse_dropout_p, size_percent=coarse_dropout_size_percent + ), + iaa.ReplaceElementwise(replace_elementwise_p, 255), + ] + ) + return aug + + def imgaug_shearing(): + # Shearing + if not self.random_choice( + [True, True, False], log_attribute="has_imgaug_shearing" + ): + return False + shear_param = self.random_choice(np.arange(-5, 5, 1)) + aug = self.random_choice( + [ + iaa.geometric.ShearX(shear_param, mode="edge", fit_output=True), + iaa.geometric.ShearY(shear_param, mode="edge", fit_output=True), + ] + ) + return aug + + def imgaug_imgcorruption(): + # Jpeg compression or pixelation + if not self.random_choice( + [True, True, False], log_attribute="has_imgaug_corruption" + ): + return False + imgcorrupt_severity = self.random_choice(np.arange(1, 2, 1)) + aug = self.random_choice( + [ + iaa.imgcorruptlike.JpegCompression(severity=imgcorrupt_severity), + iaa.imgcorruptlike.Pixelate(severity=imgcorrupt_severity), + ] + ) + return aug + + def imgaug_brightness_adjustment(): + # Brightness adjustment + if not self.random_choice( + [True, True, False], log_attribute="has_imgaug_brightness_adj" + ): + return False + brightness_adj_param = self.random_choice(np.arange(-50, 50, 1)) + aug = iaa.WithBrightnessChannels(iaa.Add(brightness_adj_param)) + return aug + + def imgaug_colour_temp_adjustment(): + # Colour temperature adjustment + if not self.random_choice( + [True, True, False], log_attribute="has_imgaug_col_adj" + ): + return False + colour_temp = self.random_choice(np.arange(1100, 10000, 1)) + aug = iaa.ChangeColorTemperature(colour_temp) + return aug + + # Define list of available augmentations + aug_list = [ + imgaug_rotation, + imgaug_black_and_white_noise, + imgaug_shearing, + imgaug_imgcorruption, + imgaug_brightness_adjustment, + imgaug_colour_temp_adjustment, + ] + + # Every one of them has a 1/3 chance of returning False + aug_list = [fun() for fun in aug_list] + aug_list = [fun for fun in aug_list if fun] + aug = iaa.Sequential(aug_list) + augmented_image = aug.augment_images([image])[0] + augmented_image = self.resize(augmented_image, original_shape) + augmented_image = augmented_image.astype(np.uint8) + return augmented_image + + def add_augmentations(self, depiction: np.array) -> np.array: + """ + This function takes a chemical structure depiction (np.array) + and returns the same image with added augmentation elements + + Args: + depiction (np.array): chemical structure depiction + + Returns: + np.array: chemical structure depiction with added augmentations + """ + if self.random_choice( + [True, False, False, False, False, False], log_attribute="has_curved_arrows" + ): + depiction = self.add_curved_arrows_to_structure(depiction) + if self.random_choice( + [True, False, False], log_attribute="has_straight_arrows" + ): + depiction = self.add_straight_arrows_to_structure(depiction) + if self.random_choice( + [True, False, False, False, False, False], log_attribute="has_id_label" + ): + depiction = self.add_chemical_label(depiction, "ID") + if self.random_choice( + [True, False, False, False, False, False], log_attribute="has_R_group_label" + ): + depiction = self.add_chemical_label(depiction, "R_GROUP") + if self.random_choice( + [True, False, False, False, False, False], + log_attribute="has_reaction_label", + ): + depiction = self.add_chemical_label(depiction, "REACTION") + depiction = self.imgaug_augment(depiction) + return depiction + + def get_random_label_position(self, width: int, height: int) -> Tuple[int, int]: + """ + Given the width and height of an image (int), this function + determines a random position in the outer 15% of the image and + returns a tuple that contain the coordinates (y,x) of that position. + + Args: + width (int): image width + height (int): image height + + Returns: + Tuple[int, int]: Random label position + """ + if self.random_choice([True, False]): + y_range = range(0, height) + x_range = list(range(0, int(0.15 * width))) + list( + range(int(0.85 * width), width) + ) + else: + y_range = list(range(0, int(0.15 * height))) + list( + range(int(0.85 * height), height) + ) + x_range = range(0, width) + return self.random_choice(y_range), self.random_choice(x_range) + + def ID_label_text(self) -> str: + """ + This function returns a string that resembles a typical + chemical ID label + + Returns: + str: Label text + """ + label_num = range(1, 50) + label_letters = [ + "a", + "b", + "c", + "d", + "e", + "f", + "g", + "i", + "j", + "k", + "l", + "m", + "n", + "o", + ] + options = [ + "only_number", + "num_letter_combination", + "numtonum", + "numcombtonumcomb", + ] + option = self.random_choice(options) + if option == "only_number": + return str(self.random_choice(label_num)) + if option == "num_letter_combination": + return str(self.random_choice(label_num)) + self.random_choice( + label_letters + ) + if option == "numtonum": + return ( + str(self.random_choice(label_num)) + + "-" + + str(self.random_choice(label_num)) + ) + if option == "numcombtonumcomb": + return ( + str(self.random_choice(label_num)) + + self.random_choice(label_letters) + + "-" + + self.random_choice(label_letters) + ) + + def new_reaction_condition_elements(self) -> Tuple[str, str, str]: + """ + Randomly redefine reaction_time, solvent and other_reactand. + + Returns: + Tuple[str, str, str]: Reaction time, solvent, reactand + """ + reaction_time = self.random_choice( + [str(num) for num in range(30)] + ) + self.random_choice([" h", " min"]) + solvent = self.random_choice( + [ + "MeOH", + "EtOH", + "CHCl3", + "DCM", + "iPrOH", + "MeCN", + "DMSO", + "pentane", + "hexane", + "benzene", + "Et2O", + "THF", + "DMF", + ] + ) + other_reactand = self.random_choice( + [ + "HF", + "HCl", + "HBr", + "NaOH", + "Et3N", + "TEA", + "Ac2O", + "DIBAL", + "DIBAL-H", + "DIPEA", + "DMAP", + "EDTA", + "HOBT", + "HOAt", + "TMEDA", + "p-TsOH", + "Tf2O", + ] + ) + return reaction_time, solvent, other_reactand + + def reaction_condition_label_text(self) -> str: + """ + This function returns a random string that looks like a + reaction condition label. + + Returns: + str: Reaction condition label text + """ + reaction_condition_label = "" + label_type = self.random_choice(["A", "B", "C", "D"]) + if label_type in ["A", "B"]: + for n in range(self.random_choice(range(1, 5))): + ( + reaction_time, + solvent, + other_reactand, + ) = self.new_reaction_condition_elements() + if label_type == "A": + reaction_condition_label += ( + str(n + 1) + + " " + + other_reactand + + ", " + + solvent + + ", " + + reaction_time + + "\n" + ) + elif label_type == "B": + reaction_condition_label += ( + str(n + 1) + + " " + + other_reactand + + ", " + + solvent + + " (" + + reaction_time + + ")\n" + ) + elif label_type == "C": + ( + reaction_time, + solvent, + other_reactand, + ) = self.new_reaction_condition_elements() + reaction_condition_label += ( + other_reactand + "\n" + solvent + "\n" + reaction_time + ) + elif label_type == "D": + reaction_condition_label += self.random_choice( + self.new_reaction_condition_elements() + ) + return reaction_condition_label + + def make_R_group_str(self) -> str: + """ + This function returns a random string that looks like an R group label. + It generates them by inserting randomly chosen elements into one of + five templates. + + Returns: + str: R group label text + """ + rest_variables = [ + "X", + "Y", + "Z", + "R", + "R1", + "R2", + "R3", + "R4", + "R5", + "R6", + "R7", + "R8", + "R9", + "R10", + "Y2", + "D", + ] + # Load list of superatoms (from OSRA) + superatoms = self.superatoms + label_type = self.random_choice(["A", "B", "C", "D", "E"]) + R_group_label = "" + if label_type == "A": + for _ in range(1, self.random_choice(range(2, 6))): + R_group_label += ( + self.random_choice(rest_variables) + + " = " + + self.random_choice(superatoms) + + "\n" + ) + elif label_type == "B": + R_group_label += " " + self.random_choice(rest_variables) + "\n" + for n in range(1, self.random_choice(range(2, 6))): + R_group_label += str(n) + " " + self.random_choice(superatoms) + "\n" + elif label_type == "C": + R_group_label += ( + " " + + self.random_choice(rest_variables) + + " " + + self.random_choice(rest_variables) + + "\n" + ) + for n in range(1, self.random_choice(range(2, 6))): + R_group_label += ( + str(n) + + " " + + self.random_choice(superatoms) + + " " + + self.random_choice(superatoms) + + "\n" + ) + elif label_type == "D": + R_group_label += ( + " " + + self.random_choice(rest_variables) + + " " + + self.random_choice(rest_variables) + + " " + + self.random_choice(rest_variables) + + "\n" + ) + for n in range(1, self.random_choice(range(2, 6))): + R_group_label += ( + str(n) + + " " + + self.random_choice(superatoms) + + " " + + self.random_choice(superatoms) + + " " + + self.random_choice(superatoms) + + "\n" + ) + if label_type == "E": + for n in range(1, self.random_choice(range(2, 6))): + R_group_label += ( + str(n) + + " " + + self.random_choice(rest_variables) + + " = " + + self.random_choice(superatoms) + + "\n" + ) + return R_group_label + + def add_chemical_label( + self, image: np.array, label_type: str, foreign_fonts: bool = True + ) -> np.array: + """ + This function takes an image (np.array) and adds random text that + looks like a chemical ID label, an R group label or a reaction + condition label around the structure. It returns the modified image. + The label type is determined by the parameter label_type (str), + which needs to be 'ID', 'R_GROUP' or 'REACTION' + + Args: + image (np.array): Chemical structure depiction + label_type (str): 'ID', 'R_GROUP' or 'REACTION' + foreign_fonts (bool, optional): Defaults to True. + + Returns: + np.array: Chemical structure depiction with label + """ + im = Image.fromarray(image) + orig_image = deepcopy(im) + width, height = im.size + # Choose random font + if self.random_choice([True, False]) or not foreign_fonts: + font_dir = self.HERE.joinpath("fonts/") + # In half of the cases: Use foreign-looking font to generate + # bigger noise variety + else: + font_dir = self.HERE.joinpath("foreign_fonts/") + + fonts = os.listdir(str(font_dir)) + # Choose random font size + font_sizes = range(10, 20) + size = self.random_choice(font_sizes) + # Generate random string that resembles the desired type of label + if label_type == "ID": + label_text = self.ID_label_text() + if label_type == "R_GROUP": + label_text = self.make_R_group_str() + if label_type == "REACTION": + label_text = self.reaction_condition_label_text() + + try: + font = ImageFont.truetype( + str(os.path.join(str(font_dir), self.random_choice(fonts))), size=size + ) + except OSError: + font = ImageFont.load_default() + + draw = ImageDraw.Draw(im, "RGBA") + + # Try different positions with the condition that the label´does not + # overlap with non-white pixels (the structure) + for _ in range(50): + y_pos, x_pos = self.get_random_label_position(width, height) + bounding_box = draw.textbbox( + (x_pos, y_pos), label_text, font=font + ) # left, up, right, low + paste_region = orig_image.crop(bounding_box) + try: + mean = ImageStat.Stat(paste_region).mean + except ZeroDivisionError: + return np.asarray(im) + if sum(mean) / len(mean) == 255: + draw.text((x_pos, y_pos), label_text, font=font, fill=(0, 0, 0, 255)) + break + return np.asarray(im) + + def add_curved_arrows_to_structure(self, image: np.array) -> np.array: + """ + This function takes an image of a chemical structure (np.array) + and adds between 2 and 4 curved arrows in random positions in the + central part of the image. + + Args: + image (np.array): Chemical structure depiction + + Returns: + np.array: Chemical structure depiction with curved arrows + """ + height, width, _ = image.shape + image = Image.fromarray(image) + orig_image = deepcopy(image) + # Determine area where arrows are pasted. + x_min, x_max = (int(0.1 * width), int(0.9 * width)) + y_min, y_max = (int(0.1 * height), int(0.9 * height)) + + arrow_dir = os.path.normpath( + str(self.HERE.joinpath("arrow_images/curved_arrows/")) + ) + + for _ in range(self.random_choice(range(2, 4))): + # Load random curved arrow image, resize and rotate it randomly. + arrow_image = Image.open( + os.path.join( + str(arrow_dir), self.random_choice(os.listdir(str(arrow_dir))) + ) + ) + new_arrow_image_shape = int( + (x_max - x_min) / self.random_choice(range(3, 6)) + ), int((y_max - y_min) / self.random_choice(range(3, 6))) + arrow_image = self.resize(np.asarray(arrow_image), new_arrow_image_shape) + arrow_image = Image.fromarray(arrow_image) + arrow_image = arrow_image.rotate( + self.random_choice(range(360)), + resample=self.random_choice( + [Image.BICUBIC, Image.NEAREST, Image.BILINEAR] + ), + expand=True, + ) + # Try different positions with the condition that the arrows are + # overlapping with non-white pixels (the structure) + for _ in range(50): + x_position = self.random_choice( + range(x_min, x_max - new_arrow_image_shape[0]) + ) + y_position = self.random_choice( + range(y_min, y_max - new_arrow_image_shape[1]) + ) + paste_region = orig_image.crop( + ( + x_position, + y_position, + x_position + new_arrow_image_shape[0], + y_position + new_arrow_image_shape[1], + ) + ) + mean = ImageStat.Stat(paste_region).mean + if sum(mean) / len(mean) < 252: + image.paste(arrow_image, (x_position, y_position), arrow_image) + + break + return np.asarray(image) + + def get_random_arrow_position(self, width: int, height: int) -> Tuple[int, int]: + """ + Given the width and height of an image (int), this function determines + a random position to paste a reaction arrow in the outer 15% frame of + the image + + Args: + width (_type_): image width + height (_type_): image height + + Returns: + Tuple[int, int]: Random arrow position + """ + if self.random_choice([True, False]): + y_range = range(0, height) + x_range = list(range(0, int(0.15 * width))) + list( + range(int(0.85 * width), width) + ) + else: + y_range = list(range(0, int(0.15 * height))) + list( + range(int(0.85 * height), height) + ) + x_range = range(0, int(0.5 * width)) + return self.random_choice(y_range), self.random_choice(x_range) + + def add_straight_arrows_to_structure(self, image: np.array) -> np.array: + """ + This function takes an image of a chemical structure (np.array) + and adds between 1 and 2 straight arrows in random positions in the + image (no overlap with other elements) + + Args: + image (np.array): Chemical structure depiction + + Returns: + np.array: Chemical structure depiction with straight arrow + """ + height, width, _ = image.shape + image = Image.fromarray(image) + + arrow_dir = os.path.normpath( + str(self.HERE.joinpath("arrow_images/straight_arrows/")) + ) + + for _ in range(self.random_choice(range(1, 3))): + # Load random curved arrow image, resize and rotate it randomly. + arrow_image = Image.open( + os.path.join( + str(arrow_dir), self.random_choice(os.listdir(str(arrow_dir))) + ) + ) + # new_arrow_image_shape = (int(width * + # self.random_choice(np.arange(0.9, 1.5, 0.1))), + # int(height/10 * self.random_choice(np.arange(0.7, 1.2, 0.1)))) + + # arrow_image = arrow_image.resize(new_arrow_image_shape, + # resample=Image.BICUBIC) + # Rotate completely randomly in half of the cases and in 180° steps + # in the other cases (higher probability that pasting works) + if self.random_choice([True, False]): + arrow_image = arrow_image.rotate( + self.random_choice(range(360)), + resample=self.random_choice( + [Image.Resampling.BICUBIC, Image.Resampling.NEAREST, Image.Resampling.BILINEAR] + ), + expand=True, + ) + else: + arrow_image = arrow_image.rotate(self.random_choice([180, 360])) + new_arrow_image_shape = arrow_image.size + # Try different positions with the condition that the arrows are + # overlapping with non-white pixels (the structure) + for _ in range(50): + y_position, x_position = self.get_random_arrow_position(width, height) + x2_position = x_position + new_arrow_image_shape[0] + y2_position = y_position + new_arrow_image_shape[1] + # Make sure we only check a region inside of the image + if x2_position > width: + x2_position = width - 1 + if y2_position > height: + y2_position = height - 1 + paste_region = image.crop( + (x_position, y_position, x2_position, y2_position) + ) + try: + mean = ImageStat.Stat(paste_region).mean + if sum(mean) / len(mean) == 255: + image.paste(arrow_image, (x_position, y_position), arrow_image) + break + except ZeroDivisionError: + pass + return np.asarray(image) + + def to_grayscale_float_img(self, image: np.array) -> np.array: + """ + This function takes an image (np.array), converts it to grayscale + and returns it. + + Args: + image (np.array): image + + Returns: + np.array: grayscale float image + """ + return img_as_float(rgb2gray(image)) + + def hand_drawn_augment(self, img) -> np.array: + """ + This function randomly applies different image augmentations with + different probabilities to the input image. + + It has been modified from the original augment.py present on + https://github.com/mtzgroup/ChemPixCH + + From the publication: + https://pubs.rsc.org/en/content/articlelanding/2021/SC/D1SC02957F + + Args: + img: the image to modify in array format. + Returns: + img: the augmented image. + """ + # resize + if self.random_choice(np.arange(0, 1, 0.01)) < 0.5: + img = self.resize_hd(img) + # blur + if self.random_choice(np.arange(0, 1, 0.01)) < 0.4: + img = self.blur(img) + # erode + if self.random_choice(np.arange(0, 1, 0.01)) < 0.4: + img = self.erode(img) + # dilate + if self.random_choice(np.arange(0, 1, 0.01)) < 0.4: + img = self.dilate(img) + # aspect_ratio + if self.random_choice(np.arange(0, 1, 0.01)) < 0.7: + img = self.aspect_ratio(img, "mol") + # affine + if self.random_choice(np.arange(0, 1, 0.01)) < 0.7: + img = self.affine(img, "mol") + # distort + if self.random_choice(np.arange(0, 1, 0.01)) < 0.8: + img = self.distort(img) + if img.shape != (255, 255, 3): + img = cv2.resize(img, (256, 256)) + return img + + def augment_bkg(self, img) -> np.array: + """ + This function randomly applies different image augmentations with + different probabilities to the input image. + Args: + img: the image to modify in array format. + Returns: + img: the augmented image. + """ + # rotate + rows, cols, _ = img.shape + angle = self.random_choice(np.arange(0, 360)) + M = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, 1) + img = cv2.warpAffine(img, M, (cols, rows), borderMode=cv2.BORDER_REFLECT) + # resize + if self.random_choice(np.arange(0, 1, 0.01)) < 0.5: + img = self.resize_hd(img) + # blur + if self.random_choice(np.arange(0, 1, 0.01)) < 0.4: + img = self.blur(img) + # erode + if self.random_choice(np.arange(0, 1, 0.01)) < 0.2: + img = self.erode(img) + # dilate + if self.random_choice(np.arange(0, 1, 0.01)) < 0.2: + img = self.dilate(img) + # aspect_ratio + if self.random_choice(np.arange(0, 1, 0.01)) < 0.3: + img = self.aspect_ratio(img, "bkg") + # affine + if self.random_choice(np.arange(0, 1, 0.01)) < 0.3: + img = self.affine(img, "bkg") + # distort + if self.random_choice(np.arange(0, 1, 0.01)) < 0.8: + img = self.distort(img) + if img.shape != (255, 255, 3): + img = cv2.resize(img, (256, 256)) + return img + + def resize_hd(self, img) -> np.array: + """ + This function resizes the image randomly from between (200-300, 200-300) + and then resizes it back to 256x256. + Args: + img: the image to modify in array format. + Returns: + img: the resized image. + """ + interpolations = [ + cv2.INTER_NEAREST, + cv2.INTER_AREA, + cv2.INTER_LINEAR, + cv2.INTER_CUBIC, + cv2.INTER_LANCZOS4, + ] + + img = cv2.resize( + img, + ( + self.random_choice(np.arange(200, 300)), + self.random_choice(np.arange(200, 300)), + ), + interpolation=self.random_choice(interpolations), + ) + img = cv2.resize( + img, (256, 256), interpolation=self.random_choice(interpolations) + ) + + return img + + def blur(self, img) -> np.array: + """ + This function blurs the image randomly between 1-3. + Args: + img: the image to modify in array format. + Returns: + img: the blurred image. + """ + n = self.random_choice(np.arange(1, 4)) + kernel = np.ones((n, n), np.float32) / n**2 + img = cv2.filter2D(img, -1, kernel) + return img + + def erode(self, img) -> np.array: + """ + This function bolds the image randomly between 1-2. + Args: + img: the image to modify in array format. + Returns: + img: the bold image. + """ + n = self.random_choice(np.arange(1, 3)) + kernel = np.ones((n, n), np.float32) / n**2 + img = cv2.erode(img, kernel, iterations=1) + return img + + def dilate(self, img) -> np.array: + """ + This function dilates the image with a factor of 2. + Args: + img: the image to modify in array format. + Returns: + img: the dilated image. + """ + n = 2 + kernel = np.ones((n, n), np.float32) / n**2 + img = cv2.dilate(img, kernel, iterations=1) + return img + + def aspect_ratio(self, img, obj=None) -> np.array: + """ + This function irregularly changes the size of the image + and converts it back to (256,256). + Args: + img: the image to modify in array format. + obj: "mol" or "bkg" to modify a chemical structure image or + a background image. + Returns: + image: the resized image. + """ + n1 = self.random_choice(np.arange(0, 50)) + n2 = self.random_choice(np.arange(0, 50)) + n3 = self.random_choice(np.arange(0, 50)) + n4 = self.random_choice(np.arange(0, 50)) + if obj == "mol": + image = cv2.copyMakeBorder( + img, n1, n2, n3, n4, cv2.BORDER_CONSTANT, value=[255, 255, 255] + ) + elif obj == "bkg": + image = cv2.copyMakeBorder(img, n1, n2, n3, n4, cv2.BORDER_REFLECT) + + image = cv2.resize(image, (256, 256)) + return image + + def affine(self, img, obj=None) -> np.array: + """ + This function randomly applies affine transformation which consists + of matrix rotations, translations and scale operations and converts + it back to (256,256). + Args: + img: the image to modify in array format. + obj: "mol" or "bkg" to modify a chemical structure image or + a background image. + Returns: + skewed: the transformed image. + """ + rows, cols, _ = img.shape + n = 20 + pts1 = np.float32([[5, 50], [200, 50], [50, 200]]) + pts2 = np.float32( + [ + [ + 5 + self.random_choice(np.arange(-n, n)), + 50 + self.random_choice(np.arange(-n, n)), + ], + [ + 200 + self.random_choice(np.arange(-n, n)), + 50 + self.random_choice(np.arange(-n, n)), + ], + [ + 50 + self.random_choice(np.arange(-n, n)), + 200 + self.random_choice(np.arange(-n, n)), + ], + ] + ) + + M = cv2.getAffineTransform(pts1, pts2) + + if obj == "mol": + skewed = cv2.warpAffine(img, M, (cols, rows), borderValue=[255, 255, 255]) + elif obj == "bkg": + skewed = cv2.warpAffine(img, M, (cols, rows), borderMode=cv2.BORDER_REFLECT) + + skewed = cv2.resize(skewed, (256, 256)) + return skewed + + def elastic_transform(self, image, alpha_sigma) -> np.array: + """ + Elastic deformation of images as described in [Simard2003]_. + .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for + Convolutional Neural Networks applied to Visual Document Analysis", in + Proc. of the International Conference on Document Analysis and + Recognition, 2003. + https://gist.github.com/erniejunior/601cdf56d2b424757de5 + This function distords an image randomly changing the alpha and gamma + values. + Args: + image: the image to modify in array format. + alpha_sigma: alpha and sigma values randomly selected as a list. + Returns: + distored_image: the image after the transformation with the same size + as it had originally. + """ + alpha = alpha_sigma[0] + sigma = alpha_sigma[1] + random_state = np.random.RandomState(self.random_choice(np.arange(1, 1000))) + + shape = image.shape + dx = ( + gaussian_filter( + (random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0 + ) + * alpha + ) + random_state = np.random.RandomState(self.random_choice(np.arange(1, 1000))) + dy = ( + gaussian_filter( + (random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0 + ) + * alpha + ) + + x, y, z = np.meshgrid( + np.arange(shape[0]), np.arange(shape[1]), np.arange(shape[2]) + ) + indices = ( + np.reshape(y + dy, (-1, 1)), + np.reshape(x + dx, (-1, 1)), + np.reshape(z, (-1, 1)), + ) + + distored_image = map_coordinates( + image, indices, order=self.random_choice(np.arange(1, 5)), mode="reflect" + ) + return distored_image.reshape(image.shape) + + def distort(self, img) -> np.array: + """ + This function randomly selects a list with the shape [a, g] where + a=alpha and g=gamma and passes them along with the input image + to the elastic_transform function that will do the image distorsion. + Args: + img: the image to modify in array format. + Returns: + the output from elastic_transform function which is the image + after the transformation with the same size as it had originally. + """ + sigma_alpha = [ + (self.random_choice(np.arange(9, 11)), self.random_choice(np.arange(2, 4))), + (self.random_choice(np.arange(80, 100)), 4), + (self.random_choice(np.arange(150, 300)), 5), + ( + self.random_choice(np.arange(800, 1200)), + self.random_choice(np.arange(8, 10)), + ), + ( + self.random_choice(np.arange(1500, 2000)), + self.random_choice(np.arange(10, 15)), + ), + ( + self.random_choice(np.arange(5000, 8000)), + self.random_choice(np.arange(15, 25)), + ), + ( + self.random_choice(np.arange(10000, 15000)), + self.random_choice(np.arange(20, 25)), + ), + ( + self.random_choice(np.arange(45000, 55000)), + self.random_choice(np.arange(30, 35)), + ), + ] + choice = self.random_choice(range(len(sigma_alpha))) + sigma_alpha_chosen = sigma_alpha[choice] + return self.elastic_transform(img, sigma_alpha_chosen) + + def degrade_img(self, img) -> np.array: + """ + This function randomly degrades the input image by applying different + degradation steps with different robabilities. + Args: + img: the image to modify in array format. + Returns: + img: the degraded image. + """ + # s+p + if self.random_choice(np.arange(0, 1, 0.01)) < 0.1: + img = self.s_and_p(img) + + # scale + if self.random_choice(np.arange(0, 1, 0.01)) < 0.5: + img = self.scale(img) + + # brightness + if self.random_choice(np.arange(0, 1, 0.01)) < 0.7: + img = self.brightness(img) + + # contrast + if self.random_choice(np.arange(0, 1, 0.01)) < 0.7: + img = self.contrast(img) + + # sharpness + if self.random_choice(np.arange(0, 1, 0.01)) < 0.5: + img = self.sharpness(img) + + return img + + def contrast(self, img) -> np.array: + """ + This function randomly changes the input image contrast. + Args: + img: the image to modify in array format. + Returns: + img: the image with the contrast changes. + """ + if self.random_choice(np.arange(0, 1, 0.01)) < 0.8: # increase contrast + f = self.random_choice(np.arange(1, 2, 0.01)) + else: # decrease contrast + f = self.random_choice(np.arange(0.5, 1, 0.01)) + im_pil = Image.fromarray(img) + enhancer = ImageEnhance.Contrast(im_pil) + im = enhancer.enhance(f) + img = np.asarray(im) + return np.asarray(im) + + def brightness(self, img) -> np.array: + """ + This function randomly changes the input image brightness. + Args: + img: the image to modify in array format. + Returns: + img: the image with the brightness changes. + """ + f = self.random_choice(np.arange(0.4, 1.1, 0.01)) + im_pil = Image.fromarray(img) + enhancer = ImageEnhance.Brightness(im_pil) + im = enhancer.enhance(f) + img = np.asarray(im) + return np.asarray(im) + + def sharpness(self, img) -> np.array: + """ + This function randomly changes the input image sharpness. + Args: + img: the image to modify in array format. + Returns: + img: the image with the sharpness changes. + """ + if self.random_choice(np.arange(0, 1, 0.01)) < 0.5: # increase sharpness + f = self.random_choice(np.arange(0.1, 1, 0.01)) + else: # decrease sharpness + f = self.random_choice(np.arange(1, 10)) + im_pil = Image.fromarray(img) + enhancer = ImageEnhance.Sharpness(im_pil) + im = enhancer.enhance(f) + img = np.asarray(im) + return np.asarray(im) + + def s_and_p(self, img) -> np.array: + """ + This function randomly adds salt and pepper to the input image. + Args: + img: the image to modify in array format. + Returns: + out: the image with the s&p changes. + """ + amount = self.random_choice(np.arange(0.001, 0.01)) + # add some s&p + s_vs_p = 0.5 + out = np.copy(img) + # Salt mode + num_salt = int(np.ceil(amount * img.size * s_vs_p)) + coords = [] + for i in img.shape: + coordinates = [] + for _ in range(num_salt): + coordinates.append(self.random_choice(np.arange(0, i - 1))) + coords.append(np.array(coordinates)) + out[tuple(coords)] = 1 + # pepper + num_pepper = int(np.ceil(amount * img.size * (1.0 - s_vs_p))) + coords = [] + for i in img.shape: + coordinates = [] + for _ in range(num_pepper): + coordinates.append(self.random_choice(np.arange(0, i - 1))) + coords.append(np.array(coordinates)) + out[tuple(coords)] = 0 + return out + + def scale(self, img) -> np.array: + """ + This function randomly scales the input image. + Args: + img: the image to modify in array format. + Returns: + res: the scaled image. + """ + f = self.random_choice(np.arange(0.5, 1.5, 0.01)) + res = cv2.resize(img, None, fx=f, fy=f, interpolation=cv2.INTER_CUBIC) + res = cv2.resize( + res, None, fx=1.0 / f, fy=1.0 / f, interpolation=cv2.INTER_CUBIC + ) + return res diff --git a/RanDepict/indigo_functionalities.py b/RanDepict/indigo_functionalities.py new file mode 100644 index 0000000..59a2e01 --- /dev/null +++ b/RanDepict/indigo_functionalities.py @@ -0,0 +1,126 @@ +from indigo import Indigo +from indigo import IndigoException +from indigo.renderer import IndigoRenderer +import io +import numpy as np +from skimage import io as sk_io +from skimage.color import rgba2rgb +from skimage.util import img_as_ubyte +from typing import Tuple + + +class IndigoFunctionalities: + """ + Child class of RandomDepictor that contains all RDKit-related functions. + ___ + This class does not work on its own. It is meant to be used as a child class. + """ + + def indigo_depict( + self, + smiles: str = None, + mol_block: str = None, + shape: Tuple[int, int] = (512, 512) + ) -> np.array: + """ + This function takes a mol block str and an image shape. + It renders the chemical structures using Indigo with random + rendering/depiction settings and returns an RGB image (np.array) + with the given image shape. + + Args: + smiles (str): SMILES representation of molecule + mol_block (str): mol block representation of molecule + shape (Tuple[int, int], optional): im shape. Defaults to (512, 512) + + Returns: + np.array: Chemical structure depiction + """ + # Instantiate Indigo with random settings and IndigoRenderer + indigo, renderer = self.get_random_indigo_rendering_settings() + if not smiles and not mol_block: + raise ValueError("Either smiles or mol_block must be provided") + if smiles: + mol_block = self._smiles_to_mol_block(smiles, + generate_2d=self.random_choice( + ["rdkit", "cdk", "indigo"] + )) + try: + molecule = indigo.loadMolecule(mol_block) + except IndigoException: + return None + # Kekulize in 67% of cases + if not self.random_choice( + [True, True, False], log_attribute="indigo_kekulized" + ): + molecule.aromatize() + temp = renderer.renderToBuffer(molecule) + temp = io.BytesIO(temp) + depiction = sk_io.imread(temp) + depiction = self.resize(depiction, (shape[0], shape[1])) + depiction = rgba2rgb(depiction) + depiction = img_as_ubyte(depiction) + return depiction + + def get_random_indigo_rendering_settings( + self, shape: Tuple[int, int] = (299, 299) + ) -> Indigo: + """ + This function defines random rendering options for the structure + depictions created using Indigo. + It returns an Indigo object with the settings. + + Args: + shape (Tuple[int, int], optional): im shape. Defaults to (299, 299) + + Returns: + Indigo: Indigo object that contains depictions settings + """ + # Define random shape for depiction (within boundaries);) + indigo = Indigo() + renderer = IndigoRenderer(indigo) + # Get slightly distorted shape + y, x = shape + indigo.setOption("render-image-width", x) + indigo.setOption("render-image-height", y) + # Set random bond line width + bond_line_width = float( + self.random_choice( + np.arange(0.5, 2.5, 0.1), log_attribute="indigo_bond_line_width" + ) + ) + indigo.setOption("render-bond-line-width", bond_line_width) + # Set random relative thickness + relative_thickness = float( + self.random_choice( + np.arange(0.5, 1.5, 0.1), log_attribute="indigo_relative_thickness" + ) + ) + indigo.setOption("render-relative-thickness", relative_thickness) + # Output_format: PNG + indigo.setOption("render-output-format", "png") + # Set random atom label rendering model + # (standard is rendering terminal groups) + if self.random_choice([True] + [False] * 19, log_attribute="indigo_labels_all"): + # show all atom labels + indigo.setOption("render-label-mode", "all") + elif self.random_choice( + [True] + [False] * 3, log_attribute="indigo_labels_hetero" + ): + indigo.setOption( + "render-label-mode", "hetero" + ) # only hetero atoms, no terminal groups + # Render bold bond for Haworth projection + if self.random_choice([True, False], log_attribute="indigo_render_bold_bond"): + indigo.setOption("render-bold-bond-detection", "True") + # Render labels for stereobonds + stereo_style = self.random_choice( + ["ext", "old", "none"], log_attribute="indigo_stereo_label_style" + ) + indigo.setOption("render-stereo-style", stereo_style) + # Collapse superatoms (default: expand) + if self.random_choice( + [True, False], log_attribute="indigo_collapse_superatoms" + ): + indigo.setOption("render-superatom-mode", "collapse") + return indigo, renderer diff --git a/RanDepict/pikachu_functionalities.py b/RanDepict/pikachu_functionalities.py new file mode 100644 index 0000000..69765f2 --- /dev/null +++ b/RanDepict/pikachu_functionalities.py @@ -0,0 +1,75 @@ + +import numpy as np +from pikachu.drawing import drawing +from pikachu.smiles.smiles import read_smiles +from typing import Tuple + + +class PikachuFunctionalities: + def get_random_pikachu_rendering_settings( + self, shape: Tuple[int, int] = (512, 512) + ) -> drawing.Options: + """ + This function defines random rendering options for the structure + depictions created using PIKAChU. + It returns an pikachu.drawing.drawing.Options object with the settings. + + Args: + shape (Tuple[int, int], optional): im shape. Defaults to (512, 512) + + Returns: + options: Options object that contains depictions settings + """ + options = drawing.Options() + options.height, options.width = shape + options.bond_thickness = self.random_choice(np.arange(0.5, 2.2, 0.1), + log_attribute="pikachu_bond_line_width") + options.bond_length = self.random_choice(np.arange(10, 25, 1), + log_attribute="pikachu_bond_length") + options.chiral_bond_width = options.bond_length * self.random_choice( + np.arange(0.05, 0.2, 0.01) + ) + options.short_bond_length = self.random_choice(np.arange(0.2, 0.6, 0.05), + log_attribute="pikachu_short_bond_length") + options.double_bond_length = self.random_choice(np.arange(0.6, 0.8, 0.05), + log_attribute="pikachu_double_bond_length") + options.bond_spacing = options.bond_length * self.random_choice( + np.arange(0.15, 0.28, 0.01), + log_attribute="pikachu_bond_spacing" + + ) + options.padding = self.random_choice(np.arange(10, 50, 5), + log_attribute="pikachu_padding") + # options.font_size_large = 5 + # options.font_size_small = 3 + return options + + def pikachu_depict( + self, + smiles: str = None, + shape: Tuple[int, int] = (512, 512) + ) -> np.array: + """ + This function takes a mol block str and an image shape. + It renders the chemical structures using PIKAChU with random + rendering/depiction settings and returns an RGB image (np.array) + with the given image shape. + + Args: + smiles (str, optional): smiles representation of molecule + shape (Tuple[int, int], optional): im shape. Defaults to (512, 512) + + Returns: + np.array: Chemical structure depiction + """ + structure = read_smiles(smiles) + + depiction_settings = self.get_random_pikachu_rendering_settings() + if "." in smiles: + drawer = drawing.draw_multiple(structure, options=depiction_settings) + else: + drawer = drawing.Drawer(structure, options=depiction_settings) + depiction = drawer.get_image_as_array() + depiction = self.central_square_image(depiction) + depiction = self.resize(depiction, (shape[0], shape[1])) + return depiction diff --git a/RanDepict/randepict.py b/RanDepict/randepict.py index 07b6029..3c5657f 100644 --- a/RanDepict/randepict.py +++ b/RanDepict/randepict.py @@ -1,88 +1,41 @@ from __future__ import annotations import copy - +import cv2 +from indigo import Indigo +from jpype import startJVM, getDefaultJVMPath +from jpype import JVMNotFoundException, isJVMStarted +from multiprocessing import set_start_method, get_context +import numpy as np +from omegaconf import OmegaConf import os from pathlib import Path -import numpy as np -import io -from skimage import io as sk_io -from skimage.color import rgba2rgb, rgb2gray -from skimage.util import img_as_ubyte, img_as_float -from multiprocessing import set_start_method, get_context -import imgaug.augmenters as iaa +from PIL import Image import random -from copy import deepcopy -from typing import Optional, Tuple, List, Dict, Any, Callable import re - from rdkit import Chem from rdkit.Chem import AllChem -from rdkit.Chem.rdAbbreviations import CondenseMolAbbreviations -from rdkit.Chem.rdAbbreviations import GetDefaultAbbreviations -from rdkit.Chem.Draw import rdMolDraw2D -from rdkit import DataStructs -from rdkit.SimDivFilters.rdSimDivPickers import MaxMinPicker -from itertools import product - -from omegaconf import OmegaConf, DictConfig # configuration package -from dataclasses import dataclass, field - -from indigo import Indigo -from indigo import IndigoException -from indigo.renderer import IndigoRenderer -from jpype import startJVM, getDefaultJVMPath -from jpype import JClass, JVMNotFoundException, isJVMStarted -from pikachu.drawing import drawing -from pikachu.chem.molfile.read_molfile import MolFileReader - -import base64 +from skimage import io as sk_io +from skimage.util import img_as_ubyte +from typing import Callable, Dict, List, Optional, Tuple -import cv2 -from scipy.ndimage import gaussian_filter -from scipy.ndimage import map_coordinates +from .augmentations import Augmentations +from .cdk_functionalities import CDKFunctionalities +from .config import RandomDepictorConfig +from .indigo_functionalities import IndigoFunctionalities +from .pikachu_functionalities import PikachuFunctionalities +from .rdkit_functionalities import RDKitFuntionalities -from PIL import Image, ImageFont, ImageDraw, ImageStat, ImageEnhance # Below version 9.0, PIL stores resampling methods differently if not hasattr(Image, 'Resampling'): Image.Resampling = Image -@dataclass -class RandomDepictorConfig: - """ - Examples - -------- - >>> c1 = RandomDepictorConfig(seed=24, styles=["cdk", "indigo"]) - >>> c1 - RandomDepictorConfig(seed=24, hand_drawn=False, augment=True, styles=['cdk', 'indigo']) - >>> c2 = RandomDepictorConfig(styles=["cdk", "indigo", "pikachu", "rdkit"]) - >>> c2 - RandomDepictorConfig(seed=42, hand_drawn=False, augment=True, styles=['cdk', 'indigo', 'pikachu', 'rdkit']) - """ - seed: int = 42 - hand_drawn: bool = False - augment: bool = True - # unions of containers are not supported yet - # https://github.com/omry/omegaconf/issues/144 - # styles: Union[str, List[str]] = field(default_factory=lambda: ["cdk", "indigo", "pikachu", "rdkit"]) - styles: List[str] = field(default_factory=lambda: ["cdk", "indigo", "pikachu", "rdkit"]) - @classmethod - def from_config(cls, dict_config: Optional[DictConfig] = None) -> 'RandomDepictorConfig': - return OmegaConf.structured(cls(**dict_config)) - - def __post_init__(self): - # Ensure styles are always List[str] when "cdk, indigo" is passed - if isinstance(self.styles, str): - self.styles = [v.strip() for v in self.styles.split(",")] - if len(self.styles) == 0: - raise ValueError("Empty list of styles was supplied.") - # Not sure if this is the best way in order to not repeat the list of styles - ss = set(self.__dataclass_fields__['styles'].default_factory()) - if any([s not in ss for s in self.styles]): - raise ValueError(f"Use only {', '.join(ss)}") - - -class RandomDepictor: + +class RandomDepictor(Augmentations, + CDKFunctionalities, + IndigoFunctionalities, + PikachuFunctionalities, + RDKitFuntionalities): """ This class contains everything necessary to generate a variety of random depictions with given SMILES strings. An instance of RandomDepictor @@ -224,1431 +177,128 @@ def __exit__(self, type, value, tb): # shutdownJVM() pass - def get_all_rdkit_abbreviations( - self, - ) -> List[Chem.rdAbbreviations._vectstruct]: - """ - This function returns the Default abbreviations for superatom and functional - group collapsing in RDKit as well as alternative abbreviations defined in - rdkit_alternative_superatom_abbreviations.txt. - - Returns: - Chem.rdAbbreviations._vectstruct: RDKit's data structure that contains the - abbreviations - """ - abbreviations = [] - abbreviations.append(GetDefaultAbbreviations()) - abbr_path = self.HERE.joinpath("rdkit_alternative_superatom_abbreviations.txt") - - with open(abbr_path) as alternative_abbreviations: - split_lines = [line[:-1].split(",") - for line in alternative_abbreviations.readlines()] - swap_dict = {line[0]: line[1:] for line in split_lines} - - abbreviations.append(self.get_modified_rdkit_abbreviations(swap_dict)) - for key in swap_dict.keys(): - new_labels = [] - for label in swap_dict[key]: - if label[:2] in ["n-", "i-", "t-"]: - label = f"{label[2:]}-{label[0]}" - elif label[-2:] in ["-n", "-i", "-t"]: - label = f"{label[-1]}-{label[:-2]}" - new_labels.append(label) - swap_dict[key] = new_labels - abbreviations.append(self.get_modified_rdkit_abbreviations(swap_dict)) - return abbreviations - - def get_modified_rdkit_abbreviations( + def random_depiction( self, - swap_dict: Dict - ) -> Chem.rdAbbreviations._vectstruct: - """ - This function takes a dictionary that maps the original superatom/FG label in - the RDKit abbreviations to the desired labels, replaces them as defined in the - dictionary and returns the abbreviations in RDKit's preferred format. - - Args: - swap_dict (Dict): Dictionary that maps the original label (eg. "Et") to the - desired label (eg. "C2H5"), a displayed label (eg. - "C2H5") and a reversed display label - (eg. "H5C2"). - Example: - {"Et": [ - "C2H5", - "C2H5" - "H5C2" - ]} - - Returns: - Chem.rdAbbreviations._vectstruct: Modified abbreviations - """ - alt_abbreviations = GetDefaultAbbreviations() - for abbr in alt_abbreviations: - alt_label = swap_dict.get(abbr.label) - if alt_label: - abbr.label, abbr.displayLabel, abbr.displayLabelW = alt_label - return alt_abbreviations - - def random_choice(self, iterable: List, log_attribute: str = False): - """ - This function takes an iterable, calls random.choice() on it, - increases random.seed by 1 and returns the result. This way, results - produced by RanDepict are replicable. - - Additionally, this function handles the generation of depictions and - augmentations from given fingerprints by handling all random decisions - according to the fingerprint template. - - Args: - iterable (List): iterable to pick from - log_attribute (str, optional): ID for fingerprint. - Defaults to False. - - Returns: - Any: "Randomly" picked element - """ - # Keep track of seed and change it with every pseudo-random decision. - self.seed += 1 - random.seed(self.seed) - - # Generation from fingerprint: - if self.from_fingerprint and log_attribute: - # Get dictionaries that define positions and linked conditions - pos_cond_dicts = self.active_scheme[log_attribute] - for pos_cond_dict in pos_cond_dicts: - pos = pos_cond_dict["position"] - cond = pos_cond_dict["one_if"] - if self.active_fingerprint[pos]: - # If the condition is a range: adapt iterable and go on - if isinstance(cond, tuple): - iterable = [ - item - for item in iterable - if item > cond[0] - 0.001 - if item < cond[1] + 0.001 - ] - break - # Otherwise, simply return the condition value - else: - return cond - # Pseudo-randomly pick an element from the iterable - result = random.choice(iterable) - - return result - - def random_image_size(self, shape: Tuple[int, int]) -> Tuple[int, int]: - """ - This function takes a random image shape and returns an image shape - where the first two dimensions are slightly distorted - (90-110% of original value). - - Args: - shape (Tuple[int, int]): original shape - - Returns: - Tuple[int, int]: distorted shape - """ - # Set random depiction image shape (to cause a slight distortion) - y = int(shape[0] * self.random_choice(np.arange(0.9, 1.1, 0.02))) - x = int(shape[1] * self.random_choice(np.arange(0.9, 1.1, 0.02))) - return y, x - - def hand_drawn_augment(self, img) -> np.array: - """ - This function randomly applies different image augmentations with - different probabilities to the input image. - - It has been modified from the original augment.py present on - https://github.com/mtzgroup/ChemPixCH - - From the publication: - https://pubs.rsc.org/en/content/articlelanding/2021/SC/D1SC02957F - - Args: - img: the image to modify in array format. - Returns: - img: the augmented image. - """ - # resize - if self.random_choice(np.arange(0, 1, 0.01)) < 0.5: - img = self.resize_hd(img) - # blur - if self.random_choice(np.arange(0, 1, 0.01)) < 0.4: - img = self.blur(img) - # erode - if self.random_choice(np.arange(0, 1, 0.01)) < 0.4: - img = self.erode(img) - # dilate - if self.random_choice(np.arange(0, 1, 0.01)) < 0.4: - img = self.dilate(img) - # aspect_ratio - if self.random_choice(np.arange(0, 1, 0.01)) < 0.7: - img = self.aspect_ratio(img, "mol") - # affine - if self.random_choice(np.arange(0, 1, 0.01)) < 0.7: - img = self.affine(img, "mol") - # distort - if self.random_choice(np.arange(0, 1, 0.01)) < 0.8: - img = self.distort(img) - if img.shape != (255, 255, 3): - img = cv2.resize(img, (256, 256)) - return img - - def augment_bkg(self, img) -> np.array: - """ - This function randomly applies different image augmentations with - different probabilities to the input image. - Args: - img: the image to modify in array format. - Returns: - img: the augmented image. - """ - # rotate - rows, cols, _ = img.shape - angle = self.random_choice(np.arange(0, 360)) - M = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, 1) - img = cv2.warpAffine(img, M, (cols, rows), borderMode=cv2.BORDER_REFLECT) - # resize - if self.random_choice(np.arange(0, 1, 0.01)) < 0.5: - img = self.resize_hd(img) - # blur - if self.random_choice(np.arange(0, 1, 0.01)) < 0.4: - img = self.blur(img) - # erode - if self.random_choice(np.arange(0, 1, 0.01)) < 0.2: - img = self.erode(img) - # dilate - if self.random_choice(np.arange(0, 1, 0.01)) < 0.2: - img = self.dilate(img) - # aspect_ratio - if self.random_choice(np.arange(0, 1, 0.01)) < 0.3: - img = self.aspect_ratio(img, "bkg") - # affine - if self.random_choice(np.arange(0, 1, 0.01)) < 0.3: - img = self.affine(img, "bkg") - # distort - if self.random_choice(np.arange(0, 1, 0.01)) < 0.8: - img = self.distort(img) - if img.shape != (255, 255, 3): - img = cv2.resize(img, (256, 256)) - return img - - def resize_hd(self, img) -> np.array: - """ - This function resizes the image randomly from between (200-300, 200-300) - and then resizes it back to 256x256. - Args: - img: the image to modify in array format. - Returns: - img: the resized image. - """ - interpolations = [ - cv2.INTER_NEAREST, - cv2.INTER_AREA, - cv2.INTER_LINEAR, - cv2.INTER_CUBIC, - cv2.INTER_LANCZOS4, - ] - - img = cv2.resize( - img, - ( - self.random_choice(np.arange(200, 300)), - self.random_choice(np.arange(200, 300)), - ), - interpolation=self.random_choice(interpolations), - ) - img = cv2.resize( - img, (256, 256), interpolation=self.random_choice(interpolations) - ) - - return img - - def blur(self, img) -> np.array: - """ - This function blurs the image randomly between 1-3. - Args: - img: the image to modify in array format. - Returns: - img: the blurred image. + smiles: str, + shape: Tuple[int, int] = (299, 299), + ) -> np.array: """ - n = self.random_choice(np.arange(1, 4)) - kernel = np.ones((n, n), np.float32) / n**2 - img = cv2.filter2D(img, -1, kernel) - return img + This function takes a SMILES and depicts it using Rdkit, Indigo, CDK or PIKACHU. + The depiction method and the specific parameters for the depiction are + chosen completely randomly. The purpose of this function is to enable + depicting a diverse variety of chemical structure depictions. - def erode(self, img) -> np.array: - """ - This function bolds the image randomly between 1-2. Args: - img: the image to modify in array format. - Returns: - img: the bold image. - """ - n = self.random_choice(np.arange(1, 3)) - kernel = np.ones((n, n), np.float32) / n**2 - img = cv2.erode(img, kernel, iterations=1) - return img + smiles (str): SMILES representation of molecule + shape (Tuple[int, int], optional): im shape. Defaults to (299, 299) - def dilate(self, img) -> np.array: - """ - This function dilates the image with a factor of 2. - Args: - img: the image to modify in array format. Returns: - img: the dilated image. + np.array: Chemical structure depiction """ - n = 2 - kernel = np.ones((n, n), np.float32) / n**2 - img = cv2.dilate(img, kernel, iterations=1) - return img + depiction_functions = self.get_depiction_functions(smiles) - def aspect_ratio(self, img, obj=None) -> np.array: - """ - This function irregularly changes the size of the image - and converts it back to (256,256). - Args: - img: the image to modify in array format. - obj: "mol" or "bkg" to modify a chemical structure image or - a background image. - Returns: - image: the resized image. - """ - n1 = self.random_choice(np.arange(0, 50)) - n2 = self.random_choice(np.arange(0, 50)) - n3 = self.random_choice(np.arange(0, 50)) - n4 = self.random_choice(np.arange(0, 50)) - if obj == "mol": - image = cv2.copyMakeBorder( - img, n1, n2, n3, n4, cv2.BORDER_CONSTANT, value=[255, 255, 255] - ) - elif obj == "bkg": - image = cv2.copyMakeBorder(img, n1, n2, n3, n4, cv2.BORDER_REFLECT) + for _ in range(3): + if len(depiction_functions) != 0: + # Pick random depiction function and call it + depiction_function = self.random_choice(depiction_functions) + depiction = depiction_function(smiles=smiles, shape=shape) + if depiction is False or depiction is None: + depiction_functions.remove(depiction_function) + else: + break + else: + return None - image = cv2.resize(image, (256, 256)) - return image + if self.hand_drawn: + path_bkg = self.HERE.joinpath("backgrounds/") + # Augment molecule image + mol_aug = self.hand_drawn_augment(depiction) - def affine(self, img, obj=None) -> np.array: - """ - This function randomly applies affine transformation which consists - of matrix rotations, translations and scale operations and converts - it back to (256,256). - Args: - img: the image to modify in array format. - obj: "mol" or "bkg" to modify a chemical structure image or - a background image. - Returns: - skewed: the transformed image. - """ - rows, cols, _ = img.shape - n = 20 - pts1 = np.float32([[5, 50], [200, 50], [50, 200]]) - pts2 = np.float32( - [ - [ - 5 + self.random_choice(np.arange(-n, n)), - 50 + self.random_choice(np.arange(-n, n)), - ], - [ - 200 + self.random_choice(np.arange(-n, n)), - 50 + self.random_choice(np.arange(-n, n)), - ], - [ - 50 + self.random_choice(np.arange(-n, n)), - 200 + self.random_choice(np.arange(-n, n)), - ], - ] - ) + # Randomly select background image and use is as it is + backgroud_selected = self.random_choice(os.listdir(path_bkg)) + bkg = cv2.imread(os.path.join(os.path.normpath(path_bkg), backgroud_selected)) + bkg = cv2.resize(bkg, (256, 256)) + # Combine augmented molecule and augmented background + p = 0.7 + mol_bkg = cv2.addWeighted(mol_aug, p, bkg, 1 - p, gamma=0) - M = cv2.getAffineTransform(pts1, pts2) + """ + If you want to randomly augment the background as well, + simply comment the previous section and uncomment the next one. + """ - if obj == "mol": - skewed = cv2.warpAffine(img, M, (cols, rows), borderValue=[255, 255, 255]) - elif obj == "bkg": - skewed = cv2.warpAffine(img, M, (cols, rows), borderMode=cv2.BORDER_REFLECT) + """# Randomly select background image and augment it + bkg_aug = self.augment_bkg(bkg) + bkg_aug = cv2.resize(bkg_aug,(256,256)) + # Combine augmented molecule and augmented background + p=0.7 + mol_bkg = cv2.addWeighted(mol_aug, p, bkg_aug, 1-p, gamma=0)""" - skewed = cv2.resize(skewed, (256, 256)) - return skewed + # Degrade total image + depiction = self.degrade_img(mol_bkg) + return depiction - def elastic_transform(self, image, alpha_sigma) -> np.array: - """ - Elastic deformation of images as described in [Simard2003]_. - .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for - Convolutional Neural Networks applied to Visual Document Analysis", in - Proc. of the International Conference on Document Analysis and - Recognition, 2003. - https://gist.github.com/erniejunior/601cdf56d2b424757de5 - This function distords an image randomly changing the alpha and gamma - values. - Args: - image: the image to modify in array format. - alpha_sigma: alpha and sigma values randomly selected as a list. - Returns: - distored_image: the image after the transformation with the same size - as it had originally. + def random_depiction_with_coordinates( + self, + smiles: str, + augment: bool = False, + shape: Tuple[int, int] = (512, 512), + ) -> Tuple[np.array, str]: """ - alpha = alpha_sigma[0] - sigma = alpha_sigma[1] - random_state = np.random.RandomState(self.random_choice(np.arange(1, 1000))) - - shape = image.shape - dx = ( - gaussian_filter( - (random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0 - ) - * alpha - ) - random_state = np.random.RandomState(self.random_choice(np.arange(1, 1000))) - dy = ( - gaussian_filter( - (random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0 - ) - * alpha - ) - - x, y, z = np.meshgrid( - np.arange(shape[0]), np.arange(shape[1]), np.arange(shape[2]) - ) - indices = ( - np.reshape(y + dy, (-1, 1)), - np.reshape(x + dx, (-1, 1)), - np.reshape(z, (-1, 1)), - ) - - distored_image = map_coordinates( - image, indices, order=self.random_choice(np.arange(1, 5)), mode="reflect" - ) - return distored_image.reshape(image.shape) + This function takes a SMILES and depicts it using Rdkit, Indigo or CDK. + We cannot use PIKAChU here, as it does not depict given coordinates, but it + always generates them during the prediction process. + The depiction method and the specific parameters for the depiction are + chosen completely randomly. The purpose of this function is to enable + depicting a diverse variety of chemical structure depictions. - def distort(self, img) -> np.array: - """ - This function randomly selects a list with the shape [a, g] where - a=alpha and g=gamma and passes them along with the input image - to the elastic_transform function that will do the image distorsion. - Args: - img: the image to modify in array format. - Returns: - the output from elastic_transform function which is the image - after the transformation with the same size as it had originally. - """ - sigma_alpha = [ - (self.random_choice(np.arange(9, 11)), self.random_choice(np.arange(2, 4))), - (self.random_choice(np.arange(80, 100)), 4), - (self.random_choice(np.arange(150, 300)), 5), - ( - self.random_choice(np.arange(800, 1200)), - self.random_choice(np.arange(8, 10)), - ), - ( - self.random_choice(np.arange(1500, 2000)), - self.random_choice(np.arange(10, 15)), - ), - ( - self.random_choice(np.arange(5000, 8000)), - self.random_choice(np.arange(15, 25)), - ), - ( - self.random_choice(np.arange(10000, 15000)), - self.random_choice(np.arange(20, 25)), - ), - ( - self.random_choice(np.arange(45000, 55000)), - self.random_choice(np.arange(30, 35)), - ), - ] - choice = self.random_choice(range(len(sigma_alpha))) - sigma_alpha_chosen = sigma_alpha[choice] - return self.elastic_transform(img, sigma_alpha_chosen) + The depiction (np.array) and the cxSMILES (str) that encodes the coordinates of + the depicted molecule are returned. - def degrade_img(self, img) -> np.array: - """ - This function randomly degrades the input image by applying different - degradation steps with different robabilities. Args: - img: the image to modify in array format. - Returns: - img: the degraded image. - """ - # s+p - if self.random_choice(np.arange(0, 1, 0.01)) < 0.1: - img = self.s_and_p(img) - - # scale - if self.random_choice(np.arange(0, 1, 0.01)) < 0.5: - img = self.scale(img) - - # brightness - if self.random_choice(np.arange(0, 1, 0.01)) < 0.7: - img = self.brightness(img) - - # contrast - if self.random_choice(np.arange(0, 1, 0.01)) < 0.7: - img = self.contrast(img) - - # sharpness - if self.random_choice(np.arange(0, 1, 0.01)) < 0.5: - img = self.sharpness(img) - - # Modify the next line if you want a particular image size as output - # img = cv2.resize(img, (256, 256)) - return img + smiles (str): SMILES representation of a molecule + augment (bool, optional): Whether add augmentations to the image. + Defaults to False. + shape (Tuple[int, int], optional): Image shape. Defaults to (512, 512). - def contrast(self, img) -> np.array: - """ - This function randomly changes the input image contrast. - Args: - img: the image to modify in array format. - Returns: - img: the image with the contrast changes. - """ - if self.random_choice(np.arange(0, 1, 0.01)) < 0.8: # increase contrast - f = self.random_choice(np.arange(1, 2, 0.01)) - else: # decrease contrast - f = self.random_choice(np.arange(0.5, 1, 0.01)) - im_pil = Image.fromarray(img) - enhancer = ImageEnhance.Contrast(im_pil) - im = enhancer.enhance(f) - img = np.asarray(im) - return np.asarray(im) - - def brightness(self, img) -> np.array: - """ - This function randomly changes the input image brightness. - Args: - img: the image to modify in array format. - Returns: - img: the image with the brightness changes. - """ - f = self.random_choice(np.arange(0.4, 1.1, 0.01)) - im_pil = Image.fromarray(img) - enhancer = ImageEnhance.Brightness(im_pil) - im = enhancer.enhance(f) - img = np.asarray(im) - return np.asarray(im) - - def sharpness(self, img) -> np.array: - """ - This function randomly changes the input image sharpness. - Args: - img: the image to modify in array format. Returns: - img: the image with the sharpness changes. + Tuple[np.array, str]: structure depiction, cxSMILES """ - if self.random_choice(np.arange(0, 1, 0.01)) < 0.5: # increase sharpness - f = self.random_choice(np.arange(0.1, 1, 0.01)) - else: # decrease sharpness - f = self.random_choice(np.arange(1, 10)) - im_pil = Image.fromarray(img) - enhancer = ImageEnhance.Sharpness(im_pil) - im = enhancer.enhance(f) - img = np.asarray(im) - return np.asarray(im) - - def s_and_p(self, img) -> np.array: - """ - This function randomly adds salt and pepper to the input image. - Args: - img: the image to modify in array format. - Returns: - out: the image with the s&p changes. - """ - amount = self.random_choice(np.arange(0.001, 0.01)) - # add some s&p - s_vs_p = 0.5 - out = np.copy(img) - # Salt mode - num_salt = int(np.ceil(amount * img.size * s_vs_p)) - coords = [] - for i in img.shape: - coordinates = [] - for _ in range(num_salt): - coordinates.append(self.random_choice(np.arange(0, i - 1))) - coords.append(np.array(coordinates)) - out[tuple(coords)] = 1 - # pepper - num_pepper = int(np.ceil(amount * img.size * (1.0 - s_vs_p))) - coords = [] - for i in img.shape: - coordinates = [] - for _ in range(num_pepper): - coordinates.append(self.random_choice(np.arange(0, i - 1))) - coords.append(np.array(coordinates)) - out[tuple(coords)] = 0 - return out - - def scale(self, img) -> np.array: - """ - This function randomly scales the input image. - Args: - img: the image to modify in array format. - Returns: - res: the scaled image. - """ - f = self.random_choice(np.arange(0.5, 1.5, 0.01)) - res = cv2.resize(img, None, fx=f, fy=f, interpolation=cv2.INTER_CUBIC) - res = cv2.resize( - res, None, fx=1.0 / f, fy=1.0 / f, interpolation=cv2.INTER_CUBIC - ) - return res + orig_styles = self._config.styles + self._config.styles = [style for style in orig_styles if style != 'pikachu'] + depiction_functions = self.get_depiction_functions(smiles) + self._config.styles = orig_styles + mol_block = self._smiles_to_mol_block(smiles, + self.random_choice(['rdkit', 'indigo', 'cdk'])) + cxsmiles = self._cdk_mol_block_to_cxsmiles(mol_block) + fun = self.random_choice(depiction_functions) + depiction = fun(mol_block=mol_block, shape=shape) + if augment: + depiction = self.add_augmentations(depiction) + return depiction, cxsmiles - def get_random_pikachu_rendering_settings( - self, shape: Tuple[int, int] = (299, 299) - ) -> drawing.Options: + def get_depiction_functions(self, smiles: str) -> List[Callable]: """ - This function defines random rendering options for the structure - depictions created using PIKAChU. - It returns an pikachu.drawing.drawing.Options object with the settings. + PIKAChU, RDKit and Indigo can run into problems if certain R group variables + are present in the input molecule, and PIKAChU cannot handle isotopes. + Hence, the depiction functions that use their functionalities need to + be removed based on the input smiles str (if necessary). Args: - shape (Tuple[int, int], optional): im shape. Defaults to (299, 299) + smiles (str): SMILES representation of a molecule Returns: - options: Options object that contains depictions settings + List[Callable]: List of depiction functions """ - options = drawing.Options() - options.height, options.width = shape - options.bond_thickness = self.random_choice(np.arange(0.5, 2.2, 0.1), - log_attribute="pikachu_bond_line_width") - options.bond_length = self.random_choice(np.arange(10, 25, 1), - log_attribute="pikachu_bond_length") - options.chiral_bond_width = options.bond_length * self.random_choice( - np.arange(0.05, 0.2, 0.01) - ) - options.short_bond_length = self.random_choice(np.arange(0.2, 0.6, 0.05), - log_attribute="pikachu_short_bond_length") - options.double_bond_length = self.random_choice(np.arange(0.6, 0.8, 0.05), - log_attribute="pikachu_double_bond_length") - options.bond_spacing = options.bond_length * self.random_choice( - np.arange(0.15, 0.28, 0.01), - log_attribute="pikachu_bond_spacing" - ) - options.padding = self.random_choice(np.arange(10, 50, 5), - log_attribute="pikachu_padding") - # options.font_size_large = 5 - # options.font_size_small = 3 - return options - - def pikachu_depict( - self, mol_block: str, shape: Tuple[int, int] = (299, 299) - ) -> np.array: - """ - This function takes a mol block str and an image shape. - It renders the chemical structures using PIKAChU with random - rendering/depiction settings and returns an RGB image (np.array) - with the given image shape. - - Args: - mol_block (str): mol block representation of molecule - shape (Tuple[int, int], optional): im shape. Defaults to (299, 299) - - Returns: - np.array: Chemical structure depiction - """ - reader = MolFileReader(molfile_str=mol_block) - structure = reader.molfile_to_structure() - # structure = read_smiles(smiles) - depiction_settings = self.get_random_pikachu_rendering_settings() - # if "." in smiles: - # drawer = drawing.draw_multiple(structure, options=depiction_settings) - # else: - drawer = drawing.Drawer(structure, options=depiction_settings) - depiction = drawer.get_image_as_array() - depiction = self.central_square_image(depiction) - depiction = self.resize(depiction, (shape[0], shape[1])) - return depiction - - def get_random_indigo_rendering_settings( - self, shape: Tuple[int, int] = (299, 299) - ) -> Indigo: - """ - This function defines random rendering options for the structure - depictions created using Indigo. - It returns an Indigo object with the settings. - - Args: - shape (Tuple[int, int], optional): im shape. Defaults to (299, 299) - - Returns: - Indigo: Indigo object that contains depictions settings - """ - # Define random shape for depiction (within boundaries);) - indigo = Indigo() - renderer = IndigoRenderer(indigo) - # Get slightly distorted shape - y, x = self.random_image_size(shape) - indigo.setOption("render-image-width", x) - indigo.setOption("render-image-height", y) - # Set random bond line width - bond_line_width = float( - self.random_choice( - np.arange(0.5, 2.5, 0.1), log_attribute="indigo_bond_line_width" - ) - ) - indigo.setOption("render-bond-line-width", bond_line_width) - # Set random relative thickness - relative_thickness = float( - self.random_choice( - np.arange(0.5, 1.5, 0.1), log_attribute="indigo_relative_thickness" - ) - ) - indigo.setOption("render-relative-thickness", relative_thickness) - # Output_format: PNG - indigo.setOption("render-output-format", "png") - # Set random atom label rendering model - # (standard is rendering terminal groups) - if self.random_choice([True] + [False] * 19, log_attribute="indigo_labels_all"): - # show all atom labels - indigo.setOption("render-label-mode", "all") - elif self.random_choice( - [True] + [False] * 3, log_attribute="indigo_labels_hetero" - ): - indigo.setOption( - "render-label-mode", "hetero" - ) # only hetero atoms, no terminal groups - # Render bold bond for Haworth projection - if self.random_choice([True, False], log_attribute="indigo_render_bold_bond"): - indigo.setOption("render-bold-bond-detection", "True") - # Render labels for stereobonds - stereo_style = self.random_choice( - ["ext", "old", "none"], log_attribute="indigo_stereo_label_style" - ) - indigo.setOption("render-stereo-style", stereo_style) - # Collapse superatoms (default: expand) - if self.random_choice( - [True, False], log_attribute="indigo_collapse_superatoms" - ): - indigo.setOption("render-superatom-mode", "collapse") - return indigo, renderer - - def indigo_depict( - self, mol_block: str, shape: Tuple[int, int] = (299, 299) - ) -> np.array: - """ - This function takes a mol block str and an image shape. - It renders the chemical structures using Indigo with random - rendering/depiction settings and returns an RGB image (np.array) - with the given image shape. - - Args: - mol_block (str): mol block representation of molecule - shape (Tuple[int, int], optional): im shape. Defaults to (299, 299) - - Returns: - np.array: Chemical structure depiction - """ - # Instantiate Indigo with random settings and IndigoRenderer - indigo, renderer = self.get_random_indigo_rendering_settings() - # Load molecule - # try: - # if not self.has_r_group(smiles): - # molecule = indigo.loadMolecule(smiles) - # else: - # mol_str = self._smiles_to_mol_block(smiles) - # molecule = indigo.loadMolecule(mol_str) - try: - molecule = indigo.loadMolecule(mol_block) - except IndigoException: - return None - # Kekulize in 67% of cases - if not self.random_choice( - [True, True, False], log_attribute="indigo_kekulized" - ): - molecule.aromatize() - # molecule.layout() - # Write to buffer - temp = renderer.renderToBuffer(molecule) - temp = io.BytesIO(temp) - depiction = sk_io.imread(temp) - depiction = self.resize(depiction, (shape[0], shape[1])) - depiction = rgba2rgb(depiction) - depiction = img_as_ubyte(depiction) - return depiction - - def get_random_rdkit_rendering_settings( - self, smiles: str, shape: Tuple[int, int] = (299, 299) - ) -> rdMolDraw2D.MolDraw2DCairo: - """ - This function defines random rendering options for the structure - depictions created using rdkit. It returns an MolDraw2DCairo object - with the settings. - - Args: - smiles (str): SMILES representation of molecule - shape (Tuple[int, int], optional): im_shape. Defaults to (299, 299) - - Returns: - rdMolDraw2D.MolDraw2DCairo: Object that contains depiction settings - """ - # Get slightly distorted shape - y, x = self.random_image_size(shape) - # Instantiate object that saves the settings - depiction_settings = rdMolDraw2D.MolDraw2DCairo(y, x) - # Stereo bond annotation - if self.random_choice( - [True, False], log_attribute="rdkit_add_stereo_annotation" - ): - depiction_settings.drawOptions().addStereoAnnotation = True - if self.random_choice( - [True, False], log_attribute="rdkit_add_chiral_flag_labels" - ): - depiction_settings.drawOptions().includeChiralFlagLabel = True - # Atom indices - if self.random_choice( - [True, False, False, False], log_attribute="rdkit_add_atom_indices" - ): - if not self.has_r_group(smiles): - depiction_settings.drawOptions().addAtomIndices = True - # Bond line width - bond_line_width = self.random_choice( - range(1, 5), log_attribute="rdkit_bond_line_width" - ) - depiction_settings.drawOptions().bondLineWidth = bond_line_width - # Draw terminal methyl groups - if self.random_choice( - [True, False], log_attribute="rdkit_draw_terminal_methyl" - ): - depiction_settings.drawOptions().explicitMethyl = True - # Label font type and size - font_dir = self.HERE.joinpath("fonts/") - font_path = os.path.join( - str(font_dir), - self.random_choice( - os.listdir(str(font_dir)), log_attribute="rdkit_label_font" - ), - ) - depiction_settings.drawOptions().fontFile = font_path - min_font_size = self.random_choice( - range(10, 20), log_attribute="rdkit_min_font_size" - ) - depiction_settings.drawOptions().minFontSize = min_font_size - depiction_settings.drawOptions().maxFontSize = 30 - # Rotate the molecule - # depiction_settings.drawOptions().rotate = self.random_choice(range(360)) - # Fixed bond length - fixed_bond_length = self.random_choice( - range(30, 45), log_attribute="rdkit_fixed_bond_length" - ) - depiction_settings.drawOptions().fixedBondLength = fixed_bond_length - # Comic mode (looks a bit hand drawn) - if self.random_choice( - [True, False, False, False, False], log_attribute="rdkit_comic_style" - ): - depiction_settings.drawOptions().comicMode = True - # Keep it black and white - depiction_settings.drawOptions().useBWAtomPalette() - return depiction_settings - - def rdkit_depict( - self, mol_block: str, shape: Tuple[int, int] = (512, 512) - ) -> np.array: - """ - This function takes a mol_block str and an image shape. - It renders the chemical structures using Rdkit with random - rendering/depiction settings and returns an RGB image (np.array) - with the given image shape. - - Args: - mol block (str): mol block representation of molecule_ - shape (Tuple[int, int], optional): im shape. Defaults to (299, 299) - - Returns: - np.array: Chemical structure depiction - """ - # Load molecule - # if not self.has_r_group(smiles): - # mol = Chem.MolFromSmiles(smiles) - # if self.has_r_group(smiles) or not mol: - # mol_str = self._smiles_to_mol_block(smiles) - # mol = Chem.MolFromMolBlock(mol_str) - mol = Chem.MolFromMolBlock(mol_block, sanitize=True) - if mol: - # AllChem.Compute2DCoords(mol) - # Abbreviate superatoms - if self.random_choice( - [True, False], log_attribute="rdkit_collapse_superatoms" - ): - abbrevs = self.random_choice(self.get_all_rdkit_abbreviations()) - mol = CondenseMolAbbreviations(mol, abbrevs) - # Get random depiction settings - # TODO: Fix this provisory nonsense - smiles = "CCCCC" - depiction_settings = self.get_random_rdkit_rendering_settings(smiles=smiles) - rdMolDraw2D.PrepareAndDrawMolecule(depiction_settings, mol) - depiction = depiction_settings.GetDrawingText() - depiction = sk_io.imread(io.BytesIO(depiction)) - # Resize image to desired shape - depiction = self.resize(depiction, shape) - depiction = img_as_ubyte(depiction) - return np.asarray(depiction) - else: - pass - # print("RDKit was unable to read input SMILES: {}".format(smiles)) - - def has_r_group(self, smiles: str) -> bool: - """ - Determines whether or not a given SMILES str contains an R group - - Args: - smiles (str): SMILES representation of molecule - - Returns: - bool - """ - if re.search("\[.*[RXYZ].*\]", smiles): - return True - - def _cdk_get_depiction_generator(self, molecule, smiles: str): - """ - This function defines random rendering options for the structure - depictions created using CDK. - It takes an iAtomContainer and a SMILES string and returns the iAtomContainer - and the DepictionGenerator - with random rendering settings and the AtomContainer. - I followed https://github.com/cdk/cdk/wiki/Standard-Generator to adjust the - depiction parameters. - - Args: - molecule (cdk.AtomContainer): Atom container - smiles (str): smiles representation of molecule - - Returns: - DepictionGenerator, molecule: Objects that hold depiction parameters - """ - cdk_base = "org.openscience.cdk" - dep_gen = JClass("org.openscience.cdk.depict.DepictionGenerator")( - self._cdk_get_random_java_font() - ) - StandardGenerator = JClass( - cdk_base + ".renderer.generators.standard.StandardGenerator" - ) - - # Define visibility of atom/superatom labels - symbol_visibility = self.random_choice( - ["iupac_recommendation", "no_terminal_methyl", "show_all_atom_labels"], - log_attribute="cdk_symbol_visibility", - ) - SymbolVisibility = JClass("org.openscience.cdk.renderer.SymbolVisibility") - if symbol_visibility == "iupac_recommendation": - dep_gen = dep_gen.withParam( - StandardGenerator.Visibility.class_, - SymbolVisibility.iupacRecommendations(), - ) - elif symbol_visibility == "no_terminal_methyl": - # only hetero atoms, no terminal alkyl groups - dep_gen = dep_gen.withParam( - StandardGenerator.Visibility.class_, - SymbolVisibility.iupacRecommendationsWithoutTerminalCarbon(), - ) - elif symbol_visibility == "show_all_atom_labels": - dep_gen = dep_gen.withParam( - StandardGenerator.Visibility.class_, SymbolVisibility.all() - ) # show all atom labels - - # Define bond line stroke width - stroke_width = self.random_choice( - np.arange(0.8, 2.0, 0.1), log_attribute="cdk_stroke_width" - ) - dep_gen = dep_gen.withParam(StandardGenerator.StrokeRatio.class_, - stroke_width) - # Define symbol margin ratio - margin_ratio = self.random_choice( - [0, 1, 2, 2, 2, 3, 4], log_attribute="cdk_margin_ratio" - ) - dep_gen = dep_gen.withParam( - StandardGenerator.SymbolMarginRatio.class_, - JClass("java.lang.Double")(margin_ratio), - ) - # Define bond properties - double_bond_dist = self.random_choice( - np.arange(0.11, 0.25, 0.01), log_attribute="cdk_double_bond_dist" - ) - dep_gen = dep_gen.withParam(StandardGenerator.BondSeparation.class_, - double_bond_dist) - wedge_ratio = self.random_choice( - np.arange(4.5, 7.5, 0.1), log_attribute="cdk_wedge_ratio" - ) - dep_gen = dep_gen.withParam( - StandardGenerator.WedgeRatio.class_, JClass("java.lang.Double")(wedge_ratio) - ) - if self.random_choice([True, False], log_attribute="cdk_fancy_bold_wedges"): - dep_gen = dep_gen.withParam(StandardGenerator.FancyBoldWedges.class_, True) - if self.random_choice([True, False], log_attribute="cdk_fancy_hashed_wedges"): - dep_gen = dep_gen.withParam(StandardGenerator.FancyHashedWedges.class_, - True) - hash_spacing = self.random_choice( - np.arange(4.0, 6.0, 0.2), log_attribute="cdk_hash_spacing" - ) - dep_gen = dep_gen.withParam(StandardGenerator.HashSpacing.class_, hash_spacing) - # Add CIP labels - labels = False - if self.random_choice([True, False], log_attribute="cdk_add_CIP_labels"): - labels = True - JClass("org.openscience.cdk.geometry.cip.CIPTool").label(molecule) - for atom in molecule.atoms(): - label = atom.getProperty( - JClass("org.openscience.cdk.CDKConstants").CIP_DESCRIPTOR - ) - atom.setProperty(StandardGenerator.ANNOTATION_LABEL, label) - for bond in molecule.bonds(): - label = bond.getProperty( - JClass("org.openscience.cdk.CDKConstants").CIP_DESCRIPTOR - ) - bond.setProperty(StandardGenerator.ANNOTATION_LABEL, label) - # Add atom indices to the depictions - if self.random_choice( - [True, False, False, False], log_attribute="cdk_add_atom_indices" - ): - if not self.has_r_group(smiles): - # Avoid confusion with R group indices and atom numbering - labels = True - for atom in molecule.atoms(): - label = JClass("java.lang.Integer")( - 1 + molecule.getAtomNumber(atom) - ) - atom.setProperty(StandardGenerator.ANNOTATION_LABEL, label) - if labels: - # We only need black - dep_gen = dep_gen.withParam( - StandardGenerator.AnnotationColor.class_, - JClass("java.awt.Color")(0x000000), - ) - # Font size of labels - font_scale = self.random_choice( - np.arange(0.5, 0.8, 0.1), log_attribute="cdk_label_font_scale" - ) - dep_gen = dep_gen.withParam( - StandardGenerator.AnnotationFontScale.class_, - font_scale) - # Distance between atom numbering and depiction - annotation_distance = self.random_choice( - np.arange(0.15, 0.30, 0.05), log_attribute="cdk_annotation_distance" - ) - dep_gen = dep_gen.withParam( - StandardGenerator.AnnotationDistance.class_, annotation_distance - ) - # Abbreviate superatom labels in half of the cases - # TODO: Find a way to define Abbreviations object as a class attribute. - # Problem: can't be pickled. - # Right now, this is loaded every time when a structure is depicted. - # That seems inefficient. - if self.random_choice([True, False], log_attribute="cdk_collapse_superatoms"): - cdk_superatom_abrv = JClass("org.openscience.cdk.depict.Abbreviations")() - abbr_filename = self.random_choice([ - "cdk_superatom_abbreviations.smi", - "cdk_alt_superatom_abbreviations.smi"]) - abbreviation_path = str(self.HERE.joinpath(abbr_filename)) - abbreviation_path = abbreviation_path.replace("\\", "/") - abbreviation_path = JClass("java.lang.String")(abbreviation_path) - cdk_superatom_abrv.loadFromFile(abbreviation_path) - cdk_superatom_abrv.apply(molecule) - return dep_gen, molecule - - def _cdk_get_random_java_font(self): - """ - This function returns a random java.awt.Font (JClass) object - - Returns: - font: java.awt.Font (JClass object) - """ - font_size = self.random_choice( - range(10, 20), log_attribute="cdk_atom_label_font_size" - ) - Font = JClass("java.awt.Font") - font_name = self.random_choice( - ["Verdana", - "Times New Roman", - "Arial", - "Gulliver Regular", - "Helvetica", - "Courier", - "architectural", - "Geneva", - "Lucida Sans", - "Teletype"], - # log_attribute='cdk_atom_label_font' - ) - font_style = self.random_choice( - [Font.PLAIN, Font.BOLD], - # log_attribute='cdk_atom_label_font_style' - ) - font = Font(font_name, font_style, font_size) - return font - - def _cdk_rotate_coordinates(self, molecule): - """ - Given an IAtomContainer (JClass object), this function rotates the molecule - and adapts the coordinates of accordingly. The IAtomContainer is then returned.# - - Args: - molecule: IAtomContainer (JClass object) - - Returns: - molecule: IAtomContainer (JClass object) - """ - cdk_base = "org.openscience.cdk" - point = JClass(cdk_base + ".geometry.GeometryTools").get2DCenter(molecule) - rot_degrees = self.random_choice(range(360)) - JClass(cdk_base + ".geometry.GeometryTools").rotate( - molecule, point, rot_degrees - ) - return molecule - - def _cdk_generate_2d_coordinates(self, molecule): - """ - Given an IAtomContainer (JClass object), this function adds 2D coordinate to - the molecule. The modified IAtomContainer is then returned. - - Args: - molecule: IAtomContainer (JClass object) - - Returns: - molecule: IAtomContainer (JClass object) - """ - cdk_base = "org.openscience.cdk" - sdg = JClass(cdk_base + ".layout.StructureDiagramGenerator")() - sdg.setMolecule(molecule) - sdg.generateCoordinates(molecule) - molecule = sdg.getMolecule() - return molecule - - def _cdk_bufferedimage_to_numpyarray( - self, - image - ) -> np.ndarray: - """ - This function converts a BufferedImage (JClass object) into a numpy array. - - Args: - image (BufferedImage (JClass object)) - - Returns: - image (np.ndarray) - """ - # Write the image into a format that can be read by skimage - ImageIO = JClass("javax.imageio.ImageIO") - os = JClass("java.io.ByteArrayOutputStream")() - Base64 = JClass("java.util.Base64") - ImageIO.write( - image, JClass("java.lang.String")("PNG"), Base64.getEncoder().wrap(os) - ) - image = bytes(os.toString("UTF-8")) - image = base64.b64decode(image) - image = sk_io.imread(image, plugin="imageio") - image = img_as_ubyte(image) - return image - - def _cdk_render_molecule( - self, - molecule, - smiles: str, - shape: Tuple[int, int] - ): - """ - This function takes an IAtomContainer (JClass object), the corresponding SMILES - string and an image shape and returns a BufferedImage (JClass object) with the - rendered molecule. - - Args: - molecule (IAtomContainer (JClass object)): molecule - smiles (str): SMILES string - shape (Tuple[int, int]): y, x - Returns: - depiction (np.ndarray): chemical structure depiction - """ - dep_gen, molecule = self._cdk_get_depiction_generator(molecule, smiles) - dep_gen = dep_gen.withSize(shape[1], shape[0]) - dep_gen = dep_gen.withFillToFit() - depiction = dep_gen.depict(molecule).toImg() - depiction = self._cdk_bufferedimage_to_numpyarray(depiction) - return depiction - - def cdk_depict( - self, mol_block: str, shape: Tuple[int, int] = (299, 299) - ) -> np.array: - """ - This function takes a mol block str and an image shape. - It renders the chemical structures using CDK with random - rendering/depiction settings and returns an RGB image (np.array) - with the given image shape. - The general workflow here is a JPype adaptation of code published - by Egon Willighagen in 'Groovy Cheminformatics with the Chemistry - Development Kit': - https://egonw.github.io/cdkbook/ctr.html#depict-a-compound-as-an-image - with additional adaptations to create all the different depiction - types from - https://github.com/cdk/cdk/wiki/Standard-Generator - - Args: - mol_block (str): SMILES representation of molecule - shape (Tuple[int, int], optional): im shape. Defaults to (299, 299) - - Returns: - np.array: Chemical structure depiction - """ - molecule = self._cdk_mol_block_to_iatomcontainer(mol_block) - # molecule = self._cdk_smiles_to_IAtomContainer(smiles) - # molecule = self._cdk_generate_2d_coordinates(molecule) - # molecule = self._cdk_rotate_coordinates(molecule) - smiles = "C1=CC=CC=C1" - depiction = self._cdk_render_molecule(molecule, smiles, shape) - return depiction - - def _cdk_smiles_to_IAtomContainer(self, smiles: str): - """ - This function takes a SMILES representation of a molecule and - returns the corresponding IAtomContainer object. - - Args: - smiles (str): SMILES representation of the molecule - - Returns: - IAtomContainer: CDK IAtomContainer object that represents the molecule - """ - cdk_base = "org.openscience.cdk" - SCOB = JClass(cdk_base + ".silent.SilentChemObjectBuilder") - SmilesParser = JClass(cdk_base + ".smiles.SmilesParser")(SCOB.getInstance()) - if self.random_choice([True, False, False], log_attribute="cdk_kekulized"): - SmilesParser.kekulise(False) - molecule = SmilesParser.parseSmiles(smiles) - return molecule - - def _smiles_to_mol_block( - self, - smiles: str, - generate_2d: bool = False, - ) -> str: - """ - This function takes a SMILES representation of a molecule and returns - the content of the corresponding SD file using the CDK. - ___ - The SMILES parser of the CDK is much more tolerant than the parsers of - RDKit and Indigo. - ___ - - Args: - smiles (str): SMILES representation of a molecule - generate_2d (bool or str, optional): False if no coordinates are created - Otherwise pick tool for coordinate - generation: - "rdkit", "cdk", "indigo" or "pikachu". - - Returns: - mol_block (str): content of SD file of input molecule - """ - if not generate_2d: - molecule = self._cdk_smiles_to_IAtomContainer(smiles) - mol_block = self._cdk_iatomcontainer_to_mol_block(molecule) - elif generate_2d == "cdk": - molecule = self._cdk_smiles_to_IAtomContainer(smiles) - molecule = self._cdk_generate_2d_coordinates(molecule) - molecule = self._cdk_rotate_coordinates(molecule) - mol_block = self._cdk_iatomcontainer_to_mol_block(molecule) - elif generate_2d == "rdkit": - mol_block = self._smiles_to_mol_block(smiles) - molecule = Chem.MolFromMolBlock(mol_block, sanitize=False) - if molecule: - AllChem.Compute2DCoords(molecule) - mol_block = Chem.MolToMolBlock(molecule) - atom_container = self._cdk_mol_block_to_iatomcontainer(mol_block) - atom_container = self._cdk_rotate_coordinates(atom_container) - mol_block = self._cdk_iatomcontainer_to_mol_block(atom_container) - else: - raise ValueError(f"RDKit could not read molecule: {smiles}") - elif generate_2d == "indigo": - indigo = Indigo() - mol_block = self._smiles_to_mol_block(smiles) - molecule = indigo.loadMolecule(mol_block) - molecule.layout() - buf = indigo.writeBuffer() - buf.sdfAppend(molecule) - mol_block = buf.toString() - atom_container = self._cdk_mol_block_to_iatomcontainer(mol_block) - atom_container = self._cdk_rotate_coordinates(atom_container) - mol_block = self._cdk_iatomcontainer_to_mol_block(atom_container) - elif generate_2d == "pikachu": - pass - return mol_block - - def _cdk_mol_block_to_iatomcontainer(self, mol_block: str): - """ - Given a mol block, this function returns an IAtomContainer (JClass) object. - - Args: - mol_block (str): content of MDL MOL file - - Returns: - IAtomContainer: CDK IAtomContainer object that represents the molecule - """ - # xyz_reader = JClass("org.openscience.cdk.io.XYZReader") - scob = JClass("org.openscience.cdk.silent.SilentChemObjectBuilder") - bldr = scob.getInstance() - iac_class = JClass("org.openscience.cdk.interfaces.IAtomContainer").class_ - string_reader = JClass("java.io.StringReader")(mol_block) - mdlr = JClass("org.openscience.cdk.io.MDLV2000Reader")(string_reader) - iatomcontainer = mdlr.read(bldr.newInstance(iac_class)) - mdlr.close() - # reader = xyz_reader(string_reader) - # chemfile = reader.read(JClass("org.openscience.cdk.ChemFile")()) - # manip = JClass("org.openscience.cdk.tools.manipulator.ChemFileManipulator") - # iatomcontainer = manip.getAllAtomContainers(chemfile).get(0) - return iatomcontainer - - def _cdk_iatomcontainer_to_mol_block(self, i_atom_container) -> str: - """ - This function takes an IAtomContainer object and returns the content - of the corresponding MDL MOL file as a string. - - Args: - i_atom_container (CDK IAtomContainer (JClass object)) - - Returns: - str: string content of MDL MOL file - """ - string_writer = JClass("java.io.StringWriter")() - mol_writer = JClass("org.openscience.cdk.io.MDLV2000Writer")(string_writer) - mol_writer.write(i_atom_container) - mol_writer.close() - mol_str = string_writer.toString() - return str(mol_str) - - def normalise_padding(self, im: np.array) -> np.array: - """This function takes an RGB image (np.array) and deletes white space at - the borders. Then 0-10% of the image width/height is added as padding - again. The modified image is returned - - Args: - im: input image (np.array) - - Returns: - output: the modified image (np.array) - """ - # Remove white space at borders - mask = im > 200 - all_white = mask.sum(axis=2) > 0 - rows = np.flatnonzero((~all_white).sum(axis=1)) - cols = np.flatnonzero((~all_white).sum(axis=0)) - crop = im[rows.min(): rows.max() + 1, cols.min(): cols.max() + 1, :] - # Add padding again. - pad_range = np.arange(5, int(crop.shape[0] * 0.2), 1) - if len(pad_range) > 0: - pad = self.random_choice(np.arange(5, int(crop.shape[0] * 0.2), 1)) - else: - pad = 5 - crop = np.pad( - crop, - pad_width=((pad, pad), (pad, pad), (0, 0)), - mode="constant", - constant_values=255, - ) - return crop - - def central_square_image(self, im: np.array) -> np.array: - """ - This function takes image (np.array) and will add white padding - so that the image has a square shape with the width/height of the - longest side of the original image. - - Args: - im (np.array): Input image - - Returns: - np.array: Output image - """ - # Create new blank white image - max_wh = max(im.shape) - new_im = 255 * np.ones((max_wh, max_wh, 3), np.uint8) - # Determine paste coordinates and paste image - upper = int((new_im.shape[0] - im.shape[0]) / 2) - lower = int((new_im.shape[0] - im.shape[0]) / 2) + im.shape[0] - left = int((new_im.shape[1] - im.shape[1]) / 2) - right = int((new_im.shape[1] - im.shape[1]) / 2) + im.shape[1] - new_im[upper:lower, left:right] = im - return new_im - - def random_depiction( - self, - smiles: str, - shape: Tuple[int, int] = (299, 299), - ) -> np.array: - """ - This function takes a SMILES and depicts it using Rdkit, Indigo, CDK or PIKACHU. - The depiction method and the specific parameters for the depiction are - chosen completely randomly. The purpose of this function is to enable - depicting a diverse variety of chemical structure depictions. - - Args: - smiles (str): SMILES representation of molecule - shape (Tuple[int, int], optional): im shape. Defaults to (299, 299) - - Returns: - np.array: Chemical structure depiction - """ - depiction_functions = self.get_depiction_functions(smiles) - - for _ in range(3): - if len(depiction_functions) != 0: - # Pick random depiction function and call it - depiction_function = self.random_choice(depiction_functions) - depiction = depiction_function(smiles, shape) - if depiction is False or depiction is None: - depiction_functions.remove(depiction_function) - else: - break - else: - break - - if self.hand_drawn: - path_bkg = self.HERE.joinpath("backgrounds/") - # Augment molecule image - mol_aug = self.hand_drawn_augment(depiction) - - # Randomly select background image and use is as it is - backgroud_selected = self.random_choice(os.listdir(path_bkg)) - bkg = cv2.imread(os.path.join(os.path.normpath(path_bkg), backgroud_selected)) - bkg = cv2.resize(bkg, (256, 256)) - # Combine augmented molecule and augmented background - p = 0.7 - mol_bkg = cv2.addWeighted(mol_aug, p, bkg, 1 - p, gamma=0) - - """ - If you want to randomly augment the background as well, - simply comment the previous section and uncomment the next one. - """ - - """# Randomly select background image and augment it - bkg_aug = self.augment_bkg(bkg) - bkg_aug = cv2.resize(bkg_aug,(256,256)) - # Combine augmented molecule and augmented background - p=0.7 - mol_bkg = cv2.addWeighted(mol_aug, p, bkg_aug, 1-p, gamma=0)""" - - # Degrade total image - depiction = self.degrade_img(mol_bkg) - return depiction - - def get_depiction_functions(self, smiles: str) -> List[Callable]: - """ - PIKAChU, RDKit and Indigo can run into problems if certain R group variables - are present in the input molecule, and PIKAChU cannot handle isotopes. - Hence, the depiction functions that use their functionalities need to - be removed based on the input smiles str (if necessary). - - Args: - smiles (str): SMILES representation of a molecule - - Returns: - List[Callable]: List of depiction functions - """ - - depiction_functions_registry = { - 'rdkit': self.rdkit_depict, - 'indigo': self.indigo_depict, - 'cdk': self.cdk_depict, - 'pikachu': self.pikachu_depict, - } - depiction_functions = [depiction_functions_registry[k] - for k in self._config.styles] + depiction_functions_registry = { + 'rdkit': self.rdkit_depict, + 'indigo': self.indigo_depict, + 'cdk': self.cdk_depict, + 'pikachu': self.pikachu_depict, + } + depiction_functions = [depiction_functions_registry[k] + for k in self._config.styles] # Remove PIKAChU if there is an isotope if re.search("(\[\d\d\d?[A-Z])|(\[2H\])|(\[3H\])|(D)|(T)", smiles): @@ -1669,716 +319,13 @@ def get_depiction_functions(self, smiles: str) -> List[Callable]: if self.indigo_depict in depiction_functions: if re.search("\[R0\]|\[X\]|[4-9][0-9]+|3[3-9]|[XYZR]\d+[a-f]", smiles): depiction_functions.remove(self.indigo_depict) - # Workaround because PIKAChU fails to depict large structures - # TODO: Delete workaround when problem is fixed in PIKAChU - # https://github.com/BTheDragonMaster/pikachu/issues/11 - if len(smiles) > 100: - if self.pikachu_depict in depiction_functions: - depiction_functions.remove(self.pikachu_depict) - return depiction_functions - - def resize(self, image: np.array, shape: Tuple[int], HQ: bool = False) -> np.array: - """ - This function takes an image (np.array) and a shape and returns - the resized image (np.array). It uses Pillow to do this, as it - seems to have a bigger variety of scaling methods than skimage. - The up/downscaling method is chosen randomly. - - Args: - image (np.array): the input image - shape (Tuple[int, int], optional): im shape. Defaults to (299, 299) - HQ (bool): if true, only choose from Image.BICUBIC, Image.LANCZOS - ___ - Returns: - np.array: the resized image - - """ - image = Image.fromarray(image) - shape = (shape[0], shape[1]) - if not HQ: - image = image.resize( - shape, resample=self.random_choice(self.PIL_resize_methods) - ) - else: - image = image = image.resize( - shape, resample=self.random_choice(self.PIL_HQ_resize_methods) - ) - - return np.asarray(image) - - def imgaug_augment( - self, - image: np.array, - ) -> np.array: - """ - This function applies a random amount of augmentations to - a given image (np.array) using and returns the augmented image - (np.array). - - Args: - image (np.array): input image - - Returns: - np.array: output image (augmented) - """ - original_shape = image.shape - - # Choose number of augmentations to apply (0-2); - # return image if nothing needs to be done. - aug_number = self.random_choice(range(0, 3)) - if not aug_number: - return image - - # Add some padding to avoid weird artifacts after rotation - image = np.pad( - image, ((1, 1), (1, 1), (0, 0)), mode="constant", constant_values=255 - ) - - def imgaug_rotation(): - # Rotation between -10 and 10 degrees - if not self.random_choice( - [True, True, False], log_attribute="has_imgaug_rotation" - ): - return False - rot_angle = self.random_choice(np.arange(-10, 10, 1)) - aug = iaa.Affine(rotate=rot_angle, mode="edge", fit_output=True) - return aug - - def imgaug_black_and_white_noise(): - # Black and white noise - if not self.random_choice( - [True, True, False], log_attribute="has_imgaug_salt_pepper" - ): - return False - coarse_dropout_p = self.random_choice(np.arange(0.0002, 0.0015, 0.0001)) - coarse_dropout_size_percent = self.random_choice(np.arange(1.0, 1.1, 0.01)) - replace_elementwise_p = self.random_choice(np.arange(0.01, 0.3, 0.01)) - aug = iaa.Sequential( - [ - iaa.CoarseDropout( - coarse_dropout_p, size_percent=coarse_dropout_size_percent - ), - iaa.ReplaceElementwise(replace_elementwise_p, 255), - ] - ) - return aug - - def imgaug_shearing(): - # Shearing - if not self.random_choice( - [True, True, False], log_attribute="has_imgaug_shearing" - ): - return False - shear_param = self.random_choice(np.arange(-5, 5, 1)) - aug = self.random_choice( - [ - iaa.geometric.ShearX(shear_param, mode="edge", fit_output=True), - iaa.geometric.ShearY(shear_param, mode="edge", fit_output=True), - ] - ) - return aug - - def imgaug_imgcorruption(): - # Jpeg compression or pixelation - if not self.random_choice( - [True, True, False], log_attribute="has_imgaug_corruption" - ): - return False - imgcorrupt_severity = self.random_choice(np.arange(1, 2, 1)) - aug = self.random_choice( - [ - iaa.imgcorruptlike.JpegCompression(severity=imgcorrupt_severity), - iaa.imgcorruptlike.Pixelate(severity=imgcorrupt_severity), - ] - ) - return aug - - def imgaug_brightness_adjustment(): - # Brightness adjustment - if not self.random_choice( - [True, True, False], log_attribute="has_imgaug_brightness_adj" - ): - return False - brightness_adj_param = self.random_choice(np.arange(-50, 50, 1)) - aug = iaa.WithBrightnessChannels(iaa.Add(brightness_adj_param)) - return aug - - def imgaug_colour_temp_adjustment(): - # Colour temperature adjustment - if not self.random_choice( - [True, True, False], log_attribute="has_imgaug_col_adj" - ): - return False - colour_temp = self.random_choice(np.arange(1100, 10000, 1)) - aug = iaa.ChangeColorTemperature(colour_temp) - return aug - - # Define list of available augmentations - aug_list = [ - imgaug_rotation, - imgaug_black_and_white_noise, - imgaug_shearing, - imgaug_imgcorruption, - imgaug_brightness_adjustment, - imgaug_colour_temp_adjustment, - ] - - # Every one of them has a 1/3 chance of returning False - aug_list = [fun() for fun in aug_list] - aug_list = [fun for fun in aug_list if fun] - aug = iaa.Sequential(aug_list) - augmented_image = aug.augment_images([image])[0] - augmented_image = self.resize(augmented_image, original_shape) - augmented_image = augmented_image.astype(np.uint8) - return augmented_image - - def add_augmentations(self, depiction: np.array) -> np.array: - """ - This function takes a chemical structure depiction (np.array) - and returns the same image with added augmentation elements - - Args: - depiction (np.array): chemical structure depiction - - Returns: - np.array: chemical structure depiction with added augmentations - """ - if self.random_choice( - [True, False, False, False, False, False], log_attribute="has_curved_arrows" - ): - depiction = self.add_curved_arrows_to_structure(depiction) - if self.random_choice( - [True, False, False], log_attribute="has_straight_arrows" - ): - depiction = self.add_straight_arrows_to_structure(depiction) - if self.random_choice( - [True, False, False, False, False, False], log_attribute="has_id_label" - ): - depiction = self.add_chemical_label(depiction, "ID") - if self.random_choice( - [True, False, False, False, False, False], log_attribute="has_R_group_label" - ): - depiction = self.add_chemical_label(depiction, "R_GROUP") - if self.random_choice( - [True, False, False, False, False, False], - log_attribute="has_reaction_label", - ): - depiction = self.add_chemical_label(depiction, "REACTION") - depiction = self.imgaug_augment(depiction) - return depiction - - def get_random_label_position(self, width: int, height: int) -> Tuple[int, int]: - """ - Given the width and height of an image (int), this function - determines a random position in the outer 15% of the image and - returns a tuple that contain the coordinates (y,x) of that position. - - Args: - width (int): image width - height (int): image height - - Returns: - Tuple[int, int]: Random label position - """ - if self.random_choice([True, False]): - y_range = range(0, height) - x_range = list(range(0, int(0.15 * width))) + list( - range(int(0.85 * width), width) - ) - else: - y_range = list(range(0, int(0.15 * height))) + list( - range(int(0.85 * height), height) - ) - x_range = range(0, width) - return self.random_choice(y_range), self.random_choice(x_range) - - def ID_label_text(self) -> str: - """ - This function returns a string that resembles a typical - chemical ID label - - Returns: - str: Label text - """ - label_num = range(1, 50) - label_letters = [ - "a", - "b", - "c", - "d", - "e", - "f", - "g", - "i", - "j", - "k", - "l", - "m", - "n", - "o", - ] - options = [ - "only_number", - "num_letter_combination", - "numtonum", - "numcombtonumcomb", - ] - option = self.random_choice(options) - if option == "only_number": - return str(self.random_choice(label_num)) - if option == "num_letter_combination": - return str(self.random_choice(label_num)) + self.random_choice( - label_letters - ) - if option == "numtonum": - return ( - str(self.random_choice(label_num)) - + "-" - + str(self.random_choice(label_num)) - ) - if option == "numcombtonumcomb": - return ( - str(self.random_choice(label_num)) - + self.random_choice(label_letters) - + "-" - + self.random_choice(label_letters) - ) - - def new_reaction_condition_elements(self) -> Tuple[str, str, str]: - """ - Randomly redefine reaction_time, solvent and other_reactand. - - Returns: - Tuple[str, str, str]: Reaction time, solvent, reactand - """ - reaction_time = self.random_choice( - [str(num) for num in range(30)] - ) + self.random_choice([" h", " min"]) - solvent = self.random_choice( - [ - "MeOH", - "EtOH", - "CHCl3", - "DCM", - "iPrOH", - "MeCN", - "DMSO", - "pentane", - "hexane", - "benzene", - "Et2O", - "THF", - "DMF", - ] - ) - other_reactand = self.random_choice( - [ - "HF", - "HCl", - "HBr", - "NaOH", - "Et3N", - "TEA", - "Ac2O", - "DIBAL", - "DIBAL-H", - "DIPEA", - "DMAP", - "EDTA", - "HOBT", - "HOAt", - "TMEDA", - "p-TsOH", - "Tf2O", - ] - ) - return reaction_time, solvent, other_reactand - - def reaction_condition_label_text(self) -> str: - """ - This function returns a random string that looks like a - reaction condition label. - - Returns: - str: Reaction condition label text - """ - reaction_condition_label = "" - label_type = self.random_choice(["A", "B", "C", "D"]) - if label_type in ["A", "B"]: - for n in range(self.random_choice(range(1, 5))): - ( - reaction_time, - solvent, - other_reactand, - ) = self.new_reaction_condition_elements() - if label_type == "A": - reaction_condition_label += ( - str(n + 1) - + " " - + other_reactand - + ", " - + solvent - + ", " - + reaction_time - + "\n" - ) - elif label_type == "B": - reaction_condition_label += ( - str(n + 1) - + " " - + other_reactand - + ", " - + solvent - + " (" - + reaction_time - + ")\n" - ) - elif label_type == "C": - ( - reaction_time, - solvent, - other_reactand, - ) = self.new_reaction_condition_elements() - reaction_condition_label += ( - other_reactand + "\n" + solvent + "\n" + reaction_time - ) - elif label_type == "D": - reaction_condition_label += self.random_choice( - self.new_reaction_condition_elements() - ) - return reaction_condition_label - - def make_R_group_str(self) -> str: - """ - This function returns a random string that looks like an R group label. - It generates them by inserting randomly chosen elements into one of - five templates. - - Returns: - str: R group label text - """ - rest_variables = [ - "X", - "Y", - "Z", - "R", - "R1", - "R2", - "R3", - "R4", - "R5", - "R6", - "R7", - "R8", - "R9", - "R10", - "Y2", - "D", - ] - # Load list of superatoms (from OSRA) - superatoms = self.superatoms - label_type = self.random_choice(["A", "B", "C", "D", "E"]) - R_group_label = "" - if label_type == "A": - for _ in range(1, self.random_choice(range(2, 6))): - R_group_label += ( - self.random_choice(rest_variables) - + " = " - + self.random_choice(superatoms) - + "\n" - ) - elif label_type == "B": - R_group_label += " " + self.random_choice(rest_variables) + "\n" - for n in range(1, self.random_choice(range(2, 6))): - R_group_label += str(n) + " " + self.random_choice(superatoms) + "\n" - elif label_type == "C": - R_group_label += ( - " " - + self.random_choice(rest_variables) - + " " - + self.random_choice(rest_variables) - + "\n" - ) - for n in range(1, self.random_choice(range(2, 6))): - R_group_label += ( - str(n) - + " " - + self.random_choice(superatoms) - + " " - + self.random_choice(superatoms) - + "\n" - ) - elif label_type == "D": - R_group_label += ( - " " - + self.random_choice(rest_variables) - + " " - + self.random_choice(rest_variables) - + " " - + self.random_choice(rest_variables) - + "\n" - ) - for n in range(1, self.random_choice(range(2, 6))): - R_group_label += ( - str(n) - + " " - + self.random_choice(superatoms) - + " " - + self.random_choice(superatoms) - + " " - + self.random_choice(superatoms) - + "\n" - ) - if label_type == "E": - for n in range(1, self.random_choice(range(2, 6))): - R_group_label += ( - str(n) - + " " - + self.random_choice(rest_variables) - + " = " - + self.random_choice(superatoms) - + "\n" - ) - return R_group_label - - def add_chemical_label( - self, image: np.array, label_type: str, foreign_fonts: bool = True - ) -> np.array: - """ - This function takes an image (np.array) and adds random text that - looks like a chemical ID label, an R group label or a reaction - condition label around the structure. It returns the modified image. - The label type is determined by the parameter label_type (str), - which needs to be 'ID', 'R_GROUP' or 'REACTION' - - Args: - image (np.array): Chemical structure depiction - label_type (str): 'ID', 'R_GROUP' or 'REACTION' - foreign_fonts (bool, optional): Defaults to True. - - Returns: - np.array: Chemical structure depiction with label - """ - im = Image.fromarray(image) - orig_image = deepcopy(im) - width, height = im.size - # Choose random font - if self.random_choice([True, False]) or not foreign_fonts: - font_dir = self.HERE.joinpath("fonts/") - # In half of the cases: Use foreign-looking font to generate - # bigger noise variety - else: - font_dir = self.HERE.joinpath("foreign_fonts/") - - fonts = os.listdir(str(font_dir)) - # Choose random font size - font_sizes = range(10, 20) - size = self.random_choice(font_sizes) - # Generate random string that resembles the desired type of label - if label_type == "ID": - label_text = self.ID_label_text() - if label_type == "R_GROUP": - label_text = self.make_R_group_str() - if label_type == "REACTION": - label_text = self.reaction_condition_label_text() - - try: - font = ImageFont.truetype( - str(os.path.join(str(font_dir), self.random_choice(fonts))), size=size - ) - except OSError: - font = ImageFont.load_default() - - draw = ImageDraw.Draw(im, "RGBA") - - # Try different positions with the condition that the label´does not - # overlap with non-white pixels (the structure) - for _ in range(50): - y_pos, x_pos = self.get_random_label_position(width, height) - bounding_box = draw.textbbox( - (x_pos, y_pos), label_text, font=font - ) # left, up, right, low - paste_region = orig_image.crop(bounding_box) - try: - mean = ImageStat.Stat(paste_region).mean - except ZeroDivisionError: - return np.asarray(im) - if sum(mean) / len(mean) == 255: - draw.text((x_pos, y_pos), label_text, font=font, fill=(0, 0, 0, 255)) - break - return np.asarray(im) - - def add_curved_arrows_to_structure(self, image: np.array) -> np.array: - """ - This function takes an image of a chemical structure (np.array) - and adds between 2 and 4 curved arrows in random positions in the - central part of the image. - - Args: - image (np.array): Chemical structure depiction - - Returns: - np.array: Chemical structure depiction with curved arrows - """ - height, width, _ = image.shape - image = Image.fromarray(image) - orig_image = deepcopy(image) - # Determine area where arrows are pasted. - x_min, x_max = (int(0.1 * width), int(0.9 * width)) - y_min, y_max = (int(0.1 * height), int(0.9 * height)) - - arrow_dir = os.path.normpath( - str(self.HERE.joinpath("arrow_images/curved_arrows/")) - ) - - for _ in range(self.random_choice(range(2, 4))): - # Load random curved arrow image, resize and rotate it randomly. - arrow_image = Image.open( - os.path.join( - str(arrow_dir), self.random_choice(os.listdir(str(arrow_dir))) - ) - ) - new_arrow_image_shape = int( - (x_max - x_min) / self.random_choice(range(3, 6)) - ), int((y_max - y_min) / self.random_choice(range(3, 6))) - arrow_image = self.resize(np.asarray(arrow_image), new_arrow_image_shape) - arrow_image = Image.fromarray(arrow_image) - arrow_image = arrow_image.rotate( - self.random_choice(range(360)), - resample=self.random_choice( - [Image.BICUBIC, Image.NEAREST, Image.BILINEAR] - ), - expand=True, - ) - # Try different positions with the condition that the arrows are - # overlapping with non-white pixels (the structure) - for _ in range(50): - x_position = self.random_choice( - range(x_min, x_max - new_arrow_image_shape[0]) - ) - y_position = self.random_choice( - range(y_min, y_max - new_arrow_image_shape[1]) - ) - paste_region = orig_image.crop( - ( - x_position, - y_position, - x_position + new_arrow_image_shape[0], - y_position + new_arrow_image_shape[1], - ) - ) - mean = ImageStat.Stat(paste_region).mean - if sum(mean) / len(mean) < 252: - image.paste(arrow_image, (x_position, y_position), arrow_image) - - break - return np.asarray(image) - - def get_random_arrow_position(self, width: int, height: int) -> Tuple[int, int]: - """ - Given the width and height of an image (int), this function determines - a random position to paste a reaction arrow in the outer 15% frame of - the image - - Args: - width (_type_): image width - height (_type_): image height - - Returns: - Tuple[int, int]: Random arrow position - """ - if self.random_choice([True, False]): - y_range = range(0, height) - x_range = list(range(0, int(0.15 * width))) + list( - range(int(0.85 * width), width) - ) - else: - y_range = list(range(0, int(0.15 * height))) + list( - range(int(0.85 * height), height) - ) - x_range = range(0, int(0.5 * width)) - return self.random_choice(y_range), self.random_choice(x_range) - - def add_straight_arrows_to_structure(self, image: np.array) -> np.array: - """ - This function takes an image of a chemical structure (np.array) - and adds between 1 and 2 straight arrows in random positions in the - image (no overlap with other elements) - - Args: - image (np.array): Chemical structure depiction - - Returns: - np.array: Chemical structure depiction with straight arrow - """ - height, width, _ = image.shape - image = Image.fromarray(image) - - arrow_dir = os.path.normpath( - str(self.HERE.joinpath("arrow_images/straight_arrows/")) - ) - - for _ in range(self.random_choice(range(1, 3))): - # Load random curved arrow image, resize and rotate it randomly. - arrow_image = Image.open( - os.path.join( - str(arrow_dir), self.random_choice(os.listdir(str(arrow_dir))) - ) - ) - # new_arrow_image_shape = (int(width * - # self.random_choice(np.arange(0.9, 1.5, 0.1))), - # int(height/10 * self.random_choice(np.arange(0.7, 1.2, 0.1)))) - - # arrow_image = arrow_image.resize(new_arrow_image_shape, - # resample=Image.BICUBIC) - # Rotate completely randomly in half of the cases and in 180° steps - # in the other cases (higher probability that pasting works) - if self.random_choice([True, False]): - arrow_image = arrow_image.rotate( - self.random_choice(range(360)), - resample=self.random_choice( - [Image.Resampling.BICUBIC, Image.Resampling.NEAREST, Image.Resampling.BILINEAR] - ), - expand=True, - ) - else: - arrow_image = arrow_image.rotate(self.random_choice([180, 360])) - new_arrow_image_shape = arrow_image.size - # Try different positions with the condition that the arrows are - # overlapping with non-white pixels (the structure) - for _ in range(50): - y_position, x_position = self.get_random_arrow_position(width, height) - x2_position = x_position + new_arrow_image_shape[0] - y2_position = y_position + new_arrow_image_shape[1] - # Make sure we only check a region inside of the image - if x2_position > width: - x2_position = width - 1 - if y2_position > height: - y2_position = height - 1 - paste_region = image.crop( - (x_position, y_position, x2_position, y2_position) - ) - try: - mean = ImageStat.Stat(paste_region).mean - if sum(mean) / len(mean) == 255: - image.paste(arrow_image, (x_position, y_position), arrow_image) - break - except ZeroDivisionError: - pass - return np.asarray(image) - - def to_grayscale_float_img(self, image: np.array) -> np.array: - """ - This function takes an image (np.array), converts it to grayscale - and returns it. - - Args: - image (np.array): image - - Returns: - np.array: grayscale float image - """ - return img_as_float(rgb2gray(image)) + # Workaround because PIKAChU fails to depict large structures + # TODO: Delete workaround when problem is fixed in PIKAChU + # https://github.com/BTheDragonMaster/pikachu/issues/11 + if len(smiles) > 100: + if self.pikachu_depict in depiction_functions: + depiction_functions.remove(self.pikachu_depict) + return depiction_functions def depict_save( self, @@ -2631,6 +578,7 @@ def batch_depict_save_with_fingerprints( smiles_list = [smi for smi in smiles_list for _ in range(images_per_structure)] # Generate corresponding amount of fingerprints dataset_size = len(smiles_list) + from .depiction_feature_ranges import DepictionFeatureRanges FR = DepictionFeatureRanges() fingerprint_tuples = FR.generate_fingerprints_for_dataset( dataset_size, @@ -2695,6 +643,7 @@ def batch_depict_with_fingerprints( smiles_list = [smi for smi in smiles_list for _ in range(images_per_structure)] # Generate corresponding amount of fingerprints dataset_size = len(smiles_list) + from .depiction_feature_ranges import DepictionFeatureRanges FR = DepictionFeatureRanges() fingerprint_tuples = FR.generate_fingerprints_for_dataset( dataset_size, @@ -2723,59 +672,16 @@ def batch_depict_with_fingerprints( ) return list(depictions) + def random_choice(self, iterable: List, log_attribute: str = False): + """ + This function takes an iterable, calls random.choice() on it, + increases random.seed by 1 and returns the result. This way, results + produced by RanDepict are replicable. -class DepictionFeatureRanges(RandomDepictor): - """Class for depiction feature fingerprint generation""" - - def __init__(self): - super().__init__() - # Fill ranges. By simply using all the depiction and augmentation - # functions, the available features are saved by the overwritten - # random_choice function. We just have to make sure to run through - # every available decision once to get all the information about the - # feature space that we need. - smiles = "CN1C=NC2=C1C(=O)N(C(=O)N2C)C" - - # Call every depiction function - depiction = self(smiles) - depiction = self.cdk_depict(smiles) - depiction = self.rdkit_depict(smiles) - depiction = self.indigo_depict(smiles) - depiction = self.pikachu_depict(smiles) - # Call augmentation function - depiction = self.add_augmentations(depiction) - # Generate schemes for Fingerprint creation - self.schemes = self.generate_fingerprint_schemes() - ( - self.CDK_scheme, - self.RDKit_scheme, - self.Indigo_scheme, - self.PIKAChU_scheme, - self.augmentation_scheme, - ) = self.schemes - # Generate the pool of all valid fingerprint combinations - - self.generate_all_possible_fingerprints() - self.FP_length_scheme_dict = { - len(self.CDK_fingerprints[0]): self.CDK_scheme, - len(self.RDKit_fingerprints[0]): self.RDKit_scheme, - len(self.Indigo_fingerprints[0]): self.Indigo_scheme, - len(self.PIKAChU_fingerprints[0]): self.PIKAChU_scheme, - len(self.augmentation_fingerprints[0]): self.augmentation_scheme, - } + Additionally, this function handles the generation of depictions and + augmentations from given fingerprints by handling all random decisions + according to the fingerprint template. - def random_choice(self, iterable: List, log_attribute: str = False) -> Any: - """ - In RandomDepictor, this function would take an iterable, call - random_choice() on it, increase the seed attribute by 1 and return - the result. - ___ - Here, this function is overwritten, so that it also sets the class - attribute $log_attribute_range to contain the iterable. - This way, a DepictionFeatureRanges object can easily be filled with - all the iterables that define the complete depiction feature space - (for fingerprint generation). - ___ Args: iterable (List): iterable to pick from log_attribute (str, optional): ID for fingerprint. @@ -2784,683 +690,129 @@ def random_choice(self, iterable: List, log_attribute: str = False) -> Any: Returns: Any: "Randomly" picked element """ - # Save iterables as class attributes (for fingerprint generation) - if log_attribute: - setattr(self, f"{log_attribute}_range", iterable) - # Pseudo-randomly pick element from iterable + # Keep track of seed and change it with every pseudo-random decision. self.seed += 1 random.seed(self.seed) - result = random.choice(iterable) - return result - def generate_fingerprint_schemes(self) -> List[Dict]: - """ - Generates fingerprint schemes (see generate_fingerprint_scheme()) - for the depictions with CDK, RDKit and Indigo as well as the - augmentations. - ___ - Returns: - List[Dict]: [cdk_scheme: Dict, rdkit_scheme: Dict, - indigo_scheme: Dict, augmentation_scheme: Dict] - """ - fingerprint_schemes = [] - range_IDs = [att for att in dir(self) if "range" in att] - # Generate fingerprint scheme for our cdk, indigo and rdkit depictions - depiction_toolkits = ["cdk", "rdkit", "indigo", "pikachu", ""] - for toolkit in depiction_toolkits: - toolkit_range_IDs = [att for att in range_IDs if toolkit in att] - # Delete toolkit-specific ranges - # (The last time this loop runs, only augmentation-related ranges - # are left) - for ID in toolkit_range_IDs: - range_IDs.remove(ID) - # [:-6] --> remove "_range" at the end - toolkit_range_dict = { - attr[:-6]: list(set(getattr(self, attr))) for attr in toolkit_range_IDs - } - fingerprint_scheme = self.generate_fingerprint_scheme(toolkit_range_dict) - fingerprint_schemes.append(fingerprint_scheme) - return fingerprint_schemes - - def generate_fingerprint_scheme(self, ID_range_map: Dict) -> Dict: - """ - This function takes the ID_range_map and returns a dictionary that - defines where each feature is represented in the depiction feature - fingerprint. - ___ - Example: - >> example_ID_range_map = {'thickness': [0, 1, 2, 3], - 'kekulized': [True, False]} - >> generate_fingerprint_scheme(example_ID_range_map) - >>>> {'thickness': [{'position': 0, 'one_if': 0}, - {'position': 1, 'one_if': 1}, - {'position': 2, 'one_if': 2}, - {'position': 3, 'one_if': 3}], - 'kekulized': [{'position': 4, 'one_if': True}]} - Args: - ID_range_map (Dict): dict that maps an ID (str) of a feature range - to the feature range itself (iterable) + # Generation from fingerprint: + if self.from_fingerprint and log_attribute: + # Get dictionaries that define positions and linked conditions + pos_cond_dicts = self.active_scheme[log_attribute] + for pos_cond_dict in pos_cond_dicts: + pos = pos_cond_dict["position"] + cond = pos_cond_dict["one_if"] + if self.active_fingerprint[pos]: + # If the condition is a range: adapt iterable and go on + if isinstance(cond, tuple): + iterable = [ + item + for item in iterable + if item > cond[0] - 0.001 + if item < cond[1] + 0.001 + ] + break + # Otherwise, simply return the condition value + else: + return cond + # Pseudo-randomly pick an element from the iterable + result = random.choice(iterable) - Returns: - Dict: Map of feature ID (str) and dictionaries that define the - fingerprint position and a condition - """ - fingerprint_scheme = {} - position = 0 - for feature_ID in ID_range_map.keys(): - feature_range = ID_range_map[feature_ID] - # Make sure numeric ranges don't take up more than 5 positions - # in the fingerprint - if ( - type(feature_range[0]) in [int, float, np.float64, np.float32] - and len(feature_range) > 5 - ): - subranges = self.split_into_n_sublists(feature_range, n=3) - position_dicts = [] - for subrange in subranges: - subrange_minmax = (min(subrange), max(subrange)) - position_dict = {"position": position, "one_if": subrange_minmax} - position_dicts.append(position_dict) - position += 1 - fingerprint_scheme[feature_ID] = position_dicts - # Bools take up only one position in the fingerprint - elif isinstance(feature_range[0], bool): - assert len(feature_range) == 2 - position_dicts = [{"position": position, "one_if": True}] - position += 1 - fingerprint_scheme[feature_ID] = position_dicts - else: - # For other types of categorical data: Each category gets one - # position in the FP - position_dicts = [] - for feature in feature_range: - position_dict = {"position": position, "one_if": feature} - position_dicts.append(position_dict) - position += 1 - fingerprint_scheme[feature_ID] = position_dicts - return fingerprint_scheme - - def split_into_n_sublists(self, iterable, n: int) -> List[List]: - """ - Takes an iterable, sorts it, splits it evenly into n lists - and returns the split lists. + return result - Args: - iterable ([type]): Iterable that is supposed to be split - n (int): Amount of sublists to return - Returns: - List[List]: Split list - """ - iterable = sorted(iterable) - iter_len = len(iterable) - sublists = [] - for i in range(0, iter_len, int(np.ceil(iter_len / n))): - sublists.append(iterable[i: i + int(np.ceil(iter_len / n))]) - return sublists - - def get_number_of_possible_fingerprints(self, scheme: Dict) -> int: + def has_r_group(self, smiles: str) -> bool: """ - This function takes a fingerprint scheme (Dict) as returned by - generate_fingerprint_scheme() - and returns the number of possible fingerprints for that scheme. + Determines whether or not a given SMILES str contains an R group Args: - scheme (Dict): Output of generate_fingerprint_scheme() + smiles (str): SMILES representation of molecule Returns: - int: Number of possible fingerprints - """ - comb_count = 1 - for feature_key in scheme.keys(): - if len(scheme[feature_key]) != 1: - # n fingerprint positions -> n options - # (because only one position can be [1]) - # n = 3 --> [1][0][0] or [0][1][0] or [0][0][1] - comb_count *= len(scheme[feature_key]) - else: - # One fingerprint position -> two options: [0] or [1] - comb_count *= 2 - return comb_count - - def get_FP_building_blocks(self, scheme: Dict) -> List[List[List]]: + bool """ - This function takes a fingerprint scheme (Dict) as returned by - generate_fingerprint_scheme() - and returns a list of possible building blocks. - Example: - scheme = {'thickness': [{'position': 0, 'one_if': 0}, - {'position': 1, 'one_if': 1}, - {'position': 2, 'one_if': 2}, - {'position': 3, 'one_if': 3}], - 'kekulized': [{'position': 4, 'one_if': True}]} - - --> Output: [[[1,0,0,0], [0,1,0,0], [0,0,1,0], [0,0,0,1]], - [[1], [0]]] - - Args: - scheme (Dict): Output of generate_fingerprint_scheme() - - Returns: - List that contains the valid fingerprint parts that represent the - different features + if re.search("\[.*[RXYZ].*\]", smiles): + return True - """ - FP_building_blocks = [] - for feature_key in scheme.keys(): - position_condition_dicts = scheme[feature_key] - FP_building_blocks.append([]) - # Add every single valid option to the building block - for position_index in range(len(position_condition_dicts)): - # Add list of zeros - FP_building_blocks[-1].append([0] * len(position_condition_dicts)) - # Replace one zero with a one - FP_building_blocks[-1][-1][position_index] = 1 - # If a feature is described by only one position in the FP, - # make sure that 0 and 1 are listed options - if FP_building_blocks[-1] == [[1]]: - FP_building_blocks[-1].append([0]) - return FP_building_blocks - - def flatten_fingerprint( + def _smiles_to_mol_block( self, - unflattened_list: List[List], - ) -> List: + smiles: str, + generate_2d: bool = False, + ) -> str: """ - This function takes a list of lists and returns a list. + This function takes a SMILES representation of a molecule and returns + the content of the corresponding SD file using the CDK. ___ - Looks like this could be one line elsewhere but this function used for - parallelisation of FP generation and consequently needs to be wrapped - up in a separate function. - - Args: - unflattened_list (List[List[X,Y,Z]]) - - Returns: - flattened_list (List[X,Y,Z]): - """ - flattened_list = [ - element for sublist in unflattened_list for element in sublist - ] - return flattened_list - - def generate_all_possible_fingerprints_per_scheme( - self, - scheme: Dict, - ) -> List[List[int]]: - """ - This function takes a fingerprint scheme (Dict) as returned by - generate_fingerprint_scheme() - and returns a List of all possible fingerprints for that scheme. - - Args: - scheme (Dict): Output of generate_fingerprint_scheme() - name (str): name that is used for filename of saved FPs - - Returns: - List[List[int]]: List of fingerprints - """ - # Determine valid building blocks for fingerprints - FP_building_blocks = self.get_FP_building_blocks(scheme) - # Determine cartesian product of valid building blocks to get all - # valid fingerprints - FP_generator = product(*FP_building_blocks) - flattened_fingerprints = list(map(self.flatten_fingerprint, FP_generator)) - return flattened_fingerprints - - def generate_all_possible_fingerprints(self) -> None: - """ - This function generates all possible valid fingerprint combinations - for the four available fingerprint schemes if they have not been - created already. Otherwise, they are loaded from files. - This function returns None but saves the fingerprint pools as a - class attribute $ID_fingerprints - """ - FP_names = ["CDK", "RDKit", "Indigo", "PIKAChU", "augmentation"] - for scheme_index in range(len(self.schemes)): - exists_already = False - n_FP = self.get_number_of_possible_fingerprints(self.schemes[scheme_index]) - # Load fingerprint pool from file (if it exists) - FP_filename = "{}_fingerprints.npz".format(FP_names[scheme_index]) - FP_file_path = self.HERE.joinpath(FP_filename) - if os.path.exists(FP_file_path): - fps = np.load(FP_file_path)["arr_0"] - if len(fps) == n_FP: - exists_already = True - # Otherwise, generate the fingerprint pool - if not exists_already: - print("No saved fingerprints found. This may take a minute.") - fps = self.generate_all_possible_fingerprints_per_scheme( - self.schemes[scheme_index] - ) - np.savez_compressed(FP_file_path, fps) - print( - "{} fingerprints were saved in {}.".format( - FP_names[scheme_index], FP_file_path - ) - ) - setattr(self, "{}_fingerprints".format(FP_names[scheme_index]), fps) - return - - def convert_to_int_arr( - self, fingerprints: List[List[int]] - ) -> List[DataStructs.cDataStructs.ExplicitBitVect]: - """ - Takes a list of fingerprints (List[int]) and returns them as a list of - rdkit.DataStructs.cDataStructs.ExplicitBitVect so that they can be - processed by RDKit's MaxMinPicker. - - Args: - fingerprints (List[List[int]]): List of fingerprints - - Returns: - List[DataStructs.cDataStructs.ExplicitBitVect]: Converted arrays - """ - converted_fingerprints = [] - for fp in fingerprints: - bitstring = "".join(np.array(fp).astype(str)) - fp_converted = DataStructs.cDataStructs.CreateFromBitString(bitstring) - converted_fingerprints.append(fp_converted) - return converted_fingerprints - - def pick_fingerprints( - self, - fingerprints: List[List[int]], - n: int, - ) -> np.array: - """ - Given a list of fingerprints and a number n of fingerprints to pick, - this function uses RDKit's MaxMin Picker to pick n fingerprints and - returns them. - - Args: - fingerprints (List[List[int]]): List of fingerprints - n (int): Number of fingerprints to pick - - Returns: - np.array: Picked fingerprints - """ - - converted_fingerprints = self.convert_to_int_arr(fingerprints) - - """TODO: I don't like this function definition in the function but - according to the RDKit Documentation, the fingerprints need to be - given in the distance function as the default value.""" - - def dice_dist( - fp_index_1: int, - fp_index_2: int, - fingerprints: List[ - DataStructs.cDataStructs.ExplicitBitVect - ] = converted_fingerprints, - ) -> float: - """ - Returns the dice similarity between two fingerprints. - Args: - fp_index_1 (int): index of first fingerprint in fingerprints - fp_index_2 (int): index of second fingerprint in fingerprints - fingerprints (List[cDataStructs.ExplicitBitVect]): fingerprints - - Returns: - float: Dice similarity between the two fingerprints - """ - return 1 - DataStructs.DiceSimilarity( - fingerprints[fp_index_1], fingerprints[fp_index_2] - ) - - # If we want to pick more fingerprints than there are in the pool, - # simply distribute the complete pool as often as possible and pick - # the amount that is not dividable by the size of the pool - picked_fingerprints, n = self.correct_amount_of_FP_to_pick(fingerprints, n) - - picker = MaxMinPicker() - pick_indices = picker.LazyPick(dice_dist, len(fingerprints), n, seed=42) - if isinstance(picked_fingerprints, bool): - picked_fingerprints = np.array([fingerprints[i] for i in pick_indices]) - else: - picked_fingerprints = np.concatenate( - (np.array(picked_fingerprints), np.array(([fingerprints[i] for i in pick_indices]))) - ) - return picked_fingerprints - - def correct_amount_of_FP_to_pick(self, fingerprints: List, n: int) -> Tuple[List, int]: - """ - When picking n elements from a list of fingerprints, if the amount of fingerprints is - bigger than n, there is no need to pick n fingerprints. Instead, the complete fingerprint - list is added to the picked fingerprints as often as possible while only the amount - that is not dividable by the fingerprint pool size is picked. + The SMILES parser of the CDK is much more tolerant than the parsers of + RDKit and Indigo. ___ - Given a list of fingerprints and the amount of fingerprints to pick n, this function - returns a list of "picked" fingerprints and (in the ideal case) a corrected lower number - of fingerprints to be picked - - Args: - fingerprints (List): _description_ - n (int): _description_ - - Returns: - Tuple[List, int]: _description_ - """ - if n > len(fingerprints): - oversize_factor = int(n / len(fingerprints)) - picked_fingerprints = np.concatenate([fingerprints for _ - in range(oversize_factor)]) - n = n - len(fingerprints) * oversize_factor - else: - picked_fingerprints = False - return picked_fingerprints, n - - def generate_fingerprints_for_dataset( - self, - size: int, - indigo_proportion: float = 0.15, - rdkit_proportion: float = 0.25, - pikachu_proportion: float = 0.25, - cdk_proportion: float = 0.35, - aug_proportion: float = 0.5, - ) -> List[List[int]]: - """ - Given a dataset size (int) and (optional) proportions for the - different types of fingerprints, this function returns - - Args: - size (int): Desired dataset size, number of returned fingerprints - indigo_proportion (float): Indigo proportion. Defaults to 0.15. - rdkit_proportion (float): RDKit proportion. Defaults to 0.25. - pikachu_proportion (float): PIKAChU proportion. Defaults to 0.25. - cdk_proportion (float): CDK proportion. Defaults to 0.35. - aug_proportion (float): Augmentation proportion. Defaults to 0.5. - - Raises: - ValueError: - - If the sum of Indigo, RDKit, PIKAChU and CDK proportions is not 1 - - If the augmentation proportion is > 1 - - Returns: - List[List[int]]: List of lists containing the fingerprints. - ___ - Depending on augmentation_proportion, the depiction fingerprints - are paired with augmentation fingerprints or not. - - Example output: - [[$some_depiction_fingerprint, $some augmentation_fingerprint], - [$another_depiction_fingerprint] - [$yet_another_depiction_fingerprint]] - - """ - # Make sure that the given proportion arguments make sense - if sum((indigo_proportion, rdkit_proportion, pikachu_proportion, cdk_proportion)) != 1: - raise ValueError( - "Sum of Indigo, CDK, PIKAChU and RDKit proportions arguments has to be 1" - ) - if aug_proportion > 1: - raise ValueError( - "The proportion of augmentation fingerprints can't be > 1." - ) - # Pick and return diverse fingerprints - picked_Indigo_fingerprints = self.pick_fingerprints( - self.Indigo_fingerprints, int(size * indigo_proportion) - ) - picked_RDKit_fingerprints = self.pick_fingerprints( - self.RDKit_fingerprints, int(size * rdkit_proportion) - ) - picked_PIKAChU_fingerprints = self.pick_fingerprints( - self.PIKAChU_fingerprints, int(size * pikachu_proportion) - ) - picked_CDK_fingerprints = self.pick_fingerprints( - self.CDK_fingerprints, int(size * cdk_proportion) - ) - picked_augmentation_fingerprints = self.pick_fingerprints( - self.augmentation_fingerprints, int(size * aug_proportion) - ) - # Distribute augmentation_fingerprints over depiction fingerprints - fingerprint_tuples = self.distribute_elements_evenly( - picked_augmentation_fingerprints, - picked_Indigo_fingerprints, - picked_RDKit_fingerprints, - picked_PIKAChU_fingerprints, - picked_CDK_fingerprints, - ) - # Shuffle fingerprint tuples randomly to avoid the same smiles - # always being depicted with the same cheminformatics toolkit - random.seed(self.seed) - random.shuffle(fingerprint_tuples) - return fingerprint_tuples - - def distribute_elements_evenly( - self, elements_to_be_distributed: List[Any], *args: List[Any] - ) -> List[List[Any]]: - """ - This function distributes the elements from elements_to_be_distributed - evenly over the lists of elements given in args. It can be used to link - augmentation fingerprints to given lists of depiction fingerprints. - - Example: - distribute_elements_evenly(["A", "B", "C", "D"], [1, 2, 3], [4, 5, 6]) - Output: [[1, "A"], [2, "B"], [3], [4, "C"], [5, "D"], [6]] - --> see test_distribute_elements_evenly() in ../Tests/test_functions.py - - Args: - elements_to_be_distributed (List[Any]): elements to be distributed - args: Every arg is a list of elements (B) - - Returns: - List[List[Any]]: List of Lists (B, A) where the elements A are - distributed evenly over the elements B according - to the length of the list of elements B - """ - # Make sure that the input is valid - args_total_len = len([element for sublist in args for element in sublist]) - if len(elements_to_be_distributed) > args_total_len: - raise ValueError("Can't take more elements to be distributed than in args.") - - output = [] - start_index = 0 - for element_list in args: - # Define part of elements_to_be_distributed that belongs to this - # element_sublist - sublist_len = len(element_list) - end_index = start_index + int( - sublist_len / args_total_len * len(elements_to_be_distributed) - ) - select_elements_to_be_distributed = elements_to_be_distributed[ - start_index:end_index - ] - for element_index in range(len(element_list)): - if element_index < len(select_elements_to_be_distributed): - output.append( - [ - element_list[element_index], - select_elements_to_be_distributed[element_index], - ] - ) - else: - output.append([element_list[element_index]]) - start_index = start_index + int( - sublist_len / args_total_len * len(elements_to_be_distributed) - ) - return output - - -class RandomMarkushStructureCreator: - def __init__(self, *, variables_list=None, max_index=20): - """ - RandomMarkushStructureCreator objects are instantiated with the desired - inserted R group variables. Otherwise, "R", "X" and "Z" are used. - """ - # Instantiate RandomDepictor for reproducible random decisions - self.depictor = RandomDepictor() - # Define R group variables - if variables_list is None: - self.r_group_variables = ["R", "X", "Y", "Z"] - else: - self.r_group_variables = variables_list - - self.potential_indices = range(1, max_index + 1) - - def generate_markush_structure_dataset(self, smiles_list: List[str]) -> List[str]: - """ - This function takes a list of SMILES, replaces 1-4 carbon or hydrogen atoms per - molecule with R groups and returns the resulting list of SMILES. - - Args: - smiles_list (List[str]): SMILES representations of molecules - - Returns: - List[str]: SMILES reprentations of markush structures - """ - numbers = [self.depictor.random_choice(range(1, 5)) for _ in smiles_list] - r_group_smiles = [ - self.insert_R_group_var(smiles_list[index], numbers[index]) - for index in range(len(smiles_list)) - ] - return r_group_smiles - - def insert_R_group_var(self, smiles: str, num: int) -> str: - """ - This function takes a smiles string and a number of R group variables. It then - replaces the given number of H or C atoms with R groups and returns the SMILES str. - - Args: - smiles (str): SMILES (absolute) representation of a molecule - num (int): number of R group variables to be inserted - - Returns: - smiles (str): input SMILES with $num inserted R group variables - """ - smiles = self.add_explicite_hydrogen_to_smiles(smiles) - potential_replacement_positions = self.get_valid_replacement_positions(smiles) - r_groups = [] - # Replace C or H in SMILES with * - # If we would directly insert the R group variables, CDK would replace them with '*' - # later when removing the explicite hydrogen atoms - smiles = list(smiles) - for _ in range(num): - if len(potential_replacement_positions) > 0: - position = self.depictor.random_choice(potential_replacement_positions) - smiles[position] = "*" - potential_replacement_positions.remove(position) - r_groups.append(self.get_r_group_smiles()) - else: - break - # Remove explicite hydrogen again and get absolute SMILES - smiles = "".join(smiles) - smiles = self.remove_explicite_hydrogen_from_smiles(smiles) - # Replace * with R groups - for r_group in r_groups: - smiles = smiles.replace("*", r_group, 1) - return smiles - - def get_r_group_smiles(self) -> str: - """ - This function returns a random R group substring that can be inserted - into an existing SMILES str. - - Returns: - str: SMILES compatible of R group str - """ - has_indices = self.depictor.random_choice([True, True, True, True, False]) - r_group_var = self.depictor.random_choice(self.r_group_variables) - if has_indices: - index = self.depictor.random_choice(self.potential_indices) - if self.depictor.random_choice([True, False, False]): - index_char = self.depictor.random_choice(["a", "b", "c", "d", "e", "f"]) - else: - index_char = "" - return f"[{r_group_var}{index}{index_char}]" - else: - return f"[{r_group_var}]" - - def get_valid_replacement_positions(self, smiles: str) -> List[int]: - """ - Returns positions in a SMILES str where elements in the str can be replaced with - R groups without endangering its validity - - Args: - smiles (str): SMILES representation of a molecule - - Returns: - replacement_positions (List[int]): valid replacement positions for R group variables - """ - # Add space char to represent insertion position at the end of smiles str - smiles = f"{smiles} " - replacement_positions = [] - for index in range(len(smiles)): - # Be aware of digits --> don't destroy ring syntax - if smiles[index].isdigit(): - continue - # Don't replace isotopes to not to end up with [13R] - elif index >= 2 and smiles[index - 2].isdigit() and smiles[index] == "]": - continue - # Don't produce charged R groups (eg. "R+") - elif smiles[index] in ["+", "-"]: - continue - elif smiles[index - 1] == "H" and smiles[index] == "]": - replacement_positions.append(index - 1) - # Only replace "C" and "H" - elif smiles[index - 1] == "C": - # Don't replace "C" in "Cl", "Ca", Cu", etc... - if smiles[index] not in [ - "s", - "a", - "e", - "o", - "u", - "r", - "l", - "f", - "d", - "n", - "@" # replacing chiral C leads to invalid SMILES - ]: - replacement_positions.append(index - 1) - return replacement_positions - - def add_explicite_hydrogen_to_smiles(self, smiles: str) -> str: - """ - This function takes a SMILES str and uses CDK to add explicite hydrogen atoms. - It returns an adapted version of the SMILES str. Args: smiles (str): SMILES representation of a molecule + generate_2d (bool or str, optional): False if no coordinates are created + Otherwise pick tool for coordinate + generation: + "rdkit", "cdk" or "indigo" + If rdkit or Indigo cannot handle + certain Markush SMILES, the CDK is used Returns: - smiles (str): SMILES representation of a molecule with explicite H + mol_block (str): content of SD file of input molecule """ - i_atom_container = self.depictor._cdk_smiles_to_IAtomContainer(smiles) - - # Add explicite hydrogen atoms - cdk_base = "org.openscience.cdk." - manipulator = JClass(cdk_base + "tools.manipulator.AtomContainerManipulator") - manipulator.convertImplicitToExplicitHydrogens(i_atom_container) - - # Create absolute SMILES - smi_flavor = JClass("org.openscience.cdk.smiles.SmiFlavor").Absolute - smiles_generator = JClass("org.openscience.cdk.smiles.SmilesGenerator")( - smi_flavor - ) - smiles = smiles_generator.create(i_atom_container) - return str(smiles) + if not generate_2d: + molecule = self._cdk_smiles_to_IAtomContainer(smiles) + return self._cdk_iatomcontainer_to_mol_block(molecule) + elif generate_2d == "cdk": + molecule = self._cdk_smiles_to_IAtomContainer(smiles) + molecule = self._cdk_generate_2d_coordinates(molecule) + molecule = self._cdk_rotate_coordinates(molecule) + return self._cdk_iatomcontainer_to_mol_block(molecule) + elif generate_2d == "rdkit": + if re.search("\[[RXZ]\]|\[[XYZ]\d+", smiles): + return self._smiles_to_mol_block(smiles, generate_2d="cdk") + mol_block = self._smiles_to_mol_block(smiles) + molecule = Chem.MolFromMolBlock(mol_block, sanitize=False) + if molecule: + AllChem.Compute2DCoords(molecule) + mol_block = Chem.MolToMolBlock(molecule) + atom_container = self._cdk_mol_block_to_iatomcontainer(mol_block) + atom_container = self._cdk_rotate_coordinates(atom_container) + return self._cdk_iatomcontainer_to_mol_block(atom_container) + else: + raise ValueError(f"RDKit could not read molecule: {smiles}") + elif generate_2d == "indigo": + if re.search("\[R0\]|\[X\]|[4-9][0-9]+|3[3-9]|[XYZR]\d+[a-f]", smiles): + return self._smiles_to_mol_block(smiles, generate_2d="cdk") + indigo = Indigo() + mol_block = self._smiles_to_mol_block(smiles) + molecule = indigo.loadMolecule(mol_block) + molecule.layout() + buf = indigo.writeBuffer() + buf.sdfAppend(molecule) + mol_block = buf.toString() + atom_container = self._cdk_mol_block_to_iatomcontainer(mol_block) + atom_container = self._cdk_rotate_coordinates(atom_container) + return self._cdk_iatomcontainer_to_mol_block(atom_container) + elif generate_2d == "pikachu": + pass - def remove_explicite_hydrogen_from_smiles(self, smiles: str) -> str: + def central_square_image(self, im: np.array) -> np.array: """ - This function takes a SMILES str and uses CDK to remove explicite hydrogen atoms. - It returns an adapted version of the SMILES str. + This function takes image (np.array) and will add white padding + so that the image has a square shape with the width/height of the + longest side of the original image. Args: - smiles (str): SMILES representation of a molecule + im (np.array): Input image Returns: - smiles (str): SMILES representation of a molecule with explicite H + np.array: Output image """ - i_atom_container = self.depictor._cdk_smiles_to_IAtomContainer(smiles) - # Remove explicite hydrogen atoms - cdk_base = "org.openscience.cdk." - manipulator = JClass(cdk_base + "tools.manipulator.AtomContainerManipulator") - i_atom_container = manipulator.copyAndSuppressedHydrogens(i_atom_container) - # Create absolute SMILES - smi_flavor = JClass("org.openscience.cdk.smiles.SmiFlavor").Absolute - smiles_generator = JClass("org.openscience.cdk.smiles.SmilesGenerator")( - smi_flavor - ) - smiles = smiles_generator.create(i_atom_container) - return str(smiles) + # Create new blank white image + max_wh = max(im.shape) + new_im = 255 * np.ones((max_wh, max_wh, 3), np.uint8) + # Determine paste coordinates and paste image + upper = int((new_im.shape[0] - im.shape[0]) / 2) + lower = int((new_im.shape[0] - im.shape[0]) / 2) + im.shape[0] + left = int((new_im.shape[1] - im.shape[1]) / 2) + right = int((new_im.shape[1] - im.shape[1]) / 2) + im.shape[1] + new_im[upper:lower, left:right] = im + return new_im diff --git a/RanDepict/random_markush_structure_generator.py b/RanDepict/random_markush_structure_generator.py new file mode 100644 index 0000000..c710a16 --- /dev/null +++ b/RanDepict/random_markush_structure_generator.py @@ -0,0 +1,188 @@ +from jpype import JClass +# import sys +from typing import List +from .randepict import RandomDepictor + +class RandomMarkushStructureCreator: + def __init__(self, *, variables_list=None, max_index=20): + """ + RandomMarkushStructureCreator objects are instantiated with the desired + inserted R group variables. Otherwise, "R", "X" and "Z" are used. + """ + # Instantiate RandomDepictor for reproducible random decisions + self.depictor = RandomDepictor() + # Define R group variables + if variables_list is None: + self.r_group_variables = ["R", "X", "Y", "Z"] + else: + self.r_group_variables = variables_list + + self.potential_indices = range(1, max_index + 1) + + def generate_markush_structure_dataset(self, smiles_list: List[str]) -> List[str]: + """ + This function takes a list of SMILES, replaces 1-4 carbon or hydrogen atoms per + molecule with R groups and returns the resulting list of SMILES. + + Args: + smiles_list (List[str]): SMILES representations of molecules + + Returns: + List[str]: SMILES reprentations of markush structures + """ + numbers = [self.depictor.random_choice(range(1, 5)) for _ in smiles_list] + r_group_smiles = [ + self.insert_R_group_var(smiles_list[index], numbers[index]) + for index in range(len(smiles_list)) + ] + return r_group_smiles + + def insert_R_group_var(self, smiles: str, num: int) -> str: + """ + This function takes a smiles string and a number of R group variables. It then + replaces the given number of H or C atoms with R groups and returns the SMILES str. + + Args: + smiles (str): SMILES (absolute) representation of a molecule + num (int): number of R group variables to be inserted + + Returns: + smiles (str): input SMILES with $num inserted R group variables + """ + smiles = self.add_explicite_hydrogen_to_smiles(smiles) + potential_replacement_positions = self.get_valid_replacement_positions(smiles) + r_groups = [] + # Replace C or H in SMILES with * + # If we would directly insert the R group variables, CDK would replace them with '*' + # later when removing the explicite hydrogen atoms + smiles = list(smiles) + for _ in range(num): + if len(potential_replacement_positions) > 0: + position = self.depictor.random_choice(potential_replacement_positions) + smiles[position] = "*" + potential_replacement_positions.remove(position) + r_groups.append(self.get_r_group_smiles()) + else: + break + # Remove explicite hydrogen again and get absolute SMILES + smiles = "".join(smiles) + smiles = self.remove_explicite_hydrogen_from_smiles(smiles) + # Replace * with R groups + for r_group in r_groups: + smiles = smiles.replace("*", r_group, 1) + return smiles + + def get_r_group_smiles(self) -> str: + """ + This function returns a random R group substring that can be inserted + into an existing SMILES str. + + Returns: + str: SMILES compatible of R group str + """ + has_indices = self.depictor.random_choice([True, True, True, True, False]) + r_group_var = self.depictor.random_choice(self.r_group_variables) + if has_indices: + index = self.depictor.random_choice(self.potential_indices) + if self.depictor.random_choice([True, False, False]): + index_char = self.depictor.random_choice(["a", "b", "c", "d", "e", "f"]) + else: + index_char = "" + return f"[{r_group_var}{index}{index_char}]" + else: + return f"[{r_group_var}]" + + def get_valid_replacement_positions(self, smiles: str) -> List[int]: + """ + Returns positions in a SMILES str where elements in the str can be replaced with + R groups without endangering its validity + + Args: + smiles (str): SMILES representation of a molecule + + Returns: + replacement_positions (List[int]): valid replacement positions for R group variables + """ + # Add space char to represent insertion position at the end of smiles str + smiles = f"{smiles} " + replacement_positions = [] + for index in range(len(smiles)): + # Be aware of digits --> don't destroy ring syntax + if smiles[index].isdigit(): + continue + # Don't replace isotopes to not to end up with [13R] + elif index >= 2 and smiles[index - 2].isdigit() and smiles[index] == "]": + continue + # Don't produce charged R groups (eg. "R+") + elif smiles[index] in ["+", "-"]: + continue + elif smiles[index - 1] == "H" and smiles[index] == "]": + replacement_positions.append(index - 1) + # Only replace "C" and "H" + elif smiles[index - 1] == "C": + # Don't replace "C" in "Cl", "Ca", Cu", etc... + if smiles[index] not in [ + "s", + "a", + "e", + "o", + "u", + "r", + "l", + "f", + "d", + "n", + "@" # replacing chiral C leads to invalid SMILES + ]: + replacement_positions.append(index - 1) + return replacement_positions + + def add_explicite_hydrogen_to_smiles(self, smiles: str) -> str: + """ + This function takes a SMILES str and uses CDK to add explicite hydrogen atoms. + It returns an adapted version of the SMILES str. + + Args: + smiles (str): SMILES representation of a molecule + + Returns: + smiles (str): SMILES representation of a molecule with explicite H + """ + i_atom_container = self.depictor._cdk_smiles_to_IAtomContainer(smiles) + + # Add explicite hydrogen atoms + cdk_base = "org.openscience.cdk." + manipulator = JClass(cdk_base + "tools.manipulator.AtomContainerManipulator") + manipulator.convertImplicitToExplicitHydrogens(i_atom_container) + + # Create absolute SMILES + smi_flavor = JClass("org.openscience.cdk.smiles.SmiFlavor").Absolute + smiles_generator = JClass("org.openscience.cdk.smiles.SmilesGenerator")( + smi_flavor + ) + smiles = smiles_generator.create(i_atom_container) + return str(smiles) + + def remove_explicite_hydrogen_from_smiles(self, smiles: str) -> str: + """ + This function takes a SMILES str and uses CDK to remove explicite hydrogen atoms. + It returns an adapted version of the SMILES str. + + Args: + smiles (str): SMILES representation of a molecule + + Returns: + smiles (str): SMILES representation of a molecule with explicite H + """ + i_atom_container = self.depictor._cdk_smiles_to_IAtomContainer(smiles) + # Remove explicite hydrogen atoms + cdk_base = "org.openscience.cdk." + manipulator = JClass(cdk_base + "tools.manipulator.AtomContainerManipulator") + i_atom_container = manipulator.copyAndSuppressedHydrogens(i_atom_container) + # Create absolute SMILES + smi_flavor = JClass("org.openscience.cdk.smiles.SmiFlavor").Absolute + smiles_generator = JClass("org.openscience.cdk.smiles.SmilesGenerator")( + smi_flavor + ) + smiles = smiles_generator.create(i_atom_container) + return str(smiles) diff --git a/RanDepict/rdkit_functionalities.py b/RanDepict/rdkit_functionalities.py new file mode 100644 index 0000000..1eb8234 --- /dev/null +++ b/RanDepict/rdkit_functionalities.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +import os +import numpy as np +import io +from skimage import io as sk_io +from skimage.util import img_as_ubyte + +from typing import Tuple, List, Dict + +from rdkit import Chem +from rdkit.Chem.rdAbbreviations import CondenseMolAbbreviations +from rdkit.Chem.rdAbbreviations import GetDefaultAbbreviations +from rdkit.Chem.Draw import rdMolDraw2D + + +class RDKitFuntionalities: + """ + Child class of RandomDepictor that contains all RDKit-related functions. + ___ + This class does not work on its own. It is meant to be used as a child class. + """ + def rdkit_depict( + self, + smiles: str = None, + mol_block: str = None, + has_R_group: bool = False, + shape: Tuple[int, int] = (512, 512) + ) -> np.array: + """ + This function takes a mol_block str and an image shape. + It renders the chemical structures using Rdkit with random + rendering/depiction settings and returns an RGB image (np.array) + with the given image shape. + + Args: + smiles (str, Optional): SMILES representation of molecule + mol block (str, Optional): mol block representation of molecule + has_R_group (bool): Whether the molecule has R groups (used to determine + whether or not to use atom numbering as it can be + confusing with R groups indices) for SMILES, this is + checked using a simple regex. This argument only has + an effect if the mol_block is provided. + # TODO: check this in mol_block + shape (Tuple[int, int], optional): im shape. Defaults to (512, 512) + + Returns: + np.array: Chemical structure depiction + """ + if not smiles and not mol_block: + raise ValueError("Either smiles or mol_block must be provided") + if smiles: + has_R_group = self.has_r_group(smiles) + mol_block = self._smiles_to_mol_block(smiles, + generate_2d=self.random_choice( + ["rdkit", "cdk", "indigo"] + )) + + mol = Chem.MolFromMolBlock(mol_block, sanitize=True) + if mol: + return self.rdkit_depict_from_mol_object(mol, has_R_group, shape) + else: + print("RDKit was unable to read input:\n{}\n{}\n".format(smiles, mol_block)) + return None + + def rdkit_depict_from_mol_object( + self, + mol: Chem.rdchem.Mol, + has_R_group: bool = False, + shape: Tuple[int, int] = (512, 512), + ) -> np.array: + """ + This function takes a mol object and an image shape. + It renders the chemical structures using Rdkit with random + rendering/depiction settings and returns an RGB image (np.array) + with the given image shape. + + Args: + mol (Chem.rdchem.Mol): RDKit mol object + has_R_group (bool): Whether the molecule has R groups (used to determine + whether or not to use atom numbering as it can be + confusing with R groups indices) + # TODO: check this in mol object + shape (Tuple[int, int], optional): im shape. Defaults to (512, 512) + + Returns: + np.array: Chemical structure depiction + """ + # Get random depiction settings + if self.random_choice([True, False], log_attribute="rdkit_collapse_superatoms"): + abbrevs = self.random_choice(self.get_all_rdkit_abbreviations()) + mol = CondenseMolAbbreviations(mol, abbrevs) + + depiction_settings = self.get_random_rdkit_rendering_settings( + has_R_group=has_R_group) + rdMolDraw2D.PrepareAndDrawMolecule(depiction_settings, mol) + depiction = depiction_settings.GetDrawingText() + depiction = sk_io.imread(io.BytesIO(depiction)) + # Resize image to desired shape + depiction = self.resize(depiction, shape) + depiction = img_as_ubyte(depiction) + return np.asarray(depiction) + + def get_all_rdkit_abbreviations( + self, + ) -> List[Chem.rdAbbreviations._vectstruct]: + """ + This function returns the Default abbreviations for superatom and functional + group collapsing in RDKit as well as alternative abbreviations defined in + rdkit_alternative_superatom_abbreviations.txt. + + Returns: + Chem.rdAbbreviations._vectstruct: RDKit's data structure that contains the + abbreviations + """ + abbreviations = [] + abbreviations.append(GetDefaultAbbreviations()) + abbr_path = self.HERE.joinpath("rdkit_alternative_superatom_abbreviations.txt") + + with open(abbr_path) as alternative_abbreviations: + split_lines = [line[:-1].split(",") + for line in alternative_abbreviations.readlines()] + swap_dict = {line[0]: line[1:] for line in split_lines} + + abbreviations.append(self.get_modified_rdkit_abbreviations(swap_dict)) + for key in swap_dict.keys(): + new_labels = [] + for label in swap_dict[key]: + if label[:2] in ["n-", "i-", "t-"]: + label = f"{label[2:]}-{label[0]}" + elif label[-2:] in ["-n", "-i", "-t"]: + label = f"{label[-1]}-{label[:-2]}" + new_labels.append(label) + swap_dict[key] = new_labels + abbreviations.append(self.get_modified_rdkit_abbreviations(swap_dict)) + return abbreviations + + def get_modified_rdkit_abbreviations( + self, + swap_dict: Dict + ) -> Chem.rdAbbreviations._vectstruct: + """ + This function takes a dictionary that maps the original superatom/FG label in + the RDKit abbreviations to the desired labels, replaces them as defined in the + dictionary and returns the abbreviations in RDKit's preferred format. + + Args: + swap_dict (Dict): Dictionary that maps the original label (eg. "Et") to the + desired label (eg. "C2H5"), a displayed label (eg. + "C2H5") and a reversed display label + (eg. "H5C2"). + Example: + {"Et": [ + "C2H5", + "C2H5" + "H5C2" + ]} + + Returns: + Chem.rdAbbreviations._vectstruct: Modified abbreviations + """ + alt_abbreviations = GetDefaultAbbreviations() + for abbr in alt_abbreviations: + alt_label = swap_dict.get(abbr.label) + if alt_label: + abbr.label, abbr.displayLabel, abbr.displayLabelW = alt_label + return alt_abbreviations + + def get_random_rdkit_rendering_settings( + self, + has_R_group: (bool) = False, + shape: Tuple[int, int] = (299, 299) + ) -> rdMolDraw2D.MolDraw2DCairo: + """ + This function defines random rendering options for the structure + depictions created using rdkit. It returns an MolDraw2DCairo object + with the settings. + + Args: + has_R_group (bool): SMILES representation of molecule + shape (Tuple[int, int], optional): im_shape. Defaults to (299, 299) + + Returns: + rdMolDraw2D.MolDraw2DCairo: Object that contains depiction settings + """ + y, x = shape + # Instantiate object that saves the settings + depiction_settings = rdMolDraw2D.MolDraw2DCairo(y, x) + # Stereo bond annotation + if self.random_choice( + [True, False], log_attribute="rdkit_add_stereo_annotation" + ): + depiction_settings.drawOptions().addStereoAnnotation = True + if self.random_choice( + [True, False], log_attribute="rdkit_add_chiral_flag_labels" + ): + depiction_settings.drawOptions().includeChiralFlagLabel = True + # Atom indices + if self.random_choice( + [True, False, False, False], log_attribute="rdkit_add_atom_indices" + ): + if not has_R_group: + depiction_settings.drawOptions().addAtomIndices = True + # Bond line width + bond_line_width = self.random_choice( + range(1, 5), log_attribute="rdkit_bond_line_width" + ) + depiction_settings.drawOptions().bondLineWidth = bond_line_width + # Draw terminal methyl groups + if self.random_choice( + [True, False], log_attribute="rdkit_draw_terminal_methyl" + ): + depiction_settings.drawOptions().explicitMethyl = True + # Label font type and size + font_dir = self.HERE.joinpath("fonts/") + font_path = os.path.join( + str(font_dir), + self.random_choice( + os.listdir(str(font_dir)), log_attribute="rdkit_label_font" + ), + ) + depiction_settings.drawOptions().fontFile = font_path + min_font_size = self.random_choice( + range(10, 20), log_attribute="rdkit_min_font_size" + ) + depiction_settings.drawOptions().minFontSize = min_font_size + depiction_settings.drawOptions().maxFontSize = 30 + # Rotate the molecule + # depiction_settings.drawOptions().rotate = self.random_choice(range(360)) + # Fixed bond length + fixed_bond_length = self.random_choice( + range(30, 45), log_attribute="rdkit_fixed_bond_length" + ) + depiction_settings.drawOptions().fixedBondLength = fixed_bond_length + # Comic mode (looks a bit hand drawn) + if self.random_choice( + [True, False, False, False, False], log_attribute="rdkit_comic_style" + ): + depiction_settings.drawOptions().comicMode = True + # Keep it black and white + depiction_settings.drawOptions().useBWAtomPalette() + return depiction_settings diff --git a/Tests/test_functions.py b/Tests/test_functions.py index 2fe40f4..50203fa 100644 --- a/Tests/test_functions.py +++ b/Tests/test_functions.py @@ -1,5 +1,6 @@ from RanDepict import RandomDepictor, DepictionFeatureRanges, RandomMarkushStructureCreator from rdkit import DataStructs +import re import numpy as np from omegaconf import OmegaConf @@ -384,7 +385,20 @@ def test_depict_and_resize_pikachu(self): for smiles in test_smiles: im = self.depictor.pikachu_depict(smiles) assert type(im) == np.ndarray - + + def test_random_depiction_with_coordinates(self): + smiles = "CCC" + with RandomDepictor() as depictor: + for index in range(20): + if index < 10: + depiction, cx_smiles = depictor.random_depiction_with_coordinates(smiles) + + else: + depiction, cx_smiles = depictor.random_depiction_with_coordinates(smiles, + augment=True) + assert type(depiction) == np.ndarray + assert cx_smiles[:3] == smiles + def test_get_depiction_functions_normal(self): # For a molecule without isotopes or R groups, all toolkits can be used observed = self.depictor.get_depiction_functions('c1ccccc1C(O)=O') @@ -432,7 +446,7 @@ def test_get_depiction_functions_X(self): def test_smiles_to_mol_str(self): # Compare generated mol file str with reference string - mol_str = self.depictor._cdk_smiles_to_mol_block("CC") + mol_str = self.depictor._smiles_to_mol_block("CC") mol_str_lines = mol_str.split('\n') with open('Tests/test.mol', 'r') as ref_mol_file: ref_lines = ref_mol_file.readlines() diff --git a/examples/RanDepictNotebook.ipynb b/examples/RanDepictNotebook.ipynb index ba222bc..ec2dc39 100644 --- a/examples/RanDepictNotebook.ipynb +++ b/examples/RanDepictNotebook.ipynb @@ -6,41 +6,38 @@ "metadata": { "scrolled": true }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "'1.1.6'" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "import os\n", "import ipyplot\n", "\n", - "from RanDepict import RandomDepictor, RandomMarkushStructureCreator" + "from RanDepict import RandomDepictor, RandomMarkushStructureCreator\n", + "import RanDepict\n", + "RanDepict.__version__" ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Depict chemical structures with CDK, RDKit, Indigo or PIKAChU\n", "\n", - "After calling an instance of RandomDepictor, depictions with randomly chosen parameters are created by calling the functions\n", - "\n", - "- depict_and_resize_cdk(\n", - " smiles: str, \n", - " image_shape: Tuple[int,int]\n", - " )\n", - "- depict_and_resize_rdkit(\n", - " smiles: str, \n", - " image_shape: Tuple[int,int]\n", - " )\n", - "- depict_and_resize_indigo(\n", - " smiles: str, \n", - " image_shape: Tuple[int,int]\n", - " )\n", - "- depict_and_resize_pikachu(\n", - " smiles: str, \n", - " image_shape: Tuple[int,int]\n", - " )\n", + "After calling an instance of RandomDepictor, depictions with randomly chosen parameters are created by calling the functions `cdk_depict`, `rdkit_depict`, `indigo_depict` and `pikachu_depict`.\n", "\n", "\n", - "The SMILES string needs to be given, the image shape defaults to (299,299,3).\n", + "The SMILES or mol_block string needs to be given.\n", "\n", "Each of these functions returns an np.array which represents an RGB image of the chemical structure." ] @@ -62,14 +59,14 @@ "text/html": [ "\n", " \n", "
\n", - " \n", - " \n", - " \n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + " \n", + "
\n", + "
\n", + "
\n", + "

0

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

1

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

2

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

3

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

4

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

5

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

6

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

7

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

8

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

9

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

10

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

11

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

12

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

13

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

14

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

15

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

16

\n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

6

\n", - " \n", + "
\n", + "
\n", + "

17

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

7

\n", - " \n", + "
\n", + "
\n", + "

18

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

8

\n", - " \n", + "
\n", + "
\n", + "

19

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

9

\n", - " \n", + "
\n", + "
\n", + "

20

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

10

\n", - " \n", + "
\n", + "
\n", + "

21

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

11

\n", - " \n", + "
\n", + "
\n", + "

22

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

12

\n", - " \n", + "
\n", + "
\n", + "

23

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

13

\n", - " \n", + "
\n", + "
\n", + "

24

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

14

\n", - " \n", + "
\n", + "
\n", + "

25

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

15

\n", - " \n", + "
\n", + "
\n", + "

26

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

16

\n", - " \n", + "
\n", + "
\n", + "

27

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

17

\n", - " \n", + "
\n", + "
\n", + "

28

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

18

\n", - " \n", + "
\n", + "
\n", + "

29

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

19

\n", - " \n", + "
\n", + "
\n", + "

30

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", + " \n", + "
\n", + "
\n", + "

31

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", "
\n", - " \n", + " \n", + "
\n", + "
\n", + "

32

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - " \n", - "
\n", - "
\n", - "
\n", - "

0

\n", - " \n", + " \n", + "
\n", + "
\n", + "

33

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

34

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

35

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

36

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

37

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

38

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

39

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

40

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

41

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

42

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

43

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

44

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

45

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

46

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

47

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

48

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

49

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

50

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

51

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

52

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

1

\n", - " \n", + "
\n", + "
\n", + "

53

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

2

\n", - " \n", + "
\n", + "
\n", + "

54

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

3

\n", - " \n", + "
\n", + "
\n", + "

55

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

4

\n", - " \n", + "
\n", + "
\n", + "

56

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

5

\n", - " \n", + "
\n", + "
\n", + "

57

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

6

\n", - " \n", + "
\n", + "
\n", + "

58

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

7

\n", - " \n", + "
\n", + "
\n", + "

59

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

8

\n", - " \n", + "
\n", + "
\n", + "

60

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

9

\n", - " \n", + "
\n", + "
\n", + "

61

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

10

\n", - " \n", + "
\n", + "
\n", + "

62

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

11

\n", - " \n", + "
\n", + "
\n", + "

63

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

12

\n", - " \n", + "
\n", + "
\n", + "

64

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

13

\n", - " \n", + "
\n", + "
\n", + "

65

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

14

\n", - " \n", + "
\n", + "
\n", + "

66

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

15

\n", - " \n", + "
\n", + "
\n", + "

67

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

16

\n", - " \n", + "
\n", + "
\n", + "

68

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

17

\n", - " \n", + "
\n", + "
\n", + "

69

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

18

\n", - " \n", + "
\n", + "
\n", + "

70

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

19

\n", - " \n", + "
\n", + "
\n", + "

71

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", - "
" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# Depict and save two batches of images\n", - "smiles = \"CN1C=NC2=C1C(=O)N(C(=O)N2C)C\"\n", - "\n", - "with RandomDepictor(42) as depictor:\n", - " fp_depictions = depictor.batch_depict_with_fingerprints([smiles],\n", - " 20,\n", - " aug_proportion = 0)\n", - " fp_aug_depictions = depictor.batch_depict_with_fingerprints([smiles],\n", - " 20,\n", - " aug_proportion = 1)\n", - "ipyplot.plot_images(fp_depictions, max_images=20, img_width=100)\n", - "ipyplot.plot_images(fp_aug_depictions, max_images=20, img_width=100)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Create and save a batch of images while ensuring diversity using feature fingerprints\n", - "\n", - "\n", - "After calling an instance of RandomDepictor, simply call the method batch_depict_save_with_fingerprints().\n", - "\n", - "Args:\n", - "\n", - "- smiles_list: List[str]\n", - "- images_per_structure: int\n", - "- output_dir: str\n", - "- ID_list: List[str]\n", - "- indigo_proportion: float = 0.15\n", - "- rdkit_proportion: float = 0.3\n", - "- cdk_proportion: float = 0.55\n", - "- aug_proportion: float = 0.5\n", - "- shape: Tuple[int, int] = (299, 299)\n", - "- processes: int = 4\n", - "- seed: int = 42\n", - "\n", - "\n", - "*Note: The images that are created here, were used for the animations in the GitHub repository" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "# Make sure the output directories exist\n", - "if not os.path.exists('not_augmented_fingerprint'):\n", - " os.mkdir('not_augmented_fingerprint')\n", - " \n", - "if not os.path.exists('augmented_fingerprint'):\n", - " os.mkdir('augmented_fingerprint')\n", - "\n", - "# Depict and save two batches of images\n", - "smiles = \"CN1C=NC2=C1C(=O)N(C(=O)N2C)CC(=O)O\"\n", - "with RandomDepictor(42) as depictor:\n", - " depictor.batch_depict_save_with_fingerprints([smiles], \n", - " 100, \n", - " 'not_augmented_fingerprint',\n", - " ['caffeine_{}'.format(n) for n in range(100)],\n", - " aug_proportion = 0)\n", - " depictor.batch_depict_save_with_fingerprints([smiles], \n", - " 100, \n", - " 'augmented_fingerprint',\n", - " ['caffeine_{}'.format(n) for n in range(100)],\n", - " aug_proportion = 1)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Artificial generation of SMILES that represent markush structures\n", - "\n", - "Generate markush structures based on list of SMILES strings" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['CC(C)[R]',\n", - " 'C1=C[R](=[R]([X15])[X13]([H])=C1)C=O',\n", - " 'C1CC(C(CC1[R10])[Z14])[R4]']" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "markush_generator = RandomMarkushStructureCreator()\n", - "input_smiles = ['CCC', 'C1=CC=CC=C1C(=O)', 'C1CCCCC1']\n", - "markush_smiles = markush_generator.generate_markush_structure_dataset(input_smiles)\n", - "markush_smiles" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Depict the markush structures using RanDepict" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Warning! Rogue electron.\n", - "R10_6\n", - "Warning! Rogue electron.\n", - "Z14_7\n", - "Warning! Rogue electron.\n", - "R4_8\n", - "Warning! Rogue electron.\n", - "R10_6\n", - "Warning! Rogue electron.\n", - "Z14_7\n", - "Warning! Rogue electron.\n", - "R4_8\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - "
\n", - " \n", - " \n", - " \n", + " \n", + "
\n", + "
\n", + "

77

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - " \n", - "
\n", - "
\n", - "
\n", - "

0

\n", - " \n", + " \n", + "
\n", + "
\n", + "

78

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

1

\n", - " \n", + "
\n", + "
\n", + "

79

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

2

\n", - " \n", + "
\n", + "
\n", + "

80

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

81

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

82

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

83

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

84

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

85

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

86

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

87

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

88

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

89

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

90

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

91

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

92

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

93

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

94

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

95

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

96

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

97

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

98

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

99

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", @@ -6306,22 +7543,30 @@ ], "source": [ "with RandomDepictor() as depictor:\n", - " markush_depictions = [depictor.random_depiction(smi)\n", + " markush_depictions = [depictor.random_depiction(smi, shape=(100,100))\n", " for smi in markush_smiles]\n", - "ipyplot.plot_images(markush_depictions, max_images=100, img_width=299)" + "ipyplot.plot_images(markush_depictions, max_images=100, img_width=100)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "?\n" + ] + } + ], "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3.7.10 ('RanDepict')", + "display_name": "base", "language": "python", "name": "python3" }, @@ -6339,7 +7584,7 @@ }, "vscode": { "interpreter": { - "hash": "fbd02a9bd0ce7b123cc507013b34dda769d49e72a01d3fbf30ab009e81556d20" + "hash": "30f3beb01d96c3b334f501aee3e5616b49ff465bee6f3dcf7506c52a9acac780" } } }, diff --git a/examples/generate_fingerprint_based_dataset_with_and_without_augmentations.py b/examples/generate_fingerprint_based_dataset_with_and_without_augmentations.py index bbd7ba3..fbfee97 100644 --- a/examples/generate_fingerprint_based_dataset_with_and_without_augmentations.py +++ b/examples/generate_fingerprint_based_dataset_with_and_without_augmentations.py @@ -179,13 +179,13 @@ def depict_from_fingerprint( # Depict molecule try: if "indigo" in list(schemes[0].keys())[0]: - depiction = depictor.depict_and_resize_indigo(smiles, shape) + depiction = depictor.indigo_depict(smiles, shape) elif "rdkit" in list(schemes[0].keys())[0]: - depiction = depictor.depict_and_resize_rdkit(smiles, shape) + depiction = depictor.rdkit_depict(smiles, shape) elif "cdk" in list(schemes[0].keys())[0]: depiction = depictor.cdk_depict(smiles, shape) elif "pikachu" in list(schemes[0].keys())[0]: - depiction = depictor.depict_and_resize_pikachu(smiles, shape) + depiction = depictor.pikachu_depict(smiles, shape) except IndexError: depiction = None diff --git a/examples/randepict_batch_run_tfrecord_output.py b/examples/randepict_batch_run_tfrecord_output.py index 6405510..2a72d27 100644 --- a/examples/randepict_batch_run_tfrecord_output.py +++ b/examples/randepict_batch_run_tfrecord_output.py @@ -1,7 +1,7 @@ import os import io from typing import Tuple, List -import argparse +import sys import numpy as np from multiprocessing import Process import time @@ -193,25 +193,21 @@ def main() -> None: The annotation is an array saved as a str ([1 2 3 4 .. .. X]) ___ ''' - # Parse arguments - parser = argparse.ArgumentParser() - parser.add_argument("file", nargs="+") - args = parser.parse_args() - - # Read input data from file - for file_in in args.file: - ID_list = [] - smiles_list = [] - tokens_list = [] - with open(file_in, "r") as fp: - for line in fp.readlines(): - if line[-1] == '\n': - line = line[:-1] - line = line.replace(";[ ", ";[").replace(" ", " ").replace(" ", ",") - ID, smiles, tokens = line.split(";") - ID_list.append(ID) - smiles_list.append(smiles) - tokens_list.append(np.array(eval(tokens))) + + num_procs = int(sys.argv[2]) + + ID_list = [] + smiles_list = [] + tokens_list = [] + with open(sys.argv[1], "r") as fp: + for line in fp.readlines(): + if line[-1] == '\n': + line = line[:-1] + line = line.replace(";[ ", ";[").replace(" ", " ").replace(" ", ",") + ID, smiles, tokens = line.split(";") + ID_list.append(ID) + smiles_list.append(smiles) + tokens_list.append(np.array(eval(tokens))) # Set desired image shape and number of depictions per SMILES and output paths im_per_SMILES_noaug = 1 @@ -230,8 +226,8 @@ def main() -> None: ID_list, SMILES_chunksize, depiction_img_shape, - 20, - random.randint(0, 100), + num_procs, + 42, 1800) From 2ab7d2ecdbbc930db7722300f1ee7eb685099628 Mon Sep 17 00:00:00 2001 From: Otto Brinkhaus Date: Thu, 15 Jun 2023 14:45:20 +0200 Subject: [PATCH 7/8] version bump: 1.1.6 --> 1.1.7 --- CITATION.cff | 2 +- RanDepict/__init__.py | 2 +- setup.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CITATION.cff b/CITATION.cff index 687433c..e013149 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -8,7 +8,7 @@ authors: given-names: "Kohulan" orcid: "https://orcid.org/0000-0003-1066-7792" title: "RanDepict" -version: 1.1.6 +version: 1.1.7 doi: 10.5281/zenodo.5205528 date-released: 2021-08-17 url: "https://github.com/OBrink/RanDepict" diff --git a/RanDepict/__init__.py b/RanDepict/__init__.py index d65b6e7..138d2ba 100644 --- a/RanDepict/__init__.py +++ b/RanDepict/__init__.py @@ -21,7 +21,7 @@ """ -__version__ = "1.1.6" +__version__ = "1.1.7" __all__ = [ "RanDepict", diff --git a/setup.py b/setup.py index a0a335a..02e74b3 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ setuptools.setup( name="RanDepict", - version="1.1.6", + version="1.1.7", author="Otto Brinkhaus", author_email="otto.brinkhaus@uni-jena.de, kohulan.rajan@uni-jena.de", maintainer="Otto Brinkhaus, Kohulan Rajan", From b545aca65c2a770943a51fb47908552742d5bfab Mon Sep 17 00:00:00 2001 From: Otto Brinkhaus Date: Thu, 15 Jun 2023 14:47:46 +0200 Subject: [PATCH 8/8] update documentation --- docs/tutorial.ipynb | 8065 ++++++++++-------------------- examples/RanDepictNotebook.ipynb | 4 +- 2 files changed, 2569 insertions(+), 5500 deletions(-) diff --git a/docs/tutorial.ipynb b/docs/tutorial.ipynb index f4a7d0c..0fe35f4 100644 --- a/docs/tutorial.ipynb +++ b/docs/tutorial.ipynb @@ -6,41 +6,38 @@ "metadata": { "scrolled": true }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "'1.1.7'" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "import os\n", "import ipyplot\n", "\n", - "from RanDepict import RandomDepictor, RandomMarkushStructureCreator" + "from RanDepict import RandomDepictor, RandomMarkushStructureCreator\n", + "import RanDepict\n", + "RanDepict.__version__" ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Depict chemical structures with CDK, RDKit, Indigo or PIKAChU\n", "\n", - "After calling an instance of RandomDepictor, depictions with randomly chosen parameters are created by calling the functions\n", - "\n", - "- depict_and_resize_cdk(\n", - " smiles: str, \n", - " image_shape: Tuple[int,int]\n", - " )\n", - "- depict_and_resize_rdkit(\n", - " smiles: str, \n", - " image_shape: Tuple[int,int]\n", - " )\n", - "- depict_and_resize_indigo(\n", - " smiles: str, \n", - " image_shape: Tuple[int,int]\n", - " )\n", - "- depict_and_resize_pikachu(\n", - " smiles: str, \n", - " image_shape: Tuple[int,int]\n", - " )\n", + "After calling an instance of RandomDepictor, depictions with randomly chosen parameters are created by calling the functions `cdk_depict`, `rdkit_depict`, `indigo_depict` and `pikachu_depict`.\n", "\n", "\n", - "The SMILES string needs to be given, the image shape defaults to (299,299,3).\n", + "The SMILES or mol_block string needs to be given.\n", "\n", "Each of these functions returns an np.array which represents an RGB image of the chemical structure." ] @@ -62,14 +59,14 @@ "text/html": [ "\n", " \n", "
\n", - " \n", - " \n", - " \n", "
\n", - " \n", - "
\n", - "
\n", - "

24

\n", - " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + " \n", + "
\n", + "
\n", + "
\n", + "

0

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

25

\n", - " \n", + "
\n", + "
\n", + "

1

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

26

\n", - " \n", + "
\n", + "
\n", + "

2

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

27

\n", - " \n", + "
\n", + "
\n", + "

3

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

28

\n", - " \n", + "
\n", + "
\n", + "

4

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

29

\n", - " \n", + "
\n", + "
\n", + "

5

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

30

\n", - " \n", + "
\n", + "
\n", + "

6

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

31

\n", - " \n", + "
\n", + "
\n", + "

7

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

32

\n", - " \n", + "
\n", + "
\n", + "

8

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

33

\n", - " \n", + "
\n", + "
\n", + "

9

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

34

\n", - " \n", + "
\n", + "
\n", + "

10

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

35

\n", - " \n", + "
\n", + "
\n", + "

11

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

36

\n", - " \n", + "
\n", + "
\n", + "

12

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

37

\n", - " \n", + "
\n", + "
\n", + "

13

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

38

\n", - " \n", + "
\n", + "
\n", + "

14

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

39

\n", - " \n", + "
\n", + "
\n", + "

15

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

40

\n", - " \n", + "
\n", + "
\n", + "

16

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

41

\n", - " \n", + "
\n", + "
\n", + "

17

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

42

\n", - " \n", + "
\n", + "
\n", + "

18

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

43

\n", - " \n", + "
\n", + "
\n", + "

19

\n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

44

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

45

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

46

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

47

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

48

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

49

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

50

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

51

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

52

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

53

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

54

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

55

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

56

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

57

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

58

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

59

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

60

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

61

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

62

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

63

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

64

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

65

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

66

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

67

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

68

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

69

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

70

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

71

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

72

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

73

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

74

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

75

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

76

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

77

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

78

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

79

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

80

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

81

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

82

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

83

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

84

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

85

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

86

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

87

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

88

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

89

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

90

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

91

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

92

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

93

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

94

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

95

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

96

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

97

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

98

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

99

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - "
\n", - " \n", - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - " \n", - "
\n", - "
\n", - "
\n", - "

0

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

1

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

2

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

3

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

4

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

5

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

6

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

7

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

8

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

9

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

10

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

11

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

12

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

13

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

14

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

15

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

16

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

17

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

18

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

19

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

20

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

21

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

22

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

23

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

24

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

25

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

26

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

27

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

28

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

29

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

30

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

31

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

32

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

33

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

34

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

35

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

36

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

37

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

38

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

39

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

40

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

41

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

42

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

43

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

44

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

45

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

46

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

47

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

48

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

49

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

50

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

51

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

52

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

53

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

54

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

55

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

56

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

57

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

58

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

59

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

60

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

61

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

62

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

63

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

64

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

65

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

66

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

67

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

68

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

69

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

70

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

71

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

72

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

73

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

74

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

75

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

76

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

77

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

78

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

79

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

80

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

81

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

82

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

83

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

84

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

85

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

86

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

87

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

88

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

89

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

90

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

91

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

92

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

93

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

94

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

95

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

96

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

97

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

98

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

99

\n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", " \n", "
\n", @@ -7223,20 +4351,395 @@ }, "metadata": {}, "output_type": "display_data" + } + ], + "source": [ + "smiles = \"CN1C=NC2=C1C(=O)N(C(=O)N2C)CC(=O)O\"\n", + "with RandomDepictor(hand_drawn=True) as depictor:\n", + " random_augmented_images = []\n", + " for _ in range(20):\n", + " random_augmented_images.append(depictor(smiles))\n", + " \n", + "\n", + "ipyplot.plot_images(random_augmented_images, max_images=20, img_width=100)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## " + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create and save a batch of images\n", + "\n", + "After calling an instance of RandomDepictor, simply call the method depict_save().\n", + "\n", + "Args:\n", + "\n", + "- smiles_list (List[str]): List of SMILES str\n", + "- images_per_structure (int): Amount of images to create per SMILES str\n", + "- output_dir (str): Output directory \n", + "- augment (bool): Boolean that indicates whether or not to use augmentations\n", + "- ID_list (List[str]): List of IDs (should be as long as smiles_list)\n", + "- shape (Tuple[int, int], optional): image shape. Defaults to (299, 299).\n", + "- processes (int, optional): Number of parallel threads. Defaults to 4.\n", + "- seed (int, optional): Seed for pseudo-random decisions. Defaults to 42." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Make sure the output directories exist\n", + "if not os.path.exists('not_augmented'):\n", + " os.mkdir('not_augmented')\n", + " \n", + "if not os.path.exists('augmented'):\n", + " os.mkdir('augmented')\n", + "\n", + "# Depict and save two batches of images\n", + "smiles = \"CN1C=NC2=C1C(=O)N(C(=O)N2C)CC(=O)O\"\n", + "with RandomDepictor(42) as depictor:\n", + " depictor.batch_depict_save([smiles], 20, 'not_augmented', False, ['caffeine'], (299, 299), 5)\n", + " depictor.batch_depict_save([smiles], 20, 'augmented', True, ['caffeine'], (299, 299), 5)\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if not os.path.exists('kohulan'):\n", + " os.mkdir(\"kohulan\")\n", + "smiles = \"CN1C=NC2=C1C(=O)N(C(=O)N2C)C\"\n", + "r_smiles = \"[R1]N1C=NC2=C1[X](=O)N(C(=O)N2C)[R]\"\n", + "seed = 233\n", + "r_seed = 1\n", + "with RandomDepictor(1) as depictor:\n", + " depictor.depict_save(smiles, 1, 'kohulan', False, 'caffeine_299_299', (299, 299), seed=seed)\n", + " depictor.depict_save(smiles, 1, 'kohulan', True, 'caffeine_aug_299_299', (299, 299), seed=seed)\n", + " depictor.depict_save(smiles, 1, 'kohulan', False, 'caffeine_512_512', (512, 512), seed=seed)\n", + " depictor.depict_save(smiles, 1, 'kohulan', True, 'caffeine_aug_512_512', (512, 512), seed=seed)\n", + " depictor.depict_save(r_smiles, 1, 'kohulan', False, 'caffeine_R_299_299', (299, 299), seed=r_seed)\n", + " depictor.depict_save(r_smiles, 1, 'kohulan', True, 'caffeine_R_aug_299_299', (299, 299), seed=r_seed)\n", + " depictor.depict_save(r_smiles, 1, 'kohulan', False, 'caffeine_R_512_512', (512, 512), seed=r_seed)\n", + " depictor.depict_save(r_smiles, 1, 'kohulan', True, 'caffeine_R_aug_512_512', (512, 512), seed=r_seed)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create a batch of images while ensuring diversity using feature fingerprints\n", + "\n", + "\n", + "After calling an instance of RandomDepictor, simply call the method batch_depict_with_fingerprints().\n", + "\n", + "Args:\n", + "\n", + "- smiles_list: List[str]\n", + "- images_per_structure: int\n", + "- indigo_proportion: float = 0.15\n", + "- rdkit_proportion: float = 0.25\n", + "- pikachu_proportion: float = 0.25\n", + "- cdk_proportion: float = 0.35\n", + "- aug_proportion: float = 0.5\n", + "- shape: Tuple[int, int] = (299, 299)\n", + "- processes: int = 4\n", + "- seed: int = 42\n", + "\n", + "* Note: Have a look at examples/generate_depiction_grids_with_fingerprints.py to see how this function was used to generate the grid figures from our publication." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Depict and save two batches of images\n", + "smiles = \"CN1C=NC2=C1C(=O)N(C(=O)N2C)C\"\n", + "\n", + "with RandomDepictor(42) as depictor:\n", + " fp_depictions = depictor.batch_depict_with_fingerprints([smiles],\n", + " 20,\n", + " aug_proportion = 0)\n", + " fp_aug_depictions = depictor.batch_depict_with_fingerprints([smiles],\n", + " 20,\n", + " aug_proportion = 1)\n", + "ipyplot.plot_images(fp_depictions, max_images=20, img_width=100)\n", + "ipyplot.plot_images(fp_aug_depictions, max_images=20, img_width=100)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create and save a batch of images while ensuring diversity using feature fingerprints\n", + "\n", + "\n", + "After calling an instance of RandomDepictor, simply call the method batch_depict_save_with_fingerprints().\n", + "\n", + "Args:\n", + "\n", + "- smiles_list: List[str]\n", + "- images_per_structure: int\n", + "- output_dir: str\n", + "- ID_list: List[str]\n", + "- indigo_proportion: float = 0.15\n", + "- rdkit_proportion: float = 0.3\n", + "- cdk_proportion: float = 0.55\n", + "- aug_proportion: float = 0.5\n", + "- shape: Tuple[int, int] = (299, 299)\n", + "- processes: int = 4\n", + "- seed: int = 42\n", + "\n", + "\n", + "*Note: The images that are created here, were used for the animations in the GitHub repository" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Make sure the output directories exist\n", + "if not os.path.exists('not_augmented_fingerprint'):\n", + " os.mkdir('not_augmented_fingerprint')\n", + " \n", + "if not os.path.exists('augmented_fingerprint'):\n", + " os.mkdir('augmented_fingerprint')\n", + "\n", + "# Depict and save two batches of images\n", + "smiles = \"CN1C=NC2=C1C(=O)N(C(=O)N2C)CC(=O)O\"\n", + "with RandomDepictor(42) as depictor:\n", + " depictor.batch_depict_save_with_fingerprints([smiles], \n", + " 100, \n", + " 'not_augmented_fingerprint',\n", + " ['caffeine_{}'.format(n) for n in range(100)],\n", + " aug_proportion = 0)\n", + " depictor.batch_depict_save_with_fingerprints([smiles], \n", + " 100, \n", + " 'augmented_fingerprint',\n", + " ['caffeine_{}'.format(n) for n in range(100)],\n", + " aug_proportion = 1)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Artificial generation of SMILES that represent markush structures\n", + "\n", + "Generate markush structures based on list of SMILES strings" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['CN1C=NC2=C1C(=O)N(C)C(=O)N2C[X20]',\n", + " 'C1=NC2=C(N1C[Z7])[X](=O)N(C(=O)N2C[X2])[X8]([H])([H])[H]',\n", + " 'C1=NC2=C(C(=O)N(C[Y1])C(=O)N2[R7e]([H])([H])[H])N1[Y4]([H])([H])[H]',\n", + " 'CN1C=NC2=C1C(=O)N(C)C(=O)N2C[X12]',\n", + " 'CN1C2=C(N(C=N2)[R]([H])([H])[H])[R13d](=O)N(C[X9])C1=O',\n", + " 'CN1C(=O)C2=C(N=[Y2c]([H])N2C[X8c])N(C1=O)[Z13]([H])([H])[H]',\n", + " 'CN1C2=C(C(=O)N(C)C1=O)N(C=N2)[X11]([H])([H])[H]',\n", + " 'CN1C=NC2=C1C(=O)N(C[R18])[X2](=O)N2C([Z15])[R16]',\n", + " 'CN1C=NC2=C1[Y14](=O)N(C)C(=O)N2C[X18b]',\n", + " 'CN1C2=C(N(C=N2)[R12c]([H])([H])[H])[R2](=O)N(C1=O)[Z]([H])([H])[H]',\n", + " 'CN1C=NC2=C1C(=O)N(C[Z])C(=O)N2C[X10e]',\n", + " 'CN1C=NC2=C1[Z2b](=O)N(C)[X](=O)N2C',\n", + " 'CN1C=NC2=C1C(=O)N(C[X7])C(=O)N2C',\n", + " 'CN1C2=C(C(=O)N(C)C1=O)N(C=N2)C[Z9]',\n", + " 'CN1C2=C(C(=O)N(C)C1=O)N(C)[R17]([H])=N2',\n", + " 'CN1C2=C(C(=O)N(C1=O)[X]([H])([H])[H])N(C=N2)C[X4d]',\n", + " 'CN1C=NC2=C1[Y18](=O)N(C)C(=O)N2C[Y20]',\n", + " 'CN1C2=C(N=C1[X8b])N(C(=O)N(C)C2=O)C([X])[R2]',\n", + " 'CN1C2=C(C(=O)N(C[R11f])C1=O)N(C=N2)[R5]([H])([Z])[Z]',\n", + " 'CN1C2=C(C(=O)N(C)C1=O)N(C=N2)[Z12c]([H])([H])[Z2]',\n", + " 'C(N1C2=C(N=C1[R14])N(C[X])C(=O)N(C2=O)[Z14]([H])([H])[H])[Y]',\n", + " 'CN1C2=C(C(=O)N(C1=O)[Y18]([H])([H])[Y20])N(C[X])C(=N2)[Z17d]',\n", + " 'CN1[Y18](=O)C2=C(N=CN2C[X4c])N(C[R15a])[Y19c]1=O',\n", + " 'CN1C=NC2=C1C(=O)N(C[Y2])C(=O)N2C',\n", + " 'CN1C2=C(C(=O)N(C)C1=O)N(C=N2)C[Z20]',\n", + " 'CN1C=NC2=C1C(=O)N(C[Y])[X20b](=O)N2[R9]([H])([H])[Y16d]',\n", + " 'CN1C2=C(N=C1[X16])N(C)C(=O)N(C)C2=O',\n", + " 'CN1C=NC2=C1C(=O)N(C[R1])C(=O)N2C',\n", + " 'CN1C=NC2=C1C(=O)N(C[Z1a])C(=O)N2C([R5])[Y5d]',\n", + " 'CN1C=NC2=C1[Y10](=O)N(C)C(=O)N2C',\n", + " 'CN1C2=C(N=[X1a]1[H])N(C[X1])C(=O)N(C[Z1])C2=O',\n", + " 'CN1C=NC2=C1[Y10](=O)N(C)C(=O)N2C',\n", + " 'CN1C(=O)C2=C(N=CN2C([R])[Y10])N(C1=O)C([Y11])[R20f]',\n", + " 'CN1C2=C(N(C=N2)C([Z3])[Z4f])[Z6b](=O)N(C)C1=O',\n", + " 'CN1C2=C(N=[Z]1[H])N(C[Y9a])C(=O)N(C[R8])C2=O',\n", + " 'CN1C2=C(N(C=N2)[Y9]([H])([H])[H])[Y12b](=O)N(C)C1=O',\n", + " 'CN1C=NC2=C1C(=O)N(C[Z1])C(=O)N2C[Z]',\n", + " 'CN1C=NC2=C1C(=O)N(C[Y11b])[R](=O)N2[X]([H])([H])[H]',\n", + " 'CN1C2=C(N=C1[R17])N(C(=O)N(C[Z10])C2=O)[Y13a]([H])([H])[Z20]',\n", + " 'CN1C=NC2=C1C(=O)N(C[Y13])C(=O)N2[X1]([H])([H])[H]',\n", + " 'CN1C(=O)N(C[Y12f])C2=C(N(C(=N2)[Y9])[Y13]([H])([H])[H])[R17b]1=O',\n", + " 'CN1C2=C(N(C[X18])[X2]([H])=N2)[R14](=O)N(C)C1=O',\n", + " 'CN1C(=O)C2=C(N=CN2[X19d]([H])([H])[H])N(C1=O)[X14]([H])([H])[H]',\n", + " 'CN1C=NC2=C1C(=O)N(C[R])C(=O)N2C',\n", + " 'CN1C(=O)C2=C(N=CN2C[Z11b])N(C1=O)[R6c]([H])([H])[H]',\n", + " 'CN1C2=C(C(=O)N(C[Z4c])[X2]1=O)N([R18]([H])([H])[H])[Y19]([H])=N2',\n", + " 'CN1C=NC2=C1C(=O)N(C)C(=O)N2C[Y]',\n", + " 'CN1C(=O)C2=C(N=CN2C[X16e])N(C1=O)[Z15]([H])([H])[H]',\n", + " 'CN1C=NC2=C1C(=O)N(C)C(=O)N2C[R5]',\n", + " 'CN1C2=C(N=[Y7]1[H])N(C[X9])[X4f](=O)N(C)[X14]2=O',\n", + " 'CN1[R](=O)C2=C(N=CN2C[R])N([Z1]([H])([H])[H])[Z]1=O',\n", + " 'CN1C=NC2=C1[R19d](=O)N(C[Y])C(=O)N2C',\n", + " 'CN1C2=C(C(=O)N(C[Z8])C1=O)N(C)[X17]([H])=N2',\n", + " 'CN1C(=O)C2=C(N=CN2[Y1]([H])([H])[X16f])N(C[X17])C1=O',\n", + " 'CN1C2=C(C(=O)N(C)C1=O)N(C[X20])C(=N2)[Y12f]',\n", + " 'CN1C=NC2=C1C(=O)N(C[Y4b])[Z](=O)N2C[Y]',\n", + " 'C1=NC2=C(C(=O)N(C(=O)N2C[Z5])C([Z])[Z])N1C[X15c]',\n", + " 'CN1C2=C(N=C1[Z2e])N(C)C(=O)N(C2=O)[X4b]([H])([H])[H]',\n", + " 'CN1C(=O)C2=C(N=[Z13]([H])N2C)N(C[Z])C1=O',\n", + " 'CN1C(=O)C2=C(N=CN2C[Z10])N(C1=O)[Y16]([H])([H])[H]',\n", + " 'CN1C(=O)C2=C(N=CN2C[X17])N(C1=O)[R18]([H])([X])[X6]',\n", + " 'CN1C2=C(C(=O)N(C)C1=O)N(C=N2)C[R17]',\n", + " 'CN1C2=C(N=C1[Z6])N(C)[Z17c](=O)N(C[R4])C2=O',\n", + " 'C(N1C2=C(C(=O)N(C[R1])C1=O)N(C(=N2)[R13])[Z3f]([H])([H])[H])[X19f]',\n", + " 'CN1C2=C(C(=O)N(C[Y])C1=O)N(C=N2)[Y4a]([H])([H])[H]',\n", + " 'CN1C(=O)C2=C(N=CN2C[Y4b])N(C[R14])C1=O',\n", + " 'C1=NC2=C(C(=O)N(C[Z14])C(=O)N2C[Z16])N1C[R20]',\n", + " 'C1=NC2=C(C(=O)N(C(=O)N2C[Y20])C([Z17d])[Y13])N1[X8]([H])([H])[H]',\n", + " 'CN1C=NC2=C1C(=O)N(C(=O)N2C)[Y9]([H])([H])[R20]',\n", + " 'CN1C2=C(C(=O)N(C1=O)C([Y16f])[X14e])N(C=N2)[Y4]([H])([H])[X17f]',\n", + " 'CN1C=NC2=C1C(=O)N(C(=O)N2C)[X19e]([H])([H])[H]',\n", + " 'CN1C=NC2=C1C(=O)N(C(=O)N2[X15]([H])([H])[H])[X5]([H])([H])[H]',\n", + " 'CN1C2=C(C(=O)N(C)[Z18]1=O)N([X13]([H])([H])[H])[Y8]([H])=N2',\n", + " 'CN1C2=C(N(C=N2)C[R19])[Y11](=O)N(C[Y3])C1=O',\n", + " 'CN1C2=C(N(C=N2)[Z11]([H])([H])[H])[R](=O)N(C)C1=O',\n", + " 'CN1C=NC2=C1C(=O)N(C)C(=O)N2C([Z18e])[R19]',\n", + " 'C1=NC2=C(C(=O)N(C[R])C(=O)N2C[Z])N1C[Z14]',\n", + " 'CN1C2=C(C(=O)N(C)C1=O)N(C[Z3])C(=N2)[X]',\n", + " 'CN1C=NC2=C1C(=O)N(C(=O)N2C)[R10]([H])([H])[H]',\n", + " 'CN1C2=C(C(=O)N(C1=O)[X]([H])([H])[H])N(C=N2)C[R7a]',\n", + " 'CN1C2=C(C(=O)N(C)C1=O)N(C=N2)C[Y19]',\n", + " 'CN1C2=C(C(=O)N(C1=O)[X]([H])([H])[H])N(C[Z12a])[Y11]([H])=N2',\n", + " 'CN1C2=C(N=[Z]1[H])N(C)[R16](=O)N(C)C2=O',\n", + " 'CN1C=NC2=C1C(=O)N(C)[Z12](=O)N2C',\n", + " 'CN1C2=C(C(=O)N(C)C1=O)N(C)[Z8e]([H])=N2',\n", + " 'CN1C(=O)C2=C(N=CN2C[Y9])N(C1=O)[Z]([H])([H])[H]',\n", + " 'CN1C2=C(N(C)C(=N2)[X5])[R](=O)N(C[R17])C1=O',\n", + " 'CN1C2=C(C(=O)N(C)[R11]1=O)N(C(=N2)[Z13d])C([X18])[R]',\n", + " 'C1=NC2=C(C(=O)N(C[X6])C(=O)N2[X19]([H])([H])[H])N1C[R1d]',\n", + " 'C1=NC2=C(C(=O)N(C[R])[Y5](=O)N2C[R])N1C[Y13]',\n", + " 'CN1C2=C(C(=O)N(C1=O)C([X6])([R14a])[Y4])N(C=N2)C[Z4]',\n", + " 'C1=NC2=C(C(=O)N(C[X15c])C(=O)N2C([X4])[R1])N1C[Z17e]',\n", + " 'CN1C2=C(N(C[Y3])[Y]([H])=N2)[Y18](=O)N(C)C1=O',\n", + " 'CN1C(=O)C2=C(N=C(N2[R9]([H])([H])[Y11])[X8a])N(C[Z])C1=O',\n", + " 'CN1C2=C(C(=O)N(C)C1=O)N(C=N2)C[R20c]',\n", + " 'CN1C=NC2=C1C(=O)N(C[Z8])C(=O)N2C[Z3b]',\n", + " 'CN1C=NC2=C1C(=O)N(C(=O)N2C)[Z14b]([H])([H])[H]',\n", + " 'CN1C=NC2=C1C(=O)N(C[X16])C(=O)N2C',\n", + " 'CN1C=NC2=C1C(=O)N(C(=O)N2C)[X8]([H])([H])[H]',\n", + " 'CN1C2=C(N(C(=N2)[Y20])[R2e]([H])([H])[H])[R12](=O)N(C1=O)[Y12b]([H])([H])[H]']" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "markush_generator = RandomMarkushStructureCreator()\n", + "input_smiles = ['CN1C=NC2=C1C(=O)N(C(=O)N2C)C'] * 100\n", + "markush_smiles = markush_generator.generate_markush_structure_dataset(input_smiles)\n", + "markush_smiles" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Depict the markush structures using RanDepict" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Warning! Rogue electron.\n", + "X12_14\n", + "Warning! Rogue electron.\n", + "X12_14\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[14:09:41] WARNING: not removing hydrogen atom with dummy atom neighbors\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Warning! Rogue electron.\n", + "R_10\n", + "Warning! Rogue electron.\n", + "R_10\n", + "Warning! Rogue electron.\n", + "Z14_8\n", + "Warning! Rogue electron.\n", + "Z16_13\n", + "Warning! Rogue electron.\n", + "R20_16\n", + "Warning! Rogue electron.\n", + "Z14_8\n", + "Warning! Rogue electron.\n", + "Z16_13\n", + "Warning! Rogue electron.\n", + "R20_16\n", + "Warning! Rogue electron.\n", + "X16_10\n", + "Warning! Rogue electron.\n", + "X16_10\n" + ] }, { "data": { "text/html": [ "\n", " \n", "
\n", - " \n", - " \n", - " \n", + " \n", + "
\n", + "
\n", + "

80

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - " \n", - "
\n", - "
\n", - "
\n", - "

0

\n", - " \n", + " \n", + "
\n", + "
\n", + "

81

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

1

\n", - " \n", + "
\n", + "
\n", + "

82

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
\n", "
\n", " \n", - "
\n", - "
\n", - "

2

\n", - " \n", + "
\n", + "
\n", + "

83

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

84

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

85

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

86

\n", + " \n", " \n", " \n", " \n", - " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

87

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

88

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

89

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

90

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

91

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

92

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

93

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

94

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

95

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

96

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

97

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

98

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

99

\n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
\n", @@ -10482,24 +7543,32 @@ ], "source": [ "with RandomDepictor() as depictor:\n", - " markush_depictions = [depictor.random_depiction(smi)\n", + " markush_depictions = [depictor.random_depiction(smi, shape=(100,100))\n", " for smi in markush_smiles]\n", - "ipyplot.plot_images(markush_depictions, max_images=100, img_width=299)" + "ipyplot.plot_images(markush_depictions, max_images=100, img_width=100)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "?\n" + ] + } + ], "source": [] } ], "metadata": { "kernelspec": { - "display_name": "RanDepict", + "display_name": "base", "language": "python", - "name": "randepict" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -10515,7 +7584,7 @@ }, "vscode": { "interpreter": { - "hash": "7161fd5581079cf7327a177c4ccb46649bc05fdb264323da3438b807a661b6f2" + "hash": "30f3beb01d96c3b334f501aee3e5616b49ff465bee6f3dcf7506c52a9acac780" } } }, diff --git a/examples/RanDepictNotebook.ipynb b/examples/RanDepictNotebook.ipynb index ec2dc39..1b710f2 100644 --- a/examples/RanDepictNotebook.ipynb +++ b/examples/RanDepictNotebook.ipynb @@ -10,7 +10,7 @@ { "data": { "text/plain": [ - "'1.1.6'" + "'1.1.7'" ] }, "execution_count": 1, @@ -4513,7 +4513,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [