From 38f4133d1583c2a15ff6c62bc05fccca15496b6e Mon Sep 17 00:00:00 2001 From: SimonBoothroyd Date: Fri, 27 Oct 2023 17:40:41 -0400 Subject: [PATCH] Add dimer fitting target (#44) --- .github/PULL_REQUEST_TEMPLATE.md | 5 + README.md | 3 + descent/{optimizers => optim}/__init__.py | 2 +- descent/{optimizers => optim}/_lm.py | 0 descent/targets/__init__.py | 1 + descent/targets/dimers.py | 346 ++++++++++++++++++ descent/tests/conftest.py | 14 + descent/tests/data/DESMOCK/DESMOCK.csv | 2 + .../DESMOCK/geometries/4321/DESMOCK_123.mol | 21 ++ .../tests/{optimizers => optim}/__init__.py | 0 .../tests/{optimizers => optim}/test_lm.py | 6 +- descent/tests/targets/__init__.py | 0 descent/tests/targets/test_dimers.py | 206 +++++++++++ descent/tests/utils/__init__.py | 0 descent/tests/utils/test_reporting.py | 24 ++ descent/utils/__init__.py | 1 + descent/utils/reporting.py | 73 ++++ devtools/envs/base.yaml | 3 + 18 files changed, 703 insertions(+), 4 deletions(-) create mode 100644 .github/PULL_REQUEST_TEMPLATE.md rename descent/{optimizers => optim}/__init__.py (55%) rename descent/{optimizers => optim}/_lm.py (100%) create mode 100644 descent/targets/__init__.py create mode 100644 descent/targets/dimers.py create mode 100644 descent/tests/data/DESMOCK/DESMOCK.csv create mode 100644 descent/tests/data/DESMOCK/geometries/4321/DESMOCK_123.mol rename descent/tests/{optimizers => optim}/__init__.py (100%) rename descent/tests/{optimizers => optim}/test_lm.py (96%) create mode 100644 descent/tests/targets/__init__.py create mode 100644 descent/tests/targets/test_dimers.py create mode 100644 descent/tests/utils/__init__.py create mode 100644 descent/tests/utils/test_reporting.py create mode 100644 descent/utils/__init__.py create mode 100644 descent/utils/reporting.py diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..7d780b3 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,5 @@ +## Description +Provide a brief description of the PR's purpose here. + +## Status +- [ ] Ready to go \ No newline at end of file diff --git a/README.md b/README.md index bf53976..82a2075 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,9 @@ DESCENT The `descent` framework aims to offer a modern API for training classical force field parameters (either from a traditional format such as SMIRNOFF or from some ML model) against reference data using `pytorch`. +This framework benefited hugely from [ForceBalance](https://github.com/leeping/forcebalance), and a significant +number of learning from that project, and from Lee-Ping, have influenced the design of this one. + ***Warning**: This code is currently experimental and under active development. If you are using this it, please be aware that it is not guaranteed to provide correct results, the documentation and testing maybe be incomplete, and the API can change without notice.* diff --git a/descent/optimizers/__init__.py b/descent/optim/__init__.py similarity index 55% rename from descent/optimizers/__init__.py rename to descent/optim/__init__.py index 85e4b09..387a9ca 100644 --- a/descent/optimizers/__init__.py +++ b/descent/optim/__init__.py @@ -1,5 +1,5 @@ """Custom parameter optimizers.""" -from descent.optimizers._lm import LevenbergMarquardt, LevenbergMarquardtConfig +from descent.optim._lm import LevenbergMarquardt, LevenbergMarquardtConfig __all__ = ["LevenbergMarquardt", "LevenbergMarquardtConfig"] diff --git a/descent/optimizers/_lm.py b/descent/optim/_lm.py similarity index 100% rename from descent/optimizers/_lm.py rename to descent/optim/_lm.py diff --git a/descent/targets/__init__.py b/descent/targets/__init__.py new file mode 100644 index 0000000..f43bc67 --- /dev/null +++ b/descent/targets/__init__.py @@ -0,0 +1 @@ +"""Targets to train / assess models to / against.""" diff --git a/descent/targets/dimers.py b/descent/targets/dimers.py new file mode 100644 index 0000000..dae32d6 --- /dev/null +++ b/descent/targets/dimers.py @@ -0,0 +1,346 @@ +"""Train against dimer energies.""" +import pathlib +import typing + +import pyarrow +import smee +import smee.utils +import torch + +import descent.utils.reporting + +if typing.TYPE_CHECKING: + import pandas + from rdkit import Chem + + +EnergyFn = typing.Callable[ + ["pandas.DataFrame", tuple[str, ...], torch.Tensor], torch.Tensor +] + + +DATA_SCHEMA = pyarrow.schema( + [ + ("smiles_a", pyarrow.string()), + ("smiles_b", pyarrow.string()), + ("coords", pyarrow.list_(pyarrow.float64())), + ("energy", pyarrow.list_(pyarrow.float64())), + ("source", pyarrow.string()), + ] +) + + +class Dimer(typing.TypedDict): + """Represents a single experimental data point.""" + + smiles_a: str + smiles_b: str + + coords: torch.Tensor + energy: torch.Tensor + + source: str + + +def create_dataset(entries: list[Dimer]) -> pyarrow.Table: + """Create a dataset from a list of existing dimers. + + Args: + entries: The dimers to create the dataset from. + + Returns: + The created dataset. + """ + # TODO: validate rows + return pyarrow.Table.from_pylist( + [ + { + "smiles_a": entry["smiles_a"], + "smiles_b": entry["smiles_b"], + "coords": torch.tensor(entry["coords"]).flatten().tolist(), + "energy": torch.tensor(entry["energy"]).flatten().tolist(), + "source": entry["source"], + } + for entry in entries + ], + schema=DATA_SCHEMA, + ) + + +def _mol_to_smiles(mol: "Chem.Mol") -> str: + """Convert a molecule to a SMILES string with atom mapping. + + Args: + mol: The molecule to convert. + + Returns: + The SMILES string. + """ + from rdkit import Chem + + mol = Chem.AddHs(mol) + + for atom in mol.GetAtoms(): + atom.SetAtomMapNum(atom.GetIdx() + 1) + + return Chem.MolToSmiles(mol) + + +def create_from_des( + data_dir: pathlib.Path, + energy_fn: EnergyFn, +) -> pyarrow.Table: + """Create a dataset from a DESXXX dimer set. + + Args: + data_dir: The path to the DESXXX directory. + energy_fn: A function which computes the reference energy of a dimer. This + should take as input a pandas DataFrame containing the metadata for a + given group, a tuple of geometry IDs, and a tensor of coordinates with + ``shape=(n_dimers, n_atoms, 3)``. It should return a tensor of energies + with ``shape=(n_dimers,)`` and units of [kcal/mol]. + + Returns: + The created dataset. + """ + import pandas + from rdkit import Chem + + metadata = pandas.read_csv(data_dir / f"{data_dir.name}.csv", index_col=False) + + system_ids = metadata["system_id"].unique() + entries: list[Dimer] = [] + + for system_id in system_ids: + system_data = metadata[metadata["system_id"] == system_id] + + group_ids = metadata[metadata["system_id"] == system_id]["group_id"].unique() + + for group_id in group_ids: + group_data = system_data[system_data["group_id"] == group_id] + group_orig = group_data["group_orig"].unique()[0] + + geometry_ids = tuple(group_data["geom_id"].values) + + dimer_example = Chem.MolFromMolFile( + f"{data_dir}/geometries/{system_id}/DES{group_orig}_{geometry_ids[0]}.mol", + removeHs=False, + ) + mol_a, mol_b = Chem.GetMolFrags(dimer_example, asMols=True) + + smiles_a = _mol_to_smiles(mol_a) + smiles_b = _mol_to_smiles(mol_b) + + source = ( + f"{data_dir.name} system={system_id} orig={group_orig} group={group_id}" + ) + + coords_raw = [ + Chem.MolFromMolFile( + f"{data_dir}/geometries/{system_id}/DES{group_orig}_{geometry_id}.mol", + removeHs=False, + ) + .GetConformer() + .GetPositions() + .tolist() + for geometry_id in geometry_ids + ] + + coords = torch.tensor(coords_raw) + energy = energy_fn(group_data, geometry_ids, coords) + + entries.append( + { + "smiles_a": smiles_a, + "smiles_b": smiles_b, + "coords": coords, + "energy": energy, + "source": source, + } + ) + + return create_dataset(entries) + + +def extract_smiles(dataset: pyarrow.Table) -> list[str]: + """Return a list of unique SMILES strings in the dataset. + + Args: + dataset: The dataset to extract the SMILES strings from. + + Returns: + The list of unique SMILES strings. + """ + + smiles_a = dataset["smiles_a"].drop_null().unique().to_pylist() + smiles_b = dataset["smiles_b"].drop_null().unique().to_pylist() + + return sorted({*smiles_a, *smiles_b}) + + +def compute_dimer_energy( + topology_a: smee.TensorTopology, + topology_b: smee.TensorTopology, + force_field: smee.TensorForceField, + coords: torch.Tensor, +) -> torch.Tensor: + """Compute the energy of a dimer in a series of conformers. + + Args: + topology_a: The topology of the first monomer. + topology_b: The topology of the second monomer. + force_field: The force field to use. + coords: The coordinates of the dimer with ``shape=(n_dimers, n_atoms, 3)``. + + Returns: + The energy [kcal/mol] of the dimer in each conformer. + """ + dimer = smee.TensorSystem([topology_a, topology_b], [1, 1], False) + + coords_a = coords[:, : topology_a.n_atoms, :] + + if topology_a.v_sites is not None: + coords_a = smee.geometry.add_v_site_coords( + topology_a.v_sites, coords_a, force_field + ) + + coords_b = coords[:, topology_a.n_atoms :, :] + + if topology_b.v_sites is not None: + coords_b = smee.geometry.add_v_site_coords( + topology_b.v_sites, coords_b, force_field + ) + + coords = torch.cat([coords_a, coords_b], dim=1) + + energy_dimer = smee.compute_energy(dimer, force_field, coords) + + energy_a = smee.compute_energy(topology_a, force_field, coords_a) + energy_b = smee.compute_energy(topology_b, force_field, coords_b) + + return energy_dimer - energy_a - energy_b + + +def _predict( + dimer: Dimer, + force_field: smee.TensorForceField, + topologies: dict[str, smee.TensorTopology], +) -> tuple[torch.Tensor, torch.Tensor]: + """Predict the energies of a single dimer in multiple conformations. + + Args: + dimer: The dimer to predict the energies of. + force_field: The force field to use. + topologies: The topologies of each monomer. Each key should be a fully + mapped SMILES string. + + Returns: + The reference and predicted energies [kcal/mol] with ``shape=(n_confs,)``. + """ + + n_coords = len(dimer["energy"]) + + coords_flat = smee.utils.tensor_like( + dimer["coords"], force_field.potentials[0].parameters + ) + coords = coords_flat.reshape(n_coords, -1, 3) + + predicted = compute_dimer_energy( + topologies[dimer["smiles_a"]], + topologies[dimer["smiles_b"]], + force_field, + coords, + ) + reference = smee.utils.tensor_like(dimer["energy"], predicted) + + return reference, predicted + + +def predict( + dataset: pyarrow.Table, + force_field: smee.TensorForceField, + topologies: dict[str, smee.TensorTopology], +) -> tuple[torch.Tensor, torch.Tensor]: + """Predict the energies of each dimer in the dataset. + + Args: + dataset: The dataset to predict the energies of. + force_field: The force field to use. + topologies: The topologies of each monomer. Each key should be a fully + mapped SMILES string. + + Returns: + The reference and predicted energies [kcal/mol] of each dimer, each with + ``shape=(n_dimers * n_conf_per_dimer,)``. + """ + + dimers: list[Dimer] = dataset.to_pylist() + + reference, predicted = zip( + *[_predict(dimer, force_field, topologies) for dimer in dimers] + ) + return torch.stack(reference).flatten(), torch.stack(predicted).flatten() + + +def _plot_energies(energies: dict[str, torch.Tensor]) -> str: + from matplotlib import pyplot + + figure, axis = pyplot.subplots(1, 1, figsize=(4.0, 4.0)) + + for i, (k, v) in enumerate(energies.items()): + axis.plot( + v.cpu().detach().numpy(), + label=k, + linestyle="none", + marker=descent.utils.reporting.DEFAULT_MARKERS[i], + color=descent.utils.reporting.DEFAULT_COLORS[i], + ) + + axis.set_xlabel("Idx") + axis.set_ylabel("Energy [kcal / mol]") + + axis.legend() + + figure.tight_layout() + img = descent.utils.reporting.figure_to_img(figure) + + pyplot.close(figure) + + return img + + +def report( + dataset: pyarrow.Table, + force_fields: dict[str, smee.TensorForceField], + topologies: dict[str, smee.TensorTopology], + output_path: pathlib.Path, +): + """Generate a report comparing the predicted and reference energies of each dimer. + + Args: + dataset: The dataset to generate the report for. + force_fields: The force fields to use to predict the energies. + topologies: The topologies of each monomer. Each key should be a fully + mapped SMILES string. + output_path: The path to write the report to. + """ + import pandas + + rows = [] + + for entry in dataset.to_pylist(): + energies = {"ref": torch.tensor(entry["energy"])} + energies.update( + (force_field_name, _predict(entry, force_field, topologies)[1]) + for force_field_name, force_field in force_fields.items() + ) + + plot_img = _plot_energies(energies) + + mol_img = descent.utils.reporting.mols_to_img( + entry["smiles_a"], entry["smiles_b"] + ) + rows.append({"Dimer": mol_img, "Energy [kcal/mol]": plot_img}) + + output_path.parent.mkdir(parents=True, exist_ok=True) + return pandas.DataFrame(rows).to_html(output_path, escape=False, index=False) diff --git a/descent/tests/conftest.py b/descent/tests/conftest.py index e69de29..a23081f 100644 --- a/descent/tests/conftest.py +++ b/descent/tests/conftest.py @@ -0,0 +1,14 @@ +import pathlib + +import pytest + + +@pytest.fixture +def tmp_cwd(tmp_path, monkeypatch) -> pathlib.Path: + monkeypatch.chdir(tmp_path) + yield tmp_path + + +@pytest.fixture +def data_dir() -> pathlib.Path: + return pathlib.Path(__file__).parent / "data" diff --git a/descent/tests/data/DESMOCK/DESMOCK.csv b/descent/tests/data/DESMOCK/DESMOCK.csv new file mode 100644 index 0000000..ff3b24e --- /dev/null +++ b/descent/tests/data/DESMOCK/DESMOCK.csv @@ -0,0 +1,2 @@ +smiles0,smiles1,system_id,group_orig,group_id,geom_id,reference +CO,O,4321,MOCK,1423,123,-1.23 \ No newline at end of file diff --git a/descent/tests/data/DESMOCK/geometries/4321/DESMOCK_123.mol b/descent/tests/data/DESMOCK/geometries/4321/DESMOCK_123.mol new file mode 100644 index 0000000..8fd5db6 --- /dev/null +++ b/descent/tests/data/DESMOCK/geometries/4321/DESMOCK_123.mol @@ -0,0 +1,21 @@ + + RDKit 3D + + 9 7 0 0 0 0 0 0 0 0999 V2000 + 0.0000 1.0000 2.0000 C 0 0 0 0 0 0 0 0 0 0 0 0 + 3.0000 4.0000 5.0000 O 0 0 0 0 0 0 0 0 0 0 0 0 + 6.0000 7.0000 8.0000 H 0 0 0 0 0 0 0 0 0 0 0 0 + 9.0000 10.0000 11.0000 H 0 0 0 0 0 0 0 0 0 0 0 0 + 12.0000 13.0000 14.0000 H 0 0 0 0 0 0 0 0 0 0 0 0 + 15.0000 16.0000 17.0000 H 0 0 0 0 0 0 0 0 0 0 0 0 + 18.0000 19.0000 20.0000 O 0 0 0 0 0 0 0 0 0 0 0 0 + 21.0000 22.0000 23.0000 H 0 0 0 0 0 0 0 0 0 0 0 0 + 24.0000 25.0000 26.0000 H 0 0 0 0 0 0 0 0 0 0 0 0 + 1 2 1 0 + 1 3 1 0 + 1 4 1 0 + 1 5 1 0 + 2 6 1 0 + 7 8 1 0 + 7 9 1 0 +M END diff --git a/descent/tests/optimizers/__init__.py b/descent/tests/optim/__init__.py similarity index 100% rename from descent/tests/optimizers/__init__.py rename to descent/tests/optim/__init__.py diff --git a/descent/tests/optimizers/test_lm.py b/descent/tests/optim/test_lm.py similarity index 96% rename from descent/tests/optimizers/test_lm.py rename to descent/tests/optim/test_lm.py index 913c579..955e038 100644 --- a/descent/tests/optimizers/test_lm.py +++ b/descent/tests/optim/test_lm.py @@ -3,7 +3,7 @@ import pytest import torch -from descent.optimizers._lm import ( +from descent.optim._lm import ( LevenbergMarquardt, _damping_factor_loss_fn, _solver, @@ -82,7 +82,7 @@ def test_damping_factor_loss_fn(mocker): hessian = mocker.Mock() solver_fn = mocker.patch( - "descent.optimizers._lm._solver", autospec=True, return_value=(dx, 0.0) + "descent.optim._lm._solver", autospec=True, return_value=(dx, 0.0) ) trust_radius = 12 @@ -118,7 +118,7 @@ def test_levenberg_marquardt_adaptive(mocker, caplog): ), ] mock_step_fn = mocker.patch( - "descent.optimizers._lm._step", autospec=True, side_effect=mock_dx_traj + "descent.optim._lm._step", autospec=True, side_effect=mock_dx_traj ) mock_loss_traj = [ diff --git a/descent/tests/targets/__init__.py b/descent/tests/targets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/descent/tests/targets/test_dimers.py b/descent/tests/targets/test_dimers.py new file mode 100644 index 0000000..dd6f692 --- /dev/null +++ b/descent/tests/targets/test_dimers.py @@ -0,0 +1,206 @@ +import openff.interchange +import openff.toolkit +import openff.units +import pytest +import smee.converters +import torch + +from descent.targets.dimers import ( + Dimer, + compute_dimer_energy, + create_dataset, + create_from_des, + extract_smiles, + predict, + report, +) + + +@pytest.fixture +def mock_dimer() -> Dimer: + return { + "smiles_a": "[C:1]([O:2][H:6])([H:3])([H:4])[H:5]", + "smiles_b": "[O:1]([H:2])[H:3]", + "coords": torch.arange(54, dtype=torch.float32).reshape(2, 9, 3), + "energy": 3.0 * torch.arange(2, dtype=torch.float32), + "source": "some source...", + } + + +def test_create_dataset(mock_dimer): + expected_data_entries = [ + { + "smiles_a": mock_dimer["smiles_a"], + "smiles_b": mock_dimer["smiles_b"], + "coords": mock_dimer["coords"].flatten().tolist(), + "energy": mock_dimer["energy"].tolist(), + "source": mock_dimer["source"], + }, + ] + + dataset = create_dataset([mock_dimer]) + assert len(dataset) == 1 + + data_entries = dataset.to_pylist() + + assert data_entries == pytest.approx(expected_data_entries) + + +def test_create_from_des(data_dir): + expected_coords = torch.arange(6 * 3 + 3 * 3, dtype=torch.float32).reshape(1, 9, 3) + + def energy_fn(data, ids, coords): + assert coords.shape == expected_coords.shape + assert torch.allclose(coords, expected_coords) + + assert ids == (123,) + + return torch.tensor(data["reference"].values) + + dataset = create_from_des(data_dir / "DESMOCK", energy_fn) + assert len(dataset) == 1 + + expected = { + "smiles_a": "[C:1]([O:2][H:6])([H:3])([H:4])[H:5]", + "smiles_b": "[O:1]([H:2])[H:3]", + "coords": expected_coords.flatten().tolist(), + "energy": [-1.23], + "source": "DESMOCK system=4321 orig=MOCK group=1423", + } + + assert dataset.to_pylist() == [pytest.approx(expected)] + + +def test_extract_smiles(mock_dimer): + expected_smiles = ["[C:1]([O:2][H:6])([H:3])([H:4])[H:5]", "[O:1]([H:2])[H:3]"] + + dataset = create_dataset([mock_dimer, mock_dimer]) + smiles = extract_smiles(dataset) + + assert smiles == expected_smiles + + +def test_compute_dimer_energy(): + openff_ff = openff.toolkit.ForceField() + openff_ff.get_parameter_handler("vdW").add_parameter( + { + "smirks": "[Ar:1]", + "epsilon": 1.0 * openff.units.unit.kilocalorie / openff.units.unit.mole, + "sigma": 1.0 * openff.units.unit.angstrom, + } + ) + openff_ff.get_parameter_handler("vdW").add_parameter( + { + "smirks": "[He:1]", + "epsilon": 4.0 * openff.units.unit.kilocalorie / openff.units.unit.mole, + "sigma": 1.0 * openff.units.unit.angstrom, + } + ) + + interchanges = [ + openff.interchange.Interchange.from_smirnoff( + openff_ff, openff.toolkit.Molecule.from_smiles(smiles).to_topology() + ) + for smiles in ("[Ar]", "[He]") + ] + tensor_ff, [top_a, top_b] = smee.converters.convert_interchange(interchanges) + + coords = torch.tensor( + [[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], [[0.0, 0.0, 0.0], [2.0, 0.0, 0.0]]] + ) + + # eps = sqrt(4 * 1) + expected_energies = ( + 4.0 * 2.0 * torch.tensor([0.0, (1.0 / 2.0) ** 12 - (1.0 / 2.0) ** 6]) + ) + + energies = compute_dimer_energy(top_a, top_b, tensor_ff, coords) + assert energies.shape == expected_energies.shape + assert torch.allclose(energies, expected_energies) + + +def test_compute_dimer_energy_v_sites(): + openff_ff = openff.toolkit.ForceField("tip4p_fb.offxml") + + interchange = openff.interchange.Interchange.from_smirnoff( + openff_ff, openff.toolkit.Molecule.from_smiles("O").to_topology() + ) + tensor_ff, [top] = smee.converters.convert_interchange(interchange) + + coords = torch.tensor( + [ + [ + [-1.0, -1.0, 0.0], + [0.0, 0.0, 0.0], + [1.0, -1.0, 0.0], + [-1.0, 2.0, 0.0], + [0.0, 1.0, 0.0], + [1.0, 2.0, 0.0], + ] + ], + dtype=torch.float64, + ) + + energies = compute_dimer_energy(top, top, tensor_ff, coords) + assert energies.shape == (1,) + assert not torch.isnan(energies).any() + + +def test_predict(mock_dimer, mocker): + dataset = create_dataset([mock_dimer]) + + expected_y_pred = torch.Tensor([-1.23, 4.56]) + + mock_energy_fn = mocker.patch( + "descent.targets.dimers.compute_dimer_energy", + autospec=True, + return_value=expected_y_pred, + ) + + mock_ff = mocker.MagicMock() + mock_ff.potentials[0].parameters = torch.zeros(1) + + mock_top_a = mocker.Mock() + mock_tob_b = mocker.Mock() + + topologies = { + mock_dimer["smiles_a"]: mock_top_a, + mock_dimer["smiles_b"]: mock_tob_b, + } + + y_ref, y_pred = predict(dataset, mock_ff, topologies) + + assert y_pred.shape == (2,) + assert torch.allclose(y_pred, expected_y_pred) + + assert y_ref.shape == mock_dimer["energy"].shape + assert torch.allclose(y_ref, mock_dimer["energy"]) + + expected_coords = mock_dimer["coords"] + + mock_energy_fn.assert_called_once_with( + mock_top_a, mock_tob_b, mock_ff, pytest.approx(expected_coords) + ) + + +def test_report(tmp_cwd, mock_dimer, mocker): + dataset = create_dataset([mock_dimer]) + + expected_y_pred = torch.Tensor([-1.23, 4.56]) + + mock_predict_fn = mocker.patch( + "descent.targets.dimers._predict", + autospec=True, + return_value=(None, expected_y_pred), + ) + + mock_ff = mocker.MagicMock() + mock_tops = mocker.MagicMock() + + expected_path = tmp_cwd / "report.html" + report(dataset, {"A": mock_ff}, mock_tops, expected_path) + + assert expected_path.exists() + assert expected_path.read_text().startswith("') + + +def test_figure_to_img(): + figure = pyplot.figure() + img = figure_to_img(figure) + pyplot.close(figure) + + assert img.startswith('') diff --git a/descent/utils/__init__.py b/descent/utils/__init__.py new file mode 100644 index 0000000..6462a92 --- /dev/null +++ b/descent/utils/__init__.py @@ -0,0 +1 @@ +"""Utilities functions.""" diff --git a/descent/utils/reporting.py b/descent/utils/reporting.py new file mode 100644 index 0000000..f128d56 --- /dev/null +++ b/descent/utils/reporting.py @@ -0,0 +1,73 @@ +"""Utilities for reporting results.""" +import base64 +import io +import itertools +import typing + +if typing.TYPE_CHECKING: + from matplotlib import pyplot + from rdkit import Chem + + +DEFAULT_COLORS, DEFAULT_MARKERS = zip( + *itertools.product(["red", "green", "blue", "black"], ["x", "o", "+", "^"]) +) + + +def _mol_from_smiles(smiles: str) -> "Chem.Mol": + from rdkit import Chem + + mol = Chem.RemoveHs(Chem.MolFromSmiles(smiles)) + + for atom in mol.GetAtoms(): + atom.SetAtomMapNum(0) + + return mol + + +def mols_to_img(*smiles: str, width: int = 400, height: int = 200) -> str: + """Renders a set of molecules as an embeddable HTML image tag. + + Args: + *smiles: The SMILES patterns of the molecules to render. + width: The width of the image. + height: The height of the image. + + Returns: + The HTML image tag. + """ + from rdkit import Chem + from rdkit.Chem import Draw + + assert len(smiles) > 0 + + mol = _mol_from_smiles(smiles[0]) + + for pattern in smiles[1:]: + mol = Chem.CombineMols(mol, _mol_from_smiles(pattern)) + + mol = Draw.PrepareMolForDrawing(mol, forceCoords=True) + + drawer = Draw.rdMolDraw2D.MolDraw2DSVG(width, height) + drawer.DrawMolecule(mol) + drawer.FinishDrawing() + + data = base64.b64encode(drawer.GetDrawingText().encode()).decode() + return f'' + + +def figure_to_img(figure: "pyplot.Figure") -> str: + """Convert a matplotlib figure to an embeddable HTML image tag. + + Args: + figure: The figure to convert. + + Returns: + The HTML image tag. + """ + + with io.BytesIO() as stream: + figure.savefig(stream, format="svg") + data = base64.b64encode(stream.getvalue()).decode() + + return f'' diff --git a/devtools/envs/base.yaml b/devtools/envs/base.yaml index 6c1e58f..255ad87 100644 --- a/devtools/envs/base.yaml +++ b/devtools/envs/base.yaml @@ -13,11 +13,14 @@ dependencies: - pytorch - pydantic + - pyarrow ### Levenberg Marquardt - scipy # Optional packages + - rdkit + - matplotlib-base # Examples - jupyter