Skip to content

Commit

Permalink
Metrics: ensure ordinal encoder of classes is the same in real and sy…
Browse files Browse the repository at this point in the history
…nthetic datasets [type:bug] (#257)

* reuse encoders

* ensure categorical encoder is trained on real and synthetic

* better transformer

* remove unnecessary imports

* better error message

* compatbility with DDIM
  • Loading branch information
bvanbreugel authored Mar 11, 2024
1 parent 14d67ee commit 41e6e5a
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions src/synthcity/metrics/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,15 +200,26 @@ def evaluate(
if metrics is None:
metrics = Metrics.list()

X_gt, _ = X_gt.encode()
X_syn, _ = X_syn.encode()
"""
We need to encode the categorical data in the real and synthetic data.
To ensure each category in the two datasets are mapped to the same one hot vector, we merge X_syn into X_gt for computing the encoder.
TODO: Check whether the optional datasets also need to be taking into account when getting the encoder.
"""
X_gt_df = X_gt.dataframe()
X_syn_df = X_syn.dataframe()
X_enc = create_from_info(pd.concat([X_gt_df, X_syn_df]), X_gt.info())
_, encoders = X_enc.encode()

# now we encode the data
X_gt, _ = X_gt.encode(encoders)
X_syn, _ = X_syn.encode(encoders)

if X_train:
X_train, _ = X_train.encode()
X_train, _ = X_train.encode(encoders)
if X_ref_syn:
X_ref_syn, _ = X_ref_syn.encode()
X_ref_syn, _ = X_ref_syn.encode(encoders)
if X_augmented:
X_augmented, _ = X_augmented.encode()
X_augmented, _ = X_augmented.encode(encoders)

scores = ScoreEvaluator()

Expand Down

0 comments on commit 41e6e5a

Please sign in to comment.