From c3d97b0d3bf74a55c561202373d3f1bae3d1602e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=B6ffler=2C=20Hannes?= Date: Fri, 22 Nov 2024 11:38:28 +0100 Subject: [PATCH] sync with internal AZ verion 4.5.11 --- .env | 2 +- .gitattributes | 2 - CHANGELOG.md | 210 +++++++++++++- NEWS.md | 20 ++ README.md | 6 +- configs/json/scoring.json | 4 +- configs/toml/pepinvent.smi | 6 + configs/toml/sampling.toml | 7 + configs/toml/staged_learning.toml | 7 + .../test_comp_unwanted_substructures.py | 4 +- pyproject.toml | 13 +- reinvent/Reinvent.py | 267 ++++-------------- reinvent/__init__.py | 2 +- reinvent/__main__.py | 4 +- .../library_design/attachment_points.py | 12 +- .../chemistry/library_design/bond_maker.py | 3 + reinvent/chemistry/tokens.py | 2 + reinvent/datapipeline/__main__.py | 3 + reinvent/datapipeline/filters/chem.py | 27 +- reinvent/datapipeline/preprocess.py | 14 +- reinvent/datapipeline/validation.py | 1 + reinvent/logger.py | 87 ------ reinvent/models/__init__.py | 2 + .../linkinvent/model_vocabulary/vocabulary.py | 1 - .../linkinvent/networks/attention_layer.py | 1 - .../models/linkinvent/networks/decoder.py | 6 +- .../models/linkinvent/networks/encoder.py | 7 +- .../models/model_factory/libinvent_adapter.py | 3 +- .../models/model_factory/pepinvent_adapter.py | 11 + .../model_factory/transformer_adapter.py | 6 +- .../models/transformer/pepinvent/__init__.py | 0 .../models/transformer/pepinvent/pepinvent.py | 6 + reinvent/runmodes/RL/__init__.py | 1 + reinvent/runmodes/RL/distance_penalty.py | 31 ++ reinvent/runmodes/RL/learning.py | 11 +- reinvent/runmodes/RL/libinvent.py | 5 +- reinvent/runmodes/RL/mol2mol.py | 30 +- reinvent/runmodes/RL/pepinvent.py | 25 ++ reinvent/runmodes/RL/reports/csv_summmary.py | 3 +- reinvent/runmodes/RL/reports/data.py | 2 + reinvent/runmodes/RL/reports/remote.py | 5 +- reinvent/runmodes/RL/reports/tensorboard.py | 2 +- reinvent/runmodes/RL/run_staged_learning.py | 24 +- reinvent/runmodes/RL/validation.py | 1 + reinvent/runmodes/TL/learning.py | 8 +- reinvent/runmodes/TL/run_transfer_learning.py | 3 +- reinvent/runmodes/TL/validation.py | 2 +- reinvent/runmodes/create_adapter.py | 4 +- reinvent/runmodes/dtos/__init__.py | 2 +- reinvent/runmodes/handler.py | 34 ++- reinvent/runmodes/reporter/__init__.py | 1 - reinvent/runmodes/reporter/remote.py | 106 ------- reinvent/runmodes/samplers/__init__.py | 1 + reinvent/runmodes/samplers/libinvent.py | 36 ++- reinvent/runmodes/samplers/linkinvent.py | 31 +- reinvent/runmodes/samplers/mol2mol.py | 10 +- reinvent/runmodes/samplers/pepinvent.py | 77 +++++ reinvent/runmodes/samplers/reinvent.py | 6 +- reinvent/runmodes/samplers/reports/common.py | 12 +- .../runmodes/samplers/reports/tensorboard.py | 4 +- reinvent/runmodes/samplers/run_sampling.py | 13 +- reinvent/runmodes/samplers/validation.py | 1 + reinvent/runmodes/scoring/validation.py | 7 + reinvent/runmodes/setup_sampler.py | 7 +- reinvent/runmodes/utils/helpers.py | 36 +-- reinvent/scoring/config.py | 23 +- reinvent/scoring/importer.py | 7 +- reinvent/scoring/scorer.py | 23 +- reinvent/scoring/transforms/__init__.py | 1 + .../scoring/transforms/exponential_decay.py | 37 +++ reinvent/scoring/transforms/transform.py | 3 +- reinvent/utils/__init__.py | 5 + reinvent/utils/cli.py | 109 +++++++ reinvent/{ => utils}/config_parse.py | 83 ++++-- reinvent/utils/helpers.py | 77 +++++ reinvent/utils/logmon.py | 246 ++++++++++++++++ reinvent/{ => utils}/prior_registry.py | 3 +- reinvent/validation.py | 6 +- reinvent/version.py | 14 +- .../OpenEye/rocs/rocs_similarity.py | 4 +- .../components/RDKit/comp_group_count.py | 2 +- .../RDKit/comp_matching_substructure.py | 2 +- reinvent_plugins/components/RDKit/comp_pmi.py | 2 +- .../RDKit/comp_rdkit_descriptors.py | 2 +- .../components/RDKit/comp_similarity.py | 2 +- reinvent_plugins/components/comp_chemprop.py | 4 +- reinvent_plugins/components/comp_icolos.py | 6 +- reinvent_plugins/components/comp_maize.py | 4 +- reinvent_plugins/components/run_program.py | 2 +- reinvent_plugins/normalizers/rdkit_smiles.py | 6 +- requirements-linux-64.lock | 69 +---- support/run-qsartuna.py | 42 +++ support/run-rascore.py | 4 +- tests/datapipeline/test_percent.py | 32 +++ tests/datapipeline/test_unwanted_tokens.py | 36 +++ .../libinvent/RNN/__init__.py | 0 .../libinvent/RNN/model_tests/__init__.py | 0 .../RNN/model_tests/decorator_model_test.py | 69 +++++ .../RNN/model_tests/test_likelihood.py | 29 ++ .../RNN/vocabulary_tests/__init__.py | 1 + .../test_tokenization_with_model.py | 25 ++ .../libinvent/transformer/__init__.py | 0 .../transformer/test_libinvent_model.py | 78 +++++ .../libinvent/transformer/test_likelihood.py | 31 ++ .../dataset_tests/test_paireddataset.py | 4 +- .../integration_tests/pepinvent/__init__.py | 0 .../pepinvent/test_likelihood.py | 27 ++ .../pepinvent/test_pepinvent_model.py | 71 +++++ .../unit_tests/libinvent/RNN/__init__.py | 0 .../unit_tests/libinvent/RNN/fixtures.py | 55 ++++ .../libinvent/RNN/model_tests/__init__.py | 0 .../RNN/model_tests/decorator_model_test.py | 67 +++++ .../RNN/model_tests/test_likelihood.py | 30 ++ .../RNN/vocabulary_tests/__init__.py | 1 + .../test_tokenization_with_model.py | 13 + .../RNN/vocabulary_tests/test_tokenizer.py | 107 +++++++ .../RNN/vocabulary_tests/test_vocabulary.py | 102 +++++++ .../libinvent/transformer/__init__.py | 0 .../transformer/dataset_tests/__init__.py | 1 + .../dataset_tests/test_paired_dataset.py | 152 ++++++++++ .../libinvent/transformer/fixtures.py | 46 +++ .../transformer/model_tests/__init__.py | 0 .../model_tests/test_libinvent_model.py | 77 +++++ .../model_tests/test_likelihood.py | 24 ++ tests/models/unit_tests/pepinvent/__init__.py | 0 tests/models/unit_tests/pepinvent/fixtures.py | 43 +++ .../unit_tests/pepinvent/test_likelihood.py | 21 ++ .../pepinvent/test_pepinvent_model.py | 67 +++++ .../components/RDKit/test_comp_mol_volume.py | 3 +- .../sampling_tests/test_sampling.py | 1 + .../unit_tests/test_remote_reporter.py | 6 +- tests/scoring/unit_tests/test_parsing.py | 81 ++++++ tests/scoring/unit_tests/test_transforms.py | 19 ++ tests/test_data.py | 6 + 134 files changed, 2663 insertions(+), 737 deletions(-) delete mode 100644 .gitattributes create mode 100644 configs/toml/pepinvent.smi create mode 100644 reinvent/datapipeline/__main__.py delete mode 100644 reinvent/logger.py create mode 100644 reinvent/models/model_factory/pepinvent_adapter.py create mode 100644 reinvent/models/transformer/pepinvent/__init__.py create mode 100644 reinvent/models/transformer/pepinvent/pepinvent.py create mode 100644 reinvent/runmodes/RL/distance_penalty.py create mode 100644 reinvent/runmodes/RL/pepinvent.py mode change 100755 => 100644 reinvent/runmodes/RL/reports/data.py delete mode 100644 reinvent/runmodes/reporter/__init__.py delete mode 100644 reinvent/runmodes/reporter/remote.py create mode 100644 reinvent/runmodes/samplers/pepinvent.py create mode 100644 reinvent/scoring/transforms/exponential_decay.py create mode 100644 reinvent/utils/__init__.py create mode 100644 reinvent/utils/cli.py rename reinvent/{ => utils}/config_parse.py (74%) create mode 100644 reinvent/utils/helpers.py create mode 100644 reinvent/utils/logmon.py rename reinvent/{ => utils}/prior_registry.py (94%) create mode 100755 support/run-qsartuna.py create mode 100644 tests/datapipeline/test_percent.py create mode 100644 tests/datapipeline/test_unwanted_tokens.py create mode 100644 tests/models/integration_tests/libinvent/RNN/__init__.py create mode 100644 tests/models/integration_tests/libinvent/RNN/model_tests/__init__.py create mode 100644 tests/models/integration_tests/libinvent/RNN/model_tests/decorator_model_test.py create mode 100644 tests/models/integration_tests/libinvent/RNN/model_tests/test_likelihood.py create mode 100644 tests/models/integration_tests/libinvent/RNN/vocabulary_tests/__init__.py create mode 100644 tests/models/integration_tests/libinvent/RNN/vocabulary_tests/test_tokenization_with_model.py create mode 100644 tests/models/integration_tests/libinvent/transformer/__init__.py create mode 100644 tests/models/integration_tests/libinvent/transformer/test_libinvent_model.py create mode 100644 tests/models/integration_tests/libinvent/transformer/test_likelihood.py create mode 100644 tests/models/integration_tests/pepinvent/__init__.py create mode 100644 tests/models/integration_tests/pepinvent/test_likelihood.py create mode 100644 tests/models/integration_tests/pepinvent/test_pepinvent_model.py create mode 100644 tests/models/unit_tests/libinvent/RNN/__init__.py create mode 100644 tests/models/unit_tests/libinvent/RNN/fixtures.py create mode 100644 tests/models/unit_tests/libinvent/RNN/model_tests/__init__.py create mode 100644 tests/models/unit_tests/libinvent/RNN/model_tests/decorator_model_test.py create mode 100644 tests/models/unit_tests/libinvent/RNN/model_tests/test_likelihood.py create mode 100644 tests/models/unit_tests/libinvent/RNN/vocabulary_tests/__init__.py create mode 100644 tests/models/unit_tests/libinvent/RNN/vocabulary_tests/test_tokenization_with_model.py create mode 100644 tests/models/unit_tests/libinvent/RNN/vocabulary_tests/test_tokenizer.py create mode 100644 tests/models/unit_tests/libinvent/RNN/vocabulary_tests/test_vocabulary.py create mode 100644 tests/models/unit_tests/libinvent/transformer/__init__.py create mode 100644 tests/models/unit_tests/libinvent/transformer/dataset_tests/__init__.py create mode 100644 tests/models/unit_tests/libinvent/transformer/dataset_tests/test_paired_dataset.py create mode 100644 tests/models/unit_tests/libinvent/transformer/fixtures.py create mode 100644 tests/models/unit_tests/libinvent/transformer/model_tests/__init__.py create mode 100644 tests/models/unit_tests/libinvent/transformer/model_tests/test_libinvent_model.py create mode 100644 tests/models/unit_tests/libinvent/transformer/model_tests/test_likelihood.py create mode 100644 tests/models/unit_tests/pepinvent/__init__.py create mode 100644 tests/models/unit_tests/pepinvent/fixtures.py create mode 100644 tests/models/unit_tests/pepinvent/test_likelihood.py create mode 100644 tests/models/unit_tests/pepinvent/test_pepinvent_model.py diff --git a/.env b/.env index 20a1897..8fa6892 100644 --- a/.env +++ b/.env @@ -1,4 +1,4 @@ # example dotenv file # make the scoring components in contrib/ available -PYTHONPATH=/location/to/REINVENT4/contrib +#PYTHONPATH=/location/to/REINVENT4/contrib diff --git a/.gitattributes b/.gitattributes deleted file mode 100644 index 9709b8d..0000000 --- a/.gitattributes +++ /dev/null @@ -1,2 +0,0 @@ -*.prior filter=lfs diff=lfs merge=lfs -text -*.pkl filter=lfs diff=lfs merge=lfs -text diff --git a/CHANGELOG.md b/CHANGELOG.md index e1d2fc4..15b40d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,214 @@ This follows the guideline on [keep a changelog](https://keepachangelog.com/) - CAZP scoring component +## [4.5.11] 2024-11-18 + +### Changed + +- Convert float nan and infs to valid json format before remote reporting + + +## [4.5.10] 2024-11-16 + +### Added + +- optional tautomer canonicalisation in data pipeline + + +## [4.5.9] 2024-11-07 + +### Fixed + +- read configuration file from stdin + + +## [4.5.8] 2024-11-07 + +### Changed + +- refactor of top level code + + +## [4.5.7] 2024-11-07 + +### Fixed + +- check if DF is set + + +## [4.5.6] 2024-11-05 + +### Added + +- YAML configuration file reader + + +## [4.5.5] 2024-11-05 + +### Added + +- Logging of configuration file absolute path + +### Changed + +- Automatic configuration file format detection + + +## [4.5.4] 2024-10-28 + +### Added + +- Exponential decay transform + +### Fixed + +- Ambiguity in parsing optional parameters with multiple endpoints and multiple optional parameters + + +## [4.5.3] 2024-10-23 + +### Added + +- component-level parameters for scoring components + + +## [4.5.2] 2024-10-23 + +### Added + +- executable module: can run `python -m reinvent` + + +## [4.5.1] 2024-10-23 + +### Added + +- SIGUSR1 for controlled termination + + +## [4.5.0] 2024-10-08 + +### Added + +- PepInvent in Sampling and Staged learning mode with example toml config provided +- PepInvent prior + + +## [4.4.37] 2024-10-07 + +### Fixed + +- Atom map number removal for Libinvent sampling dropped SMILES + + +## [4.4.36] 2024-09-27 + +### Added + +- Stage number for JSON to remote monitor + +### Changed + +- Relaxed dependencies + + +## [4.4.35] 2024-09-26 + +### Added + +- Terminate staged learning on SIGTERM and check if running in multiprocessing environment + +### Changed + +- ValueError for all scoring components such that the staged learning handler can handle failing components + + +## [4.4.34] 2024-09-16 + +### Fixed + +- SMILES in DF memory were wrongly computed + + +## [4.4.33] 2024-09-14 + +### Fixed + +- run-qsartuna.py: convert ndarray to list to make it JSON serializble + + +## [4.4.32] 2024-09-13 + +### Fixed + +- PMI component: check for embedding failure in RDKit's conformer generator + + +## [4.4.31] 2024-09-13 + +### Fixed + +- Dockstream component wrongly quoted the SMILES string +- Diversity filter setup in config was ignored + + +## [4.4.30] 2024-09-12 + +### Fixed + +- Fixed config reading bug for DF + + +## [4.4.29] 2024-09-05 + +### Changed + +- Changed Molformer sampling valid and unique from percentage to fraction on tensorboard + + +## [4.4.28] 2024-08-29 + +### Fixed + +- Fixed incorrect tanimoto similarity log in Mol2Mol sampling mode + + +## [4.4.27] 2024-07-23 + +### Fixed + +- Corrected typo in Libinvent report + + +## [4.4.26] 2024-07-21 + +### Fixed + +- Report for sampling returned np.array which is incompatibile with JSON serialization + + +## [4.4.25] 2024-07-19 + +### Fixed + +- Allowed responder as an optional input in scoring input validation + + +## [4.4.24] 2024-07-19 + +### Fixed + +- Fixed remote for Libinvent +- Batchsize defaults to 1 for TL + + +## [4.4.23] 2024-07-18 + +### Fixed + +- Added temperature parameter in Sampling and RL config validation + + ## [4.4.22] 2024-07-10 ### Fixed @@ -314,7 +522,7 @@ Various code improvements. ### Added -- Stages can now defined their own diversity filters. Global filter always overwrites stage settings. Currently no mechanism to carry over DF from previous stage, use single stage runs. +- Stages can now define their own diversity filters. Global filter always overwrites stage settings. Currently no mechanism to carry over DF from previous stage, use single stage runs. ## [4.3.11] 2024-04-30 diff --git a/NEWS.md b/NEWS.md index 178dbdf..fa5bcd9 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,23 @@ +New in REINVENT 4.5 +=================== + +For details see CHANGELOG.md. + +* PepINVENT: transformer (SMILES) based peptide generator and prior model +* Temperature factor parameter (transformer generators) for sampling and RL +* Support script run-qsartuna.py to play QSARtuna models in external environment +* Component-level parameters for scoring components +* Renamed Qptuna scoring component to [QSARtuna](https://github.com/MolecularAI/QSARtuna) +* Staged learning terminates on SIGTERM (Ctrl-C) and writes out checkpoint file +* SIGUSR1 for graceful termination of staged learning runs +* Relaxed dependencies to accomodate install of other software in same environment e.g. QSARtuna +* Updated some dependencies e.g. PyTorch (now at version 2.4.1) +* New notebook in contrib demoing docking with DockStream and OpenEye +* YAML configuration file reader +* Configuration file format is automatically detected from filename extension +* Various code improvements and fixes + + New in REINVENT 4.4 =================== diff --git a/README.md b/README.md index 4bdf261..918dae6 100644 --- a/README.md +++ b/README.md @@ -146,9 +146,9 @@ Unit and Integration Tests -------------------------- This is primarily for developers and admins/users who wish to ensure that the -installation works. The information here is not relevant to the -practical use of REINVENT. Please refer to _Basic Usage_ for instructions on -how to use the `reinvent` command. +installation works. The information here is not relevant to the practical use +of REINVENT. Please refer to _Basic Usage_ for instructions on how to use the +`reinvent` command. The REINVENT project uses the `pytest` framework for its tests. Before you run them you first have to create a configuration file for the tests. diff --git a/configs/json/scoring.json b/configs/json/scoring.json index 9e023d0..802905f 100644 --- a/configs/json/scoring.json +++ b/configs/json/scoring.json @@ -1,9 +1,9 @@ { "run_type": "scoring", - "output_csv": "scoring.csv", "json_out_config": "_scoring.json", "parameters": { - "smiles_file": "compounds.smi" + "smiles_file": "compounds.smi", + "output_csv": "scoring.csv" }, "scoring": { "type": "geometric_mean", diff --git a/configs/toml/pepinvent.smi b/configs/toml/pepinvent.smi new file mode 100644 index 0000000..28daef4 --- /dev/null +++ b/configs/toml/pepinvent.smi @@ -0,0 +1,6 @@ +# Example peptide file for REINVENT4 PepInvent +# +# One masked peptide with CHUCKLES representation per line +# ? for mask + +?|N[C@@H](CO)C(=O)|?|N[C@@H](Cc1ccc(O)cc1)C(=O)|N(C)[C@@H]([C@@H](C)O)C(=O)|N[C@H](Cc1c[nH]cn1)C(=O)|N[C@@H](CC(=O)N)C2(=O) \ No newline at end of file diff --git a/configs/toml/sampling.toml b/configs/toml/sampling.toml index 8348b86..dde4100 100644 --- a/configs/toml/sampling.toml +++ b/configs/toml/sampling.toml @@ -30,6 +30,13 @@ model_file = "priors/reinvent.prior" #temperature = 1.0 # temperature in multinomial sampling #tb_logdir = "tb_logs" # name of the TensorBoard logging directory +## Pepinvent +#model_file = "priors/pepinvent.prior" +#smiles_file = "pepinvent.smi" +#sample_strategy = "beamsearch" # multinomial or beamsearch (deterministic) +#temperature = 1.0 # temperature in multinomial sampling +#tb_logdir = "tb_logs" # name of the TensorBoard logging directory + output_file = 'sampling.csv' # sampled SMILES and NLL in CSV format num_smiles = 157 # number of SMILES to be sampled, 1 per input SMILES diff --git a/configs/toml/staged_learning.toml b/configs/toml/staged_learning.toml index 88b6e37..e98c4e4 100644 --- a/configs/toml/staged_learning.toml +++ b/configs/toml/staged_learning.toml @@ -45,6 +45,13 @@ agent_file = "priors/reinvent.prior" #sample_strategy = "multinomial" # multinomial or beamsearch (deterministic) #distance_threshold = 100 +## Pepinvent +#prior_file = "priors/pepinvent.prior" +#agent_file = "priors/pepinvent.prior" +#smiles_file = "pepinvent.smi" +#sample_strategy = "multinomial" # multinomial or beamsearch (deterministic) +#distance_threshold = 100 + batch_size = 64 # network unique_sequences = true # if true remove all duplicates raw sequences in each step diff --git a/contrib/tests/reinvent_plugins/unit_tests/components/test_comp_unwanted_substructures.py b/contrib/tests/reinvent_plugins/unit_tests/components/test_comp_unwanted_substructures.py index 84fccf2..da17547 100644 --- a/contrib/tests/reinvent_plugins/unit_tests/components/test_comp_unwanted_substructures.py +++ b/contrib/tests/reinvent_plugins/unit_tests/components/test_comp_unwanted_substructures.py @@ -2,8 +2,8 @@ import numpy as np -from reinvent_plugins.components.comp_unwanted_substructures import Parameters -from reinvent_plugins.components.comp_unwanted_substructures import UnwantedSubstructures +from reinvent_plugins.components.RDKit_extra.comp_unwanted_substructures import Parameters +from reinvent_plugins.components.RDKit_extra.comp_unwanted_substructures import UnwantedSubstructures SMILIES = [ "CC1=C(C=C(C=C1)N2C(=O)C(=C(N2)C)N=NC3=CC=CC(=C3O)C4=CC(=CC=C4)C(=O)O)C", # Eltrombopag diff --git a/pyproject.toml b/pyproject.toml index 538a757..82417e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,9 +47,8 @@ dependencies = [ "mmpdb >=2.1,<3", "molvs >=0.1.1,<0.2", "numpy >=1.21,<2", - "OpenEye-toolkits >=2022", # Requires --extra-index-url=https://pypi.anaconda.org/OpenEye/simple + "openEye-toolkits >=2022", # Requires --extra-index-url=https://pypi.anaconda.org/OpenEye/simple "pandas >=2,<3", - "pathos >=0.3.0,<2", "Pillow >=10.0,<11.0", "pydantic >=2,<3", "pytest >=8,<9", @@ -59,20 +58,18 @@ dependencies = [ "rdkit >=2021.0", "requests >=2.28,<3", "requests_mock >=1.10,<2", - "scikit-learn==1.2.2", - "scipy >=1.10,<2", "tenacity >=8.2,<9", - "tensorboard", + "tensorboard >=2,<3", "tomli >=2.0,<3", - "torch==2.3.1+cu121", # Requires --extra-index-url https://download.pytorch.org/whl/cu121 - "torchvision==0.18.1+cu121", # Needed to log molecular images to Tensorboard. + "torch==2.5.1+cu124", # Requires --extra-index-url https://download.pytorch.org/whl/cu121 "tqdm >=4.64,<5", "typing_extensions >=4.0,<5", "xxhash >=3,<4", ] [project.scripts] -reinvent = "reinvent.Reinvent:main" +reinvent = "reinvent.Reinvent:main_script" +reinvent_datapre = "reinvent.datapipeline.preprocess:main_script" # FIXME: change urls for public release. diff --git a/reinvent/Reinvent.py b/reinvent/Reinvent.py index a28c3b5..d9b9b17 100755 --- a/reinvent/Reinvent.py +++ b/reinvent/Reinvent.py @@ -4,234 +4,73 @@ from __future__ import annotations import os import sys -import argparse from dotenv import load_dotenv, find_dotenv import platform import getpass -import random -import logging import datetime -import subprocess as sp -from typing import List, Optional +from typing import Any + +from reinvent.utils import ( + parse_command_line, + get_cuda_driver_version, + set_seed, + extract_sections, + write_json_config, + enable_rdkit_log, + setup_responder, + config_parse, +) SYSTEM = platform.system() if SYSTEM != "Windows": import resource # Unix only -from rdkit import rdBase, RDLogger -import numpy as np +from rdkit import rdBase import rdkit import torch -from reinvent import version, runmodes, config_parse, setup_logger +from reinvent import version, runmodes +from reinvent.utils import setup_logger from reinvent.runmodes.utils import set_torch_device -from reinvent.runmodes.reporter.remote import setup_reporter +from reinvent.runmodes.handler import StageInterruptedControlled from .validation import ReinventConfig -INPUT_FORMAT_CHOICES = ("toml", "json") -RDKIT_CHOICES = ("all", "error", "warning", "info", "debug") -LOGLEVEL_CHOICES = tuple(level.lower() for level in logging._nameToLevel.keys()) -VERSION_STR = f"{version.__progname__} {version.__version__} {version.__copyright__}" -OVERWRITE_STR = "Overwrites setting in the configuration file" -RESPONDER_TOKEN = "RESPONDER_TOKEN" rdBase.DisableLog("rdApp.*") -# rdBase.LogToPythonLogger() -def enable_rdkit_log(levels: List[str]): - """Enable logging messages from RDKit for a specific logging level. +def main(args: Any): + """Simple entry point into Reinvent's run modes. - :param levels: the specific level(s) that need to be silenced + :param args: arguments object, can be argparse.Namespace or any other class """ - if "all" in levels: - RDLogger.EnableLog("rdApp.*") - return - - for level in levels: - RDLogger.EnableLog(f"rdApp.{level}") - - -def get_cuda_driver_version() -> Optional[str]: - """Get the CUDA driver version via modinfo if possible. - - This is for Linux only. - - :returns: driver version or None - """ - - # Alternative - # result = sp.run(["/usr/bin/nvidia-smi"], shell=False, capture_output=True) - # if "Driver Version:" in str_line: - # version = str_line.split()[5] - - try: - result = sp.run(["/sbin/modinfo", "nvidia"], shell=False, capture_output=True) - except Exception: - return - - for line in result.stdout.splitlines(): - str_line = line.decode() - - if str_line.startswith("version:"): - cuda_driver_version = str_line.split()[1] - return cuda_driver_version - - -def set_seed(seed: int): - """Set global seed for reproducibility - - :param seed: the seed to initialize the random generators - """ - - if seed is None: - return - - random.seed(seed) - - os.environ["PYTHONHASHSEED"] = str(seed) - - np.random.seed(seed) - - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.backends.cudnn.deterministic = True - - -def extract_sections(config: dict) -> dict: - """Extract the sections of a config file - - :param config: the config file - :returns: the extracted sections - """ - - # FIXME: stages are a list of dicts in RL, may clash with global lists - return {k: v for k, v in config.items() if isinstance(v, (dict, list))} - - -def parse_command_line(): - parser = argparse.ArgumentParser( - description=f"{version.__progname__}: a molecular design " - f"tool for de novo design, " - "scaffold hopping, R-group replacement, linker design, molecule " - "optimization, and others", - epilog=f"{VERSION_STR}", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - - parser.add_argument( - "config_filename", - nargs="?", - default=None, - metavar="FILE", - type=os.path.abspath, - help="Input configuration file with runtime parameters", - ) - - parser.add_argument( - "-f", - "--config-format", - metavar="FORMAT", - choices=INPUT_FORMAT_CHOICES, - default="toml", - help=f"File format of the configuration file: {', '.join(INPUT_FORMAT_CHOICES)}", - ) - - parser.add_argument( - "-d", - "--device", - metavar="DEV", - default=None, - help=f"Device to run on: cuda, cpu. {OVERWRITE_STR}.", - ) - - parser.add_argument( - "-l", - "--log-filename", - metavar="FILE", - default=None, - type=os.path.abspath, - help=f"File for logging information, otherwise writes to stderr.", - ) - - parser.add_argument( - "--log-level", - metavar="LEVEL", - choices=LOGLEVEL_CHOICES, - default="info", - help=f"Enable this and 'higher' log levels: {', '.join(LOGLEVEL_CHOICES)}.", - ) - - parser.add_argument( - "-s", - "--seed", - metavar="N", - type=int, - default=None, - help="Sets the random seeds for reproducibility", - ) - - parser.add_argument( - "--dotenv-filename", - metavar="FILE", - default=None, - type=os.path.abspath, - help=f"Dotenv file with environment setup needed for some scoring components. " - "By default the one from the installation directory will be loaded.", - ) - - parser.add_argument( - "--enable-rdkit-log-levels", - metavar="LEVEL", - choices=RDKIT_CHOICES, - nargs="+", - help=f"Enable specific RDKit log levels: {', '.join(RDKIT_CHOICES)}.", + logger = setup_logger( + name=__package__, level=args.log_level.upper(), filename=args.log_filename ) - parser.add_argument( - "-V", - "--version", - action="version", - version=f"{VERSION_STR}.", + logger.info( + f"Started {version.__progname__} {version.__version__} {version.__copyright__} on " + f"{datetime.datetime.now().strftime('%Y-%m-%d')}" ) - return parser.parse_args() - - -def setup_responder(config): - """Setup for remote monitor - - :param config: configuration - """ - - endpoint = config.get("endpoint", False) - - if not endpoint: - return - - token = os.environ.get(RESPONDER_TOKEN, None) - setup_reporter(endpoint, token) - - -def write_json_config(global_dict, json_out_config): - def dummy(config): - global_dict.update(config) - config_parse.write_json(global_dict, json_out_config) - - return dummy + logger.info(f"Command line: {' '.join(sys.argv)}") + dotenv_loaded = load_dotenv(args.dotenv_filename) # set up the environment for scoring -def main(): - """Simple entry point into Reinvent's run modes.""" + ext = None - args = parse_command_line() + if args.config_filename: + ext = args.config_filename.suffix - dotenv_loaded = load_dotenv(args.dotenv_filename) # set up the environment for scoring + if ext in (f".{e}" for e in config_parse.INPUT_FORMAT_CHOICES): + fmt = ext[1:] + else: + fmt = args.config_format - reader = getattr(config_parse, f"read_{args.config_format}") - input_config = reader(args.config_filename) + logger.info(f"Reading run configuration from {args.config_filename} using format {fmt}") + input_config = config_parse.read_config(args.config_filename, fmt) val_config = ReinventConfig(**input_config) if args.enable_rdkit_log_levels: @@ -239,9 +78,6 @@ def main(): run_type = input_config["run_type"] runner = getattr(runmodes, f"run_{run_type}") - logger = setup_logger( - name=__package__, level=args.log_level.upper(), filename=args.log_filename - ) have_version = input_config.get("version", version.__config_version__) @@ -250,13 +86,6 @@ def main(): logger.fatal(msg) raise RuntimeError(msg) - logger.info( - f"Started {version.__progname__} {version.__version__} {version.__copyright__} on " - f"{datetime.datetime.now().strftime('%Y-%m-%d')}" - ) - - logger.info(f"Command line: {' '.join(sys.argv)}") - if dotenv_loaded: if args.dotenv_filename: filename = args.dotenv_filename @@ -331,13 +160,16 @@ def main(): f"with frequency {input_config['responder']['frequency']}" ) - runner( - input_config=extract_sections(input_config), - device=actual_device, - tb_logdir=tb_logdir, - responder_config=responder_config, - write_config=write_config, - ) + try: + runner( + input_config=extract_sections(input_config), + device=actual_device, + tb_logdir=tb_logdir, + responder_config=responder_config, + write_config=write_config, + ) + except StageInterruptedControlled as e: + logger.info(f"Requested to terminate: {e.args[0]}") if SYSTEM != "Windows": maxrss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss @@ -356,5 +188,12 @@ def main(): ) +def main_script(): + """Main entry point from the command line""" + + args = parse_command_line() + main(args) + + if __name__ == "__main__": - main() + main_script() diff --git a/reinvent/__init__.py b/reinvent/__init__.py index a811125..af9cdfa 100644 --- a/reinvent/__init__.py +++ b/reinvent/__init__.py @@ -5,8 +5,8 @@ import sys -from .logger import * from reinvent import models +from reinvent.version import * from reinvent.models.libinvent.models import vocabulary # not sure why needed diff --git a/reinvent/__main__.py b/reinvent/__main__.py index 63e0f1e..6e4e555 100644 --- a/reinvent/__main__.py +++ b/reinvent/__main__.py @@ -1,3 +1,3 @@ -from .Reinvent import main +from .Reinvent import main_script -main() +main_script() diff --git a/reinvent/chemistry/library_design/attachment_points.py b/reinvent/chemistry/library_design/attachment_points.py index 61bc597..bc7686e 100644 --- a/reinvent/chemistry/library_design/attachment_points.py +++ b/reinvent/chemistry/library_design/attachment_points.py @@ -5,6 +5,7 @@ from reinvent.chemistry import conversions, tokens + def add_attachment_point_numbers(mol_or_smi, canonicalize=True): """ Adds the numbers for the attachment points throughout the molecule. @@ -36,16 +37,15 @@ def _ap_callback(_): idx += 1 return conversions.mol_to_smiles(mol) + def get_attachment_points(smile: str) -> List: """ Gets all attachment points from SMILES string. :param smile: A SMILES string :return : A list with the numbers ordered by appearance. """ - return [ - int(match.group(1)) - for match in re.finditer(tokens.ATTACHMENT_POINT_NUM_REGEXP, smile) - ] + return [int(match.group(1)) for match in re.finditer(tokens.ATTACHMENT_POINT_NUM_REGEXP, smile)] + def get_attachment_points_for_molecule(molecule: Mol) -> List: """ @@ -61,6 +61,7 @@ def get_attachment_points_for_molecule(molecule: Mol) -> List: and atom.HasProp("molAtomMapNumber") ] + def add_first_attachment_point_number(smi, num): """ Changes/adds a number to the first attachment point. @@ -75,6 +76,7 @@ def add_first_attachment_point_number(smi, num): count=1, ) + def remove_attachment_point_numbers(smile: str) -> str: """ Removes the numbers for the attachment points throughout the molecule. @@ -88,6 +90,7 @@ def remove_attachment_point_numbers(smile: str) -> str: ) return result + def remove_attachment_point_numbers_from_mol(molecule: Mol) -> Mol: """ Removes the numbers for the attachment points throughout the molecule. @@ -99,6 +102,7 @@ def remove_attachment_point_numbers_from_mol(molecule: Mol) -> Mol: atom.ClearProp("molAtomMapNumber") return molecule + def add_brackets_to_attachment_points(scaffold: str): """ Adds brackets to the attachment points (if they don't have them). diff --git a/reinvent/chemistry/library_design/bond_maker.py b/reinvent/chemistry/library_design/bond_maker.py index 2d1c152..cd5d3cf 100644 --- a/reinvent/chemistry/library_design/bond_maker.py +++ b/reinvent/chemistry/library_design/bond_maker.py @@ -29,6 +29,7 @@ def join_scaffolds_and_decorations( return None return mol + def join_molecule_fragments(scaffold: Mol, decoration: Mol, keep_label_on_atoms=False): """ Joins a RDKit MOL scaffold with a decoration. They must be labelled. @@ -110,6 +111,7 @@ def join_molecule_fragments(scaffold: Mol, decoration: Mol, keep_label_on_atoms= return scaffold + def add_attachment_point_num(atom, idx): idxs = [] if atom.HasProp("molAtomMapNumber"): @@ -127,6 +129,7 @@ def add_attachment_point_num(atom, idx): # instead, keep the last attachement point only. This implies that reaction filters will not be compatible # in the case of attachment points on the same atom (they did not work in R3 either) + def randomize_scaffold(scaffold: Mol): smi = conversions.mol_to_random_smiles(scaffold) conv_smi = None diff --git a/reinvent/chemistry/tokens.py b/reinvent/chemistry/tokens.py index 5a39878..8a62c08 100644 --- a/reinvent/chemistry/tokens.py +++ b/reinvent/chemistry/tokens.py @@ -6,3 +6,5 @@ ATTACHMENT_POINT_NUM_REGEXP = r"\[{}:(\d+)\]".format(re.escape(ATTACHMENT_POINT_TOKEN)) ATTACHMENT_POINT_REGEXP = r"(?:{0}|\[{0}[^\]]*\])".format(re.escape(ATTACHMENT_POINT_TOKEN)) ATTACHMENT_POINT_NO_BRACKETS_REGEXP = r"(? Optional[str]: cleanup_params = Standardizer.CleanupParameters() self.normalizer = Standardizer.NormalizerFromData(self.transforms, cleanup_params) + self.tautomer_enumerator = Standardizer.TautomerEnumerator() + self.instantiated = True try: @@ -109,13 +110,23 @@ def clean_smiles(self, smiles, config: FilterSection): if self.config.uncharge: Standardizer.ReionizeInPlace(mol) - new_smiles = Chem.MolToSmiles( - mol, - canonical=True, - isomericSmiles=config.keep_stereo, - kekuleSmiles=self.config.kekulize, - doRandom=self.config.randomize_smiles, - ) + # NOTE: this can be vary slow, easily by a factor of 10 or more + if config.canonical_tautomer: + mol = self.tautomer_enumerator.Canonicalize(mol) + + try: + new_smiles = Chem.MolToSmiles( + mol, + canonical=True, + isomericSmiles=config.keep_stereo, + kekuleSmiles=self.config.kekulize, + doRandom=self.config.randomize_smiles, + ) + except RuntimeError as e: + if "Invariant Violation" in e.args[0]: + return None + else: + raise # FIXME: an atom may have 3 ring numbers or more e.g. # C%108%11 which is %10 8 %11 and should become 8%11 %10 diff --git a/reinvent/datapipeline/preprocess.py b/reinvent/datapipeline/preprocess.py index bc31f23..ae38c4d 100755 --- a/reinvent/datapipeline/preprocess.py +++ b/reinvent/datapipeline/preprocess.py @@ -54,10 +54,7 @@ def parse_command_line(): return parser.parse_args() -def main(): - mp.set_start_method("spawn", force=True) - - args = parse_command_line() +def main(args): with open(args.config_filename, "rb") as tf: cfg = tomli.load(tf) @@ -245,5 +242,12 @@ def main(): listener.join() +def main_script(): + mp.set_start_method("spawn", force=True) + + args = parse_command_line() + main(args) + + if __name__ == "__main__": - main() + main_script() diff --git a/reinvent/datapipeline/validation.py b/reinvent/datapipeline/validation.py index 8773209..4d484d1 100644 --- a/reinvent/datapipeline/validation.py +++ b/reinvent/datapipeline/validation.py @@ -20,6 +20,7 @@ class FilterSection(GlobalConfig): keep_stereo: bool = True keep_isotope_molecules: bool = True uncharge: bool = True + canonical_tautomer: bool = False kekulize: bool = False randomize_smiles: bool = False report_errors: bool = False diff --git a/reinvent/logger.py b/reinvent/logger.py deleted file mode 100644 index 35e0067..0000000 --- a/reinvent/logger.py +++ /dev/null @@ -1,87 +0,0 @@ -__all__ = ["CsvFormatter", "setup_logger"] -import sys -import logging -import csv -import io -from logging.config import dictConfig, fileConfig - - -class CsvFormatter(logging.Formatter): - def __init__(self): - super().__init__() - self.output = io.StringIO() - self.writer = csv.writer(self.output) - - def format(self, record): - self.writer.writerow(record.msg) # needs to be a iterable - data = self.output.getvalue() - self.output.truncate(0) - self.output.seek(0) - return data.strip() - - -def setup_logger( - name: str = None, - config: dict = None, - filename: str = None, - formatter=None, - stream=sys.stderr, - cfg_filename: str = None, - propagate: bool = True, - level=logging.INFO, - debug=False, -): - """Setup a logging facility. - - :param name: name of the logger, root if empty or None - :param config: dictionary configuration - :param filename: optional filename for logging output - :param formatter: a logging formatter - :param stream: the output stream - :param cfg_filename: filename of a logger configuration file - :param propagate: whether to propagate to higher level loggers - :param level: logging level - :param debug: set special format for debugging - :returns: the newly set up logger - """ - - logging.captureWarnings(True) - - logger = logging.getLogger(name) - logger.setLevel(level) - - for handler in logger.handlers[:]: - logger.removeHandler(handler) - - if config is not None: - dictConfig(config) - return - - if cfg_filename is not None: - fileConfig(cfg_filename) - return - - if filename: - handler = logging.FileHandler(filename, mode="w+") - else: - handler = logging.StreamHandler(stream) - - handler.setLevel(level) - - if debug: - log_format = "%(asctime)s %(module)s.%(funcName)s +%(lineno)s: %(levelname)-4s %(message)s" - else: - log_format = "%(asctime)s <%(levelname)-4.4s> %(message)s" - - if not formatter: - formatter = logging.Formatter( - fmt=log_format, - datefmt="%H:%M:%S", - ) - - handler.setFormatter(formatter) - - logger.addHandler(handler) - logger.propagate = propagate - - return logger diff --git a/reinvent/models/__init__.py b/reinvent/models/__init__.py index 201ea08..6c618e3 100644 --- a/reinvent/models/__init__.py +++ b/reinvent/models/__init__.py @@ -6,6 +6,7 @@ from .transformer.linkinvent.linkinvent import LinkinventModel as LinkinventTransformerModel from .transformer.libinvent.libinvent import LibinventModel as LibinventTransformerModel from .transformer.mol2mol.mol2mol import Mol2MolModel +from .transformer.pepinvent.pepinvent import PepinventModel from .model_factory.model_adapter import * from .model_factory.reinvent_adapter import * @@ -13,5 +14,6 @@ from .model_factory.linkinvent_adapter import * from .model_factory.mol2mol_adapter import * from .model_factory.transformer_adapter import * +from .model_factory.pepinvent_adapter import * from .meta_data import * diff --git a/reinvent/models/linkinvent/model_vocabulary/vocabulary.py b/reinvent/models/linkinvent/model_vocabulary/vocabulary.py index 757de8a..fa6114d 100644 --- a/reinvent/models/linkinvent/model_vocabulary/vocabulary.py +++ b/reinvent/models/linkinvent/model_vocabulary/vocabulary.py @@ -105,7 +105,6 @@ def encode(self, tokens): raise KeyError(f"{token} is not supported! Supported tokens are {self.tokens()}.") return ohe_vect - def decode(self, ohe_vect): """ Decodes a one-hot encoded vector matrix to a list of tokens. diff --git a/reinvent/models/linkinvent/networks/attention_layer.py b/reinvent/models/linkinvent/networks/attention_layer.py index cd3678c..d421c1a 100644 --- a/reinvent/models/linkinvent/networks/attention_layer.py +++ b/reinvent/models/linkinvent/networks/attention_layer.py @@ -5,7 +5,6 @@ class AttentionLayer(tnn.Module): - def __init__(self, num_dimensions: int): super(AttentionLayer, self).__init__() diff --git a/reinvent/models/linkinvent/networks/decoder.py b/reinvent/models/linkinvent/networks/decoder.py index 2a7d1cb..5d9a245 100644 --- a/reinvent/models/linkinvent/networks/decoder.py +++ b/reinvent/models/linkinvent/networks/decoder.py @@ -44,11 +44,7 @@ def forward( seq_lengths: torch.Tensor, encoder_padded_seqs: torch.Tensor, hidden_states: Tuple[torch.Tensor], - ) -> ( - torch.Tensor, - Tuple[torch.Tensor], - torch.Tensor, - ): # pylint: disable=arguments-differ + ) -> (torch.Tensor, Tuple[torch.Tensor], torch.Tensor,): # pylint: disable=arguments-differ """ Performs the forward pass. :param padded_seqs: A tensor with the output sequences (batch, seq_d, dim). diff --git a/reinvent/models/linkinvent/networks/encoder.py b/reinvent/models/linkinvent/networks/encoder.py index cc4d9ee..5347a46 100644 --- a/reinvent/models/linkinvent/networks/encoder.py +++ b/reinvent/models/linkinvent/networks/encoder.py @@ -31,10 +31,9 @@ def __init__(self, num_layers: int, num_dimensions: int, vocabulary_size: int, d bidirectional=True, ) - def forward(self, padded_seqs: torch.Tensor, seq_lengths: torch.Tensor) -> ( - torch.Tensor, - (torch.Tensor, torch.Tensor), - ): # pylint: disable=arguments-differ + def forward( + self, padded_seqs: torch.Tensor, seq_lengths: torch.Tensor + ) -> (torch.Tensor, (torch.Tensor, torch.Tensor),): # pylint: disable=arguments-differ """ Performs the forward pass. :param padded_seqs: A tensor with the sequences (batch, seq). diff --git a/reinvent/models/model_factory/libinvent_adapter.py b/reinvent/models/model_factory/libinvent_adapter.py index 964903d..2fc72dc 100644 --- a/reinvent/models/model_factory/libinvent_adapter.py +++ b/reinvent/models/model_factory/libinvent_adapter.py @@ -42,5 +42,6 @@ def sample(self, scaffold_seqs, scaffold_seq_lengths) -> SampleBatch: sampled = self.model.sample_decorations(scaffold_seqs, scaffold_seq_lengths) return SampleBatch(*sampled) + class LibinventTransformerAdapter(TransformerAdapter): - pass \ No newline at end of file + pass diff --git a/reinvent/models/model_factory/pepinvent_adapter.py b/reinvent/models/model_factory/pepinvent_adapter.py new file mode 100644 index 0000000..b0b1075 --- /dev/null +++ b/reinvent/models/model_factory/pepinvent_adapter.py @@ -0,0 +1,11 @@ +"""Adapter for Pepinvent""" + +from __future__ import annotations + +__all__ = ["PepinventAdapter"] + +from reinvent.models.model_factory.transformer_adapter import TransformerAdapter + + +class PepinventAdapter(TransformerAdapter): + pass diff --git a/reinvent/models/model_factory/transformer_adapter.py b/reinvent/models/model_factory/transformer_adapter.py index 6d48504..ebb60f3 100644 --- a/reinvent/models/model_factory/transformer_adapter.py +++ b/reinvent/models/model_factory/transformer_adapter.py @@ -34,7 +34,11 @@ def likelihood_smiles( output = [dto.output for dto in sampled_sequence_list] dataset = PairedDataset(input, output, vocabulary=self.vocabulary, tokenizer=self.tokenizer) data_loader = tud.DataLoader( - dataset, LIKELIHOOD_BATCH_SIZE, drop_last=False, shuffle=False, collate_fn=PairedDataset.collate_fn + dataset, + LIKELIHOOD_BATCH_SIZE, + drop_last=False, + shuffle=False, + collate_fn=PairedDataset.collate_fn, ) likelihood = [] diff --git a/reinvent/models/transformer/pepinvent/__init__.py b/reinvent/models/transformer/pepinvent/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/reinvent/models/transformer/pepinvent/pepinvent.py b/reinvent/models/transformer/pepinvent/pepinvent.py new file mode 100644 index 0000000..8c27efe --- /dev/null +++ b/reinvent/models/transformer/pepinvent/pepinvent.py @@ -0,0 +1,6 @@ +from reinvent.models.transformer.transformer import TransformerModel + + +class PepinventModel(TransformerModel): + _model_type = "Pepinvent" + _version = 1 diff --git a/reinvent/runmodes/RL/__init__.py b/reinvent/runmodes/RL/__init__.py index 089c868..4e5f1ad 100644 --- a/reinvent/runmodes/RL/__init__.py +++ b/reinvent/runmodes/RL/__init__.py @@ -5,5 +5,6 @@ from .libinvent import * from .linkinvent import * from .mol2mol import * +from .pepinvent import * from .reward import * from .terminators import * diff --git a/reinvent/runmodes/RL/distance_penalty.py b/reinvent/runmodes/RL/distance_penalty.py new file mode 100644 index 0000000..3fce259 --- /dev/null +++ b/reinvent/runmodes/RL/distance_penalty.py @@ -0,0 +1,31 @@ +import numpy as np +import torch + +from reinvent.runmodes.RL import Learning + + +def get_distance_to_prior(likelihood, distance_threshold: float) -> np.ndarray: + # FIXME: the datatype should not be variable + if isinstance(likelihood, torch.Tensor): + ones = torch.ones_like(likelihood, requires_grad=False) + mask = torch.where(likelihood < distance_threshold, ones, distance_threshold / likelihood) + mask = mask.cpu().numpy() + else: + ones = np.ones_like(likelihood) + mask = np.where(likelihood < distance_threshold, ones, distance_threshold / likelihood) + + return mask + + +def score(learning: Learning): + """Compute the score for the SMILES strings.""" + prior_nll = learning.prior.likelihood_smiles(learning.sampled).likelihood + distance_penalties = get_distance_to_prior(prior_nll, learning.distance_threshold) + + results = learning.scoring_function( + learning.sampled.smilies, learning.invalid_mask, learning.duplicate_mask + ) + + results.total_scores *= distance_penalties + + return results diff --git a/reinvent/runmodes/RL/learning.py b/reinvent/runmodes/RL/learning.py index c4dda85..38f15fa 100644 --- a/reinvent/runmodes/RL/learning.py +++ b/reinvent/runmodes/RL/learning.py @@ -20,7 +20,7 @@ from .reports import RLTBReporter, RLCSVReporter, RLRemoteReporter, RLReportData from reinvent.runmodes.RL.data_classes import ModelState from reinvent.models.model_factory.sample_batch import SmilesState -from reinvent.runmodes.reporter.remote import get_reporter +from reinvent.utils import get_reporter from reinvent_plugins.normalizers.rdkit_smiles import normalize if TYPE_CHECKING: @@ -36,9 +36,11 @@ class Learning(ABC): """Partially abstract base class for the Template Method pattern""" + # FIXME: too many arguments def __init__( self, max_steps: int, + stage_no: int, prior: ModelAdapter, state: ModelState, scoring_function: Scorer, @@ -54,6 +56,7 @@ def __init__( """Setup of the common framework""" self.max_steps = max_steps + self.stage_no = stage_no self.prior = prior # Seed the starting state, need update in every stage @@ -117,7 +120,7 @@ def optimize(self, converged: terminator_callable) -> bool: results = self.score() if self.prior.model_type == "Libinvent": - results.smilies = normalize(results.smilies) + results.smilies = normalize(results.smilies, keep_all=True) if self._state.diversity_filter: df_mask = np.where(self.invalid_mask, True, False) @@ -308,11 +311,12 @@ def report( smilies = np.array(self.sampled.smilies)[mask_valid] if self.prior.model_type == "Libinvent": - smilies = normalize(smilies) + smilies = normalize(smilies, keep_all=True) mask_idx = (np.argwhere(mask_valid).flatten(),) report_data = RLReportData( step=step_no, + stage=self.stage_no, smilies=smilies, scaffolds=scaffolds, sampled=self.sampled, @@ -326,6 +330,7 @@ def report( loss=loss, fraction_valid_smiles=fract_valid_smiles, fraction_duplicate_smiles=fract_duplicate_smiles, + df_memory_smilies=len(diversity_filter.smiles_memory) if diversity_filter else 0, bucket_max_size=( diversity_filter.scaffold_memory.max_size if diversity_filter else None ), diff --git a/reinvent/runmodes/RL/libinvent.py b/reinvent/runmodes/RL/libinvent.py index ab48bfb..3f94d05 100644 --- a/reinvent/runmodes/RL/libinvent.py +++ b/reinvent/runmodes/RL/libinvent.py @@ -15,8 +15,7 @@ class LibinventLearning(Learning): - """LibInvent optimization - """ + """LibInvent optimization""" def update(self, results: ScoreResults): if self.prior.version == 1: # RNN-based @@ -25,4 +24,4 @@ def update(self, results: ScoreResults): return self._update_common_transformer(results) -LibinventTransformerLearning = LibinventLearning \ No newline at end of file +LibinventTransformerLearning = LibinventLearning diff --git a/reinvent/runmodes/RL/mol2mol.py b/reinvent/runmodes/RL/mol2mol.py index e6bf550..1095f1d 100644 --- a/reinvent/runmodes/RL/mol2mol.py +++ b/reinvent/runmodes/RL/mol2mol.py @@ -6,11 +6,8 @@ import logging from typing import TYPE_CHECKING -import torch -import numpy as np - from .learning import Learning -from reinvent.models.model_factory.sample_batch import SmilesState +from .distance_penalty import score as _score if TYPE_CHECKING: from reinvent.scoring import ScoreResults @@ -18,34 +15,11 @@ logger = logging.getLogger(__name__) -def get_distance_to_prior(likelihood, distance_threshold: float) -> np.ndarray: - # FIXME: the datatype should not be variable - if isinstance(likelihood, torch.Tensor): - ones = torch.ones_like(likelihood, requires_grad=False) - mask = torch.where(likelihood < distance_threshold, ones, distance_threshold / likelihood) - mask = mask.cpu().numpy() - else: - ones = np.ones_like(likelihood) - mask = np.where(likelihood < distance_threshold, ones, distance_threshold / likelihood) - - return mask - - class Mol2MolLearning(Learning): """Mol2Mol optimization""" def score(self): - """Compute the score for the SMILES stings.""" - - prior_nll = self.prior.likelihood_smiles(self.sampled).likelihood - distance_penalty = get_distance_to_prior(prior_nll, self.distance_threshold) - - results = self.scoring_function( - self.sampled.smilies, self.invalid_mask, self.duplicate_mask - ) - results.total_scores *= distance_penalty - - return results + return _score(self) def update(self, results: ScoreResults): return self._update_common_transformer(results) diff --git a/reinvent/runmodes/RL/pepinvent.py b/reinvent/runmodes/RL/pepinvent.py new file mode 100644 index 0000000..e017ff4 --- /dev/null +++ b/reinvent/runmodes/RL/pepinvent.py @@ -0,0 +1,25 @@ +"""The Pepinvent optimization algorithm""" + +from __future__ import annotations + +__all__ = ["PepinventLearning"] +import logging +from typing import TYPE_CHECKING + +from .learning import Learning +from .distance_penalty import score as _score + +if TYPE_CHECKING: + from reinvent.scoring import ScoreResults + +logger = logging.getLogger(__name__) + + +class PepinventLearning(Learning): + """Pepinvent optimization""" + + def score(self): + return _score(self) + + def update(self, results: ScoreResults): + return self._update_common_transformer(results) diff --git a/reinvent/runmodes/RL/reports/csv_summmary.py b/reinvent/runmodes/RL/reports/csv_summmary.py index e286415..3be4951 100644 --- a/reinvent/runmodes/RL/reports/csv_summmary.py +++ b/reinvent/runmodes/RL/reports/csv_summmary.py @@ -19,9 +19,10 @@ ], # Named so to be different from Scaffold from diversity filter "Linkinvent": ["Warheads", "Linker"], "Mol2Mol": ["Input_SMILES"], + "Pepinvent": ["Masked_input_peptide", "Fillers"], } -FRAGMENT_GENERATORS = ["Libinvent", "Linkinvent"] +FRAGMENT_GENERATORS = ["Libinvent", "Linkinvent", "Pepinvent"] class RLCSVReporter: diff --git a/reinvent/runmodes/RL/reports/data.py b/reinvent/runmodes/RL/reports/data.py old mode 100755 new mode 100644 index c254e01..33ca454 --- a/reinvent/runmodes/RL/reports/data.py +++ b/reinvent/runmodes/RL/reports/data.py @@ -15,6 +15,7 @@ @dataclass class RLReportData: step: int + stage: int smilies: list scaffolds: list sampled: SampleBatch @@ -28,6 +29,7 @@ class RLReportData: loss: float fraction_valid_smiles: float fraction_duplicate_smiles: float + df_memory_smilies: int bucket_max_size: int num_full_buckets: int num_total_buckets: int diff --git a/reinvent/runmodes/RL/reports/remote.py b/reinvent/runmodes/RL/reports/remote.py index 6bf6137..350f75a 100644 --- a/reinvent/runmodes/RL/reports/remote.py +++ b/reinvent/runmodes/RL/reports/remote.py @@ -29,6 +29,7 @@ def submit(self, data: RLReportData) -> None: """ step = data.step + stage = data.stage if not (step == 0 or step % self.logging_frequency == 0): return @@ -45,7 +46,6 @@ def submit(self, data: RLReportData) -> None: "prior NLL": float(data.prior_mean_nll), "agent NLL": float(data.agent_mean_nll), } - smarts_pattern = "" # get_matching_substructure(data.score_results) smiles_legend_pairs = get_smiles_legend_pairs( np.array(data.score_results.smilies)[mask_idx], @@ -56,6 +56,7 @@ def submit(self, data: RLReportData) -> None: record = { "step": step, + "stage": stage, "timestamp": time.time(), # gives microsecond resolution on Linux "components": score_components, "learning": learning_curves, @@ -65,7 +66,7 @@ def submit(self, data: RLReportData) -> None: "smarts_pattern": smarts_pattern, "smiles_legend_pairs": smiles_legend_pairs, }, - "collected smiles in memory": len(data.smilies), + "collected smiles in memory": data.df_memory_smilies, } self.reporter.send(record) diff --git a/reinvent/runmodes/RL/reports/tensorboard.py b/reinvent/runmodes/RL/reports/tensorboard.py index 4745c32..ce83fd9 100644 --- a/reinvent/runmodes/RL/reports/tensorboard.py +++ b/reinvent/runmodes/RL/reports/tensorboard.py @@ -85,7 +85,7 @@ def submit(self, data: RLReportData) -> None: labels = [f"score={score:.2f}" for score in results.total_scores] sample_size = ROWS * COLUMNS - + image_tensor = make_grid_image(data.smilies, labels, sample_size, ROWS) if image_tensor is not None: diff --git a/reinvent/runmodes/RL/run_staged_learning.py b/reinvent/runmodes/RL/run_staged_learning.py index 3785f01..d244e9a 100644 --- a/reinvent/runmodes/RL/run_staged_learning.py +++ b/reinvent/runmodes/RL/run_staged_learning.py @@ -7,7 +7,7 @@ import torch -from reinvent import config_parse, setup_logger, CsvFormatter +from reinvent.utils import setup_logger, CsvFormatter, config_parse from reinvent.runmodes import Handler, RL, create_adapter from reinvent.runmodes.setup_sampler import setup_sampler from reinvent.runmodes.RL import terminators, memories @@ -28,6 +28,8 @@ logger = logging.getLogger(__name__) +TRANSFORMERS = ["Mol2Mol", "LinkinventTransformer", "LibinventTransformer", "Pepinvent"] + def setup_diversity_filter(config: SectionDiversityFilter, rdkit_smiles_flags: dict): """Setup of the diversity filter @@ -40,17 +42,12 @@ def setup_diversity_filter(config: SectionDiversityFilter, rdkit_smiles_flags: d :return: the set up diversity filter """ - if config is None: + if config is None or not hasattr(config, "type"): return None - memory_type = config.type - - if hasattr(config, "type"): - diversity_filter = getattr(memories, memory_type) - else: - return None + diversity_filter = getattr(memories, config.type) - logger.info(f"Using diversity filter {memory_type}") + logger.info(f"Using diversity filter {config.type}") return diversity_filter( bucket_size=config.bucket_size, @@ -242,7 +239,7 @@ def run_staged_learning( rdkit_smiles_flags = dict(allowTautomers=True) - if model_type in ["Mol2Mol", "LinkinventTransformer", "LibinventTransformer"]: # Transformer-based models + if model_type in TRANSFORMERS: # Transformer-based models agent_mode = "inference" rdkit_smiles_flags.update(sanitize=True, isomericSmiles=True) rdkit_smiles_flags2 = dict(isomericSmiles=True) @@ -275,14 +272,12 @@ def run_staged_learning( global_df_only = False - if config.diversity_filter: - global_df_only = True - if parameters.use_checkpoint and "staged_learning" in agent_save_dict: logger.info(f"Using diversity filter from {agent_model_filename}") diversity_filter = agent_save_dict["staged_learning"]["diversity_filter"] - else: + elif config.diversity_filter: diversity_filter = setup_diversity_filter(config.diversity_filter, rdkit_smiles_flags2) + global_df_only = True if parameters.purge_memories: logger.info("Purging diversity filter memories after each stage") @@ -337,6 +332,7 @@ def run_staged_learning( optimize = model_learning( max_steps=package.max_steps, + stage_no=stage_no, prior=prior, state=state, scoring_function=package.scoring_function, diff --git a/reinvent/runmodes/RL/validation.py b/reinvent/runmodes/RL/validation.py index cd4fec6..02cb40e 100644 --- a/reinvent/runmodes/RL/validation.py +++ b/reinvent/runmodes/RL/validation.py @@ -19,6 +19,7 @@ class SectionParameters(GlobalConfig): batch_size: int = 100 randomize_smiles: bool = True unique_sequences: bool = False + temperature: float = 1.0 class SectionLearningStrategy(GlobalConfig): diff --git a/reinvent/runmodes/TL/learning.py b/reinvent/runmodes/TL/learning.py index 69c0eb1..16bd97f 100644 --- a/reinvent/runmodes/TL/learning.py +++ b/reinvent/runmodes/TL/learning.py @@ -24,7 +24,7 @@ from tqdm.contrib.logging import tqdm_logging_redirect from reinvent.runmodes.TL.reports import TLTBReporter, TLRemoteReporter, TLReportData -from reinvent.runmodes.reporter.remote import get_reporter +from reinvent.utils.logmon import get_reporter from reinvent.runmodes.setup_sampler import setup_sampler from reinvent.runmodes.utils.tensorboard import SummaryWriter # monkey patch from reinvent.models.meta_data import update_model_data @@ -210,10 +210,12 @@ def optimize(self): __call__ = optimize @abstractmethod - def train_epoch(self): ... + def train_epoch(self): + ... @abstractmethod - def compute_nll(self, batch): ... + def compute_nll(self, batch): + ... def _train_epoch_common(self) -> float: """Run one epoch of training diff --git a/reinvent/runmodes/TL/run_transfer_learning.py b/reinvent/runmodes/TL/run_transfer_learning.py index c80f9f7..16e89c9 100644 --- a/reinvent/runmodes/TL/run_transfer_learning.py +++ b/reinvent/runmodes/TL/run_transfer_learning.py @@ -10,8 +10,7 @@ import torch.optim as topt from reinvent.runmodes import TL, create_adapter -from reinvent.config_parse import read_smiles_csv_file -from reinvent.runmodes.reporter.remote import setup_reporter +from reinvent.utils import setup_reporter, read_smiles_csv_file from reinvent.chemistry import conversions from reinvent.chemistry.standardization.rdkit_standardizer import ( RDKitStandardizer, diff --git a/reinvent/runmodes/TL/validation.py b/reinvent/runmodes/TL/validation.py index 8e70b97..4117650 100644 --- a/reinvent/runmodes/TL/validation.py +++ b/reinvent/runmodes/TL/validation.py @@ -9,7 +9,7 @@ class SectionParameters(GlobalConfig): num_epochs: int = Field(ge=1) - batch_size: int = Field(ge=10) + batch_size: int = Field(ge=1) input_model_file: str output_model_file: str # FIXME: consider for removal smiles_file: str diff --git a/reinvent/runmodes/create_adapter.py b/reinvent/runmodes/create_adapter.py index da4c27d..63dccb3 100644 --- a/reinvent/runmodes/create_adapter.py +++ b/reinvent/runmodes/create_adapter.py @@ -12,7 +12,7 @@ import torch from reinvent import models -from reinvent.prior_registry import registry +from reinvent.utils import prior_registry from reinvent.models import meta_data logger = logging.getLogger(__name__) @@ -70,7 +70,7 @@ def resolve_model_filename(name: str) -> str: :returns: the filename of the model """ - filename = registry.get(name, None) + filename = prior_registry.get(name, None) if not filename: filename = pathlib.Path(name).resolve() diff --git a/reinvent/runmodes/dtos/__init__.py b/reinvent/runmodes/dtos/__init__.py index 4d24a4e..fc996b6 100644 --- a/reinvent/runmodes/dtos/__init__.py +++ b/reinvent/runmodes/dtos/__init__.py @@ -1 +1 @@ -from .dtos import * \ No newline at end of file +from .dtos import * diff --git a/reinvent/runmodes/handler.py b/reinvent/runmodes/handler.py index c41f2d7..575a503 100644 --- a/reinvent/runmodes/handler.py +++ b/reinvent/runmodes/handler.py @@ -7,7 +7,9 @@ """ import signal +import multiprocessing as mp import platform +import logging from pathlib import Path from typing import Callable, Dict @@ -16,15 +18,19 @@ from reinvent.models.meta_data import update_model_data +logger = logging.getLogger(__name__) + if platform.system() != "Windows": - # NOTE: SIGTERM (signal 15) seems to be triggered by terminating processes. So - # multiprocessing triggers the handler for every terminating child. - SUPPORTED_SIGNALS = (signal.SIGINT, signal.SIGQUIT) + SUPPORTED_SIGNALS = (signal.SIGINT, signal.SIGQUIT, signal.SIGTERM, signal.SIGUSR1) else: SUPPORTED_SIGNALS = (signal.SIGINT,) -class StageInterrupted(Exception): +class StageInterruptedUncontrolled(Exception): + pass + + +class StageInterruptedControlled(Exception): pass @@ -59,19 +65,22 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): """Save the data and reset the signal handler""" - # FIXME: check if exc_type is really the one we want? - for _signal in SUPPORTED_SIGNALS: signal.signal(_signal, signal.SIG_IGN) - self.save() + msg = "" + + if exc_val and exc_val.args: + msg = exc_val.args[0] + + logger.critical(f"Received exception ('{msg}'): saving checkpoint and then terminate") + + if mp.current_process().name == "MainProcess": + self.save() for _signal, _handler in self._default_handlers: signal.signal(_signal, _handler) - # prevent exception from bubbling up - # return True - @property def out_filename(self): """Getter to obtain the output filename""" @@ -126,4 +135,7 @@ def _signal_handler(self, signum, frame) -> None: :raises: StageInterrrupted """ - raise StageInterrupted + if signum == signal.SIGUSR1 and signal.SIGUSR1 in SUPPORTED_SIGNALS: + raise StageInterruptedControlled(f"Signal {signum}") + else: + raise StageInterruptedUncontrolled(f"Signal {signum}") diff --git a/reinvent/runmodes/reporter/__init__.py b/reinvent/runmodes/reporter/__init__.py deleted file mode 100644 index 38caeb2..0000000 --- a/reinvent/runmodes/reporter/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Reporters.""" diff --git a/reinvent/runmodes/reporter/remote.py b/reinvent/runmodes/reporter/remote.py deleted file mode 100644 index 40818d7..0000000 --- a/reinvent/runmodes/reporter/remote.py +++ /dev/null @@ -1,106 +0,0 @@ -"""A simple reporter facilituy to write information to a remote host - -The functionality is somewhat reminiscent of Python logging but different -enough to call it a "reporter" rather than a "logger". The code is also very -simplistic and only supports the task at hand. - -The caller is expected to set up the logger before use. If this does not -happen sending to the reporter will still be possible but will have no effect -i.e. a reporter will always be available. -""" - -__all__ = ["setup_reporter", "get_reporter"] -import requests -import json -import logging -from typing import Mapping, Optional - -logger = logging.getLogger(__name__) - - -HEADERS = { - "Accept": "application/json", - "Content-Type": "application/json", - "Authorization": None, -} - -MAX_ERR_MSG = 5 - - -class RemoteJSONReporter: - """Simplistic reporter that sends JSON to a remote server""" - - def __init__(self, url, token=None): - """Set up the reporter - - :param url: URL to send JSON to - :param token: access token for the URL - """ - - self.url = url - - if token: - self.headers["Authorization"] = token - - self.headers = HEADERS - self.max_msg = 0 - - def send(self, record) -> None: - """Send a record to a remote URL - - :param record: dictionary-like record to send to remote URL - """ - - if not isinstance(record, Mapping): - raise TypeError("The record is expected to be a mapping") - - json_msg = json.dumps(record, indent=2) - - logger.debug( - "Data sent to {url}\n\n{headers}\n\n{json_data}".format( - url=self.url, - headers="\n".join(f"{k}: {v}" for k, v in self.headers.items()), - json_data=json_msg, - ) - ) - - response = requests.post(self.url, json=json.loads(json_msg), headers=self.headers) - - # alternative: check if response.status_code != request.codes.created - if not response.ok and self.max_msg < MAX_ERR_MSG: - self.max_msg += 1 - logger.error(f"Failed to send record to: {self.url}") - logger.error(f"{response.text=}") - logger.error(f"{response.headers=}") - logger.error(f"{response.reason=}") - logger.error(f"{response.url=}") - - -_reporter = None - - -def get_reporter() -> Optional[RemoteJSONReporter]: - """Return the current reporter - - :return: reporter object - """ - - return _reporter - - -def setup_reporter(url, token=None) -> bool: - """Set up the reporter - - :param url: URL to send JSON to - :param token: access token for the URL - :returns: whether reporter was setup successfully - """ - - global _reporter - - if url: - # assume endpoint is readily available... - _reporter = RemoteJSONReporter(url, token) - return True - - return False diff --git a/reinvent/runmodes/samplers/__init__.py b/reinvent/runmodes/samplers/__init__.py index 15f76ff..6e6dd66 100644 --- a/reinvent/runmodes/samplers/__init__.py +++ b/reinvent/runmodes/samplers/__init__.py @@ -5,3 +5,4 @@ from .libinvent import * from .linkinvent import * from .mol2mol import * +from .pepinvent import * diff --git a/reinvent/runmodes/samplers/libinvent.py b/reinvent/runmodes/samplers/libinvent.py index 7a6c39c..db4a043 100644 --- a/reinvent/runmodes/samplers/libinvent.py +++ b/reinvent/runmodes/samplers/libinvent.py @@ -1,17 +1,16 @@ """The LibInvent sampling module""" __all__ = ["LibinventSampler", "LibinventTransformerSampler"] -from typing import List, Tuple +from typing import List import logging -import torch import torch.utils.data as tud +from rdkit import Chem from .sampler import Sampler, validate_smiles, remove_duplicate_sequences from . import params from reinvent.models.libinvent.models.dataset import Dataset from reinvent.models.model_factory.sample_batch import SampleBatch -from reinvent.runmodes.utils.helpers import join_fragments from reinvent.chemistry import conversions from reinvent.chemistry.library_design import attachment_points, bond_maker from reinvent.models.transformer.core.dataset.dataset import Dataset as TransformerDataset @@ -26,7 +25,7 @@ def sample(self, smilies: List[str]) -> SampleBatch: """Samples the LibInvent model for the given number of SMILES :param smilies: list of SMILES used for sampling - :returns: list of SampledSequencesDTO + :returns: SampleBatch """ if self.model.version == 2: # Transformer-based @@ -78,7 +77,7 @@ def sample(self, smilies: List[str]) -> SampleBatch: if self.unique_sequences: sampled = remove_duplicate_sequences(sampled) - mols = join_fragments(sampled, reverse=False, keep_labels=True) + mols = self._join_fragments(sampled) sampled.smilies, sampled.states = validate_smiles( mols, sampled.output, isomeric=self.isomeric @@ -87,8 +86,7 @@ def sample(self, smilies: List[str]) -> SampleBatch: return sampled def _standardize_input(self, scaffold_list: List[str]): - return [conversions.convert_to_standardized_smiles(scaffold) - for scaffold in scaffold_list] + return [conversions.convert_to_standardized_smiles(scaffold) for scaffold in scaffold_list] def _get_randomized_smiles(self, scaffolds: List[str]): """Randomize the scaffold SMILES""" @@ -98,5 +96,29 @@ def _get_randomized_smiles(self, scaffolds: List[str]): return randomized + def _join_fragments(self, sequences: SampleBatch) -> List[Chem.Mol]: + """Join input scaffold and generated decorators + + :param sequences: a batch of sequences + :returns: a list of RDKit molecules + """ + + mols = [] + + for sample in sequences: + input_scaffold = sample.input + decorators = sample.output + + scaffold = attachment_points.add_attachment_point_numbers( + input_scaffold, canonicalize=False + ) + mol: Chem.Mol = bond_maker.join_scaffolds_and_decorations( # may return None + scaffold, decorators, keep_labels_on_atoms=True + ) + + mols.append(mol) + + return mols + LibinventTransformerSampler = LibinventSampler diff --git a/reinvent/runmodes/samplers/linkinvent.py b/reinvent/runmodes/samplers/linkinvent.py index 7bf261d..09481ee 100644 --- a/reinvent/runmodes/samplers/linkinvent.py +++ b/reinvent/runmodes/samplers/linkinvent.py @@ -5,14 +5,14 @@ import logging import torch.utils.data as tud +from rdkit import Chem from .sampler import Sampler, validate_smiles, remove_duplicate_sequences from . import params from reinvent.models.linkinvent.dataset.dataset import Dataset from reinvent.models.model_factory.sample_batch import SampleBatch -from reinvent.runmodes.utils.helpers import join_fragments from reinvent.chemistry import conversions, tokens -from reinvent.chemistry.library_design import attachment_points +from reinvent.chemistry.library_design import attachment_points, bond_maker from ...models.transformer.core.dataset.dataset import Dataset as TransformerDataset logger = logging.getLogger(__name__) @@ -25,7 +25,7 @@ def sample(self, smilies: List[str]) -> SampleBatch: """Samples the model for the given number of SMILES :param smilies: list of SMILES used for sampling - :returns: list of SampledSequencesDTO + :returns: SampleBatch """ if self.model.version == 2: # Transformer-based @@ -76,7 +76,7 @@ def sample(self, smilies: List[str]) -> SampleBatch: if self.unique_sequences: sampled = remove_duplicate_sequences(sampled) - mols = join_fragments(sampled, reverse=True) + mols = self._join_fragments(sampled) sampled.smilies, sampled.states = validate_smiles( mols, sampled.output, isomeric=self.isomeric @@ -117,5 +117,28 @@ def _get_randomized_smiles(self, warhead_pair_list: List[str]): return randomized_warhead_pair_list + def _join_fragments(self, sequences: SampleBatch) -> List[Chem.Mol]: + """Join input warheads with generated linker + + :param sequences: a batch of sequences + :returns: a list of RDKit molecules + """ + + mols = [] + + for sample in sequences: + warheads = sample.input + generated_linker = sample.output + + linker = attachment_points.add_attachment_point_numbers( + generated_linker, canonicalize=False + ) + mol: Chem.Mol = bond_maker.join_scaffolds_and_decorations( # may return None + linker, warheads + ) + mols.append(mol) + + return mols + LinkinventTransformerSampler = LinkinventSampler diff --git a/reinvent/runmodes/samplers/mol2mol.py b/reinvent/runmodes/samplers/mol2mol.py index e63f1a9..ab4749f 100644 --- a/reinvent/runmodes/samplers/mol2mol.py +++ b/reinvent/runmodes/samplers/mol2mol.py @@ -27,12 +27,10 @@ def sample(self, smilies: List[str]) -> SampleBatch: """Samples the Mol2Mol model for the given number of SMILES :param smilies: list of SMILES used for sampling - :returns: list of SampledSequencesDTO + :returns: SampleBatch """ # Standardize smiles in the same way as training data - smilies = [ - conversions.convert_to_standardized_smiles(smile) for smile in smilies - ] + smilies = [conversions.convert_to_standardized_smiles(smile) for smile in smilies] smilies = ( [self._get_randomized_smiles(smiles) for smiles in smilies] @@ -82,9 +80,7 @@ def sample(self, smilies: List[str]) -> SampleBatch: def _get_randomized_smiles(self, smiles: str): input_mol = conversions.smile_to_mol(smiles) - randomized_smile = conversions.mol_to_random_smiles( - input_mol, isomericSmiles=self.isomeric - ) + randomized_smile = conversions.mol_to_random_smiles(input_mol, isomericSmiles=self.isomeric) return randomized_smile diff --git a/reinvent/runmodes/samplers/pepinvent.py b/reinvent/runmodes/samplers/pepinvent.py new file mode 100644 index 0000000..4cfa7de --- /dev/null +++ b/reinvent/runmodes/samplers/pepinvent.py @@ -0,0 +1,77 @@ +"""The Pepinvent sampling module""" + +__all__ = ["PepinventSampler"] +from typing import List +import logging + +import torch.utils.data as tud +from rdkit import Chem +from reinvent.chemistry import tokens + +from .sampler import Sampler, validate_smiles, remove_duplicate_sequences +from . import params +from reinvent.models.transformer.core.dataset.dataset import Dataset +from reinvent.models.transformer.core.vocabulary import SMILESTokenizer +from reinvent.models.model_factory.sample_batch import SampleBatch + +logger = logging.getLogger(__name__) + + +class PepinventSampler(Sampler): + """Carry out sampling with Pepinvent""" + + def sample(self, smilies: List[str]) -> SampleBatch: + """Samples the Pepinvent model for the given number of SMILES + + :param smilies: list of SMILES used for sampling + :returns: SampleBatch + """ + if self.sample_strategy == "multinomial": + smilies = smilies * self.batch_size + + tokenizer = SMILESTokenizer() + dataset = Dataset(smilies, self.model.get_vocabulary(), tokenizer) + dataloader = tud.DataLoader( + dataset, + batch_size=params.DATALOADER_BATCHSIZE, + shuffle=False, + collate_fn=Dataset.collate_fn, + ) + + sequences = [] + + for batch in dataloader: + src, src_mask = batch + + sampled = self.model.sample(src, src_mask, self.sample_strategy) + + for batch_row in sampled: + sequences.append(batch_row) + + sampled = SampleBatch.from_list(sequences) + + mols = self._join_fragments(sampled) + + sampled.smilies, sampled.states = validate_smiles( + mols, sampled.output, isomeric=self.isomeric + ) + + return sampled + + def _join_fragments(self, sequences: SampleBatch) -> List[Chem.Mol]: + """Join input masked peptide with generated fillers + + :param sequences: a batch of sequences + :returns: a list of RDKit molecules + """ + + mols = [] + + for sample in sequences: + smiles = sample.input + for replacement in sample.output.split(tokens.PEPINVENT_CHUCKLES_SEPARATOR_TOKEN): + smiles = smiles.replace(tokens.PEPINVENT_MASK_TOKEN, replacement, 1) + mol = Chem.MolFromSmiles(smiles.replace(tokens.PEPINVENT_CHUCKLES_SEPARATOR_TOKEN, "")) + mols.append(mol) + + return mols diff --git a/reinvent/runmodes/samplers/reinvent.py b/reinvent/runmodes/samplers/reinvent.py index 1156391..353fdbf 100644 --- a/reinvent/runmodes/samplers/reinvent.py +++ b/reinvent/runmodes/samplers/reinvent.py @@ -4,11 +4,9 @@ import logging from rdkit import Chem -from torch import Tensor from .sampler import Sampler, remove_duplicate_sequences, validate_smiles -from reinvent.models.model_factory.sample_batch import SampleBatch - +from ...models.model_factory.sample_batch import SampleBatch logger = logging.getLogger(__name__) @@ -20,7 +18,7 @@ def sample(self, dummy) -> SampleBatch: """Samples the Reinvent model for the given number of SMILES :param dummy: Reinvent does not need SMILES input - :returns: a dataclass + :returns: SampleBatch """ sampled = self.model.sample(self.batch_size) diff --git a/reinvent/runmodes/samplers/reports/common.py b/reinvent/runmodes/samplers/reports/common.py index f87e50e..e9e1f76 100644 --- a/reinvent/runmodes/samplers/reports/common.py +++ b/reinvent/runmodes/samplers/reports/common.py @@ -5,7 +5,9 @@ def common_report(sampled: SampleBatch, **kwargs): valid_mask = np.where( - (sampled.states == SmilesState.VALID) | (sampled.states == SmilesState.DUPLICATE), True, False + (sampled.states == SmilesState.VALID) | (sampled.states == SmilesState.DUPLICATE), + True, + False, ) unique_mask = np.where(sampled.states == SmilesState.VALID, True, False) @@ -18,9 +20,9 @@ def common_report(sampled: SampleBatch, **kwargs): tanimoto_scores = kwargs["Tanimoto"] nlls = sampled.nlls.cpu().detach().numpy() - additional_report["Tanimoto_valid"] = np.array(tanimoto_scores)[valid_mask] - additional_report["Tanimoto_unique"] = np.array(tanimoto_scores)[unique_mask] - additional_report["Output_likelihood_valid"] = nlls[valid_mask] - additional_report["Output_likelihood_unique"] = nlls[unique_mask] + additional_report["Tanimoto_valid"] = np.array(tanimoto_scores)[valid_mask].tolist() + additional_report["Tanimoto_unique"] = np.array(tanimoto_scores)[unique_mask].tolist() + additional_report["Output_likelihood_valid"] = nlls[valid_mask].tolist() + additional_report["Output_likelihood_unique"] = nlls[unique_mask].tolist() return fraction_valid_smiles, fraction_unique_molecules, additional_report diff --git a/reinvent/runmodes/samplers/reports/tensorboard.py b/reinvent/runmodes/samplers/reports/tensorboard.py index bfdeb8b..60c13fa 100644 --- a/reinvent/runmodes/samplers/reports/tensorboard.py +++ b/reinvent/runmodes/samplers/reports/tensorboard.py @@ -44,8 +44,8 @@ def submit(self, sampled: SampleBatch, **kwargs): self.reporter.add_text( "Data", - f"Valid SMILES: {fraction_valid_smiles}% " - f"Unique Molecules: {fraction_unique_molecules}% ", + f"Valid SMILES fraction: {fraction_valid_smiles} " + f"Unique Molecules fraction: {fraction_unique_molecules} ", ) if image_tensor is not None: diff --git a/reinvent/runmodes/samplers/run_sampling.py b/reinvent/runmodes/samplers/run_sampling.py index 6b1f56d..b6a3819 100644 --- a/reinvent/runmodes/samplers/run_sampling.py +++ b/reinvent/runmodes/samplers/run_sampling.py @@ -11,16 +11,14 @@ from torch.utils.tensorboard import SummaryWriter import numpy as np import torch -from rdkit import Chem from reinvent.runmodes import create_adapter from reinvent.runmodes.samplers.reports import ( SamplingTBReporter, SamplingRemoteReporter, ) -from reinvent.runmodes.reporter.remote import get_reporter +from reinvent.utils import get_reporter, read_smiles_csv_file from reinvent.runmodes.setup_sampler import setup_sampler -from reinvent.config_parse import read_smiles_csv_file from reinvent.models.model_factory.sample_batch import SampleBatch, SmilesState from reinvent.chemistry import conversions from reinvent_plugins.normalizers.rdkit_smiles import normalize @@ -35,9 +33,16 @@ "LibinventTransformer": ("SMILES", "Scaffold", "R-groups", "NLL"), "LinkinventTransformer": ("SMILES", "Warheads", "Linker", "NLL"), "Mol2Mol": ("SMILES", "Input_SMILES", "Tanimoto", "NLL"), + "Pepinvent": ("SMILES", "Masked_input_peptide", "Fillers", "NLL"), } -FRAGMENT_GENERATORS = ["Libinvent", "Linkinvent", "LinkinventTransformer"] +FRAGMENT_GENERATORS = [ + "Libinvent", + "Linkinvent", + "LinkinventTransformer", + "LibinventTransformer", + "Pepinvent", +] def run_sampling( diff --git a/reinvent/runmodes/samplers/validation.py b/reinvent/runmodes/samplers/validation.py index a02615a..7cc8daa 100644 --- a/reinvent/runmodes/samplers/validation.py +++ b/reinvent/runmodes/samplers/validation.py @@ -15,6 +15,7 @@ class SectionParameters(GlobalConfig): target_nll_file: str = "target_nll.csv" unique_molecules: bool = True randomize_smiles: bool = True + temperature: float = 1.0 class SectionResponder(GlobalConfig): diff --git a/reinvent/runmodes/scoring/validation.py b/reinvent/runmodes/scoring/validation.py index a9a5c33..8b2777f 100644 --- a/reinvent/runmodes/scoring/validation.py +++ b/reinvent/runmodes/scoring/validation.py @@ -1,5 +1,6 @@ """Config Validation""" +from typing import Optional from pydantic import Field from reinvent.validation import GlobalConfig @@ -12,6 +13,12 @@ class SectionParameters(GlobalConfig): standardize_smiles: bool = True +class SectionResponder(GlobalConfig): + endpoint: str + frequency: Optional[int] = Field(1, ge=1) + + class ScoringConfig(GlobalConfig): parameters: SectionParameters scoring: dict = Field(default_factory=dict) # validate in Scorer + responder: Optional[SectionResponder] = None diff --git a/reinvent/runmodes/setup_sampler.py b/reinvent/runmodes/setup_sampler.py index 35c8090..8da01fb 100644 --- a/reinvent/runmodes/setup_sampler.py +++ b/reinvent/runmodes/setup_sampler.py @@ -13,7 +13,7 @@ logger = logging.getLogger(__name__) warnings.filterwarnings("once", category=FutureWarning) -TRANSFORMERS = ["Mol2Mol", "LinkinventTransformer", "LibinventTransformer"] +TRANSFORMERS = ["Mol2Mol", "LinkinventTransformer", "LibinventTransformer", "Pepinvent"] def setup_sampler(model_type: str, config: dict, agent: ModelAdapter): @@ -37,9 +37,8 @@ def setup_sampler(model_type: str, config: dict, agent: ModelAdapter): if model_type in TRANSFORMERS and randomize_smiles: randomize_smiles = False logger.warning( - f"randomize_smiles is set to be True by user. But the model was trained using canonical SMILES" - f"where randomize_smiles might undermine the performance (this needs more investigation), " - f"but randomize_smiles is reset to be False for now." + f"randomize_smiles was set to True but the model was not trained " + f"with randomized SMILES. Setting randomize_smiles to False." ) unique_sequences = config.get("unique_sequences", False) diff --git a/reinvent/runmodes/utils/helpers.py b/reinvent/runmodes/utils/helpers.py index 800a6bc..0c925d0 100644 --- a/reinvent/runmodes/utils/helpers.py +++ b/reinvent/runmodes/utils/helpers.py @@ -5,19 +5,14 @@ from __future__ import annotations -__all__ = ["disable_gradients", "set_torch_device", "join_fragments"] +__all__ = ["disable_gradients", "set_torch_device"] import logging from typing import List, TYPE_CHECKING import torch -from rdkit import Chem - -from reinvent.chemistry.library_design import bond_maker, attachment_points if TYPE_CHECKING: - from rdkit import Chem from reinvent.models import ModelAdapter - from reinvent.models.model_factory.sample_batch import SampleBatch logger = logging.getLogger(__name__) @@ -60,32 +55,3 @@ def set_torch_device(args_device: str = None, device: str = None) -> torch.devic logger.debug(f"{actual_device=}") return actual_device - - -def join_fragments( - sequences: SampleBatch, reverse: bool, keep_labels: bool = False -) -> List[Chem.Mol]: - """Join two fragments: for LibInvent and LinkInvent - - :param sequences: a batch of sequences - :param reverse: order of fragments FIXME: needs better name! - :returns: a list of RDKit molecules - """ - - mols = [] - - for sample in sequences: - if not reverse: # LibInvent - frag1 = sample.input - frag2 = sample.output - else: # LinkInvent - frag1 = sample.output - frag2 = sample.input - - scaffold = attachment_points.add_attachment_point_numbers(frag1, canonicalize=False) - mol: Chem.Mol = bond_maker.join_scaffolds_and_decorations( # may return None - scaffold, frag2, keep_labels_on_atoms=keep_labels - ) - mols.append(mol) - - return mols diff --git a/reinvent/scoring/config.py b/reinvent/scoring/config.py index 50f65e1..cd175a6 100644 --- a/reinvent/scoring/config.py +++ b/reinvent/scoring/config.py @@ -3,7 +3,7 @@ __all__ = ["get_components"] from dataclasses import dataclass from collections import defaultdict -from typing import List, Dict, Optional, Tuple +from typing import List, Dict, Optional, Tuple, Any import logging from .importer import get_registry @@ -45,6 +45,7 @@ def get_components(components: list[dict[str, dict]]) -> ComponentType: for component in components: component_type, component_value = list(component.items())[0] endpoints: dict = component_value["endpoint"] + complevel_params = component_value.get("params", {}) # Component-level params, if exist. component_type_lookup = component_type.lower().replace("-", "").replace("_", "") @@ -67,7 +68,11 @@ def get_components(components: list[dict[str, dict]]) -> ComponentType: if weight < 0: raise RuntimeError(f"weight must be equal to or larger than zero but is {weight}") - parameters.append(params) + # Merge component-level params with this endpoint params. + # Endpoint params take precedence. + merged_params = {**complevel_params, **params} + + parameters.append(merged_params) transform = None @@ -111,12 +116,13 @@ def get_components(components: list[dict[str, dict]]) -> ComponentType: return ComponentType(scorers, filters, penalties) -def collect_params(params: List[Dict]) -> defaultdict: +def collect_params(params: List[Dict]) -> Dict[str, List[Any]]: """Convert a list of dictionaries to a dictionary Collect the values with the same key in each dictionary of the list into a dictionary. The number of key/value pairs in the dictionaries in each - item of the passed in parameters may be different. + item of the passed in parameters may be different. Missing keys will be + filled with None. :param params: list of dictionaries :returns: a dictionary @@ -124,8 +130,13 @@ def collect_params(params: List[Dict]) -> defaultdict: collected_params = defaultdict(list) + # Collect all keys from all dictionaries + keys = set() for param_dict in params: - for key, value in param_dict.items(): - collected_params[key].append(value) + keys.update(param_dict.keys()) + # Go through each dictionary and collect values for each key + for param_dict in params: + for key in keys: + collected_params[key].append(param_dict.get(key, None)) return collected_params diff --git a/reinvent/scoring/importer.py b/reinvent/scoring/importer.py index c432645..e42903f 100644 --- a/reinvent/scoring/importer.py +++ b/reinvent/scoring/importer.py @@ -32,7 +32,12 @@ def get_registry() -> dict[str, Tuple[type, type]]: if not basename.startswith("comp_"): continue - module = importlib.import_module(name) + try: + module = importlib.import_module(name) + except ImportError as e: + logger.error(f"Component {name} could not be imported: {e}") + continue + component_classes = [] param_class = None diff --git a/reinvent/scoring/scorer.py b/reinvent/scoring/scorer.py index 24ec72c..a6a1d03 100644 --- a/reinvent/scoring/scorer.py +++ b/reinvent/scoring/scorer.py @@ -17,10 +17,8 @@ import logging import numpy as np -import pathos as pa -from pathos.pools import ParallelPool -from reinvent import config_parse +from reinvent.utils import config_parse from . import aggregators from .config import get_components from .compute_scores import compute_transform @@ -39,16 +37,23 @@ def setup_scoring(config: dict) -> dict: """ component_filename = config.get("filename", "") - component_filetype = config.get("filetype", "") + component_filetype = config.get("filetype", "toml") - if component_filename and component_filetype: + if component_filename: component_filename = Path(component_filename).resolve() if component_filename.exists(): + ext = component_filename.suffix + + if ext in (f".{e}" for e in config_parse.INPUT_FORMAT_CHOICES): + fmt = ext[1:] + else: + fmt = component_filetype + logger.info(f"Reading score components from {component_filename}") - parser = getattr(config_parse, f"read_{component_filetype.lower()}") - components_config = parser(str(component_filename)) - config.update(components_config) + + input_config = config_parse.read_config(component_filename, fmt) + config.update(input_config) else: logger.error(f"Component file {component_filename} not found") @@ -99,7 +104,7 @@ def compute_results( # if self.parallel and ntasks > 1: if False: - cpu_count = pa.helpers.cpu_count() + cpu_count = 2 nodes = min(cpu_count, ntasks) pool = ParallelPool(nodes=nodes) diff --git a/reinvent/scoring/transforms/__init__.py b/reinvent/scoring/transforms/__init__.py index e8339a0..55cda20 100644 --- a/reinvent/scoring/transforms/__init__.py +++ b/reinvent/scoring/transforms/__init__.py @@ -7,3 +7,4 @@ from .sigmoids import * from .double_sigmoid import * from .value_mapping import * +from .exponential_decay import * diff --git a/reinvent/scoring/transforms/exponential_decay.py b/reinvent/scoring/transforms/exponential_decay.py new file mode 100644 index 0000000..f946a9e --- /dev/null +++ b/reinvent/scoring/transforms/exponential_decay.py @@ -0,0 +1,37 @@ +""" +Exponential decay, or exp(-x). + +For values x < 0, the output is 1.0 ("rectified" of "clamped" exponential decay). +""" + +__all__ = ["ExponentialDecay"] +from dataclasses import dataclass + +import numpy as np + +from .transform import Transform + + +@dataclass +class Parameters: + type: str + k: float + + +def expdecay(x, k=1.0): + return np.where(x < 0, 1, np.exp(-k * x)) + + +class ExponentialDecay(Transform, param_cls=Parameters): + def __init__(self, params: Parameters): + super().__init__(params) + + self.k = params.k + + if self.k <= 0: + raise ValueError(f"ExponentialDecay Transform: k must be > 0, got {self.k}") + + def __call__(self, values) -> np.ndarray: + values = np.array(values, dtype=np.float32) + transformed = expdecay(values, self.k) + return transformed diff --git a/reinvent/scoring/transforms/transform.py b/reinvent/scoring/transforms/transform.py index 3a77eca..37abb2d 100644 --- a/reinvent/scoring/transforms/transform.py +++ b/reinvent/scoring/transforms/transform.py @@ -35,4 +35,5 @@ def __init_subclass__(cls, param_cls, **kwargs): registry[registry_name] = cls, param_cls @abstractmethod - def __call__(self, predictions): ... + def __call__(self, predictions): + ... diff --git a/reinvent/utils/__init__.py b/reinvent/utils/__init__.py new file mode 100644 index 0000000..aee302c --- /dev/null +++ b/reinvent/utils/__init__.py @@ -0,0 +1,5 @@ +from .cli import * +from .config_parse import * +from .helpers import * +from .logmon import * +from .prior_registry import * diff --git a/reinvent/utils/cli.py b/reinvent/utils/cli.py new file mode 100644 index 0000000..0299d06 --- /dev/null +++ b/reinvent/utils/cli.py @@ -0,0 +1,109 @@ +"""Command line setup + +FIXME: replace with dataclass and automatic conversion to argparse? +""" + +from __future__ import annotations + +__all__ = ["parse_command_line"] +import argparse +import os +from pathlib import Path +import logging + +from reinvent import version +from reinvent.utils import config_parse + +RDKIT_CHOICES = ("all", "error", "warning", "info", "debug") +LOGLEVEL_CHOICES = tuple(level.lower() for level in logging._nameToLevel.keys()) +VERSION_STR = f"{version.__progname__} {version.__version__} {version.__copyright__}" +OVERWRITE_STR = "Overwrites setting in the configuration file" + + +def parse_command_line(): + parser = argparse.ArgumentParser( + description=f"{version.__progname__}: a molecular design " + f"tool for de novo design, " + "scaffold hopping, R-group replacement, linker design, molecule " + "optimization, and others", + epilog=f"{VERSION_STR}", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "config_filename", + nargs="?", + default=None, + metavar="FILE", + type=lambda fn: Path(fn).resolve(), + help="Input configuration file with runtime parameters", + ) + + parser.add_argument( + "-f", + "--config-format", + metavar="FORMAT", + choices=config_parse.INPUT_FORMAT_CHOICES, + default="toml", + help=f"File format of the configuration file: {', '.join(config_parse.INPUT_FORMAT_CHOICES)}. This can be used to force a specific format. By default the format is derived from the file extension.", + ) + + parser.add_argument( + "-d", + "--device", + metavar="DEV", + default=None, + help=f"Device to run on: cuda, cpu. {OVERWRITE_STR}.", + ) + + parser.add_argument( + "-l", + "--log-filename", + metavar="FILE", + default=None, + type=os.path.abspath, + help=f"File for logging information, otherwise writes to stderr.", + ) + + parser.add_argument( + "--log-level", + metavar="LEVEL", + choices=LOGLEVEL_CHOICES, + default="info", + help=f"Enable this and 'higher' log levels: {', '.join(LOGLEVEL_CHOICES)}.", + ) + + parser.add_argument( + "-s", + "--seed", + metavar="N", + type=int, + default=None, + help="Sets the random seeds for reproducibility", + ) + + parser.add_argument( + "--dotenv-filename", + metavar="FILE", + default=None, + type=os.path.abspath, + help=f"Dotenv file with environment setup needed for some scoring components. " + "By default the one from the installation directory will be loaded.", + ) + + parser.add_argument( + "--enable-rdkit-log-levels", + metavar="LEVEL", + choices=RDKIT_CHOICES, + nargs="+", + help=f"Enable specific RDKit log levels: {', '.join(RDKIT_CHOICES)}.", + ) + + parser.add_argument( + "-V", + "--version", + action="version", + version=f"{VERSION_STR}.", + ) + + return parser.parse_args() diff --git a/reinvent/config_parse.py b/reinvent/utils/config_parse.py similarity index 74% rename from reinvent/config_parse.py rename to reinvent/utils/config_parse.py index ca627c7..9079f8c 100644 --- a/reinvent/config_parse.py +++ b/reinvent/utils/config_parse.py @@ -3,10 +3,13 @@ FIXME: about everything """ -__all__ = ["read_smiles_csv_file", "read_toml", "read_json", "write_json"] +__all__ = ["read_smiles_csv_file", "read_config", "write_json"] import sys +import io +from pathlib import Path import csv import json +import yaml from typing import List, Tuple, Union, Optional, Callable import tomli @@ -14,9 +17,54 @@ from rdkit import Chem smiles_func = Callable[[str], str] +FMT_CONVERT = {"toml": tomli, "json": json, "yaml": yaml} +INPUT_FORMAT_CHOICES = tuple(FMT_CONVERT.keys()) -def has_multiple_attachment_points_to_same_atom(smiles): +def monkey_patch_yaml_load(fct): + def load(filehandle, loader=yaml.SafeLoader) -> dict: + """Monkey patch for PyYAML's load + + yaml.load requires a loader or yaml.safe_load with a default + + :param filehandle: the filehandle to read the YAML from + :returns: the parsed dictionary + """ + + return fct(filehandle, loader) + + return load + + +def yaml_loads(s) -> dict: + """loads() implementation for PyWAML + + PyWAML does not have loading from string + + :param s: the string to load + :returns: the parsed dictionary + """ + + fh = io.StringIO(s) + data = yaml.safe_load(fh) + + return data + + +# only read first YAML document +yaml.load = monkey_patch_yaml_load(yaml.load) +yaml.loads = yaml_loads + + +def has_multiple_attachment_points_to_same_atom(smiles) -> bool: + """Check a molecule for multiple attachment points on one atom + + An attachment point is a dummy atoom ("[*]") + + :param smiles: the SMILES string + :returns: True if multiple attachment points exist, False otherwise + """ + mol = Chem.MolFromSmiles(smiles) if not mol: @@ -141,34 +189,22 @@ def read_smiles_csv_file( return smilies -def read_toml(filename: Optional[str]) -> dict: - """Read a TOML file. +def read_config(filename: Optional[Path], fmt: str) -> dict: + """Read a config file in TOML, JON or (Py)YAML (safe load) format. :param filename: name of input file to be parsed as TOML, if None read from stdin + :param fmt: name of the format of the configuration + :returns: parsed dictionary """ - if isinstance(filename, str): - with open(filename, "rb") as tf: - config = tomli.load(tf) - else: - config_str = "\n".join(sys.stdin.readlines()) - config = tomli.loads(config_str) + pkg = FMT_CONVERT[fmt] - return config - - -def read_json(filename: Optional[str]) -> dict: - """Read JSON file. - - :param filename: name of input file to be parsed as JSON, if None read from stdin - """ - - if isinstance(filename, str): - with open(filename, "rb") as jf: - config = json.load(jf) + if isinstance(filename, (str, Path)): + with open(filename, "rb") as tf: + config = pkg.load(tf) else: config_str = "\n".join(sys.stdin.readlines()) - config = json.loads(config_str) + config = pkg.loads(config_str) return config @@ -179,5 +215,6 @@ def write_json(data: str, filename: str) -> None: :param data: data in a format JSON accepts :param filename: output filename """ + with open(filename, "w") as jf: json.dump(data, jf, ensure_ascii=False, indent=4) diff --git a/reinvent/utils/helpers.py b/reinvent/utils/helpers.py new file mode 100644 index 0000000..9d0db76 --- /dev/null +++ b/reinvent/utils/helpers.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import os +import random + +import subprocess as sp +from typing import Optional + +import numpy as np +import torch + +from reinvent.utils import config_parse + + +def get_cuda_driver_version() -> Optional[str]: + """Get the CUDA driver version via modinfo if possible. + + This is for Linux only. + + :returns: driver version or None + """ + + # Alternative + # result = sp.run(["/usr/bin/nvidia-smi"], shell=False, capture_output=True) + # if "Driver Version:" in str_line: + # version = str_line.split()[5] + + try: + result = sp.run(["/sbin/modinfo", "nvidia"], shell=False, capture_output=True) + except Exception: + return + + for line in result.stdout.splitlines(): + str_line = line.decode() + + if str_line.startswith("version:"): + cuda_driver_version = str_line.split()[1] + return cuda_driver_version + + +def set_seed(seed: int): + """Set global seed for reproducibility + + :param seed: the seed to initialize the random generators + """ + + if seed is None: + return + + random.seed(seed) + + os.environ["PYTHONHASHSEED"] = str(seed) + + np.random.seed(seed) + + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + + +def extract_sections(config: dict) -> dict: + """Extract the sections of a config file + + :param config: the config file + :returns: the extracted sections + """ + + # FIXME: stages are a list of dicts in RL, may clash with global lists + return {k: v for k, v in config.items() if isinstance(v, (dict, list))} + + +def write_json_config(global_dict, json_out_config): + def dummy(config): + global_dict.update(config) + config_parse.write_json(global_dict, json_out_config) + + return dummy diff --git a/reinvent/utils/logmon.py b/reinvent/utils/logmon.py new file mode 100644 index 0000000..536cb12 --- /dev/null +++ b/reinvent/utils/logmon.py @@ -0,0 +1,246 @@ +"""Logging and monitoring support + +Setup for Python and RDKit logging. Setup for remote monito +""" + +from __future__ import annotations + +__all__ = [ + "CsvFormatter", + "setup_logger", + "enable_rdkit_log", + "setup_responder", + "setup_reporter", + "get_reporter", +] + +import json +import os +import sys +import csv +import io +import logging +from logging.config import dictConfig, fileConfig +from typing import List, Mapping, Optional + +import math +import requests +from rdkit import RDLogger + +logger = logging.getLogger(__name__) + +HEADERS = { + "Accept": "application/json", + "Content-Type": "application/json", + "Authorization": None, +} + +MAX_ERR_MSG = 5 +RESPONDER_TOKEN = "RESPONDER_TOKEN" + + +class CsvFormatter(logging.Formatter): + def __init__(self): + super().__init__() + self.output = io.StringIO() + self.writer = csv.writer(self.output) + + def format(self, record): + self.writer.writerow(record.msg) # needs to be a iterable + data = self.output.getvalue() + self.output.truncate(0) + self.output.seek(0) + return data.strip() + + +def setup_logger( + name: str = None, + config: dict = None, + filename: str = None, + formatter=None, + stream=sys.stderr, + cfg_filename: str = None, + propagate: bool = True, + level=logging.INFO, + debug=False, +): + """Setup a logging facility. + + :param name: name of the logger, root if empty or None + :param config: dictionary configuration + :param filename: optional filename for logging output + :param formatter: a logging formatter + :param stream: the output stream + :param cfg_filename: filename of a logger configuration file + :param propagate: whether to propagate to higher level loggers + :param level: logging level + :param debug: set special format for debugging + :returns: the newly set up logger + """ + + logging.captureWarnings(True) + + logger = logging.getLogger(name) + logger.setLevel(level) + + for handler in logger.handlers[:]: + logger.removeHandler(handler) + + if config is not None: + dictConfig(config) + return + + if cfg_filename is not None: + fileConfig(cfg_filename) + return + + if filename: + handler = logging.FileHandler(filename, mode="w+") + else: + handler = logging.StreamHandler(stream) + + handler.setLevel(level) + + if debug: + log_format = "%(asctime)s %(module)s.%(funcName)s +%(lineno)s: %(levelname)-4s %(message)s" + else: + log_format = "%(asctime)s <%(levelname)-4.4s> %(message)s" + + if not formatter: + formatter = logging.Formatter( + fmt=log_format, + datefmt="%H:%M:%S", + ) + + handler.setFormatter(formatter) + + logger.addHandler(handler) + logger.propagate = propagate + + return logger + + +def enable_rdkit_log(levels: List[str]): + """Enable logging messages from RDKit for a specific logging level. + + :param levels: the specific level(s) that need to be silenced + """ + + if "all" in levels: + RDLogger.EnableLog("rdApp.*") + return + + for level in levels: + RDLogger.EnableLog(f"rdApp.{level}") + + +def setup_responder(config): + """Setup for remote monitor + + :param config: configuration + """ + + endpoint = config.get("endpoint", False) + + if not endpoint: + return + + token = os.environ.get(RESPONDER_TOKEN, None) + setup_reporter(endpoint, token) + + +class NanInfEncoder(json.JSONEncoder): + def _custom_encoder(self, obj): + """Recursively clean nested dictionaries and handle NaN/Infinity""" + if isinstance(obj, float): + if math.isnan(obj) or math.isinf(obj): + return None # Return None for NaN or Infinity + elif isinstance(obj, dict): + # Recursively clean nested dictionaries + return {key: self._custom_encoder(value) for key, value in obj.items()} + elif isinstance(obj, (list, tuple)): + # Recursively clean lists and tuple + return [self._custom_encoder(item) for item in obj] + return obj + + def encode(self, obj, *args, **kwargs): + return super().encode(self._custom_encoder(obj), *args, **kwargs) + + +class RemoteJSONReporter: + """Simplistic reporter that sends JSON to a remote server""" + + def __init__(self, url, token=None): + """Set up the reporter + + :param url: URL to send JSON to + :param token: access token for the URL + """ + + self.url = url + self.headers = HEADERS + + if token: + self.headers["Authorization"] = token + + self.max_msg = 0 + + def send(self, record) -> None: + """Send a record to a remote URL + + :param record: dictionary-like record to send to remote URL + """ + + if not isinstance(record, Mapping): + raise TypeError("The record is expected to be a mapping") + + json_msg = json.dumps(record, cls=NanInfEncoder, indent=2) + + logger.debug( + "Data sent to {url}\n\n{headers}\n\n{json_data}".format( + url=self.url, + headers="\n".join(f"{k}: {v}" for k, v in self.headers.items()), + json_data=json_msg, + ) + ) + + response = requests.post(self.url, json=json.loads(json_msg), headers=self.headers) + + # alternative: check if response.status_code != request.codes.created + if not response.ok and self.max_msg < MAX_ERR_MSG: + self.max_msg += 1 + logger.error(f"Failed to send record to: {self.url}") + logger.error(f"{response.text=}") + logger.error(f"{response.headers=}") + logger.error(f"{response.reason=}") + logger.error(f"{response.url=}") + + +_reporter = None + + +def get_reporter() -> Optional[RemoteJSONReporter]: + """Return the current reporter + + :return: reporter object + """ + + return _reporter + + +def setup_reporter(url, token=None) -> bool: + """Set up the reporter + + :param url: URL to send JSON to + :param token: access token for the URL + :returns: whether reporter was setup successfully + """ + + global _reporter + + if url: + # assume endpoint is readily available... + _reporter = RemoteJSONReporter(url, token) + return True + + return False diff --git a/reinvent/prior_registry.py b/reinvent/utils/prior_registry.py similarity index 94% rename from reinvent/prior_registry.py rename to reinvent/utils/prior_registry.py index 4ac686f..683b47b 100644 --- a/reinvent/prior_registry.py +++ b/reinvent/utils/prior_registry.py @@ -3,6 +3,7 @@ Maps a key to an actual prior filename. """ +__all__ = ["prior_registry"] import os import pathlib @@ -14,7 +15,7 @@ else: PRIOR_BASE = pathlib.Path(reinvent.__file__).parents[1] / "priors" -registry = { +prior_registry = { ".reinvent": PRIOR_BASE / "reinvent.prior", ".libinvent": PRIOR_BASE / "libinvent.prior", ".linkinvent": PRIOR_BASE / "linkinvent.prior", diff --git a/reinvent/validation.py b/reinvent/validation.py index 6b922b0..e21346b 100644 --- a/reinvent/validation.py +++ b/reinvent/validation.py @@ -6,7 +6,7 @@ class GlobalConfig(BaseModel): - model_config = ConfigDict(extra="forbid", protected_namespaces = ()) + model_config = ConfigDict(extra="forbid", protected_namespaces=()) class ReinventConfig(GlobalConfig): @@ -18,9 +18,9 @@ class ReinventConfig(GlobalConfig): parameters: dict # run mode dependent - scoring: Optional[dict] = None # RL, scoring + scoring: Optional[dict] = None # RL, scoring scheduler: Optional[dict] = None # TL - responder: Optional[dict] = None # Rl, TL, sampling + responder: Optional[dict] = None # Rl, TL, sampling # RL stage: Optional[list] = None diff --git a/reinvent/version.py b/reinvent/version.py index 8936112..0d014d0 100644 --- a/reinvent/version.py +++ b/reinvent/version.py @@ -1,6 +1,18 @@ """Meta information for Reinvent""" +__all__ = ["__version__", "__maintainer__", "__email__", "__copyright__"] __progname__ = "REINVENT" -__version__ = "4.4.22" +__version__ = "4.5.11" +__authors__ = [ + "Hannes H Loeffer", + "Jiazhen He", + "Alessandro Tibo", + "Jon Paul Janet", + "Alexey Voronov", + "Lewis Mervin", +] +__maintainer__ = "Hannes H Loeffer" +__email__ = "hannes.loffler@astrazeneca.com" __config_version__ = 4 __copyright__ = "(C) AstraZeneca 2017, 2023" +__license__ = "Apache 2.0" diff --git a/reinvent_plugins/components/OpenEye/rocs/rocs_similarity.py b/reinvent_plugins/components/OpenEye/rocs/rocs_similarity.py index 149706e..406e7da 100644 --- a/reinvent_plugins/components/OpenEye/rocs/rocs_similarity.py +++ b/reinvent_plugins/components/OpenEye/rocs/rocs_similarity.py @@ -119,7 +119,7 @@ def _setup_overlay_from_shape_query(self): if oeshape.OEReadShapeQuery(self.rocs_input, qry): overlay.SetupRef(qry) else: - raise Exception("error reading in SQ file") + raise ValueError("error reading in SQ file") self.rocs_overlay = overlay def _setup_overlay_from_sdf_file(self): @@ -130,7 +130,7 @@ def _setup_overlay_from_sdf_file(self): if input_stream.open(self.rocs_input): oechem.OEReadMolecule(input_stream, refmol) else: - raise Exception("error reading in ROCS sdf file") + raise ValueError("error reading in ROCS sdf file") self.overlay_prep.Prep(refmol) overlay = oeshape.OEMultiRefOverlay() overlay.SetupRef(refmol) diff --git a/reinvent_plugins/components/RDKit/comp_group_count.py b/reinvent_plugins/components/RDKit/comp_group_count.py index 31c4cad..9aa49b7 100644 --- a/reinvent_plugins/components/RDKit/comp_group_count.py +++ b/reinvent_plugins/components/RDKit/comp_group_count.py @@ -45,7 +45,7 @@ def __init__(self, params: Parameters): self.patterns.append(pattern) if not self.patterns: - raise RuntimeError(f"{__name__}: no valid SMARTS patterns found") + raise ValueError(f"{__name__}: no valid SMARTS patterns found") self.number_of_endpoints = len(params.smarts) diff --git a/reinvent_plugins/components/RDKit/comp_matching_substructure.py b/reinvent_plugins/components/RDKit/comp_matching_substructure.py index d401ea8..9e8ae56 100644 --- a/reinvent_plugins/components/RDKit/comp_matching_substructure.py +++ b/reinvent_plugins/components/RDKit/comp_matching_substructure.py @@ -47,7 +47,7 @@ def __init__(self, params: Parameters): self.patterns.append(pattern) if not self.patterns: - raise RuntimeError(f"{__name__}: no valid SMARTS patterns found") + raise ValueError(f"{__name__}: no valid SMARTS patterns found") self.number_of_endpoints = len(params.smarts) diff --git a/reinvent_plugins/components/RDKit/comp_pmi.py b/reinvent_plugins/components/RDKit/comp_pmi.py index 7bdcbd2..3c9647c 100644 --- a/reinvent_plugins/components/RDKit/comp_pmi.py +++ b/reinvent_plugins/components/RDKit/comp_pmi.py @@ -28,7 +28,7 @@ def __init__(self, params: Parameters): self.properties = params.property if not "npr1" in self.properties and not "npr2" in self.properties: - raise RuntimeError(f"{__name__}: need one or both of: 'npr1', 'npr2'") + raise ValueError(f"{__name__}: need one or both of: 'npr1', 'npr2'") self.number_of_endpoints = len(params.property) diff --git a/reinvent_plugins/components/RDKit/comp_rdkit_descriptors.py b/reinvent_plugins/components/RDKit/comp_rdkit_descriptors.py index 9ea60c5..9585889 100644 --- a/reinvent_plugins/components/RDKit/comp_rdkit_descriptors.py +++ b/reinvent_plugins/components/RDKit/comp_rdkit_descriptors.py @@ -35,7 +35,7 @@ def __init__(self, params: Parameters): desc = descriptor.lower() if desc not in KNOWN_DESCRIPTORS: - raise RuntimeError(f"{__name__}: unknown descriptor {desc}") + raise ValueError(f"{__name__}: unknown descriptor {desc}") descriptors.append(KNOWN_DESCRIPTORS[desc]) diff --git a/reinvent_plugins/components/RDKit/comp_similarity.py b/reinvent_plugins/components/RDKit/comp_similarity.py index ddd53db..6165d9b 100644 --- a/reinvent_plugins/components/RDKit/comp_similarity.py +++ b/reinvent_plugins/components/RDKit/comp_similarity.py @@ -54,7 +54,7 @@ def __init__(self, params: Parameters): ) if not fingerprints: - raise RuntimeError(f"{__name__}: unable to convert any SMILES to fingerprints") + raise ValueError(f"{__name__}: unable to convert any SMILES to fingerprints") self.fp_params.append((fingerprints, radius, use_counts, use_features)) diff --git a/reinvent_plugins/components/comp_chemprop.py b/reinvent_plugins/components/comp_chemprop.py index 0462d78..d3d317c 100644 --- a/reinvent_plugins/components/comp_chemprop.py +++ b/reinvent_plugins/components/comp_chemprop.py @@ -90,12 +90,12 @@ def __init__(self, params: Parameters): f"{', '.join(target_columns)})" ) logger.critical(msg) - raise RuntimeError(msg) + raise ValueError(msg) if target_column in seen: msg = f"{__name__}: target columns must be unique ({params.target_column})" logger.critical(msg) - raise RuntimeError(msg) + raise ValueError(msg) seen.add(target_column) diff --git a/reinvent_plugins/components/comp_icolos.py b/reinvent_plugins/components/comp_icolos.py index 17c2ea5..94ac17c 100644 --- a/reinvent_plugins/components/comp_icolos.py +++ b/reinvent_plugins/components/comp_icolos.py @@ -107,18 +107,18 @@ def parse_output(filename: str, name: str) -> List: """ if not os.path.isfile(filename): - raise RuntimeError(f"{__name__}: failed, missing output file") + raise ValueError(f"{__name__}: failed, missing output file") with open(filename, "r") as jfile: data = json.load(jfile) # TODO: this should be properly validated if "results" not in data: - raise RuntimeError(f"{__name__}: JSON file does not contain 'results'") + raise ValueError(f"{__name__}: JSON file does not contain 'results'") # FIXME: check if scores are really in the same order as the SMILES for entry in data["results"]: if entry["values_key"] == name: return entry["values"] - raise RuntimeError(f"{__name__}: JSON file does not contain scores for {name}") + raise ValueError(f"{__name__}: JSON file does not contain scores for {name}") diff --git a/reinvent_plugins/components/comp_maize.py b/reinvent_plugins/components/comp_maize.py index e95a1ce..91ff1fb 100644 --- a/reinvent_plugins/components/comp_maize.py +++ b/reinvent_plugins/components/comp_maize.py @@ -268,12 +268,12 @@ def parse_output(filename: str) -> List[float]: """ if not os.path.isfile(filename): - raise RuntimeError(f"{__name__}: failed, missing output file") + raise ValueError(f"{__name__}: failed, missing output file") with open(filename, "r", encoding="utf-8") as jfile: data = json.load(jfile) if "scores" not in data: - raise RuntimeError(f"{__name__}: JSON file does not contain 'scores'") + raise ValueError(f"{__name__}: JSON file does not contain 'scores'") return data["scores"] diff --git a/reinvent_plugins/components/run_program.py b/reinvent_plugins/components/run_program.py index 8e04bb6..5fc3f68 100644 --- a/reinvent_plugins/components/run_program.py +++ b/reinvent_plugins/components/run_program.py @@ -29,7 +29,7 @@ def run_command(command: List[str], env: dict = None, input=None, cwd=None) -> s out = error.stdout err = error.stderr - raise RuntimeError( + raise ValueError( f"{__name__}: {' '.join(command)} has failed with exit " f"code {ret}: stdout={out}, stderr={err}" ) diff --git a/reinvent_plugins/normalizers/rdkit_smiles.py b/reinvent_plugins/normalizers/rdkit_smiles.py index b7be880..0ba4130 100644 --- a/reinvent_plugins/normalizers/rdkit_smiles.py +++ b/reinvent_plugins/normalizers/rdkit_smiles.py @@ -10,7 +10,7 @@ logger = logging.getLogger("reinvent") -def normalize(smilies: List[str]) -> List: +def normalize(smilies: List[str], keep_all: bool=False) -> List: """Remove annotations from SMILES :param smilies: list of SMILES strings @@ -22,7 +22,11 @@ def normalize(smilies: List[str]) -> List: mol = Chem.MolFromSmiles(smiles) if not mol: + if keep_all: + cleaned_smilies.append(smiles) + logger.warning(f"{__name__}: {smiles} could not be converted") + continue for atom in mol.GetAtoms(): diff --git a/requirements-linux-64.lock b/requirements-linux-64.lock index a7c558f..d4f9ddf 100644 --- a/requirements-linux-64.lock +++ b/requirements-linux-64.lock @@ -9,22 +9,14 @@ absl-py==2.1.0 # via tensorboard -alabaster==0.7.16 - # via sphinx annotated-types==0.6.0 # via pydantic -babel==2.14.0 - # via sphinx -blinker==1.7.0 - # via flask certifi==2024.2.2 # via requests charset-normalizer==3.3.2 # via requests chemprop==1.5.2 # via reinvent (pyproject.toml) -click==8.1.7 - # via flask cloudpickle==3.0.0 # via hyperopt contourpy==1.2.1 @@ -33,22 +25,14 @@ cycler==0.12.1 # via matplotlib descriptastorus==2.6.1 # via reinvent (pyproject.toml) -dill==0.3.8 - # via - # multiprocess - # pathos docstring-parser==0.16 # via typed-argument-parser -docutils==0.20.1 - # via sphinx exceptiongroup==1.2.0 # via pytest filelock==3.13.4 # via # torch # triton -flask==3.0.3 - # via chemprop fonttools==4.51.0 # via matplotlib fsspec==2024.3.1 @@ -63,16 +47,10 @@ hyperopt==0.2.7 # via chemprop idna==3.7 # via requests -imagesize==1.4.1 - # via sphinx iniconfig==2.0.0 # via pytest -itsdangerous==2.1.2 - # via flask jinja2==3.1.3 # via - # flask - # sphinx # torch joblib==1.4.0 # via scikit-learn @@ -94,8 +72,6 @@ molvs==0.1.1 # via reinvent (pyproject.toml) mpmath==1.3.0 # via sympy -multiprocess==0.70.16 - # via pathos mypy-extensions==1.0.0 # via typing-inspect networkx==3.3 @@ -129,7 +105,7 @@ nvidia-cuda-nvrtc-cu12==12.1.105 # via torch nvidia-cuda-runtime-cu12==12.1.105 # via torch -nvidia-cudnn-cu12==8.9.2.26 +nvidia-cudnn-cu12==9.1.0.70 # via torch nvidia-cufft-cu12==11.0.2.54 # via torch @@ -141,7 +117,7 @@ nvidia-cusparse-cu12==12.1.0.106 # via # nvidia-cusolver-cu12 # torch -nvidia-nccl-cu12==2.19.3 +nvidia-nccl-cu12==2.21.5 # via torch nvidia-nvjitlink-cu12==12.4.127 # via @@ -155,7 +131,6 @@ packaging==24.0 # via # matplotlib # pytest - # sphinx # tensorboardx # xarray pandas==2.2.2 @@ -168,8 +143,6 @@ pandas-flavor==0.6.0 # via # chemprop # descriptastorus -pathos==0.3.2 - # via reinvent (pyproject.toml) pillow==10.3.0 # via # matplotlib @@ -178,10 +151,6 @@ pillow==10.3.0 # torchvision pluggy==1.4.0 # via pytest -pox==0.3.4 - # via pathos -ppft==1.7.6.8 - # via pathos protobuf==5.26.1 # via # tensorboard @@ -192,8 +161,6 @@ pydantic==2.7.0 # via reinvent (pyproject.toml) pydantic-core==2.18.1 # via pydantic -pygments==2.17.2 - # via sphinx pyparsing==3.1.2 # via matplotlib pytest==8.1.1 @@ -220,13 +187,8 @@ requests==2.31.0 # via # reinvent (pyproject.toml) # requests-mock - # sphinx requests-mock==1.12.1 # via reinvent (pyproject.toml) -scikit-learn==1.2.2 - # via - # chemprop - # reinvent (pyproject.toml) scipy==1.13.0 # via # chemprop @@ -240,27 +202,11 @@ six==1.16.0 # molvs # python-dateutil # tensorboard -snowballstemmer==2.2.0 - # via sphinx -sphinx==7.2.6 - # via chemprop -sphinxcontrib-applehelp==1.0.8 - # via sphinx -sphinxcontrib-devhelp==1.0.6 - # via sphinx -sphinxcontrib-htmlhelp==2.0.5 - # via sphinx -sphinxcontrib-jsmath==1.0.1 - # via sphinx -sphinxcontrib-qthelp==1.0.7 - # via sphinx -sphinxcontrib-serializinghtml==1.1.10 - # via sphinx -sympy==1.12 +sympy==1.13.1 # via torch tenacity==8.2.3 # via reinvent (pyproject.toml) -tensorboard==2.16.2 +tensorboard==2.17.1 # via reinvent (pyproject.toml) tensorboard-data-server==0.7.2 # via tensorboard @@ -272,19 +218,19 @@ tomli==2.0.1 # via # pytest # reinvent (pyproject.toml) -torch==2.2.1+cu121 +torch==2.5.1+cu121 # via # chemprop # reinvent (pyproject.toml) # torchvision -torchvision==0.17.1+cu121 +torchvision==0.20.1+cu121 # via reinvent (pyproject.toml) tqdm==4.66.2 # via # chemprop # hyperopt # reinvent (pyproject.toml) -triton==2.2.0 +triton>3.0.0 # via torch typed-argument-parser==1.10.0 # via chemprop @@ -303,7 +249,6 @@ urllib3==2.2.1 # via requests werkzeug==3.0.2 # via - # flask # tensorboard xarray==2024.3.0 # via pandas-flavor diff --git a/support/run-qsartuna.py b/support/run-qsartuna.py new file mode 100755 index 0000000..60e2991 --- /dev/null +++ b/support/run-qsartuna.py @@ -0,0 +1,42 @@ +#!/bin/env python3 +# +# This is an example how to use the ExternalProcess scoring component using +# QSARtuna, see DOI 10.1021/acs.jcim.4c00457. The scripts expects a list of +# SMILES from stdin and will # write a JSON string to stdout. +# +# QSARtuna code at https://github.com/MolecularAI/QSARtuna. +# +# [[component.ExternalProcess.endpoint]] +# name = "QSARtuna model" +# weight = 0.6 +# +# # Run Qptuna in its own environment +# # The --no-capture-output is necessary to pass through stdout from REINVENT4 +# params.executable = "/home/user/miniconda3/condabin/mamba" +# params.args = "run --no-capture-output -n qsartuna /path/to/run-qsartuna.py model_filename +# +# # Don't forget the transform if needed! +# + + +import sys +import json +import pickle + + +smilies = [smiles.strip() for smiles in sys.stdin] + +# Everything from here to END is specific to the scorer + +with open(sys.argv[1], "rb") as mfile: + model = pickle.load(mfile) + +scores = model.predict_from_smiles(smilies, uncert=False) + +# END + + +# Format the JSON string for REINVENT4 and write it to stdout +data = {"version": 1, "payload": {"predictions": list(scores)}} + +print(json.dumps(data)) diff --git a/support/run-rascore.py b/support/run-rascore.py index d6df410..5e72d54 100755 --- a/support/run-rascore.py +++ b/support/run-rascore.py @@ -17,7 +17,7 @@ # # The --no-capture-output is necessary to pass through stdout from REINVENT4 # params.executable = "/home/user/miniconda3/condabin/mamba" # params.args = "run --no-capture-output -n rascore /home/user/projects/RAScore/run-rascore.py -# # No transform needed as score is alread between 0 and 1 +# # No transform needed as score is already between 0 and 1 # @@ -27,7 +27,7 @@ from RAscore import RAscore_NN -# Created from default model in the repository +# Created from default model in the repository: # from tensorflow import keras # model = keras.models.load_model("models/DNN_chembl_fcfp_counts/model.tf") # model.save("/home/user/projects/RAScore/new_tf_2.5") diff --git a/tests/datapipeline/test_percent.py b/tests/datapipeline/test_percent.py new file mode 100644 index 0000000..1e83579 --- /dev/null +++ b/tests/datapipeline/test_percent.py @@ -0,0 +1,32 @@ +import pytest +from types import SimpleNamespace +from dataclasses import dataclass + +from reinvent.datapipeline.filters import RDKitFilter + + +SMILES = [ +"CN1CCC23c4c5ccc(OC(=O)c6ccc7ccc(C(=O)Oc8ccc9c%10c8OC8C(O)C=CC%11C(C9)N(C)CCC%108%11)cc7c6)c4OC2C(O)C=CC3C1C5", +"CC1(C)CCC(C)(C)c2cc(N3c4cc5c(cc4B4c6oc7ccc(C89CC%10CC%11CC(C8)C%119%10)cc7c6N(c6ccccc6)c6cccc3c64)C(C)(C)CCC5(C)C)ccc21", +"CC1C2C3C4C5CC6C7C(C)C8(C)C9(C)C%10(C)C%11(C)C(C)(C)C%12(C)C1(C)C21C32C43C65C78C93C2%10C%121%11", +"CC12CCC/C(=N\OCC(=O)O)C1C13c4c5c6c7c8c9c%10c%11c%12c%13c%14c%15c(c6c6c4c4c%16c%17c%18c(c%19c%20c1c5c8c1c%20c5c%19c8c%18c%18c%19c%17c%17c4c6c%15c4c%14c6c%12c%12c%14c%11c(c91)c5c%14c8c%18c%12c6c%19c4%17)C%1623)C7C%10%13", +] + + +@pytest.fixture +def rdkit_filter(): + config = SimpleNamespace(max_ring_size=7, max_num_rings=12, + keep_stereo=False, report_errors=False) + yield RDKitFilter(config) + pass # no tear-down + + +@pytest.mark.xfail +def test_large_percent(rdkit_filter): + good_smilies = [] + + for smiles in SMILES: + rdkit_smiles = rdkit_filter(smiles) + good_smilies.append(rdkit_smiles) + + assert len(good_smilies) == 0 diff --git a/tests/datapipeline/test_unwanted_tokens.py b/tests/datapipeline/test_unwanted_tokens.py new file mode 100644 index 0000000..6cbc904 --- /dev/null +++ b/tests/datapipeline/test_unwanted_tokens.py @@ -0,0 +1,36 @@ +import pytest +from types import SimpleNamespace +from dataclasses import dataclass + +from reinvent.datapipeline.filters import RegexFilter + + +SMILES = [ +"CC(=O)O[Cl+][O-]", +"CC(C)(C)[PH+]([BH2-][P+]([BH2-][ClH+2])(C(C)(C)C)C(C)(C)C)C(C)(C)C", +"[O-][Br+3]([O-])(O)Oc1cccnc1Br", +"CCC=CCI1C(=O)CCC1([IH])CC(=O)O", +"CC1=Cc2ccccc2S1(O[I+3]([O-])([O-])O)C(F)(F)F", +"C[N+]1(C2C[I-]C2)CC1" +] + + +@pytest.fixture +def regex_filter(): + config = SimpleNamespace(keep_stereo=True, keep_isotopes=False, + max_heavy_atoms=70, max_mol_weight=1200, min_heavy_atoms=2, + min_carbons=2, elements=["B", "P"]) + yield RegexFilter(config) + pass # no tear-down + + +def test_unwanted_tokens(regex_filter): + good_smilies = [] + + for smiles in SMILES: + regex_smiles = regex_filter(smiles) + print(regex_smiles) + good_smilies.append(regex_smiles) + + assert len(good_smilies) == 6 + assert not all(good_smilies) diff --git a/tests/models/integration_tests/libinvent/RNN/__init__.py b/tests/models/integration_tests/libinvent/RNN/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/integration_tests/libinvent/RNN/model_tests/__init__.py b/tests/models/integration_tests/libinvent/RNN/model_tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/integration_tests/libinvent/RNN/model_tests/decorator_model_test.py b/tests/models/integration_tests/libinvent/RNN/model_tests/decorator_model_test.py new file mode 100644 index 0000000..436c731 --- /dev/null +++ b/tests/models/integration_tests/libinvent/RNN/model_tests/decorator_model_test.py @@ -0,0 +1,69 @@ +import pytest +import unittest + +import torch +import torch.utils.data as tud + +from reinvent.runmodes.utils.helpers import set_torch_device +from reinvent.models.model_mode_enum import ModelModeEnum +from reinvent.models.libinvent.models.dataset import Dataset +from reinvent.models.libinvent.models.model import DecoratorModel +from tests.test_data import SCAFFOLD_SUZUKI + + +@pytest.mark.integration +@pytest.mark.usefixtures("device", "json_config") +class TestDecoratorModel(unittest.TestCase): + def setUp(self): + input_scaffold = SCAFFOLD_SUZUKI + scaffold_list_2 = [input_scaffold, input_scaffold] + scaffold_list_3 = [input_scaffold, input_scaffold, input_scaffold] + self._model_regime = ModelModeEnum() + self._decorator = DecoratorModel.load_from_file( + self.json_config["LIBINVENT_CHEMBL_PRIOR_PATH"], + "inference", + torch.device(self.device), + ) + set_torch_device(self.device) + + dataset_2 = Dataset( + scaffold_list_2, + self._decorator.vocabulary.scaffold_vocabulary, + self._decorator.vocabulary.scaffold_tokenizer, + ) + self.dataloader_2 = tud.DataLoader( + dataset_2, batch_size=32, shuffle=False, collate_fn=Dataset.collate_fn + ) + + dataset_3 = Dataset( + scaffold_list_3, + self._decorator.vocabulary.scaffold_vocabulary, + self._decorator.vocabulary.scaffold_tokenizer, + ) + self.dataloader_3 = tud.DataLoader( + dataset_3, batch_size=32, shuffle=False, collate_fn=Dataset.collate_fn + ) + + def test_double_scaffold_input(self): + for batch in self.dataloader_2: + ( + scaffold_smiles, + decoration_smiles, + nlls, + ) = self._decorator.sample_decorations(*batch) + + self.assertEqual(2, len(scaffold_smiles)) + self.assertEqual(2, len(decoration_smiles)) + self.assertEqual(2, len(nlls)) + + def test_triple_scaffold_input(self): + for batch in self.dataloader_3: + ( + scaffold_smiles, + decoration_smiles, + nlls, + ) = self._decorator.sample_decorations(*batch) + + self.assertEqual(3, len(scaffold_smiles)) + self.assertEqual(3, len(decoration_smiles)) + self.assertEqual(3, len(nlls)) diff --git a/tests/models/integration_tests/libinvent/RNN/model_tests/test_likelihood.py b/tests/models/integration_tests/libinvent/RNN/model_tests/test_likelihood.py new file mode 100644 index 0000000..edc04c0 --- /dev/null +++ b/tests/models/integration_tests/libinvent/RNN/model_tests/test_likelihood.py @@ -0,0 +1,29 @@ +import pytest +import unittest + +import torch + +from reinvent.models import LibinventAdapter, SampledSequencesDTO +from reinvent.runmodes.utils.helpers import set_torch_device +from reinvent.models.libinvent.models.model import DecoratorModel +from tests.test_data import ETHANE, HEXANE, PROPANE, BUTANE + + +@pytest.mark.integration +@pytest.mark.usefixtures("device", "json_config") +class TestLibInventLikelihoodSMILES(unittest.TestCase): + def setUp(self): + dto1 = SampledSequencesDTO(ETHANE, PROPANE, 0.9) + dto2 = SampledSequencesDTO(HEXANE, BUTANE, 0.1) + self.sampled_sequence_list = [dto1, dto2] + + save_dict = torch.load(self.json_config["LIBINVENT_CHEMBL_PRIOR_PATH"], map_location=self.device) + model = DecoratorModel.create_from_dict(save_dict, "inference", torch.device(self.device)) + set_torch_device(self.device) + + self.adapter = LibinventAdapter(model) + + def test_len_likelihood_smiles(self): + results = self.adapter.likelihood_smiles(self.sampled_sequence_list) + + self.assertEqual([2], list(results.likelihood.shape)) diff --git a/tests/models/integration_tests/libinvent/RNN/vocabulary_tests/__init__.py b/tests/models/integration_tests/libinvent/RNN/vocabulary_tests/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/models/integration_tests/libinvent/RNN/vocabulary_tests/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/models/integration_tests/libinvent/RNN/vocabulary_tests/test_tokenization_with_model.py b/tests/models/integration_tests/libinvent/RNN/vocabulary_tests/test_tokenization_with_model.py new file mode 100644 index 0000000..e69ba52 --- /dev/null +++ b/tests/models/integration_tests/libinvent/RNN/vocabulary_tests/test_tokenization_with_model.py @@ -0,0 +1,25 @@ +import unittest + +import pytest +import torch + +from reinvent.models import LibinventAdapter +from reinvent.models.libinvent.models.model import DecoratorModel +from reinvent.runmodes.utils.helpers import set_torch_device + + +@pytest.mark.integration +@pytest.mark.usefixtures("device", "json_config") +class TestTokenizationWithModel(unittest.TestCase): + def setUp(self): + self.smiles = "c1ccccc1CC0C" + + save_dict = torch.load(self.json_config["LIBINVENT_CHEMBL_PRIOR_PATH"], map_location=self.device) + model = DecoratorModel.create_from_dict(save_dict, "inference", torch.device(self.device)) + set_torch_device(self.device) + + self.adapter = LibinventAdapter(model) + + def test_tokenization(self): + tokenized = self.adapter.vocabulary.scaffold_tokenizer.tokenize(self.smiles) + self.assertEqual(14, len(tokenized)) diff --git a/tests/models/integration_tests/libinvent/transformer/__init__.py b/tests/models/integration_tests/libinvent/transformer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/integration_tests/libinvent/transformer/test_libinvent_model.py b/tests/models/integration_tests/libinvent/transformer/test_libinvent_model.py new file mode 100644 index 0000000..4e6399f --- /dev/null +++ b/tests/models/integration_tests/libinvent/transformer/test_libinvent_model.py @@ -0,0 +1,78 @@ +import unittest + +import pytest +import torch +import torch.utils.data as tud + +from reinvent.models.transformer.core.dataset.dataset import Dataset +from reinvent.models.transformer.core.enums.sampling_mode_enum import SamplingModesEnum +from reinvent.models.transformer.core.vocabulary import SMILESTokenizer +from reinvent.models.transformer.libinvent.libinvent import LibinventModel +from reinvent.runmodes.utils.helpers import set_torch_device +from tests.test_data import SCAFFOLD_SINGLE_POINT, SCAFFOLD_DOUBLE_POINT, SCAFFOLD_TRIPLE_POINT, SCAFFOLD_QUADRUPLE_POINT + + +@pytest.mark.integration +@pytest.mark.usefixtures("device", "json_config") +class TestLibInventModel(unittest.TestCase): + def setUp(self): + + save_dict = torch.load(self.json_config["LIBINVENT_PRIOR_PATH"], map_location=self.device) + self._model = LibinventModel.create_from_dict( + save_dict, "inference", torch.device(self.device) + ) + set_torch_device(self.device) + + self._sample_mode_enum = SamplingModesEnum() + + smiles_list = [SCAFFOLD_SINGLE_POINT] + self.data_loader_1 = self.initialize_dataloader(smiles_list) + + smiles_list = [SCAFFOLD_SINGLE_POINT, SCAFFOLD_DOUBLE_POINT] + self.data_loader_2 = self.initialize_dataloader(smiles_list) + + smiles_list = [SCAFFOLD_SINGLE_POINT, SCAFFOLD_DOUBLE_POINT, SCAFFOLD_TRIPLE_POINT] + self.data_loader_3 = self.initialize_dataloader(smiles_list) + + smiles_list = [SCAFFOLD_SINGLE_POINT, SCAFFOLD_DOUBLE_POINT, SCAFFOLD_TRIPLE_POINT, SCAFFOLD_QUADRUPLE_POINT] + self.data_loader_4 = self.initialize_dataloader(smiles_list) + + def initialize_dataloader(self, data): + dataset = Dataset(data, vocabulary=self._model.vocabulary, tokenizer=SMILESTokenizer()) + dataloader = tud.DataLoader( + dataset, len(dataset), shuffle=False, collate_fn=Dataset.collate_fn + ) + + return dataloader + + def _sample_decorations(self, data_loader): + for batch in data_loader: + return self._model.sample(*batch, decode_type=self._sample_mode_enum.MULTINOMIAL) + + def test_single_attachment_input(self): + results = self._sample_decorations(self.data_loader_1) + + self.assertEqual(1, len(results[0])) + self.assertEqual(1, len(results[1])) + self.assertEqual(1, len(results[2])) + + def test_double_attachment_input(self): + results = self._sample_decorations(self.data_loader_2) + + self.assertEqual(2, len(results[0])) + self.assertEqual(2, len(results[1])) + self.assertEqual(2, len(results[2])) + + def test_triple_attachment_input(self): + results = self._sample_decorations(self.data_loader_3) + + self.assertEqual(3, len(results[0])) + self.assertEqual(3, len(results[1])) + self.assertEqual(3, len(results[2])) + + def test_quadruple_attachment_input(self): + results = self._sample_decorations(self.data_loader_4) + + self.assertEqual(4, len(results[0])) + self.assertEqual(4, len(results[1])) + self.assertEqual(4, len(results[2])) diff --git a/tests/models/integration_tests/libinvent/transformer/test_likelihood.py b/tests/models/integration_tests/libinvent/transformer/test_likelihood.py new file mode 100644 index 0000000..313f746 --- /dev/null +++ b/tests/models/integration_tests/libinvent/transformer/test_likelihood.py @@ -0,0 +1,31 @@ +import unittest + +import pytest +import torch + +from reinvent.models import LibinventTransformerAdapter, SampledSequencesDTO +from reinvent.models.transformer.libinvent.libinvent import LibinventModel +from reinvent.runmodes.utils.helpers import set_torch_device +from tests.test_data import ETHANE, HEXANE, PROPANE, BUTANE +from tests.test_data import SCAFFOLD_SINGLE_POINT, SCAFFOLD_DOUBLE_POINT, SCAFFOLD_TRIPLE_POINT, \ + SCAFFOLD_QUADRUPLE_POINT, DECORATION_NO_SUZUKI, TWO_DECORATIONS_ONE_SUZUKI, THREE_DECORATIONS, FOUR_DECORATIONS + + +@pytest.mark.integration +@pytest.mark.usefixtures("device", "json_config") +class TestLibInventLikelihoodSMILES(unittest.TestCase): + def setUp(self): + dto1 = SampledSequencesDTO(SCAFFOLD_SINGLE_POINT, DECORATION_NO_SUZUKI, 0.4) + dto2 = SampledSequencesDTO(SCAFFOLD_DOUBLE_POINT, TWO_DECORATIONS_ONE_SUZUKI, 0.6) + dto3 = SampledSequencesDTO(SCAFFOLD_TRIPLE_POINT, THREE_DECORATIONS, 0.3) + dto4 = SampledSequencesDTO(SCAFFOLD_QUADRUPLE_POINT, FOUR_DECORATIONS, 0.5) + self.sampled_sequence_list = [dto1, dto2, dto3, dto4] + save_dict = torch.load(self.json_config["LIBINVENT_PRIOR_PATH"], map_location=self.device) + model = LibinventModel.create_from_dict(save_dict, "inference", torch.device(self.device)) + set_torch_device(self.device) + + self.adapter = LibinventTransformerAdapter(model) + + def test_len_likelihood_smiles(self): + results = self.adapter.likelihood_smiles(self.sampled_sequence_list) + self.assertEqual([4], list(results.likelihood.shape)) diff --git a/tests/models/integration_tests/mol2mol/dataset_tests/test_paireddataset.py b/tests/models/integration_tests/mol2mol/dataset_tests/test_paireddataset.py index 7d4856a..939a5ad 100644 --- a/tests/models/integration_tests/mol2mol/dataset_tests/test_paireddataset.py +++ b/tests/models/integration_tests/mol2mol/dataset_tests/test_paireddataset.py @@ -95,7 +95,7 @@ def test_trg_mask_shape(self): def test_src_content(self): result = self._get_src() comparison = torch.equal( - result, torch.tensor([[1, 17, 17, 2, 0], [1, 17, 17, 17, 2]]).to(self.device) + result, torch.tensor([[1, 60, 60, 2, 0], [1, 60, 60, 60, 2]]).to(self.device) ) self.assertTrue(comparison) @@ -113,7 +113,7 @@ def test_trg_content(self): result = self._get_trg() comparison = torch.equal( result, - torch.tensor([[1, 17, 17, 17, 17, 17, 17, 2], [1, 17, 17, 17, 17, 2, 0, 0]]).to( + torch.tensor([[1, 60, 60, 60, 60, 60, 60, 2], [1, 60, 60, 60, 60, 2, 0, 0]]).to( self.device ), ) diff --git a/tests/models/integration_tests/pepinvent/__init__.py b/tests/models/integration_tests/pepinvent/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/integration_tests/pepinvent/test_likelihood.py b/tests/models/integration_tests/pepinvent/test_likelihood.py new file mode 100644 index 0000000..bd152dc --- /dev/null +++ b/tests/models/integration_tests/pepinvent/test_likelihood.py @@ -0,0 +1,27 @@ +import unittest + +import pytest +import torch + +from reinvent.models import PepinventAdapter, SampledSequencesDTO +from reinvent.models.transformer.pepinvent.pepinvent import PepinventModel +from reinvent.runmodes.utils.helpers import set_torch_device +from tests.test_data import PEPINVENT_INPUT1, PEPINVENT_OUTPUT1, PEPINVENT_INPUT2, PEPINVENT_OUTPUT2 + + +@pytest.mark.integration +@pytest.mark.usefixtures("device", "json_config") +class TestPepInventLikelihoodSMILES(unittest.TestCase): + def setUp(self): + dto1 = SampledSequencesDTO(PEPINVENT_INPUT1, PEPINVENT_OUTPUT1, 0.9) + dto2 = SampledSequencesDTO(PEPINVENT_INPUT2, PEPINVENT_OUTPUT2, 0.1) + self.sampled_sequence_list = [dto1, dto2] + save_dict = torch.load(self.json_config["PEPINVENT_PRIOR_PATH"], map_location=self.device) + model = PepinventModel.create_from_dict(save_dict, "inference", torch.device(self.device)) + set_torch_device(self.device) + + self.adapter = PepinventAdapter(model) + + def test_len_likelihood_smiles(self): + results = self.adapter.likelihood_smiles(self.sampled_sequence_list) + self.assertEqual([2], list(results.likelihood.shape)) diff --git a/tests/models/integration_tests/pepinvent/test_pepinvent_model.py b/tests/models/integration_tests/pepinvent/test_pepinvent_model.py new file mode 100644 index 0000000..e734fbf --- /dev/null +++ b/tests/models/integration_tests/pepinvent/test_pepinvent_model.py @@ -0,0 +1,71 @@ +import unittest + +import pytest +import torch +import torch.utils.data as tud + +from reinvent.models.transformer.core.dataset.dataset import Dataset +from reinvent.models.transformer.core.enums.sampling_mode_enum import SamplingModesEnum +from reinvent.models.transformer.core.vocabulary import SMILESTokenizer +from reinvent.models.transformer.pepinvent.pepinvent import PepinventModel +from reinvent.runmodes.utils.helpers import set_torch_device +from tests.test_data import PEPINVENT_INPUT1, PEPINVENT_INPUT2, PEPINVENT_INPUT3 + + +@pytest.mark.integration +@pytest.mark.usefixtures("device", "json_config") +class TestPepInventModel(unittest.TestCase): + def setUp(self): + + save_dict = torch.load(self.json_config["PEPINVENT_PRIOR_PATH"], map_location=self.device) + self._model = PepinventModel.create_from_dict( + save_dict, "inference", torch.device(self.device) + ) + set_torch_device(self.device) + + self._sample_mode_enum = SamplingModesEnum() + + smiles_list = [PEPINVENT_INPUT1] + self.data_loader_1 = self.initialize_dataloader(smiles_list) + + smiles_list = [PEPINVENT_INPUT1, PEPINVENT_INPUT2] + self.data_loader_2 = self.initialize_dataloader(smiles_list) + + smiles_list = [PEPINVENT_INPUT1, PEPINVENT_INPUT2, PEPINVENT_INPUT3] + self.data_loader_3 = self.initialize_dataloader(smiles_list) + + smiles_list = [PEPINVENT_INPUT1, PEPINVENT_INPUT2, PEPINVENT_INPUT3] + self.data_loader_4 = self.initialize_dataloader(smiles_list) + + def initialize_dataloader(self, data): + dataset = Dataset(data, vocabulary=self._model.vocabulary, tokenizer=SMILESTokenizer()) + dataloader = tud.DataLoader( + dataset, len(dataset), shuffle=False, collate_fn=Dataset.collate_fn + ) + + return dataloader + + def _sample(self, data_loader): + for batch in data_loader: + return self._model.sample(*batch, decode_type=self._sample_mode_enum.MULTINOMIAL) + + def test_single_input(self): + results = self._sample(self.data_loader_1) + + self.assertEqual(1, len(results[0])) + self.assertEqual(1, len(results[1])) + self.assertEqual(1, len(results[2])) + + def test_double_input(self): + results = self._sample(self.data_loader_2) + + self.assertEqual(2, len(results[0])) + self.assertEqual(2, len(results[1])) + self.assertEqual(2, len(results[2])) + + def test_triple_input(self): + results = self._sample(self.data_loader_3) + + self.assertEqual(3, len(results[0])) + self.assertEqual(3, len(results[1])) + self.assertEqual(3, len(results[2])) diff --git a/tests/models/unit_tests/libinvent/RNN/__init__.py b/tests/models/unit_tests/libinvent/RNN/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/unit_tests/libinvent/RNN/fixtures.py b/tests/models/unit_tests/libinvent/RNN/fixtures.py new file mode 100644 index 0000000..aabf30c --- /dev/null +++ b/tests/models/unit_tests/libinvent/RNN/fixtures.py @@ -0,0 +1,55 @@ +import torch.nn as nn + +from reinvent.models.model_mode_enum import ModelModeEnum +from reinvent.models.model_parameter_enum import ModelParametersEnum +from reinvent.models.libinvent.models.decorator import Decorator +from reinvent.models.libinvent.models.model import DecoratorModel +from reinvent.models.libinvent.models.vocabulary import DecoratorVocabulary +from reinvent.models import meta_data +from tests.test_data import SCAFFOLD_SUZUKI + + +def _init_params(parameters): + """ + Fixed weights + """ + for p in parameters: + if p.dim() > 1: + nn.init.constant_(p, 0.5) + + +def mocked_decorator_model(): + smiles_list = [SCAFFOLD_SUZUKI] + decorator_vocabulary = DecoratorVocabulary.from_lists(smiles_list, smiles_list) + scaffold_vocabulary_size = decorator_vocabulary.len_scaffold() + decoration_vocabulary_size = decorator_vocabulary.len_decoration() + + parameter_enums = ModelParametersEnum + encoder_params = { + parameter_enums.NUMBER_OF_LAYERS: 2, + parameter_enums.NUMBER_OF_DIMENSIONS: 128, + parameter_enums.VOCABULARY_SIZE: scaffold_vocabulary_size, + parameter_enums.DROPOUT: 0, + } + decoder_params = { + parameter_enums.NUMBER_OF_LAYERS: 2, + parameter_enums.NUMBER_OF_DIMENSIONS: 128, + parameter_enums.VOCABULARY_SIZE: decoration_vocabulary_size, + parameter_enums.DROPOUT: 0, + } + decorator = Decorator(encoder_params, decoder_params) + + model_regime = ModelModeEnum() + + metadata = meta_data.ModelMetaData( + hash_id=None, + hash_id_format="", + model_id="", + origina_data_source="", + creation_date=0, + ) + + model = DecoratorModel(decorator_vocabulary, decorator, metadata, mode=model_regime.INFERENCE) + _init_params(model.network.parameters()) + + return model diff --git a/tests/models/unit_tests/libinvent/RNN/model_tests/__init__.py b/tests/models/unit_tests/libinvent/RNN/model_tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/unit_tests/libinvent/RNN/model_tests/decorator_model_test.py b/tests/models/unit_tests/libinvent/RNN/model_tests/decorator_model_test.py new file mode 100644 index 0000000..c28aa7d --- /dev/null +++ b/tests/models/unit_tests/libinvent/RNN/model_tests/decorator_model_test.py @@ -0,0 +1,67 @@ +import pytest +import unittest + +import torch +import torch.utils.data as tud + +from reinvent.models.libinvent.models.dataset import Dataset +from reinvent.runmodes.utils.helpers import set_torch_device +from tests.test_data import SCAFFOLD_SUZUKI +from tests.models.unit_tests.libinvent.RNN.fixtures import mocked_decorator_model + + +@pytest.mark.usefixtures("device") +class TestDecoratorModel(unittest.TestCase): + def setUp(self): + input_scaffold = SCAFFOLD_SUZUKI + scaffold_list_2 = [input_scaffold, input_scaffold] + scaffold_list_3 = [input_scaffold, input_scaffold, input_scaffold] + + device = torch.device(self.device) + self._decorator = mocked_decorator_model() + self._decorator.network.to(device) + self._decorator.device = device + + set_torch_device(device) + + dataset_2 = Dataset( + scaffold_list_2, + self._decorator.vocabulary.scaffold_vocabulary, + self._decorator.vocabulary.scaffold_tokenizer, + ) + self.dataloader_2 = tud.DataLoader( + dataset_2, batch_size=32, shuffle=False, collate_fn=Dataset.collate_fn + ) + + dataset_3 = Dataset( + scaffold_list_3, + self._decorator.vocabulary.scaffold_vocabulary, + self._decorator.vocabulary.scaffold_tokenizer, + ) + self.dataloader_3 = tud.DataLoader( + dataset_3, batch_size=32, shuffle=False, collate_fn=Dataset.collate_fn + ) + + def test_double_scaffold_input(self): + for batch in self.dataloader_2: + ( + scaffold_smiles, + decoration_smiles, + nlls, + ) = self._decorator.sample_decorations(*batch) + + self.assertEqual(2, len(scaffold_smiles)) + self.assertEqual(2, len(decoration_smiles)) + self.assertEqual(2, len(nlls)) + + def test_triple_scaffold_input(self): + for batch in self.dataloader_3: + ( + scaffold_smiles, + decoration_smiles, + nlls, + ) = self._decorator.sample_decorations(*batch) + + self.assertEqual(3, len(scaffold_smiles)) + self.assertEqual(3, len(decoration_smiles)) + self.assertEqual(3, len(nlls)) diff --git a/tests/models/unit_tests/libinvent/RNN/model_tests/test_likelihood.py b/tests/models/unit_tests/libinvent/RNN/model_tests/test_likelihood.py new file mode 100644 index 0000000..4d29f7b --- /dev/null +++ b/tests/models/unit_tests/libinvent/RNN/model_tests/test_likelihood.py @@ -0,0 +1,30 @@ +import pytest +import unittest + +import torch + +from reinvent.models import LibinventAdapter, SampledSequencesDTO +from reinvent.runmodes.utils.helpers import set_torch_device +from tests.test_data import ETHANE, HEXANE, PROPANE, BUTANE +from tests.models.unit_tests.libinvent.RNN.fixtures import mocked_decorator_model + + +@pytest.mark.usefixtures("device") +class TestLibInventLikelihoodSMILES(unittest.TestCase): + def setUp(self): + dto1 = SampledSequencesDTO(ETHANE, PROPANE, 0.9) + dto2 = SampledSequencesDTO(HEXANE, BUTANE, 0.1) + self.sampled_sequence_list = [dto1, dto2] + + device = torch.device(self.device) + decoder_model = mocked_decorator_model() + decoder_model.network.to(device) + decoder_model.device = device + + set_torch_device(device) + + self._model = LibinventAdapter(decoder_model) + + def test_len_likelihood_smiles(self): + results = self._model.likelihood_smiles(self.sampled_sequence_list) + self.assertEqual([2], list(results.likelihood.shape)) diff --git a/tests/models/unit_tests/libinvent/RNN/vocabulary_tests/__init__.py b/tests/models/unit_tests/libinvent/RNN/vocabulary_tests/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/models/unit_tests/libinvent/RNN/vocabulary_tests/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/models/unit_tests/libinvent/RNN/vocabulary_tests/test_tokenization_with_model.py b/tests/models/unit_tests/libinvent/RNN/vocabulary_tests/test_tokenization_with_model.py new file mode 100644 index 0000000..503e2ce --- /dev/null +++ b/tests/models/unit_tests/libinvent/RNN/vocabulary_tests/test_tokenization_with_model.py @@ -0,0 +1,13 @@ +import unittest + +from tests.models.unit_tests.libinvent.RNN.fixtures import mocked_decorator_model + + +class TestTokenizationWithModel(unittest.TestCase): + def setUp(self): + self.smiles = "c1ccccc1CC0C" + self.actor = mocked_decorator_model() + + def test_tokenization(self): + tokenized = self.actor.vocabulary.scaffold_tokenizer.tokenize(self.smiles) + self.assertEqual(14, len(tokenized)) diff --git a/tests/models/unit_tests/libinvent/RNN/vocabulary_tests/test_tokenizer.py b/tests/models/unit_tests/libinvent/RNN/vocabulary_tests/test_tokenizer.py new file mode 100644 index 0000000..1d685f2 --- /dev/null +++ b/tests/models/unit_tests/libinvent/RNN/vocabulary_tests/test_tokenizer.py @@ -0,0 +1,107 @@ +import unittest + +from reinvent.models.libinvent.models.vocabulary import SMILESTokenizer + + +class TestSmilesTokenizer(unittest.TestCase): + def setUp(self): + self.tokenizer = SMILESTokenizer() + + def test_tokenize(self): + self.assertListEqual( + self.tokenizer.tokenize("CC(C)Cc1ccc(cc1)[C@@H](C)C(=O)O"), + [ + "^", + "C", + "C", + "(", + "C", + ")", + "C", + "c", + "1", + "c", + "c", + "c", + "(", + "c", + "c", + "1", + ")", + "[C@@H]", + "(", + "C", + ")", + "C", + "(", + "=", + "O", + ")", + "O", + "$", + ], + ) + + self.assertListEqual( + self.tokenizer.tokenize("C%12CC(Br)C1CC%121[ClH]", with_begin_and_end=False), + [ + "C", + "%12", + "C", + "C", + "(", + "Br", + ")", + "C", + "1", + "C", + "C", + "%12", + "1", + "[ClH]", + ], + ) + + def test_untokenize(self): + self.assertEqual( + self.tokenizer.untokenize( + [ + "^", + "C", + "C", + "(", + "C", + ")", + "C", + "c", + "1", + "c", + "c", + "c", + "(", + "c", + "c", + "1", + ")", + "[C@@H]", + "(", + "C", + ")", + "C", + "(", + "=", + "O", + ")", + "O", + "$", + ] + ), + "CC(C)Cc1ccc(cc1)[C@@H](C)C(=O)O", + ) + + self.assertEqual( + self.tokenizer.untokenize( + ["C", "1", "C", "C", "(", "Br", ")", "C", "C", "C", "1", "[ClH]"] + ), + "C1CC(Br)CCC1[ClH]", + ) diff --git a/tests/models/unit_tests/libinvent/RNN/vocabulary_tests/test_vocabulary.py b/tests/models/unit_tests/libinvent/RNN/vocabulary_tests/test_vocabulary.py new file mode 100644 index 0000000..5b773db --- /dev/null +++ b/tests/models/unit_tests/libinvent/RNN/vocabulary_tests/test_vocabulary.py @@ -0,0 +1,102 @@ +import unittest + +import numpy as np +import numpy.testing as npt + +from reinvent.models.libinvent.models.vocabulary import Vocabulary +from tests.test_data import SIMPLE_TOKENS + + +class TestVocabulary(unittest.TestCase): + def setUp(self): + self.voc = Vocabulary(tokens=SIMPLE_TOKENS) + + def test_add_to_vocabulary_1(self): + idx = self.voc.add("#") + self.assertTrue("#" in self.voc) + self.assertTrue(idx in self.voc) + self.assertEqual(self.voc["#"], idx) + self.assertEqual(self.voc[idx], "#") + + def test_add_to_vocabulary_2(self): + idx = self.voc.add("7") + self.assertTrue("7" in self.voc) + self.assertTrue(idx in self.voc) + self.assertEqual(self.voc["7"], idx) + self.assertEqual(self.voc[idx], "7") + + def test_add_to_vocabulary_3(self): + idx = self.voc.add("1") + self.assertTrue("1" in self.voc) + self.assertTrue(idx in self.voc) + self.assertEqual(self.voc[idx], "1") + self.assertEqual(self.voc["1"], idx) + + def test_add_to_vocabulary_4(self): + with self.assertRaises(TypeError) as context: + self.voc.add(1) + self.assertTrue("Token is not a string" in str(context.exception)) + + def test_includes(self): + self.assertTrue(2 in self.voc) + self.assertTrue("1" in self.voc) + self.assertFalse(21 in self.voc) + self.assertFalse("6" in self.voc) + + def test_equal(self): + self.assertEqual(self.voc, Vocabulary(tokens=SIMPLE_TOKENS)) + self.voc.add("#") + self.assertNotEqual(self.voc, Vocabulary(tokens=SIMPLE_TOKENS)) + + def test_update_vocabulary_1(self): + idxs = self.voc.update(["5", "#"]) + self.assertTrue("5" in self.voc) + self.assertTrue(idxs[0] in self.voc) + self.assertTrue("#" in self.voc) + self.assertTrue(idxs[1] in self.voc) + self.assertEqual(self.voc["5"], idxs[0]) + self.assertEqual(self.voc[idxs[0]], "5") + self.assertEqual(self.voc["#"], idxs[1]) + self.assertEqual(self.voc[idxs[1]], "#") + + def test_update_vocabulary_2(self): + idx = self.voc.update(["1", "2"]) + self.assertTrue("1" in self.voc) + self.assertTrue("2" in self.voc) + self.assertTrue(idx[0] in self.voc) + self.assertTrue(idx[1] in self.voc) + self.assertEqual(self.voc["1"], idx[0]) + self.assertEqual(self.voc["2"], idx[1]) + self.assertEqual(idx[0], self.voc["1"]) + self.assertEqual(idx[1], self.voc["2"]) + self.assertEqual("1", self.voc[4]) + self.assertEqual("2", self.voc[5]) + self.assertEqual("1", self.voc[idx[0]]) + self.assertEqual("2", self.voc[idx[1]]) + + def test_update_vocabulary_3(self): + with self.assertRaises(TypeError) as context: + self.voc.update([1, 2]) + self.assertTrue("Token is not a string" in str(context.exception)) + + def test_delete_vocabulary_1(self): + idx3 = self.voc["1"] + del self.voc["1"] + self.assertFalse("1" in self.voc) + self.assertFalse(idx3 in self.voc) + + def test_delete_vocabulary_2(self): + idx4 = self.voc[5] + del self.voc[5] + self.assertFalse("2" in self.voc) + self.assertFalse(idx4 in self.voc) + + def test_len(self): + self.assertEqual(len(self.voc), 15) + self.assertEqual(len(Vocabulary()), 0) + + def test_encode(self): + npt.assert_almost_equal(self.voc.encode(["^", "C", "C", "$"]), np.array([1, 8, 8, 0])) + + def test_decode(self): + self.assertEqual(self.voc.decode(np.array([0, 8, 9, 8, 1])), ["$", "C", "F", "C", "^"]) diff --git a/tests/models/unit_tests/libinvent/transformer/__init__.py b/tests/models/unit_tests/libinvent/transformer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/unit_tests/libinvent/transformer/dataset_tests/__init__.py b/tests/models/unit_tests/libinvent/transformer/dataset_tests/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/models/unit_tests/libinvent/transformer/dataset_tests/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/models/unit_tests/libinvent/transformer/dataset_tests/test_paired_dataset.py b/tests/models/unit_tests/libinvent/transformer/dataset_tests/test_paired_dataset.py new file mode 100644 index 0000000..d639233 --- /dev/null +++ b/tests/models/unit_tests/libinvent/transformer/dataset_tests/test_paired_dataset.py @@ -0,0 +1,152 @@ +import unittest + +import torch +import torch.utils.data as tud + +from reinvent.models.transformer.core.dataset.paired_dataset import PairedDataset +from reinvent.models.transformer.core.vocabulary import SMILESTokenizer +from tests.models.unit_tests.libinvent.transformer.fixtures import mocked_vocabulary +from tests.test_data import SCAFFOLD_SINGLE_POINT, SCAFFOLD_DOUBLE_POINT, SCAFFOLD_TRIPLE_POINT, \ + SCAFFOLD_QUADRUPLE_POINT, DECORATION_NO_SUZUKI, TWO_DECORATIONS_ONE_SUZUKI, THREE_DECORATIONS, FOUR_DECORATIONS + + +class TestPairedDataset(unittest.TestCase): + def setUp(self): + self.smiles_input = [SCAFFOLD_SINGLE_POINT, SCAFFOLD_DOUBLE_POINT, SCAFFOLD_TRIPLE_POINT, + SCAFFOLD_QUADRUPLE_POINT] + self.smiles_output = [DECORATION_NO_SUZUKI, TWO_DECORATIONS_ONE_SUZUKI, THREE_DECORATIONS, FOUR_DECORATIONS] + self.vocabulary = mocked_vocabulary() + self.data_loader = self.initialize_dataloader(self.smiles_input, self.smiles_output) + + def initialize_dataloader(self, smiles_input, smiles_output): + dataset = PairedDataset( + smiles_input, + smiles_output, + vocabulary=self.vocabulary, + tokenizer=SMILESTokenizer(), + ) + dataloader = tud.DataLoader( + dataset, len(dataset), shuffle=False, collate_fn=PairedDataset.collate_fn + ) + return dataloader + + def _get_src(self): + for batch in self.data_loader: + return batch.input + + def _get_src_mask(self): + for batch in self.data_loader: + return batch.input_mask + + def _get_trg(self): + for batch in self.data_loader: + return batch.output + + def _get_trg_mask(self): + for batch in self.data_loader: + return batch.output_mask + + def _get_src_shape(self): + for batch in self.data_loader: + return batch.input.shape + + def _get_src_mask_shape(self): + for batch in self.data_loader: + return batch.input_mask.shape + + def _get_trg_shape(self): + for batch in self.data_loader: + return batch.output.shape + + def _get_trg_mask_shape(self): + for batch in self.data_loader: + return batch.output_mask.shape + + def test_src_shape(self): + result = self._get_src_shape() + self.assertEqual(list(result), [4, 60]) + + def test_src_mask_shape(self): + result = self._get_src_mask_shape() + self.assertEqual(list(result), [4, 1, 60]) + + def test_trg_shape(self): + result = self._get_trg_shape() + self.assertEqual(list(result), [4, 27]) + + def test_trg_mask_shape(self): + result = self._get_trg_mask_shape() + self.assertEqual(list(result), [4, 26, 26]) + + def test_src_content(self): + result = self._get_src() + comparison = torch.equal( + result, + torch.tensor([[ + 1, 15, 10, 4, 9, 13, 5, 12, 16, 4, 10, 5, 10, 4, 9, 13, 5, 10, + 4, 10, 5, 10, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0], + [1, 15, 10, 4, 10, 10, 19, 7, 19, 19, 18, 8, 18, 18, 18, 18, 18, 8, + 18, 7, 9, 13, 5, 10, 4, 9, 13, 5, 15, 2, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0], + [1, 15, 18, 7, 19, 18, 4, 15, 5, 19, 18, 4, 12, 5, 18, 7, 10, 4, + 9, 13, 5, 10, 16, 4, 10, 12, 14, 4, 12, 5, 4, 9, 13, 5, 9, 13, + 5, 10, 4, 9, 13, 5, 15, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0], + [1, 15, 10, 4, 9, 13, 5, 10, 7, 10, 10, 4, 13, 5, 10, 12, 7, 10, + 4, 13, 5, 10, 4, 17, 7, 18, 18, 4, 15, 5, 19, 4, 6, 18, 8, 18, + 18, 18, 4, 15, 5, 18, 4, 11, 5, 18, 8, 5, 19, 7, 5, 10, 4, 10, + 5, 4, 10, 5, 15, 2]]), + ) + self.assertTrue(comparison) + + def test_src_mask_content(self): + result = self._get_src_mask() + comparison = torch.equal( + result, + torch.tensor([[[True, True, True, True, True, True, True, True, True, True, + True, True, True, True, True, True, True, True, True, True, + True, True, True, False, False, False, False, False, False, False, + False, False, False, False, False, False, False, False, False, False, + False, False, False, False, False, False, False, False, False, False, + False, False, False, False, False, False, False, False, False, False]], + + [[True, True, True, True, True, True, True, True, True, True, + True, True, True, True, True, True, True, True, True, True, + True, True, True, True, True, True, True, True, True, True, + False, False, False, False, False, False, False, False, False, False, + False, False, False, False, False, False, False, False, False, False, + False, False, False, False, False, False, False, False, False, False]], + + [[True, True, True, True, True, True, True, True, True, True, + True, True, True, True, True, True, True, True, True, True, + True, True, True, True, True, True, True, True, True, True, + True, True, True, True, True, True, True, True, True, True, + True, True, True, True, False, False, False, False, False, False, + False, False, False, False, False, False, False, False, False, False]], + + [[True, True, True, True, True, True, True, True, True, True, + True, True, True, True, True, True, True, True, True, True, + True, True, True, True, True, True, True, True, True, True, + True, True, True, True, True, True, True, True, True, True, + True, True, True, True, True, True, True, True, True, True, + True, True, True, True, True, True, True, True, True, True]]]), + ) + self.assertTrue(comparison) + + def test_trg_content(self): + result = self._get_trg() + comparison = torch.equal( + result, + torch.tensor([[1, 15, 10, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 15, 18, 7, 19, 18, 19, 18, 18, 7, 20, 15, 10, 2, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 15, 18, 7, 19, 18, 19, 18, 18, 7, 20, 15, 18, 7, 19, 18, 19, 18, + 18, 7, 20, 15, 10, 2, 0, 0, 0], + [1, 15, 18, 7, 19, 18, 19, 18, 18, 7, 20, 15, 18, 7, 19, 18, 19, 18, + 18, 7, 20, 15, 10, 20, 15, 10, 2]]), + ) + self.assertTrue(comparison) diff --git a/tests/models/unit_tests/libinvent/transformer/fixtures.py b/tests/models/unit_tests/libinvent/transformer/fixtures.py new file mode 100644 index 0000000..9a4db7a --- /dev/null +++ b/tests/models/unit_tests/libinvent/transformer/fixtures.py @@ -0,0 +1,46 @@ +import pytest +from torch import nn + +from reinvent.models.transformer.core.network.encode_decode.model import EncoderDecoder +from reinvent.models.transformer.core.vocabulary import build_vocabulary +from reinvent.models.transformer.libinvent.libinvent import LibinventModel +from reinvent.models import meta_data +from tests.conftest import device +from tests.test_data import SCAFFOLD_SINGLE_POINT, SCAFFOLD_DOUBLE_POINT, SCAFFOLD_TRIPLE_POINT, \ + SCAFFOLD_QUADRUPLE_POINT, DECORATION_NO_SUZUKI, TWO_DECORATIONS_ONE_SUZUKI, THREE_DECORATIONS, FOUR_DECORATIONS + + +def _init_params(parameters): + """ + Fixed weights + """ + for p in parameters: + if p.dim() > 1: + nn.init.constant_(p, 0.5) + + +def mocked_libinvent_model(): + vocabulary = mocked_vocabulary() + encoder_decoder = EncoderDecoder(len(vocabulary)) + + metadata = meta_data.ModelMetaData( + hash_id=None, + hash_id_format="", + model_id="", + origina_data_source="", + creation_date=0, + ) + + model = LibinventModel(vocabulary, encoder_decoder, metadata) + _init_params(model.network.parameters()) + return model + + +def mocked_vocabulary(): + smiles_list = [ + SCAFFOLD_SINGLE_POINT, SCAFFOLD_DOUBLE_POINT, SCAFFOLD_TRIPLE_POINT, SCAFFOLD_QUADRUPLE_POINT, + DECORATION_NO_SUZUKI, TWO_DECORATIONS_ONE_SUZUKI, THREE_DECORATIONS, FOUR_DECORATIONS + ] + vocabulary = build_vocabulary(smiles_list) + + return vocabulary diff --git a/tests/models/unit_tests/libinvent/transformer/model_tests/__init__.py b/tests/models/unit_tests/libinvent/transformer/model_tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/unit_tests/libinvent/transformer/model_tests/test_libinvent_model.py b/tests/models/unit_tests/libinvent/transformer/model_tests/test_libinvent_model.py new file mode 100644 index 0000000..69a0a24 --- /dev/null +++ b/tests/models/unit_tests/libinvent/transformer/model_tests/test_libinvent_model.py @@ -0,0 +1,77 @@ +import pytest +import unittest + +import torch +import torch.utils.data as tud + +from reinvent.models.transformer.core.dataset.dataset import Dataset +from reinvent.models.transformer.core.enums.sampling_mode_enum import SamplingModesEnum +from reinvent.models.transformer.core.vocabulary import SMILESTokenizer +from reinvent.runmodes.utils import set_torch_device +from tests.models.unit_tests.libinvent.transformer.fixtures import mocked_libinvent_model +from tests.test_data import SCAFFOLD_SINGLE_POINT, SCAFFOLD_DOUBLE_POINT, SCAFFOLD_TRIPLE_POINT, SCAFFOLD_QUADRUPLE_POINT + + +@pytest.mark.usefixtures("device") +class TestLibInventModel(unittest.TestCase): + def setUp(self): + + device = torch.device(self.device) + self._model = mocked_libinvent_model() + self._model.network.to(device) + self._model.device = device + self._sample_mode_enum = SamplingModesEnum() + + set_torch_device(device) + + smiles_list = [SCAFFOLD_SINGLE_POINT] + self.data_loader_1 = self.initialize_dataloader(smiles_list) + + smiles_list = [SCAFFOLD_SINGLE_POINT, SCAFFOLD_DOUBLE_POINT] + self.data_loader_2 = self.initialize_dataloader(smiles_list) + + smiles_list = [SCAFFOLD_SINGLE_POINT, SCAFFOLD_DOUBLE_POINT, SCAFFOLD_TRIPLE_POINT] + self.data_loader_3 = self.initialize_dataloader(smiles_list) + + smiles_list = [SCAFFOLD_SINGLE_POINT, SCAFFOLD_DOUBLE_POINT, SCAFFOLD_TRIPLE_POINT, SCAFFOLD_QUADRUPLE_POINT] + self.data_loader_4 = self.initialize_dataloader(smiles_list) + + def initialize_dataloader(self, data): + dataset = Dataset(data, vocabulary=self._model.vocabulary, tokenizer=SMILESTokenizer()) + dataloader = tud.DataLoader( + dataset, len(dataset), shuffle=False, collate_fn=Dataset.collate_fn + ) + + return dataloader + + def _sample_decorations(self, data_loader): + for batch in data_loader: + return self._model.sample(*batch, decode_type=self._sample_mode_enum.MULTINOMIAL) + + def test_single_attachment_input(self): + results = self._sample_decorations(self.data_loader_1) + + self.assertEqual(1, len(results[0])) + self.assertEqual(1, len(results[1])) + self.assertEqual(1, len(results[2])) + + def test_double_attachment_input(self): + results = self._sample_decorations(self.data_loader_2) + + self.assertEqual(2, len(results[0])) + self.assertEqual(2, len(results[1])) + self.assertEqual(2, len(results[2])) + + def test_triple_attachment_input(self): + results = self._sample_decorations(self.data_loader_3) + + self.assertEqual(3, len(results[0])) + self.assertEqual(3, len(results[1])) + self.assertEqual(3, len(results[2])) + + def test_quadruple_attachment_input(self): + results = self._sample_decorations(self.data_loader_4) + + self.assertEqual(4, len(results[0])) + self.assertEqual(4, len(results[1])) + self.assertEqual(4, len(results[2])) diff --git a/tests/models/unit_tests/libinvent/transformer/model_tests/test_likelihood.py b/tests/models/unit_tests/libinvent/transformer/model_tests/test_likelihood.py new file mode 100644 index 0000000..3628afd --- /dev/null +++ b/tests/models/unit_tests/libinvent/transformer/model_tests/test_likelihood.py @@ -0,0 +1,24 @@ +import pytest +import unittest + +from reinvent.models import LibinventTransformerAdapter, SampledSequencesDTO +from tests.models.unit_tests.libinvent.transformer.fixtures import mocked_libinvent_model +from tests.test_data import SCAFFOLD_SINGLE_POINT, SCAFFOLD_DOUBLE_POINT, SCAFFOLD_TRIPLE_POINT, \ + SCAFFOLD_QUADRUPLE_POINT, DECORATION_NO_SUZUKI, TWO_DECORATIONS_ONE_SUZUKI, THREE_DECORATIONS, FOUR_DECORATIONS + + +@pytest.mark.usefixtures("device") +class TestLibInventLikelihoodSMILES(unittest.TestCase): + def setUp(self): + dto1 = SampledSequencesDTO(SCAFFOLD_SINGLE_POINT, DECORATION_NO_SUZUKI, 0.4) + dto2 = SampledSequencesDTO(SCAFFOLD_DOUBLE_POINT, TWO_DECORATIONS_ONE_SUZUKI, 0.6) + dto3 = SampledSequencesDTO(SCAFFOLD_TRIPLE_POINT, THREE_DECORATIONS, 0.3) + dto4 = SampledSequencesDTO(SCAFFOLD_QUADRUPLE_POINT, FOUR_DECORATIONS, 0.5) + self.sampled_sequence_list = [dto1, dto2, dto3, dto4] + + libinvent_model = mocked_libinvent_model() + self._model = LibinventTransformerAdapter(libinvent_model) + + def test_len_likelihood_smiles(self): + results = self._model.likelihood_smiles(self.sampled_sequence_list) + self.assertEqual([4], list(results.likelihood.shape)) diff --git a/tests/models/unit_tests/pepinvent/__init__.py b/tests/models/unit_tests/pepinvent/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/unit_tests/pepinvent/fixtures.py b/tests/models/unit_tests/pepinvent/fixtures.py new file mode 100644 index 0000000..4e7ebec --- /dev/null +++ b/tests/models/unit_tests/pepinvent/fixtures.py @@ -0,0 +1,43 @@ +import pytest +from torch import nn + +from reinvent.models.transformer.core.network.encode_decode.model import EncoderDecoder +from reinvent.models.transformer.core.vocabulary import build_vocabulary +from reinvent.models.transformer.pepinvent.pepinvent import PepinventModel +from reinvent.models import meta_data +from tests.test_data import PEPINVENT_INPUT1, PEPINVENT_INPUT2, PEPINVENT_INPUT3 + + +def _init_params(parameters): + """ + Fixed weights + """ + for p in parameters: + if p.dim() > 1: + nn.init.constant_(p, 0.5) + + +def mocked_pepinvent_model(): + vocabulary = mocked_vocabulary() + encoder_decoder = EncoderDecoder(len(vocabulary)) + + metadata = meta_data.ModelMetaData( + hash_id=None, + hash_id_format="", + model_id="", + origina_data_source="", + creation_date=0, + ) + + model = PepinventModel(vocabulary, encoder_decoder, metadata) + _init_params(model.network.parameters()) + return model + + +def mocked_vocabulary(): + smiles_list = [ + PEPINVENT_INPUT1, PEPINVENT_INPUT2, PEPINVENT_INPUT3 + ] + vocabulary = build_vocabulary(smiles_list) + + return vocabulary diff --git a/tests/models/unit_tests/pepinvent/test_likelihood.py b/tests/models/unit_tests/pepinvent/test_likelihood.py new file mode 100644 index 0000000..7ee3b12 --- /dev/null +++ b/tests/models/unit_tests/pepinvent/test_likelihood.py @@ -0,0 +1,21 @@ +import pytest +import unittest + +from reinvent.models import PepinventAdapter, SampledSequencesDTO +from tests.models.unit_tests.pepinvent.fixtures import mocked_pepinvent_model +from tests.test_data import PEPINVENT_INPUT1, PEPINVENT_INPUT2, PEPINVENT_OUTPUT1, PEPINVENT_OUTPUT2 + + +@pytest.mark.usefixtures("device") +class TestPepInventLikelihoodSMILES(unittest.TestCase): + def setUp(self): + dto1 = SampledSequencesDTO(PEPINVENT_INPUT1, PEPINVENT_OUTPUT1, 0.9) + dto2 = SampledSequencesDTO(PEPINVENT_INPUT2, PEPINVENT_OUTPUT2, 0.1) + self.sampled_sequence_list = [dto1, dto2] + + pepinvent_model = mocked_pepinvent_model() + self._model = PepinventAdapter(pepinvent_model) + + def test_len_likelihood_smiles(self): + results = self._model.likelihood_smiles(self.sampled_sequence_list) + self.assertEqual([2], list(results.likelihood.shape)) diff --git a/tests/models/unit_tests/pepinvent/test_pepinvent_model.py b/tests/models/unit_tests/pepinvent/test_pepinvent_model.py new file mode 100644 index 0000000..0683590 --- /dev/null +++ b/tests/models/unit_tests/pepinvent/test_pepinvent_model.py @@ -0,0 +1,67 @@ +import pytest +import unittest + +import torch +import torch.utils.data as tud + +from reinvent.models.transformer.core.dataset.dataset import Dataset +from reinvent.models.transformer.core.enums.sampling_mode_enum import SamplingModesEnum +from reinvent.models.transformer.core.vocabulary import SMILESTokenizer +from reinvent.runmodes.utils import set_torch_device +from tests.models.unit_tests.pepinvent.fixtures import mocked_pepinvent_model +from tests.test_data import PEPINVENT_INPUT1, PEPINVENT_INPUT2, PEPINVENT_INPUT3 + + +@pytest.mark.usefixtures("device") +class TestPepInventModel(unittest.TestCase): + def setUp(self): + + device = torch.device(self.device) + self._model = mocked_pepinvent_model() + self._model.network.to(device) + self._model.device = device + self._sample_mode_enum = SamplingModesEnum() + + set_torch_device(device) + + smiles_list = [PEPINVENT_INPUT1] + self.data_loader_1 = self.initialize_dataloader(smiles_list) + + smiles_list = [PEPINVENT_INPUT1, PEPINVENT_INPUT2] + self.data_loader_2 = self.initialize_dataloader(smiles_list) + + smiles_list = [PEPINVENT_INPUT1, PEPINVENT_INPUT2, PEPINVENT_INPUT3] + self.data_loader_3 = self.initialize_dataloader(smiles_list) + + def initialize_dataloader(self, data): + dataset = Dataset(data, vocabulary=self._model.vocabulary, tokenizer=SMILESTokenizer()) + dataloader = tud.DataLoader( + dataset, len(dataset), shuffle=False, collate_fn=Dataset.collate_fn + ) + + return dataloader + + def _sample(self, data_loader): + for batch in data_loader: + return self._model.sample(*batch, decode_type=self._sample_mode_enum.MULTINOMIAL) + + def test_single_input(self): + results = self._sample(self.data_loader_1) + + self.assertEqual(1, len(results[0])) + self.assertEqual(1, len(results[1])) + self.assertEqual(1, len(results[2])) + + def test_double_input(self): + results = self._sample(self.data_loader_2) + + self.assertEqual(2, len(results[0])) + self.assertEqual(2, len(results[1])) + self.assertEqual(2, len(results[2])) + + def test_triple_input(self): + results = self._sample(self.data_loader_3) + + self.assertEqual(3, len(results[0])) + self.assertEqual(3, len(results[1])) + self.assertEqual(3, len(results[2])) diff --git a/tests/reinvent_plugins/unit_tests/components/RDKit/test_comp_mol_volume.py b/tests/reinvent_plugins/unit_tests/components/RDKit/test_comp_mol_volume.py index 81a6462..80218b8 100644 --- a/tests/reinvent_plugins/unit_tests/components/RDKit/test_comp_mol_volume.py +++ b/tests/reinvent_plugins/unit_tests/components/RDKit/test_comp_mol_volume.py @@ -11,5 +11,6 @@ def test_comp_mol_volume(): results = mol_volume(smiles) expected_results = [np.array([95.144, 123.544])] + expected_results_new = [np.array([95.144, 123.04])] # RDKit's new conformer generator - assert np.allclose(np.concatenate(results.scores), expected_results) + assert np.allclose(np.concatenate(results.scores), expected_results) or np.allclose(np.concatenate(results.scores), expected_results_new) diff --git a/tests/runmodes/integration_tests/sampling_tests/test_sampling.py b/tests/runmodes/integration_tests/sampling_tests/test_sampling.py index 97b4205..543de0d 100644 --- a/tests/runmodes/integration_tests/sampling_tests/test_sampling.py +++ b/tests/runmodes/integration_tests/sampling_tests/test_sampling.py @@ -59,6 +59,7 @@ def param(request, json_config): "smiles_file": json_config["MOLFORMER_SMILES_SET_PATH"], "sample_strategy": "multinomial", "num_cols": 4, + "temperature": 1, }, "mol2mol-beam": { "model_file": ".m2m_high", diff --git a/tests/runmodes/unit_tests/test_remote_reporter.py b/tests/runmodes/unit_tests/test_remote_reporter.py index b596bc7..281bbaa 100644 --- a/tests/runmodes/unit_tests/test_remote_reporter.py +++ b/tests/runmodes/unit_tests/test_remote_reporter.py @@ -1,10 +1,6 @@ import pytest -from reinvent.runmodes.reporter.remote import ( - setup_reporter, - get_reporter, - RemoteJSONReporter, -) +from reinvent.utils.logmon import RemoteJSONReporter, get_reporter, setup_reporter def test_noop_reporter_without_setup(): diff --git a/tests/scoring/unit_tests/test_parsing.py b/tests/scoring/unit_tests/test_parsing.py index 55f1cb3..ef3d255 100644 --- a/tests/scoring/unit_tests/test_parsing.py +++ b/tests/scoring/unit_tests/test_parsing.py @@ -1,3 +1,4 @@ +from reinvent.scoring.config import collect_params from reinvent.scoring.scorer import get_components from dataclasses import fields @@ -34,3 +35,83 @@ def test_get_components(): assert not components_dict.filters assert not components_dict.penalties + + +def test_complevel_params(): + components = [ + {"QED": {"endpoint": [{"name": "QED drug-like score", "weight": 0.79}]}}, + { + "external_process": { + "params": { # Component-level params. + "args": "--loglevel DEBUG", + }, + "endpoint": [ + { + "name": "Endpoint1", + "weight": 0.5, + "params": { # Endpoint-level params. + "executable": "path/to/executable1", + }, + }, + { + "name": "Endpoint2", + "weight": 0.7, + "params": { + "executable": "path/to/executable2", + "args": "--loglevel INFO", + }, + }, + ], + } + }, + ] + + components_dict = get_components(components) + + assert len(components_dict.scorers) == 2 + assert "qed" == components_dict.scorers[0].component_type + assert "externalprocess" == components_dict.scorers[1].component_type + name, comp, transform, weights = components_dict.scorers[1].params + assert comp.executables == ["path/to/executable1", "path/to/executable2"] + assert comp.args == ["--loglevel DEBUG", "--loglevel INFO"] + + assert not components_dict.filters + assert not components_dict.penalties + + +def test_collect_params(): + # Test case 1: multiple dictionaries + params = [ + {"x": 1, "y": 1}, + {"y": 2, "z": 2}, + {"z": 3, "x": 3}, + ] + expected_output = { + "x": [1, None, 3], + "y": [1, 2, None], + "z": [None, 2, 3], + } + assert collect_params(params) == expected_output + + # Test case 2: Empty list + params = [] + expected_output = {} + assert collect_params(params) == expected_output + + # Test case 3: List with one dictionary + params = [{"key1": "value1"}] + expected_output = {"key1": ["value1"]} + assert collect_params(params) == expected_output + + # Test case 4: List with dictionaries having different keys + params = [ + {"key1": "value1"}, + {"key2": "value2"}, + {"key3": "value3"} + ] + expected_output = { + "key1": ["value1", None, None], + "key2": [None, "value2", None], + "key3": [None, None, "value3"] + } + assert collect_params(params) == expected_output diff --git a/tests/scoring/unit_tests/test_transforms.py b/tests/scoring/unit_tests/test_transforms.py index e95b85d..c3a9b18 100644 --- a/tests/scoring/unit_tests/test_transforms.py +++ b/tests/scoring/unit_tests/test_transforms.py @@ -9,6 +9,7 @@ from reinvent.scoring.transforms import RightStep from reinvent.scoring.transforms import Step from reinvent.scoring.transforms import ValueMapping +from reinvent.scoring.transforms import ExponentialDecay @pytest.mark.parametrize( @@ -121,3 +122,21 @@ def test_step(low, high): transform = Step(params) results = transform(data) assert np.all(results == ((data <= high) & (data >= low))) + + +@pytest.mark.parametrize("k", [1.0, 10]) +def test_exponential_decay(k): + from reinvent.scoring.transforms.exponential_decay import Parameters + + data = np.linspace(-5, 5, 11, dtype=np.float32) + params = Parameters( + type="", + k=k, + ) + transform = ExponentialDecay(params) + results = transform(data) + assert np.all(np.diff(results) <= 0) # Check that the function is monotonically decreasing. + assert np.all(results <= 1.0) + assert np.all(results >= 0.0) + assert results[0] == 1.0 + assert results[-1] <= 0.01 diff --git a/tests/test_data.py b/tests/test_data.py index 224ab1d..10b7c5b 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -53,6 +53,12 @@ THREE_DECORATIONS = "[*]c1ncncc1|[*]c1ncncc1|[*]C" FOUR_DECORATIONS = "[*]c1ncncc1|[*]c1ncncc1|[*]C|[*]C" +PEPINVENT_INPUT1 = "?|N[C@@H](CO)C(=O)|?|N[C@@H](Cc1ccc(O)cc1)C(=O)|N(C)[C@@H]([C@@H](C)O)C(=O)|N[C@H](Cc1c[nH]cn1)C(=O)|N[C@@H](CC(=O)N)C2(=O)" +PEPINVENT_INPUT2 = "?|?|?|N[C@@H](CCC(=O)O)C(=O)|N[C@@H]([C@@H](C)O)C(=O)|NCC(=O)|N[C@@H](CCC(=O)O)C(=O)|N[C@@H](Cc1ccccc1)C(=O)|N[C@@H](CC(C)C)C(=O)O" +PEPINVENT_INPUT3 = "?|NCC(=O)|N[C@@H](CC(=O)O)C(=O)|N[C@@H]([C@H]1C[C@H](OC(C)(C)O1)CO)C(=O)|N(C)[C@@H](CCC(=O)O)C(=O)|?|N(C)[C@@H](Cc1c[nH]c2ccccc12)C(=O)|?|N[C@@H](CCSC)C(=O)|?|N[C@@H](Cc1c[nH]cn1)C(=O)|N[C@@H](c1sc(S3)nc1c1ccc(F)cc1)C(=O)O" +PEPINVENT_OUTPUT1 = "N2[C@@H](CC(=O)N)C(=O)|N[C@@H](CNC(=O)N1CCC[C@@H]1[C@H](O)C(F)(F)F)C(=O)" +PEPINVENT_OUTPUT2 = "N[C@@H](Cc1ccccc1)C(=O)|N[C@@H]([C@@H](C)O)C(=O)|NCC(=O)|N(C)[C@@H](CC(C)C)C(=O)|N1[C@@H](CCC1)C(=O)" + IBUPROFEN_TOKENIZED = [ "^", "C",