Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add REalTabFormer to supported synthesizers #360

Merged
merged 37 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
a097ae0
CU-86b2czk01 Adding realtabformer integration
cristid9 Oct 24, 2024
5519ce4
CU-86b2czk01 Refactor realtabformer
cristid9 Nov 4, 2024
3301239
CU-86b2czk01 Adding unit test for realtabformer synthesizer
cristid9 Nov 4, 2024
1f409aa
CU-86b2czk01 Cleaning up changes
cristid9 Nov 4, 2024
f2c9128
CU-86b2czk01 Refactoring changes to comply with linter
cristid9 Nov 5, 2024
73721ee
CU-86b2czk01 Updating integration test for realtabformer integration
cristid9 Nov 5, 2024
e713afe
CU-86b2czk01 Fixing linter warnings
cristid9 Nov 5, 2024
eda43bc
CU-86b2czk01 Adding docstring comment in realtabformer
cristid9 Nov 5, 2024
2a30b14
CU-86b2czk01 Improving integration tests
cristid9 Nov 8, 2024
0211729
CU-86b2czk01 Adding unit tests
cristid9 Nov 12, 2024
59ea56a
CU-86b2czk01 Reworking unittests
cristid9 Nov 14, 2024
dd8abdd
CU-86b2czk01 Updating realtabformer error message
cristid9 Nov 14, 2024
cd87434
CU-86b2czk01 Improving code to match linter feedback
cristid9 Nov 18, 2024
e040aa8
CU-86b2czk01 Fixing linter issues
cristid9 Nov 19, 2024
0621033
CU-86b2czk01 Refactoring tests
cristid9 Nov 26, 2024
dea85c8
Addressing comments on pull request
cristid9 Nov 29, 2024
f2992b0
Updating workflows
cristid9 Dec 2, 2024
9872a1c
Rolling back toml file
cristid9 Dec 2, 2024
fd4af87
Updating macos version
cristid9 Dec 2, 2024
aa99634
Updating pytorch high watermark version
cristid9 Dec 3, 2024
4947be5
Updating pytorch high watermark version without quotes
cristid9 Dec 3, 2024
aee722a
Adding debug statements
cristid9 Dec 3, 2024
72ca286
Adding debug statements
cristid9 Dec 3, 2024
a2de668
Adding debug logger
cristid9 Dec 3, 2024
20d5d80
Enhance debugging statements
cristid9 Dec 3, 2024
ab04383
Set PYTORCH_ENABLE_MPS_FALLBACK programatically
cristid9 Dec 3, 2024
07d9a83
Set PYTORCH_MPS_HIGH_WATERMARK_RATIO programatically
cristid9 Dec 3, 2024
e0a1f54
Set PYTORCH_MPS_HIGH_WATERMARK_RATIO as float
cristid9 Dec 4, 2024
9011a90
Set PYTORCH end variables
cristid9 Dec 4, 2024
77c0e9a
Remove initialization of PYTORCH_MPS_HIGH_WATERMARK_RATIO
cristid9 Dec 4, 2024
e359f41
Skipping macos tests
cristid9 Dec 5, 2024
0a4768b
Boosting numerical check in integration test
cristid9 Dec 9, 2024
b84a641
Fixing python version for python 3.12
cristid9 Dec 9, 2024
6f45733
Cleaning up github workflows
cristid9 Dec 10, 2024
3c38070
Adjusting checks for training times in integration test
cristid9 Dec 10, 2024
562aaa3
Updating integration test
cristid9 Dec 11, 2024
6ec9e93
Addressing review comments
cristid9 Dec 11, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/minimum.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ on:
push:
pull_request:
types: [opened, reopened]

env:
PYTORCH_ENABLE_MPS_FALLBACK: "1"
PYTORCH_MPS_HIGH_WATERMARK_RATIO: 0.0
cristid9 marked this conversation as resolved.
Show resolved Hide resolved
jobs:
minimum:
runs-on: ${{ matrix.os }}
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/unit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ on:
pull_request:
types: [opened, reopened]

env:
PYTORCH_MPS_HIGH_WATERMARK_RATIO: 0.0

cristid9 marked this conversation as resolved.
Show resolved Hide resolved
jobs:
unit:
runs-on: ${{ matrix.os }}
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,6 @@ tmp/
dask-worker-space
scripts/runs
scripts/datasets

# ReaLTabFormer
rtf_checkpoints/
amontanez24 marked this conversation as resolved.
Show resolved Hide resolved
12 changes: 8 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dependencies = [
'cloudpickle>=2.1.0',
'compress-pickle>=1.2.0',
'humanfriendly>=8.2',
"numpy>=1.21.0;python_version<'3.10'",
"numpy>=1.21.6;python_version<'3.10'",
"numpy>=1.23.3;python_version>='3.10' and python_version<'3.12'",
"numpy>=1.26.0;python_version>='3.12'",
"pandas>=1.4.0;python_version<'3.11'",
Expand All @@ -42,10 +42,10 @@ dependencies = [
"scipy>=1.9.2;python_version>='3.10' and python_version<'3.12'",
"scipy>=1.12.0;python_version>='3.12'",
'tabulate>=0.8.3,<0.9',
"torch>=1.9.0;python_version<'3.10'",
"torch>=1.12.1;python_version<'3.10'",
"torch>=2.0.0;python_version>='3.10' and python_version<'3.12'",
"torch>=2.2.0;python_version>='3.12'",
'tqdm>=4.29',
'tqdm>=4.66.3',
'XlsxWriter>=1.2.8',
'rdt>=1.13.1',
'sdmetrics>=0.17.0',
Expand All @@ -64,12 +64,16 @@ sdgym = { main = 'sdgym.cli.__main__:main' }

[project.optional-dependencies]
dask = ['dask', 'distributed']
realtabformer = ['realtabformer>=0.2.1', 'transformers<4.46']
test = [
'sdgym[realtabformer]',
'pytest>=6.2.5',
'pytest-cov>=2.6.0',
'jupyter>=1.0.0,<2',
'rundoc>=0.4.3,<0.5',
'tomli>=2.0.0,<3',
'realtabformer>=0.2.1',
'transformers<4.46'
cristid9 marked this conversation as resolved.
Show resolved Hide resolved
]
dev = [
'sdgym[dask, test]',
Expand Down Expand Up @@ -231,4 +235,4 @@ convention = "google"

[tool.ruff.lint.pycodestyle]
max-doc-length = 100
max-line-length = 100
max-line-length = 100
1 change: 1 addition & 0 deletions sdgym/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,7 @@ def benchmark_single_table(
- ``CTGANSynthesizer``
- ``CopulaGANSynthesizer``
- ``TVAESynthesizer``
- ``RealTabFormerSynthesizer``

custom_synthesizers (list[class] or ``None``):
A list of custom synthesizer classes to use. These can be completely custom or
Expand Down
2 changes: 2 additions & 0 deletions sdgym/synthesizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
)
from sdgym.synthesizers.identity import DataIdentity
from sdgym.synthesizers.column import ColumnSynthesizer
from sdgym.synthesizers.realtabformer import RealTabFormerSynthesizer
from sdgym.synthesizers.sdv import (
CopulaGANSynthesizer,
CTGANSynthesizer,
Expand Down Expand Up @@ -38,4 +39,5 @@
'create_sdv_synthesizer_variant',
'create_sequential_synthesizer',
'SYNTHESIZER_MAPPING',
'RealTabFormerSynthesizer',
)
44 changes: 44 additions & 0 deletions sdgym/synthesizers/realtabformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""REaLTabFormer integration."""

import contextlib
import logging
from functools import partialmethod

import tqdm

from sdgym.synthesizers.base import BaselineSynthesizer


@contextlib.contextmanager
def prevent_tqdm_output():
"""Temporarily disables tqdm m."""
tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)
try:
yield
finally:
tqdm.__init__ = partialmethod(tqdm.__init__, disable=False)


class RealTabFormerSynthesizer(BaselineSynthesizer):
"""Custom wrapper for the REaLTabFormer synthesizer to make it work with SDGym."""

LOGGER = logging.getLogger(__name__)

def _get_trained_synthesizer(self, data, metadata):
try:
from realtabformer import REaLTabFormer
except Exception as exception:
raise ValueError(
"In order to use 'RealTabFormerSynthesizer' you have to install the extra"
" dependencies by running pip install sdgym['realtabformer'] "
) from exception

with prevent_tqdm_output():
model = REaLTabFormer(model_type='tabular')
model.fit(data, device='cpu')
cristid9 marked this conversation as resolved.
Show resolved Hide resolved

return model

def _sample_from_synthesizer(self, synthesizer, n_sample):
"""Sample synthetic data with specified sample count."""
return synthesizer.sample(n_sample, device='cpu')
cristid9 marked this conversation as resolved.
Show resolved Hide resolved
27 changes: 27 additions & 0 deletions tests/integration/synthesizers/test_realtabformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import sys

import pytest

from sdgym import load_dataset
from sdgym.synthesizers import RealTabFormerSynthesizer


@pytest.mark.skipif(sys.platform.startswith('darwin'), reason='Test not supported on github MacOS')
def test_realtabformer_end_to_end():
"""Test it without metrics."""
# Setup
data, metadata_dict = load_dataset(
'single_table', 'student_placements', limit_dataset_size=False
)
realtabformer_instance = RealTabFormerSynthesizer()

# Run
trained_synthesizer = realtabformer_instance.get_trained_synthesizer(data, metadata_dict)
sampled_data = realtabformer_instance.sample_from_synthesizer(trained_synthesizer, n_samples=10)

# Assert
assert sampled_data.shape[1] == data.shape[1], (
f'Sampled data shape {sampled_data.shape} does not match original data shape {data.shape}'
)

assert set(sampled_data.columns) == set(data.columns)
21 changes: 20 additions & 1 deletion tests/integration/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import contextlib
import io
import re
import sys
import time

import numpy as np
Expand Down Expand Up @@ -49,6 +50,25 @@ def test_benchmark_single_table_basic_synthsizers():
] == quality_scores.index.tolist()


@pytest.mark.skipif(sys.platform.startswith('darwin'), reason='Test not supported on github MacOS')
def test_benchmark_single_table_realtabformer_no_metrics():
"""Test it without metrics."""
# Run
output = sdgym.benchmark_single_table(
synthesizers=['RealTabFormerSynthesizer'],
sdv_datasets=['student_placements'],
sdmetrics=[],
)
cristid9 marked this conversation as resolved.
Show resolved Hide resolved

# Assert
train_time = output['Train_Time'][0]
sample_time = output['Sample_Time'][0]
assert isinstance(train_time, (int, float, complex)), 'Train_Time is not numerical'
assert isinstance(sample_time, (int, float, complex)), 'Sample_Time is not numerical'
assert train_time >= 0
assert sample_time >= 0


def test_benchmark_single_table_no_metrics():
"""Test it without metrics."""
# Run
Expand All @@ -62,7 +82,6 @@ def test_benchmark_single_table_no_metrics():
assert not output.empty
assert 'Train_Time' in output
assert 'Sample_Time' in output

# Expect no metric columns.
assert len(output.columns) == 10

Expand Down
63 changes: 63 additions & 0 deletions tests/unit/synthesizers/test_realtabformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Tests for the realtabformer module."""

from unittest.mock import MagicMock, patch

import numpy as np
import pandas as pd
import pytest

from sdgym.synthesizers import RealTabFormerSynthesizer


@pytest.fixture
def sample_data():
"""Provide sample data for testing."""
n_samples = 10
num_values = np.random.normal(size=n_samples)

return pd.DataFrame({
'num': num_values,
})


class TestRealTabFormerSynthesizer:
"""Unit tests for RealTabFormerSynthesizer integration with SDGym."""

@patch('realtabformer.REaLTabFormer')
def test__get_trained_synthesizer(self, mock_real_tab_former):
"""Test _get_trained_synthesizer

Initializes REaLTabFormer and fits REaLTabFormer with
correct parameters.
"""
# Setup
mock_model = MagicMock()
mock_real_tab_former.return_value = mock_model
data = MagicMock()
metadata = MagicMock()
synthesizer = RealTabFormerSynthesizer()

# Run
result = synthesizer._get_trained_synthesizer(data, metadata)

# Assert
mock_real_tab_former.assert_called_once_with(model_type='tabular')
mock_model.fit.assert_called_once_with(data, device='cpu')
assert result == mock_model, 'Expected the trained model to be returned.'

def test__sample_from_synthesizer(self):
"""Test _sample_from_synthesizer generates data with the specified sample size."""
# Setup
trained_model = MagicMock()
trained_model.sample.return_value = MagicMock(shape=(10, 5)) # Mock sample data shape
n_sample = 10
synthesizer = RealTabFormerSynthesizer()

# Run
synthetic_data = synthesizer._sample_from_synthesizer(trained_model, n_sample)

# Assert
trained_model.sample.assert_called_once_with(n_sample, device='cpu')
assert synthetic_data.shape[0] == n_sample, (
f'Expected {n_sample} rows, but got {synthetic_data.shape[0]}'
)
Loading