Skip to content

Commit

Permalink
Add REalTabFormer to supported synthesizers (#360)
Browse files Browse the repository at this point in the history
  • Loading branch information
cristid9 authored Dec 12, 2024
1 parent fcc853a commit 50ac89b
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 5 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,6 @@ tmp/
dask-worker-space
scripts/runs
scripts/datasets

# ReaLTabFormer
rtf_checkpoints/
10 changes: 6 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dependencies = [
'cloudpickle>=2.1.0',
'compress-pickle>=1.2.0',
'humanfriendly>=8.2',
"numpy>=1.21.0;python_version<'3.10'",
"numpy>=1.21.6;python_version<'3.10'",
"numpy>=1.23.3;python_version>='3.10' and python_version<'3.12'",
"numpy>=1.26.0;python_version>='3.12'",
"pandas>=1.4.0;python_version<'3.11'",
Expand All @@ -42,10 +42,10 @@ dependencies = [
"scipy>=1.9.2;python_version>='3.10' and python_version<'3.12'",
"scipy>=1.12.0;python_version>='3.12'",
'tabulate>=0.8.3,<0.9',
"torch>=1.9.0;python_version<'3.10'",
"torch>=1.12.1;python_version<'3.10'",
"torch>=2.0.0;python_version>='3.10' and python_version<'3.12'",
"torch>=2.2.0;python_version>='3.12'",
'tqdm>=4.29',
'tqdm>=4.66.3',
'XlsxWriter>=1.2.8',
'rdt>=1.13.1',
'sdmetrics>=0.17.0',
Expand All @@ -64,7 +64,9 @@ sdgym = { main = 'sdgym.cli.__main__:main' }

[project.optional-dependencies]
dask = ['dask', 'distributed']
realtabformer = ['realtabformer>=0.2.1', 'transformers<4.46']
test = [
'sdgym[realtabformer]',
'pytest>=6.2.5',
'pytest-cov>=2.6.0',
'jupyter>=1.0.0,<2',
Expand Down Expand Up @@ -231,4 +233,4 @@ convention = "google"

[tool.ruff.lint.pycodestyle]
max-doc-length = 100
max-line-length = 100
max-line-length = 100
1 change: 1 addition & 0 deletions sdgym/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,7 @@ def benchmark_single_table(
- ``CTGANSynthesizer``
- ``CopulaGANSynthesizer``
- ``TVAESynthesizer``
- ``RealTabFormerSynthesizer``
custom_synthesizers (list[class] or ``None``):
A list of custom synthesizer classes to use. These can be completely custom or
Expand Down
2 changes: 2 additions & 0 deletions sdgym/synthesizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
)
from sdgym.synthesizers.identity import DataIdentity
from sdgym.synthesizers.column import ColumnSynthesizer
from sdgym.synthesizers.realtabformer import RealTabFormerSynthesizer
from sdgym.synthesizers.sdv import (
CopulaGANSynthesizer,
CTGANSynthesizer,
Expand Down Expand Up @@ -38,4 +39,5 @@
'create_sdv_synthesizer_variant',
'create_sequential_synthesizer',
'SYNTHESIZER_MAPPING',
'RealTabFormerSynthesizer',
)
44 changes: 44 additions & 0 deletions sdgym/synthesizers/realtabformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""REaLTabFormer integration."""

import contextlib
import logging
from functools import partialmethod

import tqdm

from sdgym.synthesizers.base import BaselineSynthesizer


@contextlib.contextmanager
def prevent_tqdm_output():
"""Temporarily disables tqdm m."""
tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)
try:
yield
finally:
tqdm.__init__ = partialmethod(tqdm.__init__, disable=False)


class RealTabFormerSynthesizer(BaselineSynthesizer):
"""Custom wrapper for the REaLTabFormer synthesizer to make it work with SDGym."""

LOGGER = logging.getLogger(__name__)

def _get_trained_synthesizer(self, data, metadata):
try:
from realtabformer import REaLTabFormer
except Exception as exception:
raise ValueError(
"In order to use 'RealTabFormerSynthesizer' you have to install the extra"
" dependencies by running pip install sdgym['realtabformer'] "
) from exception

with prevent_tqdm_output():
model = REaLTabFormer(model_type='tabular')
model.fit(data, device='cpu')

return model

def _sample_from_synthesizer(self, synthesizer, n_sample):
"""Sample synthetic data with specified sample count."""
return synthesizer.sample(n_sample, device='cpu')
27 changes: 27 additions & 0 deletions tests/integration/synthesizers/test_realtabformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import sys

import pytest

from sdgym import load_dataset
from sdgym.synthesizers import RealTabFormerSynthesizer


@pytest.mark.skipif(sys.platform.startswith('darwin'), reason='Test not supported on github MacOS')
def test_realtabformer_end_to_end():
"""Test it without metrics."""
# Setup
data, metadata_dict = load_dataset(
'single_table', 'student_placements', limit_dataset_size=False
)
realtabformer_instance = RealTabFormerSynthesizer()

# Run
trained_synthesizer = realtabformer_instance.get_trained_synthesizer(data, metadata_dict)
sampled_data = realtabformer_instance.sample_from_synthesizer(trained_synthesizer, n_samples=10)

# Assert
assert sampled_data.shape[1] == data.shape[1], (
f'Sampled data shape {sampled_data.shape} does not match original data shape {data.shape}'
)

assert set(sampled_data.columns) == set(data.columns)
21 changes: 20 additions & 1 deletion tests/integration/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import contextlib
import io
import re
import sys
import time

import numpy as np
Expand Down Expand Up @@ -49,6 +50,25 @@ def test_benchmark_single_table_basic_synthsizers():
] == quality_scores.index.tolist()


@pytest.mark.skipif(sys.platform.startswith('darwin'), reason='Test not supported on github MacOS')
def test_benchmark_single_table_realtabformer_no_metrics():
"""Test it without metrics."""
# Run
output = sdgym.benchmark_single_table(
synthesizers=['RealTabFormerSynthesizer'],
sdv_datasets=['student_placements'],
sdmetrics=[],
)

# Assert
train_time = output['Train_Time'][0]
sample_time = output['Sample_Time'][0]
assert isinstance(train_time, (int, float, complex)), 'Train_Time is not numerical'
assert isinstance(sample_time, (int, float, complex)), 'Sample_Time is not numerical'
assert train_time >= 0
assert sample_time >= 0


def test_benchmark_single_table_no_metrics():
"""Test it without metrics."""
# Run
Expand All @@ -62,7 +82,6 @@ def test_benchmark_single_table_no_metrics():
assert not output.empty
assert 'Train_Time' in output
assert 'Sample_Time' in output

# Expect no metric columns.
assert len(output.columns) == 10

Expand Down
63 changes: 63 additions & 0 deletions tests/unit/synthesizers/test_realtabformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Tests for the realtabformer module."""

from unittest.mock import MagicMock, patch

import numpy as np
import pandas as pd
import pytest

from sdgym.synthesizers import RealTabFormerSynthesizer


@pytest.fixture
def sample_data():
"""Provide sample data for testing."""
n_samples = 10
num_values = np.random.normal(size=n_samples)

return pd.DataFrame({
'num': num_values,
})


class TestRealTabFormerSynthesizer:
"""Unit tests for RealTabFormerSynthesizer integration with SDGym."""

@patch('realtabformer.REaLTabFormer')
def test__get_trained_synthesizer(self, mock_real_tab_former):
"""Test _get_trained_synthesizer
Initializes REaLTabFormer and fits REaLTabFormer with
correct parameters.
"""
# Setup
mock_model = MagicMock()
mock_real_tab_former.return_value = mock_model
data = MagicMock()
metadata = MagicMock()
synthesizer = RealTabFormerSynthesizer()

# Run
result = synthesizer._get_trained_synthesizer(data, metadata)

# Assert
mock_real_tab_former.assert_called_once_with(model_type='tabular')
mock_model.fit.assert_called_once_with(data, device='cpu')
assert result == mock_model, 'Expected the trained model to be returned.'

def test__sample_from_synthesizer(self):
"""Test _sample_from_synthesizer generates data with the specified sample size."""
# Setup
trained_model = MagicMock()
trained_model.sample.return_value = MagicMock(shape=(10, 5)) # Mock sample data shape
n_sample = 10
synthesizer = RealTabFormerSynthesizer()

# Run
synthetic_data = synthesizer._sample_from_synthesizer(trained_model, n_sample)

# Assert
trained_model.sample.assert_called_once_with(n_sample, device='cpu')
assert synthetic_data.shape[0] == n_sample, (
f'Expected {n_sample} rows, but got {synthetic_data.shape[0]}'
)

0 comments on commit 50ac89b

Please sign in to comment.