Skip to content

Commit

Permalink
tidy up
Browse files Browse the repository at this point in the history
  • Loading branch information
robsdavis committed Sep 12, 2023
1 parent aaa9625 commit e7834e2
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 58 deletions.
40 changes: 0 additions & 40 deletions src/synthcity/metrics/eval_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,47 +371,21 @@ def _evaluate_time_series_performance(
if self.use_cache(cache_file):
return load_from_file(cache_file)

# print("X_gt.train().unpack()")
(
id_static_gt,
id_temporal_gt,
id_observation_times_gt,
id_outcome_gt,
) = X_gt.train().unpack(as_numpy=True)
# print(
# 444444444444444444,
# "temp: ",
# id_temporal_gt.shape,
# "obs: ",
# id_observation_times_gt.shape,
# )
# print("X_gt.test().unpack()")
# print(X_gt.test().dataframe().shape)
(
ood_static_gt,
ood_temporal_gt,
ood_observation_times_gt,
ood_outcome_gt,
) = X_gt.test().unpack(as_numpy=True)
# print(
# 444444444444444444,
# "temp: ",
# ood_temporal_gt.shape,
# "obs: ",
# ood_observation_times_gt.shape,
# )
# print("X_syn.unpack()")
# print(X_syn.dataframe().shape)
static_syn, temporal_syn, observation_times_syn, outcome_syn = X_syn.unpack(
as_numpy=True
)
# print(
# 444444444444444444,
# "temp: ",
# temporal_syn.shape,
# "obs: ",
# observation_times_syn.shape,
# )

skf = KFold(
n_splits=self._n_folds, shuffle=True, random_state=self._random_state
Expand All @@ -432,7 +406,6 @@ def ts_eval_cbk(
outcome_test: np.ndarray,
) -> float:
try:
# print(777777777777, temporal_train.shape, observation_times_train.shape)
estimator = model(**model_args).fit(
static_train, temporal_train, observation_times_train, outcome_train
)
Expand All @@ -450,18 +423,12 @@ def ts_eval_cbk(
for train_idx, test_idx in skf.split(id_static_gt):
static_train_data = id_static_gt[train_idx]
temporal_train_data = id_temporal_gt[train_idx]
# print(88888888888888888888, end=" | ")
# print("temp: ", id_temporal_gt.shape, end=" | ")
# print("obs: ", id_observation_times_gt.shape)
# print(id_temporal_gt[train_idx])
# print(id_observation_times_gt[train_idx])
observation_times_train_data = id_observation_times_gt[train_idx]
outcome_train_data = id_outcome_gt[train_idx]
static_test_data = id_static_gt[test_idx]
temporal_test_data = id_temporal_gt[test_idx]
observation_times_test_data = id_observation_times_gt[test_idx]
outcome_test_data = id_outcome_gt[test_idx]
# print("call 1")
real_score = ts_eval_cbk(
static_train_data,
temporal_train_data,
Expand All @@ -472,7 +439,6 @@ def ts_eval_cbk(
observation_times_test_data,
outcome_test_data,
)
# print("call 2")
synth_score_id = ts_eval_cbk(
static_syn,
temporal_syn,
Expand All @@ -483,7 +449,6 @@ def ts_eval_cbk(
observation_times_test_data,
outcome_test_data,
)
# print("call 3")
synth_score_ood = ts_eval_cbk(
static_syn,
temporal_syn,
Expand Down Expand Up @@ -913,11 +878,6 @@ def evaluate(
X_gt: DataLoader,
X_syn: DataLoader,
) -> Dict:
# print("evaluate")
# print("gt - temp: ", len(X_gt.info()["temporal_features"]))
# print("gt - obs: ", X_gt.info()["window_len"])
# print("syn - temp: ", len(X_syn.info()["temporal_features"]))
# print("syn - obs: ", X_syn.info()["window_len"])
if self._task_type == "survival_analysis":
return self._evaluate_survival_model(
DeephitSurvivalAnalysis, {}, X_gt, X_syn
Expand Down
18 changes: 0 additions & 18 deletions src/synthcity/plugins/core/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,27 +930,10 @@ def unpack(self, as_numpy: bool = False, pad: bool = False) -> Any:
)
if as_numpy:
longest_observation_seq = max([len(seq) for seq in temporal_data])
# print(666666666666, "temp: ", np.asarray(temporal_data).shape)
# print(
# 66666666666666,
# "obs: ",
# ma.vstack(
# [
# ma.array(
# np.resize(ot, longest_observation_seq),
# mask=[True for i in range(len(ot))]
# + [False for j in range(longest_observation_seq - len(ot))],
# )
# for ot in observation_times
# ]
# ).shape,
# )
return (
np.asarray(static_data),
np.asarray(temporal_data),
# np.asarray(pd.concat(temporal_data)),
# masked array to handle variable length sequences
# np.asarray(observation_times),
ma.vstack(
[
ma.array(
Expand All @@ -963,7 +946,6 @@ def unpack(self, as_numpy: bool = False, pad: bool = False) -> Any:
),
np.asarray(outcome),
)
# print("not numpy")
return (
static_data,
temporal_data,
Expand Down

0 comments on commit e7834e2

Please sign in to comment.