From 915ec4cfa87e32ec5c30094435cfdac9c644d928 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=B6ffler=2C=20Hannes?= Date: Fri, 19 Apr 2024 13:12:15 +0200 Subject: [PATCH] sync with AZ internal version --- .gitattributes | 3 +- CHANGELOG.md | 112 +++++++- NEWS.md | 17 ++ README.md | 27 +- configs/toml/PARAMS.md | 108 ++++---- configs/toml/sampling.toml | 2 +- configs/toml/staged_learning.toml | 2 +- configs/toml/transfer_learning.toml | 2 +- pyproject.toml | 29 +- reinvent/Reinvent.py | 14 +- reinvent/chemistry/conversions.py | 4 +- reinvent/models/__init__.py | 1 + .../model_factory/linkinvent_adapter.py | 7 +- reinvent/models/reinvent/models/dataset.py | 23 +- reinvent/models/reinvent/models/vocabulary.py | 4 +- .../models/transformer/core/vocabulary.py | 22 +- reinvent/runmodes/RL/learning.py | 29 +- reinvent/runmodes/RL/linkinvent.py | 10 +- reinvent/runmodes/RL/mol2mol.py | 19 +- reinvent/runmodes/RL/reports/csv_summmary.py | 4 +- reinvent/runmodes/RL/reports/tensorboard.py | 4 +- reinvent/runmodes/RL/run_staged_learning.py | 2 +- reinvent/runmodes/TL/configurations.py | 2 + reinvent/runmodes/TL/learning.py | 59 +++- reinvent/runmodes/TL/reinvent.py | 2 + reinvent/runmodes/TL/reports/tensorboard.py | 53 ++-- reinvent/runmodes/TL/run_transfer_learning.py | 12 +- reinvent/runmodes/create_adapter.py | 113 +++++--- reinvent/runmodes/samplers/libinvent.py | 2 +- reinvent/runmodes/samplers/linkinvent.py | 51 +++- reinvent/runmodes/samplers/mol2mol.py | 52 ++-- reinvent/runmodes/samplers/reinvent.py | 7 +- .../runmodes/samplers/reports/tensorboard.py | 2 +- reinvent/runmodes/samplers/run_sampling.py | 5 +- reinvent/runmodes/samplers/sampler.py | 31 +-- reinvent/runmodes/setup_sampler.py | 26 +- reinvent/runmodes/utils/__init__.py | 1 + reinvent/runmodes/utils/helpers.py | 29 +- reinvent/version.py | 2 +- .../components/OpenEye/comp_rocs.py | 11 +- .../components/RDKit/comp_group_count.py | 2 +- .../RDKit/comp_matching_substructure.py | 4 +- .../components/RDKit/comp_mol_volume.py | 7 +- reinvent_plugins/components/RDKit/comp_pmi.py | 2 +- .../RDKit/comp_rdkit_descriptors.py | 2 +- .../components/RDKit/comp_similarity.py | 2 +- .../components/RDKit/comp_tpsa.py | 7 +- reinvent_plugins/components/comp_chemprop.py | 5 +- .../components/comp_custom_alerts.py | 2 +- .../components/comp_dockstream.py | 2 +- .../components/comp_external_process.py | 2 +- .../components/comp_generic_rest.py | 2 +- reinvent_plugins/components/comp_icolos.py | 2 +- reinvent_plugins/components/comp_maize.py | 13 +- reinvent_plugins/components/comp_mmp.py | 10 +- reinvent_plugins/components/comp_qptuna.py | 2 +- requirements-linux-64.lock | 259 ++++++++++-------- 57 files changed, 789 insertions(+), 441 deletions(-) diff --git a/.gitattributes b/.gitattributes index 5d94f3b..9709b8d 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1 +1,2 @@ -*.prior binary +*.prior filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text diff --git a/CHANGELOG.md b/CHANGELOG.md index d34b8ee..46b250d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,10 +6,120 @@ This follows the guideline on [keep a changelog](https://keepachangelog.com/) ## [Unreleased] ### Changed -- Fragment generators using transformers - CAZP scoring component +## [4.3.5] 2024-04-18 + +### Changed + +- Code clean-up in create\_adapter() + +### Fixed + +- Import mol2mol vocabulary rather than copying the file + + +## [4.3.4] 2024-04-16 + +### Added + +- Write invalid SMILES unchanged to RL CSV + + +## [4.3.3] 2024-04-16 + +### Added + +- Notebook: demo on how to analyse RL CSV + + +## [4.3.2] 2024-04-15 + +### Added + +- Dataclass validation for scoring component parameters + +### Fixed + +- Datatype in MatchingSubstructure's Parameters: only single SMARTS is allowed + + +## [4.3.1] 2024-04-15 + +### Added + +- Notebook to demo simple RL run, TensorBoard visualisation and TensorBoard data extraction. + + +## [4.3.0] 2024-04-15 + +### Added + +- Linkinvent based on unified Transformer model supported by RL and sampling. Both beam search and multinomial sampling are implemented. + + +## [4.2.13] 2024-04-12 + +### Fixed + +- downgraded Chemprop to 1.5.2 and sklearn to 1.2.2 to retain backward compatibility + + +## [4.2.12] 2024-04-10 + +### Changed + +- New default torch device setup from PyTorch 2.x + +### Added + +- Config parameter "device" to explicitly set torch device e.g. "cuda:0" + + +## [4.2.11] 2024-04-08 + +### Fixed + +- Fixed unknown token handling for Mol2mol TL + + +## [4.2.10] 2024-04-05 + +### Fixed + +- Fixed dataloader for TL to use incomplete batch + + +## [4.2.9] 2024-04-04 + +### Fixed + +- Skip hash check in metadata if no metadata in model file + + +## [4.2.8] 2024-04-03 + +### Added + +- Mol2Mol supports unknown tokens for all the priors + + + +## [4.2.7] 2024-03-27 + +### Added + +- Optional randomization in all TL epochs for Reinvent + + +## [4.2.6] 2024-03-21 + +### Fixed + +- Return from make\_grid\_image() + + ## [4.2.5] 2024-03-20 ### Added diff --git a/NEWS.md b/NEWS.md index bbd0325..5775e73 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,20 @@ +New in REINVENT 4.3 +=================== + +For details see CHANGELOG.md. + +* Upgrade to **PyTorch 2.2**: rerun `pip install -r requirements-linux-64.lock` +* 2 new **notebooks** demoing Reinvent with reinforcement learning and also transfer learning, includes TensorBoard visualisation and basic analysis +* New Linkinvent model code based on unified transformer +* New PubChem Mol2Mol prior +* Unknown token support for PubChem based transformer models +* New "device" config parameter to allow for explicit device e.g. "cuda:0" +* Optional SMILES randomization in every TL epoch for Reinvent +* Dataclass parameter validation for most scoring components +* Invalid SMILES are now written to the reinforcement learning CSV +* Code improvements and fixes + + New in REINVENT 4.2 =================== diff --git a/README.md b/README.md index d300ca2..e4dd197 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ Installation ``` 4. Optional: if you want to use **AMD GPUs** on Linux you would need to install the [ROCm PyTorch version](https://pytorch.org/get-started/locally/) manually _after_ installation of the dependencies in point 3, e.g. ```shell - pip install torch==1.13.1+rocm5.2 torchvision==0.14.1+rocm5.2 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/rocm5.2 + pip install torch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 --index-url https://download.pytorch.org/whl/rocm5.7 ``` 5. Install the tool. The dependencies were already installed in the previous step, so there is no need to install them again (flag `--no-deps). If you want to install in editable mode (changes to the code are automatically picked up) add -e before the dot. ```shell @@ -90,22 +90,23 @@ appropriate run mode depending on the research problem you are trying to address There is additional information in `config/toml` in several `*.md` files with instructions on how to configure the TOML file. - - Tutorials / `Jupyter` notebooks ------------------------------- -NOTE: these will be updated at a later time! +Basic instructions can be found in the comments in the config examples in `config/toml`. + +Notebooks will be provided in the `notebook/` directory. Please note that we provide the notebooks in jupytext "light script" format. To work with the light scripts you will need to install jupytext. A few other packages will come in handy too. + +```shell +pip install jupytext mols2grid seaborn +``` + +The Python files in `notebook/` can then be converted to a notebook e.g. - +```shell +jupytext --to ipynb -o Reinvent_demo.ipynb Reinvent_demo.py +``` Updating dependencies @@ -113,7 +114,7 @@ Updating dependencies Update the lock files with [pip-tools](https://pypi.org/project/pip-tools/) (please, do not edit the files manually): ```shell -pip-compile --extra-index-url=https://download.pytorch.org/whl/cu113 --extra-index-url=https://pypi.anaconda.org/OpenEye/simple --resolver=backtracking pyproject.toml +pip-compile --extra-index-url=https://download.pytorch.org/whl/cu121 --extra-index-url=https://pypi.anaconda.org/OpenEye/simple --resolver=backtracking pyproject.toml ``` To update a single package, use `pip-compile --upgrade-package somepackage` (see the documentation for pip-tools). diff --git a/configs/toml/PARAMS.md b/configs/toml/PARAMS.md index 8301817..ba8d20d 100644 --- a/configs/toml/PARAMS.md +++ b/configs/toml/PARAMS.md @@ -7,20 +7,21 @@ This is a summary of TOML parameters for each run mode. Sample a number of SMILES with associated NLLs. -| Parameter | Description | -|--------------------|--------------------------------------------------------------------------------------------------------| -| run\_type | set to "sampling" | -| use\_cuda | "true" to use GPU, "false" to use CPU | -| json\_out\_config | filename of the TOML file in JSON format | -| [parameters] | starts the parameter section | -| model\_file | filename to model file from which to sample | -| smiles\_file | filename for inpurt SMILES for Lib/LinkInvent and Mol2Mol | -| sample\_strategy | Mol2Mol only: "beamsearch" or "multinomial" | -| output\_file | filename for the CSV file with samples SMILES and NLLs | -| num\_smiles | number of SMILES to sample, note: this is multiplied by the number of input SMILES | -| unique\_molecules | if "true" only return unique canonicalized SMILES | -| randomize\_smiles | if "true" shuffle atoms in input SMILES randomly | -| tb\_logdir | if not empty string name of the TensorBoard logging directory | +| Parameter | Description | +|--------------------|------------------------------------------------------------------------------------------------------| +| run\_type | set to "sampling" | +| device | set the torch device e.g "cuda:0" or "cpu" | +| use\_cuda | (deprecated) "true" to use GPU, "false" to use CPU | +| json\_out\_config | filename of the TOML file in JSON format | +| [parameters] | starts the parameter section | +| model\_file | filename to model file from which to sample | +| smiles\_file | filename for inpurt SMILES for Lib/LinkInvent and Mol2Mol | +| sample\_strategy | Transformer models: "beamsearch" or "multinomial" | +| output\_file | filename for the CSV file with samples SMILES and NLLs | +| num\_smiles | number of SMILES to sample, note: this is multiplied by the number of input SMILES | +| unique\_molecules | if "true" only return unique canonicalized SMILES | +| randomize\_smiles | if "true" shuffle atoms in input SMILES randomly | +| tb\_logdir | if not empty string name of the TensorBoard logging directory | | temperature | Mol2Mol only: default 1.0 | | target\_smiles\_path | Mol2Mol only: if not empty, filename to provided SMILES, check NLL of generating the provided SMILES | @@ -32,7 +33,8 @@ Interface to the scoring component. Does not use any models. | Parameter | Description | |---------------------|---------------------------------------------------------------------------------------------------------| | run\_type | set to "scoring" | -| use\_cuda | "true" to use GPU, "false" to use CPU | +| device | set the torch device e.g "cuda:0" or "cpu" | +| use\_cuda | (deprecated) "true" to use GPU, "false" to use CPU | | json\_out\_config | filename of the TOML file in JSON format | | [parameters] | starts the parameter section | | smiles\_file | SMILES filename, SMILES are expected in the first column | @@ -51,10 +53,11 @@ Run transfer learning on a set of input SMILES. | Parameter | Description | |------------------------|---------------------------------------------------------------| | run\_type | set to "transfer\_learning" | -| use\_cuda | "true" to use GPU, "false" to use CPU | +| device | set the torch device e.g "cuda:0" or "cpu" | +| use\_cuda | (deprecated) "true" to use GPU, "false" to use CPU | | json\_out\_config | filename of the TOML file in JSON format | | tb\_logdir | if not empty string name of the TensorBoard logging directory | -| number_of_cpus | optional parameter to control number of cpus to generate pairs. if not provided the maximum cpus will be allocated. | +| number\_of\_cpus | optional parameter to control number of cpus for pair generation. If not provided, only one CPU will be used. | | [parameters] | starts the parameter section | | num\_epochs | number of epochs to run | | save\_every\_n\_epochs | save checkpoint file every N epochs | @@ -63,7 +66,7 @@ Run transfer learning on a set of input SMILES. | num\_refs | number of references for similarity if > 0, DO NOT use with large dataset (> 200 molecules) | | input\_model\_file | filename of input prior model | | smiles\_file | SMILES file for Lib/Linkinvent and Molformer | -| output_model\_file | filename of the final model | +| output\_model\_file | filename of the final model | | pairs.upper\_threshold | Molformer: upper similarity | | pairs.lower\_threshold | Molformer: lower similarity | | pairs.min\_cardinality | Molformer: | @@ -74,40 +77,41 @@ Run transfer learning on a set of input SMILES. Run reinforcement learning (RL) and/or curriculum learning (CL). CL is simply a multi-stage RL learning. -| Parameter | Description | -|----------------------|--------------------------------------------------------------------------------------------------------------------------------------------| -| run\_type | set to "transfer\_learning" | -| use\_cuda | "true" to use GPU, "false" to use CPU | -| json\_out\_config | filename of the TOML file in JSON format | -| tb\_logdir | if not empty string name of the TensorBoard logging directory | -| [parameters] | starts the parameter section | -| summary\_csv\_prefix | prefix for output CSV filename | -| use\_checkpoint | if "true" use diversity filter from agent\_file if present | -| purge\_memories | if "true" purge all diversity filter memories (scaffold, SMILES) after each stage | -| prior\_file | filename of the prior model file, serves as reference | -| agent\_file | filename of the agent model file, used for training, replace with checkpoint file from previous stage when needed | -| batch\_size | batch size, note: affects SGD | -| uniquify\_smiles | if "true" only return unique SMILES (sampling) | -| randomize\_smiles | if "true" shuffle atoms in input SMILES randomly (sampling) | -| [learning\_strategy] | start section for RL learning strategy | -| type | use "dap" | -| sigma | sigma in the reward function | -| rate | learning rate for the torch optimizer | -| [diversity\_filter] | starts the section for the diversity filter | +| Parameter | Description | +|----------------------|--------------------------------------------------------------------------------------------------------------------------------| +| run\_type | set to "transfer\_learning" | +| device | set the torch device e.g "cuda:0" or "cpu" | +| use\_cuda | (deprecated) "true" to use GPU, "false" to use CPU | +| json\_out\_config | filename of the TOML file in JSON format | +| tb\_logdir | if not empty string name of the TensorBoard logging directory | +| [parameters] | starts the parameter section | +| summary\_csv\_prefix | prefix for output CSV filename | +| use\_checkpoint | if "true" use diversity filter from agent\_file if present | +| purge\_memories | if "true" purge all diversity filter memories (scaffold, SMILES) after each stage | +| prior\_file | filename of the prior model file, serves as reference | +| agent\_file | filename of the agent model file, used for training, replace with checkpoint file from previous stage when needed | +| batch\_size | batch size, note: affects SGD | +| unique\_sequences | if "true" only return unique raw sequence (sampling) | +| randomize\_smiles | if "true" shuffle atoms in input SMILES randomly (sampling) | +| [learning\_strategy] | start section for RL learning strategy | +| type | use "dap" | +| sigma | sigma in the reward function | +| rate | learning rate for the torch optimizer | +| [diversity\_filter] | starts the section for the diversity filter | | type | name of the filter type: "IdenticalMurckoScaffold", "IdenticalTopologicalScaffold", "ScaffoldSimilarity", "PenalizeSameSmiles" | -| bucket\_size | number of scaffolds to store before molecule is scored zero | -| minscore | minimum score | -| minsimilarity | minimum similarity in "ScaffoldSimilarity" | -| penalty\_multiplier | penalty penalty for each molecule in "PenalizeSameSmiles" | -| [inception] | starts the inception section | -| smiles\_file | filename for the "good" SMILES | -| memory\_size | number of SMILES to hold in inception memory | -| sample\_size | number of SMILES randomly sampled from memory | -| [[stage]] | starts a stage, note the double brackets | -| chkpt\_file | filename of the checkpoint file, will be written on termination and Ctrl-C | -| termination | use "simple", termination criterion | -| max\_score | maximum score when to terminate | -| min\_steps | minimum number of RL steps to avoid early termination | -| max\_steps | maximum number of RL steps to run, if maximum is hit _all_ stages will be terminated | +| bucket\_size | number of scaffolds to store before molecule is scored zero | +| minscore | minimum score | +| minsimilarity | minimum similarity in "ScaffoldSimilarity" | +| penalty\_multiplier | penalty penalty for each molecule in "PenalizeSameSmiles" | +| [inception] | starts the inception section | +| smiles\_file | filename for the "good" SMILES | +| memory\_size | number of SMILES to hold in inception memory | +| sample\_size | number of SMILES randomly sampled from memory | +| [[stage]] | starts a stage, note the double brackets | +| chkpt\_file | filename of the checkpoint file, will be written on termination and Ctrl-C | +| termination | use "simple", termination criterion | +| max\_score | maximum score when to terminate | +| min\_steps | minimum number of RL steps to avoid early termination | +| max\_steps | maximum number of RL steps to run, if maximum is hit **all** stages will be terminated | The scoring functions are added as in scoring but prefixed with stage. diff --git a/configs/toml/sampling.toml b/configs/toml/sampling.toml index 6efc593..8348b86 100644 --- a/configs/toml/sampling.toml +++ b/configs/toml/sampling.toml @@ -3,7 +3,7 @@ run_type = "sampling" -use_cuda = true # run on the GPU if true, on the CPU if false +device = "cuda:0" # set torch device e.g. "cpu" json_out_config = "_sampling.json" # write this TOML to JSON diff --git a/configs/toml/staged_learning.toml b/configs/toml/staged_learning.toml index 933beca..80b76f1 100644 --- a/configs/toml/staged_learning.toml +++ b/configs/toml/staged_learning.toml @@ -9,7 +9,7 @@ run_type = "staged_learning" -use_cuda = true # run on the GPU if true, on the CPU if false +device = "cuda:0" # set torch device e.g. "cpu" tb_logdir = "tb_logs" # name of the TensorBoard logging directory json_out_config = "_staged_learning.json" # write this TOML to JSON diff --git a/configs/toml/transfer_learning.toml b/configs/toml/transfer_learning.toml index ce79f7c..879a10f 100644 --- a/configs/toml/transfer_learning.toml +++ b/configs/toml/transfer_learning.toml @@ -5,7 +5,7 @@ run_type = "transfer_learning" -use_cuda = true # run on the GPU if true, on the CPU if false +device = "cuda:0" # set torch device e.g. "cpu" tb_logdir = "tb_TL" # name of the TensorBoard logging directory json_out_config = "json_transfer_learning.json" # write this TOML to JSON diff --git a/pyproject.toml b/pyproject.toml index 7f6884b..8a00712 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ description = "Reinvent 4" authors = [{name = "AstraZeneca"}] maintainers = [{name = "Hannes Löffler", email = "hannes.loeffler@gmail.com"}] readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.10" license = { file = "LICENSE" } keywords = [ "reinvent", @@ -40,40 +40,37 @@ dynamic = ["version"] # We go quite defensive and add upper bounds, # in case major version updates break backward compatibility. dependencies = [ - "chemprop >=1.5.2,<2.0", - "funcy >=1.18,<2", + "chemprop >=1.5.2,<1.6", + "descriptastorus >=2.6.1,<3.0", # Hidden chemprop dependency + "funcy >=2,<3", "matplotlib >=3.7,<4", "mmpdb >=2.1,<3", "molvs >=0.1.1,<0.2", "numpy >=1.21,<2", "OpenEye-toolkits >=2022", # Requires --extra-index-url=https://pypi.anaconda.org/OpenEye/simple - "pandas >=1.5,<2", + "pandas >=2,<3", "pathos >=0.3.0,<2", - "Pillow >=9.4,<10.0", - "pydantic >=1.10,<2", - "pytest >=7.2,<8", + "Pillow >=10.0,<11.0", + "pydantic >=2,<3", + "pytest >=8,<9", "pytest-mock >=3.7,<4", "python-dotenv >=1.0,<2", "PyYAML >=6.0", "rdkit >=2021.0", "requests >=2.28,<3", "requests_mock >=1.10,<2", + "scikit-learn==1.2.2", "scipy >=1.10,<2", "tenacity >=8.2,<9", "tensorboard", "tomli >=2.0,<3", - # Stick to PyTorch 1.12 for now, to "standardise" pickled models, - # Stick to CUDA 11.3 for driver compatibility. - # Requires --extra-index-url https://download.pytorch.org/whl/cu113 - "torch==1.12.1+cu113", - "torchvision==0.13.1+cu113", # Needed to log molecular images to Tensorboard. -# "torchaudio==0.12.1+cu113", + "torch==2.2.1+cu121", # Requires --extra-index-url https://download.pytorch.org/whl/cu121 + "torchvision==0.17.1+cu121", # Needed to log molecular images to Tensorboard. "tqdm >=4.64,<5", "typing_extensions >=4.0,<5", "xxhash >=3,<4", ] - [project.scripts] reinvent = "reinvent.Reinvent:main" @@ -106,3 +103,7 @@ namespaces = true [tool.setuptools.package-data] "*" = ["*.pkl.gz"] + + +[tool.black] +line-length = 100 diff --git a/reinvent/Reinvent.py b/reinvent/Reinvent.py index 882a102..10b134d 100755 --- a/reinvent/Reinvent.py +++ b/reinvent/Reinvent.py @@ -250,8 +250,18 @@ def main(): logger.info(f"Number of PyTorch CUDA devices {torch.cuda.device_count()}") - use_cuda = input_config.get("use_cuda", True) - actual_device = set_torch_device(args.device, use_cuda) + if "use_cuda" in input_config: + logger.warning("'use_cuda' is deprecated, use 'device' instead") + + device = input_config.get("device", None) + + if not device: + use_cuda = input_config.get("use_cuda", True) + + if use_cuda: + device = "cuda:0" + + actual_device = set_torch_device(args.device, device) if actual_device.type == "cuda": current_device = torch.cuda.current_device() diff --git a/reinvent/chemistry/conversions.py b/reinvent/chemistry/conversions.py index 9245b57..3dea988 100644 --- a/reinvent/chemistry/conversions.py +++ b/reinvent/chemistry/conversions.py @@ -81,7 +81,7 @@ def mol_to_smiles(self, molecule: Mol, isomericSmiles=False, canonical=True) -> if molecule: return MolToSmiles(molecule, isomericSmiles=isomericSmiles, canonical=canonical) - def mol_to_random_smiles(self, molecule: Mol) -> str: + def mol_to_random_smiles(self, molecule: Mol, isomericSmiles=False) -> str: """ Converts a Mol object into a random SMILES string. :return: A SMILES string. @@ -90,7 +90,7 @@ def mol_to_random_smiles(self, molecule: Mol) -> str: new_atom_order = list(range(molecule.GetNumAtoms())) random.shuffle(new_atom_order) random_mol = RenumberAtoms(molecule, newOrder=new_atom_order) - return MolToSmiles(random_mol, canonical=False, isomericSmiles=False) + return MolToSmiles(random_mol, canonical=False, isomericSmiles=isomericSmiles) def convert_to_rdkit_smiles( self, smiles: str, allowTautomers=True, sanitize=False, isomericSmiles=False diff --git a/reinvent/models/__init__.py b/reinvent/models/__init__.py index 757599f..23de7ee 100644 --- a/reinvent/models/__init__.py +++ b/reinvent/models/__init__.py @@ -3,6 +3,7 @@ from .reinvent.models.model import Model as ReinventModel from .libinvent.models.model import DecoratorModel as LibinventModel from .linkinvent.link_invent_model import LinkInventModel as LinkinventModel +from .transformer.linkinvent.linkinvent import LinkinventModel as LinkinventTransformerModel from .transformer.mol2mol.mol2mol import Mol2MolModel from .model_factory.model_adapter import * diff --git a/reinvent/models/model_factory/linkinvent_adapter.py b/reinvent/models/model_factory/linkinvent_adapter.py index e2de923..5b5ca3f 100644 --- a/reinvent/models/model_factory/linkinvent_adapter.py +++ b/reinvent/models/model_factory/linkinvent_adapter.py @@ -2,7 +2,7 @@ from __future__ import annotations -__all__ = ["LinkinventAdapter"] +__all__ = ["LinkinventAdapter", "LinkinventTransformerAdapter"] from typing import List, TYPE_CHECKING from .sample_batch import SampleBatch @@ -12,6 +12,7 @@ SampledSequencesDTO, BatchLikelihoodDTO, ) +from reinvent.models.model_factory.transformer_adapter import TransformerAdapter if TYPE_CHECKING: pass @@ -32,3 +33,7 @@ def sample(self, warheads_seqs, warheads_seq_lengths) -> SampleBatch: # warhead SMILES, linker SMILES, NLLs sampled = self.model.sample(warheads_seqs, warheads_seq_lengths) return SampleBatch(*sampled) + + +class LinkinventTransformerAdapter(TransformerAdapter): + pass \ No newline at end of file diff --git a/reinvent/models/reinvent/models/dataset.py b/reinvent/models/reinvent/models/dataset.py index 7b0020f..f1cdf5a 100644 --- a/reinvent/models/reinvent/models/dataset.py +++ b/reinvent/models/reinvent/models/dataset.py @@ -3,24 +3,43 @@ import torch import torch.utils.data as tud +from reinvent.chemistry import Conversions from reinvent.models.reinvent.utils import collate_fn +conversions = Conversions() + + class Dataset(tud.Dataset): """Custom PyTorch Dataset that takes a file containing \n separated SMILES""" - def __init__(self, smiles_list, vocabulary, tokenizer): + def __new__(cls, *args, **kwargs): + if "randomize" in kwargs and kwargs["randomize"]: + cls.__getitem__ = cls._getitem_with_randomization + else: + cls.__getitem__ = cls._getitem + + return super().__new__(cls) + + def __init__(self, smiles_list, vocabulary, tokenizer, randomize=False): self._vocabulary = vocabulary self._tokenizer = tokenizer self._smiles_list = list(smiles_list) - def __getitem__(self, i): + def _getitem(self, i): smiles = self._smiles_list[i] tokens = self._tokenizer.tokenize(smiles) encoded = self._vocabulary.encode(tokens) return torch.tensor(encoded, dtype=torch.long) + def _getitem_with_randomization(self, i): + smiles = conversions.randomize_smiles(self._smiles_list[i]) + tokens = self._tokenizer.tokenize(smiles) + encoded = self._vocabulary.encode(tokens) + + return torch.tensor(encoded, dtype=torch.long) + def __len__(self): return len(self._smiles_list) diff --git a/reinvent/models/reinvent/models/vocabulary.py b/reinvent/models/reinvent/models/vocabulary.py index 020aac6..3cda686 100644 --- a/reinvent/models/reinvent/models/vocabulary.py +++ b/reinvent/models/reinvent/models/vocabulary.py @@ -10,7 +10,7 @@ class Vocabulary: """Stores the tokens and their conversion to vocabulary indexes.""" def __init__( - self, tokens=None, starting_id=0, pad_token=0, bos_token=1, eos_token=2, unk_token=3 + self, tokens=None, starting_id=0, pad_token=0, bos_token=1, eos_token=2, unk_token=None ): self._tokens = {} self._current_id = starting_id @@ -98,7 +98,7 @@ def get_dictionary(self): "pad_token": getattr(self, "pad_token", 0), "bos_token": getattr(self, "bos_token", 1), "eos_token": getattr(self, "eos_token", 2), - "unk_token": getattr(self, "unk_token", 3), + "unk_token": getattr(self, "unk_token", None), } @classmethod diff --git a/reinvent/models/transformer/core/vocabulary.py b/reinvent/models/transformer/core/vocabulary.py index 0eb29e6..28746c0 100644 --- a/reinvent/models/transformer/core/vocabulary.py +++ b/reinvent/models/transformer/core/vocabulary.py @@ -8,7 +8,7 @@ class Vocabulary: """Stores the tokens and their conversion to one-hot vectors.""" - def __init__(self, tokens=None, starting_id=0, pad_token=0, bos_token=1, eos_token=2, unk_token=3): + def __init__(self, tokens=None, starting_id=0, pad_token=0, bos_token=1, eos_token=2, unk_token=None): self._tokens = {} self._current_id = starting_id @@ -84,17 +84,17 @@ def encode(self, tokens): :return : An numpy array with the tokens encoded. """ ohe_vect = np.zeros(len(tokens), dtype=np.float32) + ohe_keep_mask = np.ones_like(tokens, dtype=bool) for i, token in enumerate(tokens): if token not in self._tokens: - raise KeyError( - f"{token} is not part of the tokens " - + "the model was trained on. " - + f"The token {token} may have been generated " - + "by the internal canonicalization, but " - + "please check your input SMILES." - ) - ohe_vect[i] = self._tokens[token] - return ohe_vect + if hasattr(self, "unk_token") and (self.unk_token is not None): + unk_symbol = self[self.unk_token] + ohe_vect[i] = self._tokens[unk_symbol] + else: + ohe_keep_mask[i] = False + else: + ohe_vect[i] = self._tokens[token] + return ohe_vect[ohe_keep_mask] def decode(self, ohe_vect): """ @@ -126,7 +126,7 @@ def get_dictionary(self): "pad_token": getattr(self, "pad_token", 0), "bos_token": getattr(self, "bos_token", 1), "eos_token": getattr(self, "eos_token", 2), - "unk_token": getattr(self, "unk_token", 3), + "unk_token": getattr(self, "unk_token", None), } @classmethod diff --git a/reinvent/runmodes/RL/learning.py b/reinvent/runmodes/RL/learning.py index 7f1ca2f..3bf4976 100644 --- a/reinvent/runmodes/RL/learning.py +++ b/reinvent/runmodes/RL/learning.py @@ -238,6 +238,31 @@ def _update_common(self, results: ScoreResults): np.argwhere(self.sampled.states == SmilesState.VALID).flatten(), ) + def _update_common_transformer(self, results: ScoreResults): + """Common update for Transformer-based models, Mol2Mol, LibInvent and LinkInvent + + :param results: scoring results object + :return: total loss + """ + likelihood_dto = self._state.agent.likelihood_smiles(self.sampled) + batch = likelihood_dto.batch + + prior_nlls = self.prior.likelihood( + batch.input, batch.input_mask, batch.output, batch.output_mask + ) + + agent_nlls = likelihood_dto.likelihood + + return self.reward_nlls( + agent_nlls, + prior_nlls, + results.total_scores, + self.inception, + results.smilies, + self._state.agent, + np.argwhere(self.sampled.states == SmilesState.VALID).flatten(), + ) + # FIXME: still needed: molecule ID def report( self, @@ -327,7 +352,9 @@ def report( send_report(report, self.reporter) - csv_summary = CSVSummary(step_no, score_results, NLL_prior, NLL_agent, NLL_augm, scaffolds) + csv_summary = CSVSummary( + step_no, score_results, NLL_prior, NLL_agent, NLL_augm, scaffolds, self.sampled.states + ) header, columns = write_summary(csv_summary, write_header=self.__write_csv_header) diff --git a/reinvent/runmodes/RL/linkinvent.py b/reinvent/runmodes/RL/linkinvent.py index e77b26a..07b12e9 100644 --- a/reinvent/runmodes/RL/linkinvent.py +++ b/reinvent/runmodes/RL/linkinvent.py @@ -2,7 +2,7 @@ from __future__ import annotations -__all__ = ["LinkinventLearning"] +__all__ = ["LinkinventLearning", "LinkinventTransformerLearning"] from typing import TYPE_CHECKING import numpy as np @@ -32,4 +32,10 @@ def score(self): return results def update(self, results: ScoreResults): - return self._update_common(results) + if self.prior.model._version == 1: # RNN-based + return self._update_common(results) + elif self.prior.model._version == 2: # Transformer-based + return self._update_common_transformer(results) + + +LinkinventTransformerLearning = LinkinventLearning diff --git a/reinvent/runmodes/RL/mol2mol.py b/reinvent/runmodes/RL/mol2mol.py index f8c2402..e6bf550 100644 --- a/reinvent/runmodes/RL/mol2mol.py +++ b/reinvent/runmodes/RL/mol2mol.py @@ -48,21 +48,4 @@ def score(self): return results def update(self, results: ScoreResults): - likelihood_dto = self._state.agent.likelihood_smiles(self.sampled) - batch = likelihood_dto.batch - - prior_nlls = self.prior.likelihood( - batch.input, batch.input_mask, batch.output, batch.output_mask - ) - - agent_nlls = likelihood_dto.likelihood - - return self.reward_nlls( - agent_nlls, - prior_nlls, - results.total_scores, - self.inception, - results.smilies, - self._state.agent, - np.argwhere(self.sampled.states == SmilesState.VALID).flatten(), - ) + return self._update_common_transformer(results) diff --git a/reinvent/runmodes/RL/reports/csv_summmary.py b/reinvent/runmodes/RL/reports/csv_summmary.py index e86cf0f..c2ead51 100644 --- a/reinvent/runmodes/RL/reports/csv_summmary.py +++ b/reinvent/runmodes/RL/reports/csv_summmary.py @@ -20,6 +20,7 @@ class CSVSummary: agent_nll: float augmented_nll: float scaffolds: list + smiles_state: list def write_summary(data: CSVSummary, write_header=False) -> tuple: @@ -30,7 +31,7 @@ def write_summary(data: CSVSummary, write_header=False) -> tuple: :returns: headers and columns """ - header = ["Agent", "Prior", "Target", "Score", "SMILES"] + header = ["Agent", "Prior", "Target", "Score", "SMILES", "SMILES_state"] results = data.score_results columns = [ @@ -39,6 +40,7 @@ def write_summary(data: CSVSummary, write_header=False) -> tuple: [f"{score:.4f}" for score in data.augmented_nll], [f"{score:.7f}" for score in results.total_scores], results.smilies, + [str(state.value) for state in data.smiles_state] ] if data.scaffolds: diff --git a/reinvent/runmodes/RL/reports/tensorboard.py b/reinvent/runmodes/RL/reports/tensorboard.py index de3ba72..041cbb4 100644 --- a/reinvent/runmodes/RL/reports/tensorboard.py +++ b/reinvent/runmodes/RL/reports/tensorboard.py @@ -15,8 +15,8 @@ logger = logging.getLogger(__name__) -ROWS = 3 -COLUMNS = 10 +ROWS = 5 +COLUMNS = 6 @dataclass diff --git a/reinvent/runmodes/RL/run_staged_learning.py b/reinvent/runmodes/RL/run_staged_learning.py index 37e4ab2..837de8a 100644 --- a/reinvent/runmodes/RL/run_staged_learning.py +++ b/reinvent/runmodes/RL/run_staged_learning.py @@ -251,7 +251,7 @@ def run_staged_learning( rdkit_smiles_flags = dict(allowTautomers=True) - if model_type == "Mol2Mol": # this is a special case + if model_type in ["Mol2Mol", "LinkinventTransformer"]: # Transformer-based models agent_mode = "inference" rdkit_smiles_flags.update(sanitize=True, isomericSmiles=True) rdkit_smiles_flags2 = dict(isomericSmiles=True) diff --git a/reinvent/runmodes/TL/configurations.py b/reinvent/runmodes/TL/configurations.py index 7da5bcd..effaf7d 100644 --- a/reinvent/runmodes/TL/configurations.py +++ b/reinvent/runmodes/TL/configurations.py @@ -36,6 +36,8 @@ class Configuration: num_refs: int = 0 # number of reference molecules for similarity starting_epoch: int = 1 shuffle_each_epoch: bool = True + randomize_all_smiles: bool = False + internal_diversity: bool = False @dataclass(frozen=True) diff --git a/reinvent/runmodes/TL/learning.py b/reinvent/runmodes/TL/learning.py index 9878773..71e5a0c 100644 --- a/reinvent/runmodes/TL/learning.py +++ b/reinvent/runmodes/TL/learning.py @@ -11,7 +11,7 @@ __all__ = ["Learning"] from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from typing import List, TYPE_CHECKING import logging import random @@ -31,6 +31,7 @@ from reinvent.chemistry import Conversions from reinvent.chemistry.library_design import BondMaker, AttachmentPoints from reinvent.models.model_factory.sample_batch import SmilesState +from reinvent.runmodes.utils import mutual_similarities, internal_diversity if TYPE_CHECKING: from reinvent.models import ModelAdapter @@ -64,6 +65,7 @@ def __init__( self.validation_dataset = None self.collate_fn = None self.dataset = None + self.randomize_all_smiles = self._config.randomize_all_smiles # FIXME: ugly hard-coded model names if model_type == "Reinvent" or model_type == "Mol2Mol": @@ -73,11 +75,10 @@ def __init__( self._optimizer = configuration.optimizer self._lr_scheduler = configuration.learning_rate_scheduler - self.ref_fps = None + self.reference_fingerprints = None self.smilies = self._config.smilies self.validation_smilies = self._config.validation_smilies - self.duplicate_smiles = set() chemistry = ChemistryHelpers( Conversions(), # Lib/LinkInvent, Mol2Mol @@ -91,11 +92,15 @@ def __init__( sample_batch_size = self._config.sample_batch_size sampling_parameters = {"batch_size": sample_batch_size} - sampler, _ = setup_sampler(model_type, sampling_parameters, self.model, chemistry) + sampler, _ = setup_sampler( + model_type, sampling_parameters, self.model, chemistry + ) sampler.unique_sequences = False self.sampler = sampler - self.sampling_smilies = random.choices(self.smilies, k=self._config.sample_batch_size) + self.sampling_smilies = random.choices( + self.smilies, k=self._config.sample_batch_size + ) if not isinstance(self.sampling_smilies[0], str): self.sampling_smilies = [s[0] for s in self.sampling_smilies] @@ -107,6 +112,7 @@ def __init__( self.batch_size = configuration.batch_size self.save_freq = max(self._config.save_every_n_epochs, 1) + self.internal_diversity = self._config.internal_diversity self.reporter = get_reporter() self.tb_reporter = None @@ -116,7 +122,9 @@ def __init__( if model_type == "Reinvent": iv = torch.full((self.batch_size,), 0, dtype=torch.long) - self.tb_reporter.add_graph(self.model.model.network, input_to_model=iv.unsqueeze(1)) + self.tb_reporter.add_graph( + self.model.model.network, input_to_model=iv.unsqueeze(1) + ) if do_similarity: self._compute_similarity() @@ -143,7 +151,7 @@ def _common_dataloader(self): generator=torch.Generator(device=self.device), shuffle=self._config.shuffle_each_epoch, collate_fn=self.collate_fn, - drop_last=True, + drop_last=False, ) self.validation_dataloader = None @@ -238,7 +246,9 @@ def _train_epoch_common(self) -> float: loss.backward() if self.clip_gradient_norm > 0: - clip_grad_norm_(self.model.network.parameters(), self.clip_gradient_norm) + clip_grad_norm_( + self.model.network.parameters(), self.clip_gradient_norm + ) self._optimizer.step() @@ -270,8 +280,10 @@ def _save_model(self, epoch: int = None) -> str: def _prepare_similarity(self): nmols = min(len(self.smilies), self._config.num_refs) ref_smilies = random.sample(self.smilies, nmols) - mols = filter(lambda m: m, [Chem.MolFromSmiles(smiles) for smiles in ref_smilies]) - self.ref_fps = [Chem.RDKFingerprint(mol) for mol in mols] + mols = filter( + lambda m: m, [Chem.MolFromSmiles(smiles) for smiles in ref_smilies] + ) + self.reference_fingerprints = [Chem.RDKFingerprint(mol) for mol in mols] def _compute_similarity(self): mols = filter( @@ -319,27 +331,46 @@ def report( samples = self.sampler.sample(self.sampling_smilies) sampled_smilies = [] sampled_nlls = [] + mols = [] + duplicate_smiles = set() - for smiles, nll, state in zip(samples.smilies, samples.nlls.cpu(), samples.states): + for smiles, nll, state in zip( + samples.smilies, samples.nlls.cpu(), samples.states + ): if state == SmilesState.DUPLICATE: - self.duplicate_smiles.add(smiles) + duplicate_smiles.add(smiles) if state == SmilesState.DUPLICATE or state == SmilesState.VALID: sampled_smilies.append(smiles) sampled_nlls.append(nll) + mol = Chem.MolFromSmiles(smiles) + + if mol: + mols.append(mol) + + intdiv = 0.0 + sampled_fps = [Chem.RDKFingerprint(mol) for mol in mols] + + if self.internal_diversity: + similarities = mutual_similarities(sampled_fps) + intdiv = internal_diversity(similarities, p=2) + if self.tb_reporter: tb_data = TBData( epoch=epoch_no, mean_nll=mean_nll, mean_nll_validation=mean_nll_valid, - ref_fps=self.ref_fps, + fingerprints=sampled_fps, + reference_fingerprints=self.reference_fingerprints, sampled_smilies=sampled_smilies, sampled_nlls=np.array(sampled_nlls), fraction_valid=len(sampled_smilies) / len(samples.smilies), + number_duplicates=len(duplicate_smiles), + internal_diversity=intdiv, ) - write_report(self.tb_reporter, tb_data, self.duplicate_smiles) + write_report(self.tb_reporter, tb_data) remote_data = RemoteData( epoch=epoch_no, diff --git a/reinvent/runmodes/TL/reinvent.py b/reinvent/runmodes/TL/reinvent.py index bf1007e..589881d 100644 --- a/reinvent/runmodes/TL/reinvent.py +++ b/reinvent/runmodes/TL/reinvent.py @@ -24,6 +24,7 @@ def prepare_dataloader(self): smiles_list=self.smilies, vocabulary=self.model.vocabulary, tokenizer=SMILESTokenizer(), + randomize=self.randomize_all_smiles, ) self.validation_dataset = None @@ -33,6 +34,7 @@ def prepare_dataloader(self): smiles_list=self.validation_smilies, vocabulary=self.model.vocabulary, tokenizer=SMILESTokenizer(), + randomize=self.randomize_all_smiles, # if true much shallower minimum ) self.collate_fn = Dataset.collate_fn diff --git a/reinvent/runmodes/TL/reports/tensorboard.py b/reinvent/runmodes/TL/reports/tensorboard.py index cf5e63c..fa4c533 100644 --- a/reinvent/runmodes/TL/reports/tensorboard.py +++ b/reinvent/runmodes/TL/reports/tensorboard.py @@ -7,7 +7,7 @@ import numpy as np from rdkit import Chem, DataStructs -from reinvent.runmodes.utils import make_grid_image +from reinvent.runmodes.utils import make_grid_image, compute_similarity_from_sample ROWS = 5 @@ -18,22 +18,27 @@ class TBData: epoch: int mean_nll: float - ref_fps: List sampled_smilies: Sequence sampled_nlls: Sequence + fingerprints: Sequence + reference_fingerprints: Sequence fraction_valid: float + number_duplicates: float + internal_diversity: float mean_nll_validation: float = None -def write_report(reporter, data, duplicates) -> None: +def write_report(reporter, data) -> None: """Write out TensorBoard data :param reporter: TB reporter for writing out the data :param data: data to be written out - :param duplicates: SMILES cache """ - mean_nll_stats = {"Training Loss": data.mean_nll, "Sample Loss": data.sampled_nlls.mean()} + mean_nll_stats = { + "Training Loss": data.mean_nll, + "Sample Loss": data.sampled_nlls.mean(), + } if data.mean_nll_validation is not None: mean_nll_stats["Validation Loss"] = data.mean_nll_validation @@ -41,7 +46,12 @@ def write_report(reporter, data, duplicates) -> None: reporter.add_scalars("A_Mean NLL loss", mean_nll_stats, data.epoch) reporter.add_scalar("B_Fraction valid SMILES", data.fraction_valid, data.epoch) - reporter.add_scalar("C_Duplicate SMILES", len(duplicates), data.epoch) + reporter.add_scalar("C_Duplicate SMILES (per epoch)", data.number_duplicates, data.epoch) + + if data.internal_diversity > 0.0: + reporter.add_scalar( + "D_Internal Diversity of sample", data.internal_diversity, data.epoch + ) # FIXME: rows and cols depend on sample_batch_size image_tensor, nimage = make_grid_image( @@ -56,27 +66,10 @@ def write_report(reporter, data, duplicates) -> None: dataformats="CHW", ) # channel, height, width - if data.ref_fps: - similarities = compute_similarity_from_sample(data.sampled_smilies, data.ref_fps) - reporter.add_histogram("Tanimoto similarity on RDKitFingerprint", similarities, data.epoch) - - -def compute_similarity_from_sample(smilies: List, ref_fps: List): - """Take the first SMIlES from the input set and compute ther - average similarity from SMILES from a sample - - :param smilies: list of SMILES - :param ref_fps: reference fingerprints - """ - - mols = filter(lambda m: m, [Chem.MolFromSmiles(smiles) for smiles in smilies]) - fps = [Chem.RDKFingerprint(mol) for mol in mols] - - sims = [] - - for ref_fp in ref_fps: - sims.append(np.array(DataStructs.BulkTanimotoSimilarity(ref_fp, fps))) - - similarities = np.array(sims).mean(axis=0) - - return similarities + if data.reference_fingerprints: + similarities = compute_similarity_from_sample( + data.fingerprints, data.reference_fingerprints + ) + reporter.add_histogram( + "Tanimoto similarity on RDKitFingerprint", similarities, data.epoch + ) diff --git a/reinvent/runmodes/TL/run_transfer_learning.py b/reinvent/runmodes/TL/run_transfer_learning.py index f9acb17..0beaa35 100644 --- a/reinvent/runmodes/TL/run_transfer_learning.py +++ b/reinvent/runmodes/TL/run_transfer_learning.py @@ -41,7 +41,11 @@ def run_transfer_learning( smiles_filename = os.path.abspath(parameters["smiles_file"]) do_standardize = parameters.get("standardize_smiles", True) - do_randomize = parameters.get("randomize_smiles", True) + + randomize_all_smiles = parameters.get("randomize_all_smiles", False) + do_randomize = parameters.get("randomize_smiles", True) and not randomize_all_smiles + + internal_diversity = parameters.get("internal_diversity", False) actions = [] cols = 0 @@ -62,7 +66,9 @@ def run_transfer_learning( cols = slice(0, 2, None) # NOTE: we expect here that all data will fit into memory - smilies = read_smiles_csv_file(smiles_filename, cols, actions=actions, remove_duplicates=True) + smilies = read_smiles_csv_file( + smiles_filename, cols, actions=actions, remove_duplicates=True + ) logger.info(f"Read {len(smilies)} input SMILES from {smiles_filename}") if not smilies: @@ -96,6 +102,8 @@ def run_transfer_learning( sample_batch_size=parameters.get("sample_batch_size", 1), num_refs=parameters.get("num_refs", 0), n_cpus=config.get("number_of_cpus", 1), + randomize_all_smiles=randomize_all_smiles, + internal_diversity=internal_diversity, ) if model_type == "Mol2Mol": diff --git a/reinvent/runmodes/create_adapter.py b/reinvent/runmodes/create_adapter.py index 0caa5d6..c344eec 100644 --- a/reinvent/runmodes/create_adapter.py +++ b/reinvent/runmodes/create_adapter.py @@ -1,8 +1,8 @@ """Create a model adapter from a Torch pickle file""" __all__ = ["create_adapter"] +import os import pprint -from typing import Tuple import logging import torch @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) -def create_adapter(dict_filename: str, mode: str, device: torch.device) -> Tuple: +def create_adapter(dict_filename: str, mode: str, device: torch.device) -> tuple: """Read a dict from a Torch pickle find and return an adapter and model dict. :param dict_filename: filename of the Torch pickle file @@ -21,43 +21,24 @@ def create_adapter(dict_filename: str, mode: str, device: torch.device) -> Tuple :returns: the adapter class, the model type """ + dict_filename = os.path.abspath(dict_filename) save_dict = torch.load(dict_filename, map_location="cpu") + check_metadata(dict_filename, save_dict) - if "metadata" in save_dict: - metadata: models.ModelMetaData = save_dict["metadata"] - - if not metadata.hash_id: - logger.warning(f"{dict_filename} does not contain a hash ID") - else: - valid = models.check_valid_hash(save_dict) - pp = pprint.PrettyPrinter(indent=2) + if "model_type" in save_dict: + model_type = save_dict["model_type"] - if valid: - logger.info(f"{dict_filename} has valid hash:\n{pp.pformat(metadata.as_dict())}") - else: - logger.error(f"{dict_filename} has invalid hash:\n{pp.pformat(metadata.as_dict())}") + # kludge to handle new-style transformers + if model_type == "Linkinvent" and "version" in save_dict: + if save_dict["version"] == 2: + model_type += "Transformer" else: - logger.warning(f"{dict_filename} does not contain metadata") + model_type = orig_style_priors(save_dict) - if "model_type" in save_dict: - model_type = save_dict["model_type"] - else: # heuristics - # FIXME: ugly if - if "network" in save_dict: - model_type = "Reinvent" - elif "model" in save_dict: - model_type = "Libinvent" - elif "encoder_params" in save_dict["network_parameter"]: - model_type = "Linkinvent" - elif "num_heads" in save_dict["network_parameter"]: - model_type = "Mol2Mol" - else: - model_type = None + adapter_class = getattr(models, f"{model_type}Adapter", None) + model_class = getattr(models, f"{model_type}Model", None) - try: - adapter_class = getattr(models, f"{model_type}Adapter") - model_class = getattr(models, f"{model_type}Model") - except AttributeError: + if not adapter_class or not model_class: msg = f"Unknown model type: {model_type}" logger.fatal(msg) raise RuntimeError(msg) @@ -65,7 +46,7 @@ def create_adapter(dict_filename: str, mode: str, device: torch.device) -> Tuple model = model_class.create_from_dict(save_dict, mode, device) adapter = adapter_class(model) - compatibility(model) + compatibility_setup(model) network_params = model.network.parameters() num_params = sum([tensor.numel() for tensor in network_params]) @@ -73,8 +54,68 @@ def create_adapter(dict_filename: str, mode: str, device: torch.device) -> Tuple return adapter, save_dict, model_type -def compatibility(model): - """Compatibility mode for old Mol2Mol priors""" + +def check_metadata(dict_filename: str, save_dict: dict) -> None: + """Check the metadata of the save dict from a model file. + + CUrrently, only logs warnings or errors but does not terminate the run. + + :param dict_filename: model pickle file with the save dict + :param save_dict: the save dict + """ + + if "metadata" in save_dict: + metadata: models.ModelMetaData = save_dict["metadata"] + + if metadata is not None: + if not metadata.hash_id: + logger.warning(f"{dict_filename} does not contain a hash ID") + else: + valid = models.check_valid_hash(save_dict) + pp = pprint.PrettyPrinter(indent=2) + pp_dict = pp.pformat(metadata.as_dict()) + + if valid: + logger.info( + f"{dict_filename} has valid hash:\n{pp_dict}") + else: + logger.error( + f"{dict_filename} has invalid hash:\n{pp_dict}") + else: + logger.warning(f"{dict_filename} contains empty metadata") + else: + logger.warning(f"{dict_filename} does not contain metadata") + + +def orig_style_priors(save_dict: dict) -> str: + """Determine model type heuristically + + Originally, prior files did not contain any metadata so the model type + must be guessed from the layout of the save dict. + + :param save_dict: the save dict + :returns: the model type descriptor as a string + """ + + if "network" in save_dict: + model_type = "Reinvent" + elif "model" in save_dict: + model_type = "Libinvent" + elif "encoder_params" in save_dict["network_parameter"]: + model_type = "Linkinvent" + elif "num_heads" in save_dict["network_parameter"]: + model_type = "Mol2Mol" + else: + model_type = "" + + return model_type + + +def compatibility_setup(model): + """Compatibility mode for old Mol2Mol priors + + :param model: model adapter object + """ from reinvent.models.mol2mol.models.vocabulary import Vocabulary diff --git a/reinvent/runmodes/samplers/libinvent.py b/reinvent/runmodes/samplers/libinvent.py index 570fa9e..d23bbbb 100644 --- a/reinvent/runmodes/samplers/libinvent.py +++ b/reinvent/runmodes/samplers/libinvent.py @@ -72,7 +72,7 @@ def sample(self, smilies: List[str]) -> SampleBatch: mols = join_fragments(sampled, reverse=False, keep_labels=True) - sampled.smilies, sampled.states = validate_smiles(mols) + sampled.smilies, sampled.states = validate_smiles(mols, sampled.output) return sampled diff --git a/reinvent/runmodes/samplers/linkinvent.py b/reinvent/runmodes/samplers/linkinvent.py index 74bc869..2fcb42c 100644 --- a/reinvent/runmodes/samplers/linkinvent.py +++ b/reinvent/runmodes/samplers/linkinvent.py @@ -1,6 +1,6 @@ """The LinkInvent sampling module""" -__all__ = ["LinkinventSampler"] +__all__ = ["LinkinventSampler", "LinkinventTransformerSampler"] from typing import List import torch.utils.data as tud @@ -10,6 +10,7 @@ from reinvent.models.linkinvent.dataset.dataset import Dataset from reinvent.models.model_factory.sample_batch import SampleBatch from reinvent.runmodes.utils.helpers import join_fragments +from ...models.transformer.core.dataset.dataset import Dataset as TransformerDataset class LinkinventSampler(Sampler): @@ -22,14 +23,26 @@ def sample(self, smilies: List[str]) -> SampleBatch: :returns: list of SampledSequencesDTO """ + if self.model.model._version == 2: # Transformer-based + smilies = self._standardize_input(smilies) + warheads_list = self._get_randomized_smiles(smilies) if self.randomize_smiles else smilies clean_warheads = [ self.chemistry.attachment_points.remove_attachment_point_numbers(warheads) for warheads in warheads_list ] - clean_warheads = clean_warheads * self.batch_size - dataset = Dataset(clean_warheads, self.model.get_vocabulary().input) + if self.model.model._version == 1: # RNN-based + clean_warheads = clean_warheads * self.batch_size + + dataset = Dataset(clean_warheads, self.model.get_vocabulary().input) + elif self.model.model._version == 2: # Transformer-based + if self.sample_strategy == "multinomial": + clean_warheads = clean_warheads * self.batch_size + + dataset = TransformerDataset( + clean_warheads, self.model.get_vocabulary(), self.model.tokenizer + ) dataloader = tud.DataLoader( dataset, @@ -41,8 +54,12 @@ def sample(self, smilies: List[str]) -> SampleBatch: sequences = [] for batch in dataloader: - inputs, input_seq_lengths = batch - sampled = self.model.sample(inputs, input_seq_lengths) + inputs, input_info = batch + + if self.model.model._version == 1: + sampled = self.model.sample(inputs, input_info) + elif self.model.model._version == 2: + sampled = self.model.sample(inputs, input_info, self.sample_strategy) for batch_row in sampled: sequences.append(batch_row) @@ -54,12 +71,26 @@ def sample(self, smilies: List[str]) -> SampleBatch: mols = join_fragments(sampled, reverse=True) - sampled.smilies, sampled.states = validate_smiles(mols) + sampled.smilies, sampled.states = validate_smiles( + mols, sampled.output, isomeric=self.isomeric + ) return sampled + def _standardize_input(self, warheads_list: List[str]): + cano_warheads_list = [] + for warheads in warheads_list: + cano_warheads = "|".join( + [ + self.chemistry.conversions.convert_to_standardized_smiles(warhead) + for warhead in warheads.split("|") + ] + ) + cano_warheads_list.append(cano_warheads) + return cano_warheads_list + def _get_randomized_smiles(self, warhead_pair_list: List[str]): - """Y""" + """Randomize the warhead SMILES""" randomized_warhead_pair_list = [] @@ -69,7 +100,8 @@ def _get_randomized_smiles(self, warhead_pair_list: List[str]): self.chemistry.conversions.smile_to_mol(warhead) for warhead in warhead_list ] warhead_randomized_list = [ - self.chemistry.conversions.mol_to_random_smiles(mol) for mol in warhead_mol_list + self.chemistry.conversions.mol_to_random_smiles(mol, isomericSmiles=self.isomeric) + for mol in warhead_mol_list ] # Note do not use self.self._bond_maker.randomize_scaffold, as it would add unwanted brackets to the # attachment points (which are not part of the warhead vocabulary) @@ -79,3 +111,6 @@ def _get_randomized_smiles(self, warhead_pair_list: List[str]): randomized_warhead_pair_list.append(warhead_pair_randomized) return randomized_warhead_pair_list + + +LinkinventTransformerSampler = LinkinventSampler diff --git a/reinvent/runmodes/samplers/mol2mol.py b/reinvent/runmodes/samplers/mol2mol.py index f7442b9..f807072 100644 --- a/reinvent/runmodes/samplers/mol2mol.py +++ b/reinvent/runmodes/samplers/mol2mol.py @@ -26,7 +26,9 @@ def sample(self, smilies: List[str]) -> SampleBatch: :returns: list of SampledSequencesDTO """ # Standardize smiles in the same way as training data - smilies = [self.chemistry.conversions.convert_to_standardized_smiles(smile) for smile in smilies] + smilies = [ + self.chemistry.conversions.convert_to_standardized_smiles(smile) for smile in smilies + ] smilies = ( [self._get_randomized_smiles(smiles) for smiles in smilies] @@ -34,17 +36,10 @@ def sample(self, smilies: List[str]) -> SampleBatch: else smilies ) - self.model.set_temperature(self.temperature) # FIXME: should probably be done by caller # replace hard-coded strings if self.sample_strategy == "multinomial": smilies = smilies * self.batch_size - elif self.sample_strategy == "beamsearch": - self.model.set_beam_size(self.batch_size) - else: - raise ValueError( - f"Sample strategy `{self.sample_strategy}` is not implemented" - ) tokenizer = SMILESTokenizer() dataset = Dataset(smilies, self.model.get_vocabulary(), tokenizer) @@ -75,13 +70,17 @@ def sample(self, smilies: List[str]) -> SampleBatch: for smiles in sampled.output ] - sampled.smilies, sampled.states = validate_smiles(mols, isomeric=self.isomeric) + sampled.smilies, sampled.states = validate_smiles( + mols, sampled.output, isomeric=self.isomeric + ) return sampled def _get_randomized_smiles(self, smiles: str): input_mol = self.chemistry.conversions.smile_to_mol(smiles) - randomized_smile = self.chemistry.conversions.mol_to_random_smiles(input_mol) + randomized_smile = self.chemistry.conversions.mol_to_random_smiles( + input_mol, isomericSmiles=self.isomeric + ) return randomized_smile @@ -91,20 +90,24 @@ def calculate_tanimoto(self, reference_smiles, smiles): returns the largest if multiple reference smiles provided """ specific_parameters = {"radius": 2, "use_features": False} - ref_fingerprints = self.chemistry.conversions.smiles_to_fingerprints(reference_smiles, - radius=specific_parameters['radius'], - use_features=specific_parameters[ - 'use_features']) + ref_fingerprints = self.chemistry.conversions.smiles_to_fingerprints( + reference_smiles, + radius=specific_parameters["radius"], + use_features=specific_parameters["use_features"], + ) valid_mols, valid_idxs = self.chemistry.conversions.smiles_to_mols_and_indices(smiles) - query_fps = self.chemistry.conversions.mols_to_fingerprints(valid_mols, radius=specific_parameters['radius'], - use_features=specific_parameters[ - 'use_features']) + query_fps = self.chemistry.conversions.mols_to_fingerprints( + valid_mols, + radius=specific_parameters["radius"], + use_features=specific_parameters["use_features"], + ) similarity = Similarity() scores = similarity.calculate_tanimoto(query_fps, ref_fingerprints) return scores - def check_nll(self, input_smiles: List[str], target_smiles: List[str]) -> Tuple[ - List[str], List[str], List[float], List[float]]: + def check_nll( + self, input_smiles: List[str], target_smiles: List[str] + ) -> Tuple[List[str], List[str], List[float], List[float]]: """ Compute the NLL of generating each target smiles given each input reference smiles :param input_smiles: list of input SMILES @@ -118,8 +121,9 @@ def check_nll(self, input_smiles: List[str], target_smiles: List[str]) -> Tuple[ current_smi = smi try: - cano_smi = self.chemistry.conversions.convert_to_rdkit_smiles(smi, - sanitize=True, isomericSmiles=True) + cano_smi = self.chemistry.conversions.convert_to_rdkit_smiles( + smi, sanitize=True, isomericSmiles=True + ) current_smi = cano_smi except Exception: print(f"WARNING. SMILES {smi} is invalid") @@ -129,7 +133,9 @@ def check_nll(self, input_smiles: List[str], target_smiles: List[str]) -> Tuple[ tokenized_smi = tokenizer.tokenize(current_smi) self.model.vocabulary.encode(tokenized_smi) except KeyError as e: - print(f"WARNING. SMILES {current_smi} contains an invalid token {e}. It will be ignored") + print( + f"WARNING. SMILES {current_smi} contains an invalid token {e}. It will be ignored" + ) else: dto_list.append(SampledSequencesDTO(compound, current_smi, 0)) @@ -152,7 +158,7 @@ def check_nll(self, input_smiles: List[str], target_smiles: List[str]) -> Tuple[ # Compute Tanimoto valid_mols, valid_idxs = self.chemistry.conversions.smiles_to_mols_and_indices(target) valid_scores = self.calculate_tanimoto(input, target) - tanimoto = [None]*len(target) + tanimoto = [None] * len(target) for i, j in enumerate(valid_idxs): tanimoto[j] = valid_scores[i] diff --git a/reinvent/runmodes/samplers/reinvent.py b/reinvent/runmodes/samplers/reinvent.py index d83ff9e..2a3fa73 100644 --- a/reinvent/runmodes/samplers/reinvent.py +++ b/reinvent/runmodes/samplers/reinvent.py @@ -23,8 +23,9 @@ def sample(self, dummy) -> SampleBatch: :param dummy: Reinvent does not need SMILES input :returns: a dataclass """ - smiles_sampled, likelihood_sampled = \ - self.model.model.sample_smiles(self.batch_size, params.DATALOADER_BATCHSIZE) + smiles_sampled, likelihood_sampled = self.model.model.sample_smiles( + self.batch_size, params.DATALOADER_BATCHSIZE + ) sampled = SampleBatch(None, smiles_sampled, Tensor(likelihood_sampled)) if self.unique_sequences: @@ -35,5 +36,5 @@ def sample(self, dummy) -> SampleBatch: for smiles in sampled.output ] - sampled.smilies, sampled.states = validate_smiles(mols) + sampled.smilies, sampled.states = validate_smiles(mols, sampled.output) return sampled diff --git a/reinvent/runmodes/samplers/reports/tensorboard.py b/reinvent/runmodes/samplers/reports/tensorboard.py index 1ffa264..99dcd5e 100644 --- a/reinvent/runmodes/samplers/reports/tensorboard.py +++ b/reinvent/runmodes/samplers/reports/tensorboard.py @@ -94,4 +94,4 @@ def _report_scatter(data: TBData, reporter): if x_key in data.additional_report.keys() and y_key in data.additional_report.keys(): x, y = data.additional_report[x_key], data.additional_report[y_key] figure = plot_scatter(x, y, xlabel=xlabel, ylabel=ylabel, title=f'{len(x)} Unique') - reporter.add_figure(f'{xlabel}_{ylabel}', figure) + reporter.add_figure(f'{xlabel}_{ylabel}', figure) \ No newline at end of file diff --git a/reinvent/runmodes/samplers/run_sampling.py b/reinvent/runmodes/samplers/run_sampling.py index fe710e7..b05dae2 100644 --- a/reinvent/runmodes/samplers/run_sampling.py +++ b/reinvent/runmodes/samplers/run_sampling.py @@ -36,9 +36,12 @@ "Reinvent": ("SMILES", "NLL"), "Libinvent": ("SMILES", "Scaffold", "R-groups", "NLL"), "Linkinvent": ("SMILES", "Warheads", "Linker", "NLL"), + "LinkinventTransformer": ("SMILES", "Warheads", "Linker", "NLL"), "Mol2Mol": ("SMILES", "Input_SMILES", "Tanimoto", "NLL"), } +FRAGMENT_GENERATORS = ["Libinvent", "Linkinvent", "LinkinventTransformer"] + def run_sampling(config: dict, device, *args, **kwargs): """Sampling run setup""" @@ -132,7 +135,7 @@ def run_sampling(config: dict, device, *args, **kwargs): csv_logger.info(HEADERS[model_type]) if model_type == "Reinvent": records = zip(sampled.smilies, sampled.nlls.cpu().tolist()) - elif model_type in ["Libinvent", "Linkinvent"]: + elif model_type in FRAGMENT_GENERATORS: records = zip(sampled.smilies, sampled.items1, sampled.items2, sampled.nlls.cpu().tolist()) elif model_type == "Mol2Mol": records = zip(sampled.smilies, sampled.items1, scores, sampled.nlls.cpu().tolist()) diff --git a/reinvent/runmodes/samplers/sampler.py b/reinvent/runmodes/samplers/sampler.py index 4bfc55e..9edeb3b 100644 --- a/reinvent/runmodes/samplers/sampler.py +++ b/reinvent/runmodes/samplers/sampler.py @@ -14,7 +14,7 @@ from __future__ import annotations -__all__ = ["Sampler", "remove_duplicate_sequences", "validate_smiles", "INVALID_STR"] +__all__ = ["Sampler", "remove_duplicate_sequences", "validate_smiles"] from dataclasses import dataclass from abc import ABC, abstractmethod from typing import List, Tuple, TYPE_CHECKING @@ -34,8 +34,6 @@ logger = logging.getLogger(__name__) -INVALID_STR = "INVALID" - @dataclass class Sampler(ABC): @@ -45,13 +43,12 @@ class Sampler(ABC): # number of smiles to be generated for each input, # different from batch size used in dataloader which affect cuda memory batch_size: int - sample_strategy: str = "multinomial" # Mol2Mol - isomeric: bool = False # Mol2Mol + sample_strategy: str = "multinomial" # Transformer-based models + isomeric: bool = False # Transformer-based models randomize_smiles: bool = True unique_sequences: bool = False # backwards compatibility for R3 chemistry: ChemistryHelpers = None tokens: TransformationTokens = None # LinkInvent only - temperature: float = 1.0 @abstractmethod def sample(self, smilies: List[str]) -> SampleBatch: @@ -74,18 +71,18 @@ def remove_duplicate_sequences( orig_len = len(sampled.output) - if is_reinvent: - seq_string = np.array(sampled.output) - elif is_mol2mol: + if is_reinvent or is_mol2mol: seq_string = np.array(sampled.output) else: seq_string = np.array([f"{a}{b}" for a, b in zip(sampled.input, sampled.output)]) # order shouldn't matter here smilies, uniq_idx = np.unique(seq_string, return_index=True) + if sampled.items1: sampled.items1 = np.array(sampled.items1) sampled.items1 = list(sampled.items1[uniq_idx]) + sampled.output = list(smilies) sampled.nlls = sampled.nlls[uniq_idx] sampled.items2 = list(np.array(sampled.items2)[uniq_idx]) @@ -97,12 +94,16 @@ def remove_duplicate_sequences( return sampled -def validate_smiles(mols: List[Chem.Mol], isomeric: bool = False) -> Tuple[List, np.ndarray]: +def validate_smiles( + mols: List[Chem.Mol], smilies, isomeric: bool = False +) -> Tuple[List, np.ndarray]: """Basic validation of sampled or joined SMILES The molecules are converted to canonical SMILES. Each SMILES state is determined to be invalid, valid or duplicate. + :mols: molecules + :smilies: SMILES of molecules including invalid ones :returns: validated SMILES and their states """ @@ -110,14 +111,12 @@ def validate_smiles(mols: List[Chem.Mol], isomeric: bool = False) -> Tuple[List, smilies_states = [] # valid, invalid, duplicate seen_before = set() - for i, mol in enumerate(mols): + for mol, sampled_smiles in zip(mols, smilies): if mol: failed = Chem.SanitizeMol(mol, catchErrors=True) if not failed: - canonical_smiles = Chem.MolToSmiles( - mol, canonical=True, isomericSmiles=isomeric - ) + canonical_smiles = Chem.MolToSmiles(mol, canonical=True, isomericSmiles=isomeric) if canonical_smiles in seen_before: smilies_states.append(SmilesState.DUPLICATE) @@ -127,10 +126,10 @@ def validate_smiles(mols: List[Chem.Mol], isomeric: bool = False) -> Tuple[List, validated_smilies.append(canonical_smiles) seen_before.add(canonical_smiles) else: - validated_smilies.append(f"{INVALID_STR}{i}") + validated_smilies.append(sampled_smiles) smilies_states.append(SmilesState.INVALID) else: - validated_smilies.append(f"{INVALID_STR}{i}") + validated_smilies.append(sampled_smiles) smilies_states.append(SmilesState.INVALID) smilies_states = np.array(smilies_states) diff --git a/reinvent/runmodes/setup_sampler.py b/reinvent/runmodes/setup_sampler.py index 6c56272..889b6b0 100644 --- a/reinvent/runmodes/setup_sampler.py +++ b/reinvent/runmodes/setup_sampler.py @@ -15,6 +15,7 @@ logger = logging.getLogger(__name__) warnings.filterwarnings("once", category=FutureWarning) +TRANSFORMERS = ["Mol2Mol", "LinkinventTransformer"] def setup_sampler(model_type: str, config: dict, agent: ModelAdapter, chemistry: ChemistryHelpers): """Setup the sampling module. @@ -36,9 +37,12 @@ def setup_sampler(model_type: str, config: dict, agent: ModelAdapter, chemistry: randomize_smiles = config.get("randomize_smiles", True) temperature = config.get("temperature", 1.0) - if model_type == "Mol2Mol" and randomize_smiles: + # Transformer-based models were trained on canonical SMILES + if model_type in TRANSFORMERS and randomize_smiles: randomize_smiles = False - logger.warning(f"randomize_smiles set to false for Mol2Mol") + logger.warning(f"randomize_smiles is set to be True by user. But the model was trained using canonical SMILES" + f"where randomize_smiles might undermine the performance (this needs more investigation), " + f"but randomize_smiles is reset to be False for now.") unique_sequences = config.get("unique_sequences", False) @@ -49,31 +53,31 @@ def setup_sampler(model_type: str, config: dict, agent: ModelAdapter, chemistry: stacklevel=2, ) - if model_type == "Mol2Mol": - try: - sample_strategy = config["sample_strategy"] # for Mol2Mol - except KeyError: - sample_strategy = "multinomial" + if model_type in TRANSFORMERS: + sample_strategy = config.get("sample_strategy", "multinomial") else: sample_strategy = None tokens = TransformationTokens() # LinkInvent only isomeric = False - if model_type == "Mol2Mol": # this is a special case + if model_type in TRANSFORMERS: # for Transformer-based models isomeric = True + agent.model.set_temperature(temperature) + if sample_strategy == "beamsearch": + agent.model.set_beam_size(batch_size) + sampling_model = getattr(samplers, f"{model_type}Sampler") sampler = sampling_model( agent, batch_size=batch_size, - sample_strategy=sample_strategy, # needed for Mol2Mol - isomeric=isomeric, # needed for Mol2Mol + sample_strategy=sample_strategy, # needed for Transformer-based models + isomeric=isomeric, # needed for Transformer-based models randomize_smiles=randomize_smiles, unique_sequences=unique_sequences, chemistry=chemistry, tokens=tokens, - temperature=temperature ) return sampler, batch_size diff --git a/reinvent/runmodes/utils/__init__.py b/reinvent/runmodes/utils/__init__.py index b7b962b..1f73416 100644 --- a/reinvent/runmodes/utils/__init__.py +++ b/reinvent/runmodes/utils/__init__.py @@ -5,3 +5,4 @@ from .image import * from .helpers import * +from .evaluate import * diff --git a/reinvent/runmodes/utils/helpers.py b/reinvent/runmodes/utils/helpers.py index 3a4ec1b..cc91e34 100644 --- a/reinvent/runmodes/utils/helpers.py +++ b/reinvent/runmodes/utils/helpers.py @@ -36,31 +36,30 @@ def disable_gradients(model: ModelAdapter) -> None: param.requires_grad = False -def set_torch_device(device: str = None, use_cuda: bool = True) -> torch.device: +def set_torch_device(args_device: str = None, device: str = None) -> torch.device: """Set the Torch device - :param device: device name from the command line - :param use_cuda: whether use_cuda was set in the user config + :param args_device: device name from the command line + :param device: device name from the config """ - logger.debug(f"{device=} {use_cuda=}") + logger.debug(f"{device=} {args_device=}") - if device: # command line overwrites config file + # NOTE: ChemProp > 1.5 would need "spawn" but hits performance 4-5 times + # Windows requires "spawn" + #torch.multiprocessing.set_start_method('fork') + + if args_device: # command line overwrites config file # NOTE: this will throw a RuntimeError if the device is not available + torch.set_default_device(args_device) + actual_device = torch.device(args_device) + elif device: + torch.set_default_device(device) actual_device = torch.device(device) - elif use_cuda and torch.cuda.is_available(): - actual_device = torch.device("cuda") else: # we assume there are no other devices... + torch.set_default_device("cpu") actual_device = torch.device("cpu") - # FIXME: as of PyTorch 2.1 this should be replaced with - # torch.set_default_dtype() and torch.set_default_device() - # The dtype can be set to torch.float32 for both CPU and GPU - if actual_device.type == "cuda": - torch.set_default_tensor_type(torch.cuda.FloatTensor) - else: # assume CPU... - torch.set_default_tensor_type(torch.FloatTensor) - logger.debug(f"{actual_device=}") return actual_device diff --git a/reinvent/version.py b/reinvent/version.py index cb2d978..3bbbf9e 100644 --- a/reinvent/version.py +++ b/reinvent/version.py @@ -1,6 +1,6 @@ """Meta information for Reinvent""" __progname__ = "REINVENT" -__version__ = "4.2.5" +__version__ = "4.3.5" __config_version__ = 4 __copyright__ = "(C) AstraZeneca 2017, 2023" diff --git a/reinvent_plugins/components/OpenEye/comp_rocs.py b/reinvent_plugins/components/OpenEye/comp_rocs.py index 6fb69c9..b1126c6 100644 --- a/reinvent_plugins/components/OpenEye/comp_rocs.py +++ b/reinvent_plugins/components/OpenEye/comp_rocs.py @@ -15,14 +15,15 @@ from __future__ import annotations -import logging -import copy __all__ = ["ROCSSimilarity"] -from dataclasses import dataclass, field +import copy from typing import List, Optional +import logging import numpy as np +from pydantic import Field +from pydantic.dataclasses import dataclass from .rocs.rocs_similarity import ROCSOverlay from ..component_results import ComponentResults @@ -47,9 +48,9 @@ class Parameters: shape_weight: List[float] max_stereocenters: List[int] ewindow: List[int] - maxconfs: [List[int]] + maxconfs: List[int] similarity_measure: List[str] - custom_cff: Optional[List[str]] = field(default_factory=lambda: [None]) + custom_cff: Optional[List[str]] = Field(default_factory=lambda: [None]) @add_tag("__component") diff --git a/reinvent_plugins/components/RDKit/comp_group_count.py b/reinvent_plugins/components/RDKit/comp_group_count.py index 3adc52f..16cd1b8 100644 --- a/reinvent_plugins/components/RDKit/comp_group_count.py +++ b/reinvent_plugins/components/RDKit/comp_group_count.py @@ -8,11 +8,11 @@ __all__ = ["GroupCount"] -from dataclasses import dataclass from typing import List from rdkit import Chem import numpy as np +from pydantic.dataclasses import dataclass from ..component_results import ComponentResults from reinvent_plugins.mol_cache import molcache diff --git a/reinvent_plugins/components/RDKit/comp_matching_substructure.py b/reinvent_plugins/components/RDKit/comp_matching_substructure.py index 86043ab..d0572d7 100644 --- a/reinvent_plugins/components/RDKit/comp_matching_substructure.py +++ b/reinvent_plugins/components/RDKit/comp_matching_substructure.py @@ -8,11 +8,11 @@ __all__ = ["MatchingSubstructure"] -from dataclasses import dataclass from typing import List from rdkit import Chem import numpy as np +from pydantic.dataclasses import dataclass from ..component_results import ComponentResults from reinvent_plugins.mol_cache import molcache @@ -30,7 +30,7 @@ class Parameters: endpoint. """ - smarts: List[List[str]] + smarts: List[str] use_chirality: List[bool] diff --git a/reinvent_plugins/components/RDKit/comp_mol_volume.py b/reinvent_plugins/components/RDKit/comp_mol_volume.py index 756c9eb..3d93a57 100644 --- a/reinvent_plugins/components/RDKit/comp_mol_volume.py +++ b/reinvent_plugins/components/RDKit/comp_mol_volume.py @@ -3,11 +3,12 @@ The quality depends on the quality of the conformer. """ -from dataclasses import dataclass, field from typing import List, Optional import numpy as np from rdkit.Chem import AllChem as Chem +from pydantic import Field +from pydantic.dataclasses import dataclass from ..component_results import ComponentResults from ..add_tag import add_tag @@ -17,8 +18,8 @@ @add_tag("__parameters") @dataclass class Parameters: - grid_spacing: Optional[List[float]] = field(default_factory=lambda: [0.2]) - box_margin: Optional[List[float]] = field(default_factory=lambda: [2.0]) + grid_spacing: Optional[List[float]] = Field(default_factory=lambda: [0.2]) + box_margin: Optional[List[float]] = Field(default_factory=lambda: [2.0]) @add_tag("__component") diff --git a/reinvent_plugins/components/RDKit/comp_pmi.py b/reinvent_plugins/components/RDKit/comp_pmi.py index 68629ba..7538b70 100644 --- a/reinvent_plugins/components/RDKit/comp_pmi.py +++ b/reinvent_plugins/components/RDKit/comp_pmi.py @@ -1,10 +1,10 @@ """Compute the PMI score in RDKit""" -from dataclasses import dataclass from typing import List import numpy as np from rdkit.Chem import AllChem as Chem +from pydantic.dataclasses import dataclass from ..component_results import ComponentResults from ..add_tag import add_tag diff --git a/reinvent_plugins/components/RDKit/comp_rdkit_descriptors.py b/reinvent_plugins/components/RDKit/comp_rdkit_descriptors.py index bb268b6..f396131 100644 --- a/reinvent_plugins/components/RDKit/comp_rdkit_descriptors.py +++ b/reinvent_plugins/components/RDKit/comp_rdkit_descriptors.py @@ -1,7 +1,6 @@ """Compute a desired list of RDKit descriptors up to a total of 210""" __all__ = ["RDKitDescriptors"] -from dataclasses import dataclass from typing import List import logging @@ -9,6 +8,7 @@ from rdkit.Chem import Descriptors from rdkit.ML.Descriptors.MoleculeDescriptors import MolecularDescriptorCalculator import numpy as np +from pydantic.dataclasses import dataclass from ..component_results import ComponentResults from reinvent_plugins.mol_cache import molcache diff --git a/reinvent_plugins/components/RDKit/comp_similarity.py b/reinvent_plugins/components/RDKit/comp_similarity.py index b0eec49..f47491c 100644 --- a/reinvent_plugins/components/RDKit/comp_similarity.py +++ b/reinvent_plugins/components/RDKit/comp_similarity.py @@ -4,10 +4,10 @@ __all__ = ["TanimotoDistance"] -from dataclasses import dataclass from typing import List import numpy as np +from pydantic.dataclasses import dataclass from reinvent.chemistry.conversions import Conversions from reinvent.chemistry.similarity import Similarity diff --git a/reinvent_plugins/components/RDKit/comp_tpsa.py b/reinvent_plugins/components/RDKit/comp_tpsa.py index 2788b36..18705e0 100644 --- a/reinvent_plugins/components/RDKit/comp_tpsa.py +++ b/reinvent_plugins/components/RDKit/comp_tpsa.py @@ -8,12 +8,13 @@ __all__ = ["TPSA"] -from dataclasses import dataclass, field -from typing import List +from typing import List, Optional from rdkit import Chem from rdkit.Chem import Descriptors import numpy as np +from pydantic import Field +from pydantic.dataclasses import dataclass from ..component_results import ComponentResults from reinvent_plugins.mol_cache import molcache @@ -25,7 +26,7 @@ class Parameters: """Parameters for the scoring component""" - includeSandP: Optional[List[bool]] = field(default_factory=lambda: [False]) + includeSandP: Optional[List[bool]] = Field(default_factory=lambda: [False]) @add_tag("__component") diff --git a/reinvent_plugins/components/comp_chemprop.py b/reinvent_plugins/components/comp_chemprop.py index ca1675d..772264e 100644 --- a/reinvent_plugins/components/comp_chemprop.py +++ b/reinvent_plugins/components/comp_chemprop.py @@ -23,12 +23,13 @@ from __future__ import annotations __all__ = ["ChemProp"] -from dataclasses import dataclass, field from typing import List import logging import chemprop import numpy as np +from pydantic import Field +from pydantic.dataclasses import dataclass from .component_results import ComponentResults from .add_tag import add_tag @@ -50,7 +51,7 @@ class Parameters: """ checkpoint_dir: List[str] - rdkit_2d_normalized: List[bool] = field(default_factory=lambda: [False]) + rdkit_2d_normalized: List[bool] = Field(default_factory=lambda: [False]) @add_tag("__component") diff --git a/reinvent_plugins/components/comp_custom_alerts.py b/reinvent_plugins/components/comp_custom_alerts.py index 4f178ef..2462515 100644 --- a/reinvent_plugins/components/comp_custom_alerts.py +++ b/reinvent_plugins/components/comp_custom_alerts.py @@ -1,11 +1,11 @@ """Compute scores with RDKit's QED""" __all__ = ["CustomAlerts"] -from dataclasses import dataclass from typing import List import numpy as np from rdkit import Chem +from pydantic.dataclasses import dataclass from .component_results import ComponentResults from reinvent_plugins.mol_cache import molcache diff --git a/reinvent_plugins/components/comp_dockstream.py b/reinvent_plugins/components/comp_dockstream.py index 1046af9..c29eafe 100644 --- a/reinvent_plugins/components/comp_dockstream.py +++ b/reinvent_plugins/components/comp_dockstream.py @@ -12,10 +12,10 @@ import logging import copy -from dataclasses import dataclass, field from typing import List, Optional import numpy as np +from pydantic.dataclasses import dataclass from .component_results import ComponentResults from .run_program import run_command diff --git a/reinvent_plugins/components/comp_external_process.py b/reinvent_plugins/components/comp_external_process.py index d288c60..4c67798 100644 --- a/reinvent_plugins/components/comp_external_process.py +++ b/reinvent_plugins/components/comp_external_process.py @@ -11,10 +11,10 @@ import os import shlex import json -from dataclasses import dataclass from typing import List import numpy as np +from pydantic.dataclasses import dataclass from .component_results import ComponentResults from .run_program import run_command diff --git a/reinvent_plugins/components/comp_generic_rest.py b/reinvent_plugins/components/comp_generic_rest.py index e346f81..ca55ff3 100644 --- a/reinvent_plugins/components/comp_generic_rest.py +++ b/reinvent_plugins/components/comp_generic_rest.py @@ -3,11 +3,11 @@ from __future__ import annotations __all__ = ["REST"] -from dataclasses import dataclass from typing import List import requests import numpy as np +from pydantic.dataclasses import dataclass from .component_results import ComponentResults from .add_tag import add_tag diff --git a/reinvent_plugins/components/comp_icolos.py b/reinvent_plugins/components/comp_icolos.py index ad48eb3..2b12a86 100644 --- a/reinvent_plugins/components/comp_icolos.py +++ b/reinvent_plugins/components/comp_icolos.py @@ -9,10 +9,10 @@ import json import tempfile import time -from dataclasses import dataclass from typing import List, IO import numpy as np +from pydantic.dataclasses import dataclass from .component_results import ComponentResults from .run_program import run_command diff --git a/reinvent_plugins/components/comp_maize.py b/reinvent_plugins/components/comp_maize.py index bcce8a8..6c46811 100644 --- a/reinvent_plugins/components/comp_maize.py +++ b/reinvent_plugins/components/comp_maize.py @@ -137,10 +137,11 @@ import json import tempfile import time -from dataclasses import dataclass, field from typing import List, Any import numpy as np +from pydantic import Field +from pydantic.dataclasses import dataclass from .component_results import ComponentResults from .run_program import run_command @@ -169,11 +170,11 @@ class Parameters: executable: List[str] workflow: List[str] - debug: List[bool] = field(default_factory=lambda: [False]) - keep: List[bool] = field(default_factory=lambda: [False]) - log: List[str | None] = field(default_factory=lambda: [None]) - config: List[str | None] = field(default_factory=lambda: [None]) - parameters: List[dict[str, Any]] = field(default_factory=lambda: [{}]) + debug: List[bool] = Field(default_factory=lambda: [False]) + keep: List[bool] = Field(default_factory=lambda: [False]) + log: List[str | None] = Field(default_factory=lambda: [None]) + config: List[str | None] = Field(default_factory=lambda: [None]) + parameters: List[dict[str, Any]] = Field(default_factory=lambda: [{}]) CMD = "{exe} {config} --inp {inp} --out {out} --parameters {params}" diff --git a/reinvent_plugins/components/comp_mmp.py b/reinvent_plugins/components/comp_mmp.py index b3ea928..f7c875b 100644 --- a/reinvent_plugins/components/comp_mmp.py +++ b/reinvent_plugins/components/comp_mmp.py @@ -7,13 +7,13 @@ import logging import shlex from io import StringIO -from dataclasses import dataclass, field from typing import List import numpy as np import pandas as pd - from rdkit import Chem +from pydantic import Field +from pydantic.dataclasses import dataclass from .component_results import ComponentResults from .run_program import run_command @@ -34,9 +34,9 @@ class Parameters: """ reference_smiles: List[List[str]] - num_of_cuts: List[int] = field(default_factory=lambda: [1]) - max_variable_heavies: List[int] = field(default_factory=lambda: [40]) - max_variable_ratio: List[float] = field(default_factory=lambda: [0.33]) + num_of_cuts: List[int] = Field(default_factory=lambda: [1]) + max_variable_heavies: List[int] = Field(default_factory=lambda: [40]) + max_variable_ratio: List[float] = Field(default_factory=lambda: [0.33]) FRAG_CMD = "mmpdb --quiet fragment --num-cuts {ncuts}" diff --git a/reinvent_plugins/components/comp_qptuna.py b/reinvent_plugins/components/comp_qptuna.py index 8de575b..c506541 100644 --- a/reinvent_plugins/components/comp_qptuna.py +++ b/reinvent_plugins/components/comp_qptuna.py @@ -4,12 +4,12 @@ __all__ = ["Qptuna"] import pickle -from dataclasses import dataclass from typing import List import logging import json import numpy as np +from pydantic.dataclasses import dataclass from .component_results import ComponentResults from .add_tag import add_tag diff --git a/requirements-linux-64.lock b/requirements-linux-64.lock index 366728b..a7c558f 100644 --- a/requirements-linux-64.lock +++ b/requirements-linux-64.lock @@ -2,62 +2,66 @@ # This file is autogenerated by pip-compile with Python 3.10 # by the following command: # -# pip-compile --extra-index-url=https://download.pytorch.org/whl/cu113 --extra-index-url=https://pypi.anaconda.org/OpenEye/simple --output-file=requirements-linux-64.lock pyproject.toml +# pip-compile --extra-index-url=https://download.pytorch.org/whl/cu121 --extra-index-url=https://pypi.anaconda.org/OpenEye/simple --output-file=requirements-linux-64.lock pyproject.toml # ---extra-index-url https://download.pytorch.org/whl/cu113 +--extra-index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://pypi.anaconda.org/OpenEye/simple -absl-py==1.4.0 +absl-py==2.1.0 # via tensorboard -alabaster==0.7.13 +alabaster==0.7.16 # via sphinx -attrs==22.2.0 - # via pytest -babel==2.12.1 +annotated-types==0.6.0 + # via pydantic +babel==2.14.0 # via sphinx -cachetools==5.3.0 - # via google-auth -certifi==2022.12.7 +blinker==1.7.0 + # via flask +certifi==2024.2.2 # via requests -charset-normalizer==3.1.0 +charset-normalizer==3.3.2 # via requests chemprop==1.5.2 # via reinvent (pyproject.toml) -click==8.1.3 +click==8.1.7 # via flask -cloudpickle==2.2.1 +cloudpickle==3.0.0 # via hyperopt -contourpy==1.0.7 +contourpy==1.2.1 # via matplotlib -cycler==0.11.0 +cycler==0.12.1 # via matplotlib -dill==0.3.6 +descriptastorus==2.6.1 + # via reinvent (pyproject.toml) +dill==0.3.8 # via # multiprocess # pathos -docutils==0.19 +docstring-parser==0.16 + # via typed-argument-parser +docutils==0.20.1 # via sphinx -exceptiongroup==1.1.0 +exceptiongroup==1.2.0 # via pytest -flask==2.2.3 +filelock==3.13.4 + # via + # torch + # triton +flask==3.0.3 # via chemprop -fonttools==4.38.0 +fonttools==4.51.0 # via matplotlib -funcy==1.18 +fsspec==2024.3.1 + # via torch +funcy==2.0 # via reinvent (pyproject.toml) -future==0.18.3 +future==1.0.0 # via hyperopt -google-auth==2.16.2 - # via - # google-auth-oauthlib - # tensorboard -google-auth-oauthlib==0.4.6 - # via tensorboard -grpcio==1.51.3 +grpcio==1.62.1 # via tensorboard hyperopt==0.2.7 # via chemprop -idna==3.4 +idna==3.7 # via requests imagesize==1.4.1 # via sphinx @@ -65,23 +69,22 @@ iniconfig==2.0.0 # via pytest itsdangerous==2.1.2 # via flask -jinja2==3.1.2 +jinja2==3.1.3 # via # flask # sphinx -joblib==1.2.0 + # torch +joblib==1.4.0 # via scikit-learn -kiwisolver==1.4.4 +kiwisolver==1.4.5 # via matplotlib -lazy-loader==0.1 - # via pandas-flavor -markdown==3.4.1 +markdown==3.6 # via tensorboard -markupsafe==2.1.2 +markupsafe==2.1.5 # via # jinja2 # werkzeug -matplotlib==3.7.1 +matplotlib==3.8.4 # via # chemprop # reinvent (pyproject.toml) @@ -89,16 +92,21 @@ mmpdb==2.1 # via reinvent (pyproject.toml) molvs==0.1.1 # via reinvent (pyproject.toml) -multiprocess==0.70.14 +mpmath==1.3.0 + # via sympy +multiprocess==0.70.16 # via pathos mypy-extensions==1.0.0 # via typing-inspect -networkx==3.0 - # via hyperopt -numpy==1.24.2 +networkx==3.3 + # via + # hyperopt + # torch +numpy==1.26.4 # via # chemprop # contourpy + # descriptastorus # hyperopt # matplotlib # pandas @@ -110,171 +118,194 @@ numpy==1.24.2 # tensorboardx # torchvision # xarray -oauthlib==3.2.2 - # via requests-oauthlib -openeye-toolkits==2022.2.2 +nvidia-cublas-cu12==12.1.3.1 + # via + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 + # torch +nvidia-cuda-cupti-cu12==12.1.105 + # via torch +nvidia-cuda-nvrtc-cu12==12.1.105 + # via torch +nvidia-cuda-runtime-cu12==12.1.105 + # via torch +nvidia-cudnn-cu12==8.9.2.26 + # via torch +nvidia-cufft-cu12==11.0.2.54 + # via torch +nvidia-curand-cu12==10.3.2.106 + # via torch +nvidia-cusolver-cu12==11.4.5.107 + # via torch +nvidia-cusparse-cu12==12.1.0.106 + # via + # nvidia-cusolver-cu12 + # torch +nvidia-nccl-cu12==2.19.3 + # via torch +nvidia-nvjitlink-cu12==12.4.127 + # via + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 +nvidia-nvtx-cu12==12.1.105 + # via torch +openeye-toolkits==2023.2.3 # via reinvent (pyproject.toml) -openeye-toolkits-python3-linux-x64==2022.2.2 - # via openeye-toolkits -packaging==23.0 +packaging==24.0 # via # matplotlib # pytest # sphinx # tensorboardx # xarray -pandas==1.5.3 +pandas==2.2.2 # via # chemprop # pandas-flavor # reinvent (pyproject.toml) # xarray -pandas-flavor==0.5.0 - # via chemprop -pathos==0.3.0 +pandas-flavor==0.6.0 + # via + # chemprop + # descriptastorus +pathos==0.3.2 # via reinvent (pyproject.toml) -pillow==9.4.0 +pillow==10.3.0 # via # matplotlib # rdkit # reinvent (pyproject.toml) # torchvision -pluggy==1.0.0 +pluggy==1.4.0 # via pytest -pox==0.3.2 +pox==0.3.4 # via pathos -ppft==1.7.6.6 +ppft==1.7.6.8 # via pathos -protobuf==3.20.3 +protobuf==5.26.1 # via # tensorboard # tensorboardx py4j==0.10.9.7 # via hyperopt -pyasn1==0.4.8 - # via - # pyasn1-modules - # rsa -pyasn1-modules==0.2.8 - # via google-auth -pydantic==1.10.5 +pydantic==2.7.0 # via reinvent (pyproject.toml) -pygments==2.14.0 +pydantic-core==2.18.1 + # via pydantic +pygments==2.17.2 # via sphinx -pyparsing==3.0.9 +pyparsing==3.1.2 # via matplotlib -pytest==7.2.2 +pytest==8.1.1 # via # pytest-mock # reinvent (pyproject.toml) -pytest-mock==3.12.0 +pytest-mock==3.14.0 # via reinvent (pyproject.toml) -python-dateutil==2.8.2 +python-dateutil==2.9.0.post0 # via # matplotlib # pandas -python-dotenv==1.0.0 +python-dotenv==1.0.1 # via reinvent (pyproject.toml) -pytz==2022.7.1 +pytz==2024.1 # via pandas -pyyaml==6.0 - # via reinvent (pyproject.toml) -rdkit==2022.9.5 +pyyaml==6.0.1 # via reinvent (pyproject.toml) -requests==2.28.2 +rdkit==2023.9.5 + # via + # descriptastorus + # reinvent (pyproject.toml) +requests==2.31.0 # via # reinvent (pyproject.toml) # requests-mock - # requests-oauthlib # sphinx - # tensorboard - # torchvision -requests-mock==1.10.0 +requests-mock==1.12.1 # via reinvent (pyproject.toml) -requests-oauthlib==1.3.1 - # via google-auth-oauthlib -rsa==4.9 - # via google-auth -scikit-learn==1.2.1 - # via chemprop -scipy==1.10.1 +scikit-learn==1.2.2 # via # chemprop + # reinvent (pyproject.toml) +scipy==1.13.0 + # via + # chemprop + # descriptastorus # hyperopt # reinvent (pyproject.toml) # scikit-learn six==1.16.0 # via - # google-auth # hyperopt # molvs # python-dateutil - # requests-mock + # tensorboard snowballstemmer==2.2.0 # via sphinx -sphinx==6.1.3 +sphinx==7.2.6 # via chemprop -sphinxcontrib-applehelp==1.0.4 +sphinxcontrib-applehelp==1.0.8 # via sphinx -sphinxcontrib-devhelp==1.0.2 +sphinxcontrib-devhelp==1.0.6 # via sphinx -sphinxcontrib-htmlhelp==2.0.1 +sphinxcontrib-htmlhelp==2.0.5 # via sphinx sphinxcontrib-jsmath==1.0.1 # via sphinx -sphinxcontrib-qthelp==1.0.3 +sphinxcontrib-qthelp==1.0.7 # via sphinx -sphinxcontrib-serializinghtml==1.1.5 +sphinxcontrib-serializinghtml==1.1.10 # via sphinx -tenacity==8.2.2 +sympy==1.12 + # via torch +tenacity==8.2.3 # via reinvent (pyproject.toml) -tensorboard==2.12.0 +tensorboard==2.16.2 # via reinvent (pyproject.toml) -tensorboard-data-server==0.7.0 +tensorboard-data-server==0.7.2 # via tensorboard -tensorboard-plugin-wit==1.8.1 - # via tensorboard -tensorboardx==2.6 +tensorboardx==2.6.2.2 # via chemprop -threadpoolctl==3.1.0 +threadpoolctl==3.4.0 # via scikit-learn tomli==2.0.1 # via # pytest # reinvent (pyproject.toml) -torch==1.12.1+cu113 +torch==2.2.1+cu121 # via # chemprop # reinvent (pyproject.toml) # torchvision -torchvision==0.13.1+cu113 +torchvision==0.17.1+cu121 # via reinvent (pyproject.toml) -tqdm==4.65.0 +tqdm==4.66.2 # via # chemprop # hyperopt # reinvent (pyproject.toml) -typed-argument-parser==1.7.2 +triton==2.2.0 + # via torch +typed-argument-parser==1.10.0 # via chemprop -typing-extensions==4.5.0 +typing-extensions==4.11.0 # via # pydantic + # pydantic-core # reinvent (pyproject.toml) # torch - # torchvision - # typed-argument-parser # typing-inspect -typing-inspect==0.8.0 +typing-inspect==0.9.0 # via typed-argument-parser -urllib3==1.26.14 +tzdata==2024.1 + # via pandas +urllib3==2.2.1 # via requests -werkzeug==2.2.3 +werkzeug==3.0.2 # via # flask # tensorboard -wheel==0.38.4 - # via tensorboard -xarray==2023.2.0 +xarray==2024.3.0 # via pandas-flavor xxhash==3.4.1 # via reinvent (pyproject.toml)