-
Notifications
You must be signed in to change notification settings - Fork 62
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add REalTabFormer to supported synthesizers (#360)
- Loading branch information
Showing
8 changed files
with
166 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -112,3 +112,6 @@ tmp/ | |
dask-worker-space | ||
scripts/runs | ||
scripts/datasets | ||
|
||
# ReaLTabFormer | ||
rtf_checkpoints/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
"""REaLTabFormer integration.""" | ||
|
||
import contextlib | ||
import logging | ||
from functools import partialmethod | ||
|
||
import tqdm | ||
|
||
from sdgym.synthesizers.base import BaselineSynthesizer | ||
|
||
|
||
@contextlib.contextmanager | ||
def prevent_tqdm_output(): | ||
"""Temporarily disables tqdm m.""" | ||
tqdm.__init__ = partialmethod(tqdm.__init__, disable=True) | ||
try: | ||
yield | ||
finally: | ||
tqdm.__init__ = partialmethod(tqdm.__init__, disable=False) | ||
|
||
|
||
class RealTabFormerSynthesizer(BaselineSynthesizer): | ||
"""Custom wrapper for the REaLTabFormer synthesizer to make it work with SDGym.""" | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
def _get_trained_synthesizer(self, data, metadata): | ||
try: | ||
from realtabformer import REaLTabFormer | ||
except Exception as exception: | ||
raise ValueError( | ||
"In order to use 'RealTabFormerSynthesizer' you have to install the extra" | ||
" dependencies by running pip install sdgym['realtabformer'] " | ||
) from exception | ||
|
||
with prevent_tqdm_output(): | ||
model = REaLTabFormer(model_type='tabular') | ||
model.fit(data, device='cpu') | ||
|
||
return model | ||
|
||
def _sample_from_synthesizer(self, synthesizer, n_sample): | ||
"""Sample synthetic data with specified sample count.""" | ||
return synthesizer.sample(n_sample, device='cpu') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import sys | ||
|
||
import pytest | ||
|
||
from sdgym import load_dataset | ||
from sdgym.synthesizers import RealTabFormerSynthesizer | ||
|
||
|
||
@pytest.mark.skipif(sys.platform.startswith('darwin'), reason='Test not supported on github MacOS') | ||
def test_realtabformer_end_to_end(): | ||
"""Test it without metrics.""" | ||
# Setup | ||
data, metadata_dict = load_dataset( | ||
'single_table', 'student_placements', limit_dataset_size=False | ||
) | ||
realtabformer_instance = RealTabFormerSynthesizer() | ||
|
||
# Run | ||
trained_synthesizer = realtabformer_instance.get_trained_synthesizer(data, metadata_dict) | ||
sampled_data = realtabformer_instance.sample_from_synthesizer(trained_synthesizer, n_samples=10) | ||
|
||
# Assert | ||
assert sampled_data.shape[1] == data.shape[1], ( | ||
f'Sampled data shape {sampled_data.shape} does not match original data shape {data.shape}' | ||
) | ||
|
||
assert set(sampled_data.columns) == set(data.columns) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
"""Tests for the realtabformer module.""" | ||
|
||
from unittest.mock import MagicMock, patch | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
|
||
from sdgym.synthesizers import RealTabFormerSynthesizer | ||
|
||
|
||
@pytest.fixture | ||
def sample_data(): | ||
"""Provide sample data for testing.""" | ||
n_samples = 10 | ||
num_values = np.random.normal(size=n_samples) | ||
|
||
return pd.DataFrame({ | ||
'num': num_values, | ||
}) | ||
|
||
|
||
class TestRealTabFormerSynthesizer: | ||
"""Unit tests for RealTabFormerSynthesizer integration with SDGym.""" | ||
|
||
@patch('realtabformer.REaLTabFormer') | ||
def test__get_trained_synthesizer(self, mock_real_tab_former): | ||
"""Test _get_trained_synthesizer | ||
Initializes REaLTabFormer and fits REaLTabFormer with | ||
correct parameters. | ||
""" | ||
# Setup | ||
mock_model = MagicMock() | ||
mock_real_tab_former.return_value = mock_model | ||
data = MagicMock() | ||
metadata = MagicMock() | ||
synthesizer = RealTabFormerSynthesizer() | ||
|
||
# Run | ||
result = synthesizer._get_trained_synthesizer(data, metadata) | ||
|
||
# Assert | ||
mock_real_tab_former.assert_called_once_with(model_type='tabular') | ||
mock_model.fit.assert_called_once_with(data, device='cpu') | ||
assert result == mock_model, 'Expected the trained model to be returned.' | ||
|
||
def test__sample_from_synthesizer(self): | ||
"""Test _sample_from_synthesizer generates data with the specified sample size.""" | ||
# Setup | ||
trained_model = MagicMock() | ||
trained_model.sample.return_value = MagicMock(shape=(10, 5)) # Mock sample data shape | ||
n_sample = 10 | ||
synthesizer = RealTabFormerSynthesizer() | ||
|
||
# Run | ||
synthetic_data = synthesizer._sample_from_synthesizer(trained_model, n_sample) | ||
|
||
# Assert | ||
trained_model.sample.assert_called_once_with(n_sample, device='cpu') | ||
assert synthetic_data.shape[0] == n_sample, ( | ||
f'Expected {n_sample} rows, but got {synthetic_data.shape[0]}' | ||
) |