Skip to content

Commit

Permalink
update dependencies (#288)
Browse files Browse the repository at this point in the history
* update dependencies

* fix trapz import

* add python 3.11 to workflows

* remove python 3.8 support

* clean up workflows

* add retries to get_airfoil

* skip mnist download if not linux for speed

* clean up
  • Loading branch information
robsdavis authored Sep 3, 2024
1 parent 943fa28 commit 9dd7f6c
Show file tree
Hide file tree
Showing 34 changed files with 409 additions and 121 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.9", "3.10", "3.11"]
os: [macos-latest]

steps:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_all_tutorials.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.9", "3.10", "3.11"]
os: [ubuntu-latest]
steps:
- uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.9", "3.10", "3.11"]
os: [macos-latest, ubuntu-latest, windows-latest]
steps:
- uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_tutorials.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.9", "3.10", "3.11"]
os: [ubuntu-latest]
steps:
- uses: actions/checkout@v2
Expand Down
4 changes: 2 additions & 2 deletions prereq.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
numpy>=1.20, <1.24
torch>=1.10.0,<2.0
numpy>=1.20
torch>=2.1, <2.3 # Max due to tsai
tsai
wheel>=0.40
18 changes: 9 additions & 9 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -29,32 +29,33 @@ include_package_data = True
package_dir =
=src

python_requires = >=3.8
python_requires = >=3.9

install_requires =
importlib-metadata
pandas>=1.4,<2
torch>=1.10.0,<2.0
pandas>=2.1 # min due to lifelines
torch>=2.1, <2.3 # Max due to tsai
scikit-learn>=1.2
nflows>=0.14
numpy>=1.20, <1.24
lifelines>=0.27,!= 0.27.5, <0.27.8
numpy>=1.20, <2.0
lifelines>=0.29.0, <0.30.0 # max due to xgbse
opacus>=1.3
networkx>2.0,<3.0
decaf-synthetic-data>=0.1.6
optuna>=3.1
shap
tenacity
tqdm
loguru
pydantic<2.0
cloudpickle
scipy
xgboost<2.0.0
xgboost<3.0.0
geomloss
pgmpy
redis
pycox
xgbse
xgbse>=0.3.1
pykeops
fflows
monai
Expand Down Expand Up @@ -96,13 +97,12 @@ testing =
click

goggle =
dgl<2.0
dgl
torch_geometric
torch_sparse
torch_scatter

all =
importlib-metadata;python_version<"3.8"
%(testing)s
%(goggle)s

Expand Down
15 changes: 12 additions & 3 deletions src/synthcity/plugins/core/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,11 +931,20 @@ def unpack(self, as_numpy: bool = False, pad: bool = False) -> Any:

if as_numpy:
longest_observation_seq = max([len(seq) for seq in temporal_data])
padded_temporal_data = np.zeros(
(len(temporal_data), longest_observation_seq, 5)
)
mask = np.ones((len(temporal_data), longest_observation_seq, 5), dtype=bool)
for i, arr in enumerate(temporal_data):
padded_temporal_data[i, : arr.shape[0], :] = arr # Copy the actual data
mask[
i, : arr.shape[0], :
] = False # Set mask to False where actual data is present

masked_temporal_data = ma.masked_array(padded_temporal_data, mask)
return (
np.asarray(static_data),
np.asarray(
temporal_data
), # TODO: check this works with time series benchmarks
masked_temporal_data, # TODO: check this works with time series benchmarks
# masked array to handle variable length sequences
ma.vstack(
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,16 @@
import numpy as np
import pandas as pd
from lifelines import KaplanMeierFitter
from scipy.integrate import trapz
from xgbse.non_parametric import _get_conditional_probs_from_survival

try:
# third party
from scipy.integrate import trapz
except ImportError:
from numpy import (
trapz,
) # As a fallback for older versions if scipy's import path changes

# synthcity absolute
from synthcity.plugins.core.models.survival_analysis.third_party.metrics import (
brier_score,
Expand Down
9 changes: 8 additions & 1 deletion src/synthcity/plugins/core/models/time_to_event/tte_aft.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@
import pandas as pd
from lifelines import WeibullAFTFitter
from pydantic import validate_arguments
from scipy.integrate import trapz

try:
# third party
from scipy.integrate import trapz
except ImportError:
from numpy import (
trapz,
) # As a fallback for older versions if scipy's import path changes

# synthcity absolute
from synthcity.plugins.core.distribution import Distribution, FloatDistribution
Expand Down
9 changes: 8 additions & 1 deletion src/synthcity/plugins/core/models/time_to_event/tte_coxph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@
import pandas as pd
from lifelines import CoxPHFitter
from pydantic import validate_arguments
from scipy.integrate import trapz

try:
# third party
from scipy.integrate import trapz
except ImportError:
from numpy import (
trapz,
) # As a fallback for older versions if scipy's import path changes

# synthcity absolute
from synthcity.plugins.core.distribution import Distribution, FloatDistribution
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,16 @@
import torchtuples as tt
from pycox.models import DeepHitSingle
from pydantic import validate_arguments
from scipy.integrate import trapz
from sklearn.model_selection import train_test_split

try:
# third party
from scipy.integrate import trapz
except ImportError:
from numpy import (
trapz,
) # As a fallback for older versions if scipy's import path changes

# synthcity absolute
from synthcity.plugins.core.distribution import (
CategoricalDistribution,
Expand Down
9 changes: 8 additions & 1 deletion src/synthcity/plugins/core/models/time_to_event/tte_xgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,17 @@
import numpy as np
import pandas as pd
from pydantic import validate_arguments
from scipy.integrate import trapz
from xgbse import XGBSEDebiasedBCE, XGBSEKaplanNeighbors, XGBSEStackedWeibull
from xgbse.converters import convert_to_structured

try:
# third party
from scipy.integrate import trapz
except ImportError:
from numpy import (
trapz,
) # As a fallback for older versions if scipy's import path changes

# synthcity absolute
from synthcity.plugins.core.distribution import (
CategoricalDistribution,
Expand Down
9 changes: 9 additions & 0 deletions src/synthcity/plugins/core/serializable.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ class Serializable:
"""Utility class for model persistence."""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
derived_module_path: Optional[Path] = None
self.fitted = (
False # make sure all serializable objects are not fitted by default
)

search_module = self.__class__.__module__
if not search_module.endswith(".py"):
Expand Down Expand Up @@ -58,9 +62,14 @@ def save_dict(self) -> dict:
data = self.__dict__[key]
if isinstance(data, Serializable):
members[key] = self.__dict__[key].save_dict()
elif key == "model":
members[key] = serialize(self.__dict__[key])
else:
members[key] = copy.deepcopy(self.__dict__[key])

if "fitted" not in members:
members["fitted"] = self.fitted # Ensure 'fitted' is always serialized

return {
"source": "synthcity",
"data": members,
Expand Down
145 changes: 136 additions & 9 deletions src/synthcity/utils/serialization.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,146 @@
# stdlib
import hashlib
from pathlib import Path
from typing import Any, Union
from typing import Any, List, Union

# third party
import cloudpickle
import pandas as pd


def save(model: Any) -> bytes:
return cloudpickle.dumps(model)


def load(buff: bytes) -> Any:
return cloudpickle.loads(buff)
from opacus import PrivacyEngine

# The list of plugins that are not simply loadable with cloudpickle
unloadable_plugins: List[str] = [
"dpgan", # DP-GAN plugin id not loadable with cloudpickle due to the DPOptimizer
]


# TODO: simplify this function back to just cloudpickle.dumps(model), if possible (i.e. if the DPOptimizer is not needed or becomes loadable with cloudpickle)
def save(custom_model: Any) -> bytes:
"""
Serialize a custom model object that may or may not contain a PyTorch model with a privacy engine.
Args:
custom_model: The custom model object to serialize, potentially containing a PyTorch model with a privacy engine.
Returns:
bytes: Serialized model state as bytes.
"""
# Checks is custom model is not a plugin without circular import
if not hasattr(custom_model, "name"):
return cloudpickle.dumps(custom_model)

if custom_model.name() not in unloadable_plugins:
return cloudpickle.dumps(custom_model)

# Initialize the checkpoint dictionary
checkpoint = {
"custom_model_state": None,
"pytorch_model_state": None,
"privacy_engine_state": None,
"optimizer_state": None,
"optimizer_class": None,
"optimizer_defaults": None,
}

# Save the state of the custom model object (excluding the PyTorch model and optimizer)
custom_model_state = {
key: value for key, value in custom_model.__dict__.items() if key != "model"
}
checkpoint["custom_model_state"] = cloudpickle.dumps(custom_model_state)

# Check if the custom model contains a PyTorch model
pytorch_model = None
if hasattr(custom_model, "model"):
pytorch_model = getattr(custom_model, "model")

# If a PyTorch model is found, check if it's using Opacus for DP
if pytorch_model:
checkpoint["pytorch_model_state"] = pytorch_model.state_dict()
if hasattr(pytorch_model, "privacy_engine") and isinstance(
pytorch_model.privacy_engine, PrivacyEngine
):
# Handle DP Optimizer
optimizer = pytorch_model.privacy_engine.optimizer

checkpoint.update(
{
"optimizer_state": optimizer.state_dict(),
"privacy_engine_state": pytorch_model.privacy_engine.state_dict(),
"optimizer_class": optimizer.__class__,
"optimizer_defaults": optimizer.defaults,
}
)

# Serialize the entire state with cloudpickle
return cloudpickle.dumps(checkpoint)


# TODO: simplify this function back to just cloudpickle.loads(model), if possible (i.e. if the DPOptimizer is not needed or becomes loadable with cloudpickle)
def load(buff: bytes, custom_model: Any = None) -> Any:
"""
Deserialize a custom model object that may or may not contain a PyTorch model with a privacy engine.
Args:
buff (bytes): Serialized model state as bytes.
custom_model: The custom model instance to load the state into.
Returns:
custom_model: The deserialized custom model with its original state.
"""
# Load the checkpoint
if custom_model is None or custom_model.name() not in unloadable_plugins:
return cloudpickle.loads(buff)

if custom_model is None:
raise ValueError(
f"custom_model must be provided when loading one of the following plugins: {unloadable_plugins}"
)

checkpoint = cloudpickle.loads(buff)
# Restore the custom model's own state (excluding the PyTorch model)
custom_model_state = cloudpickle.loads(checkpoint["custom_model_state"])
for key, value in custom_model_state.items():
setattr(custom_model, key, value)

# Find the PyTorch model inside the custom model if it exists
pytorch_model = None
if hasattr(custom_model, "model"):
pytorch_model = getattr(custom_model, "model")

# Load the states into the PyTorch model if it exists
if pytorch_model and checkpoint["pytorch_model_state"] is not None:
pytorch_model.load_state_dict(checkpoint["pytorch_model_state"])

# Check if the serialized model had a privacy engine
if checkpoint["privacy_engine_state"] is not None:
# If there was a privacy engine, recreate and reattach it
optimizer_class = checkpoint["optimizer_class"]
optimizer_defaults = checkpoint["optimizer_defaults"]

# Ensure the optimizer is correctly created with model's parameters
optimizer = optimizer_class(
pytorch_model.parameters(), **optimizer_defaults
)

# Recreate the privacy engine
privacy_engine = PrivacyEngine(
pytorch_model,
sample_rate=optimizer.defaults.get(
"sample_rate", 0.01
), # Use saved or default values
noise_multiplier=optimizer.defaults.get("noise_multiplier", 1.0),
max_grad_norm=optimizer.defaults.get("max_grad_norm", 1.0),
)
privacy_engine.attach(optimizer)

# Load the saved states
optimizer.load_state_dict(checkpoint["optimizer_state"])
privacy_engine.load_state_dict(checkpoint["privacy_engine_state"])

# Assign back to the PyTorch model (or the appropriate container)
pytorch_model.privacy_engine = privacy_engine

return custom_model


def save_to_file(path: Union[str, Path], model: Any) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion src/synthcity/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.2.10"
__version__ = "0.2.11"

MAJOR_VERSION = ".".join(__version__.split(".")[:-1])
PATCH_VERSION = __version__.split(".")[-1]
Loading

0 comments on commit 9dd7f6c

Please sign in to comment.