Skip to content

Commit

Permalink
CU-86b2czk01 Fixing linter issues
Browse files Browse the repository at this point in the history
  • Loading branch information
cristid9 committed Nov 19, 2024
1 parent 14ac5fd commit 43d446e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
5 changes: 4 additions & 1 deletion sdgym/synthesizers/realtabformer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""REaLTabFormer integration."""

import contextlib
from functools import partialmethod

import tqdm

from functools import partialmethod
from sdgym.synthesizers.base import BaselineSynthesizer


@contextlib.contextmanager
def prevent_tqdm_output():
"""Temporarily disables tqdm m."""
Expand Down
9 changes: 5 additions & 4 deletions tests/unit/synthesizers/test_realtabformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class TestRealTabFormerSynthesizer:

def test_get_trained_synthesizer(self):
"""Test _get_trained_synthesizer initializes
and fits REaLTabFormer with correct parameters."""
and fits REaLTabFormer with correct parameters."""
with patch('realtabformer.REaLTabFormer') as MockREaLTabFormer:
# Setup
mock_model = MagicMock()
Expand All @@ -40,7 +40,7 @@ def test_get_trained_synthesizer(self):
# Assert
MockREaLTabFormer.assert_called_once_with(model_type='tabular')
mock_model.fit.assert_called_once_with(data, device='cpu')
assert result == mock_model, "Expected the trained model to be returned."
assert result == mock_model, 'Expected the trained model to be returned.'

def test_sample_from_synthesizer(self):
"""Test _sample_from_synthesizer generates data with the specified sample size."""
Expand All @@ -55,5 +55,6 @@ def test_sample_from_synthesizer(self):

# Assert
trained_model.sample.assert_called_once_with(n_sample, device='cpu')
assert synthetic_data.shape[0] == n_sample, \
f"Expected {n_sample} rows, but got {synthetic_data.shape[0]}"
assert synthetic_data.shape[0] == n_sample, (
f'Expected {n_sample} rows, but got {synthetic_data.shape[0]}'
)

0 comments on commit 43d446e

Please sign in to comment.