Skip to content

Commit

Permalink
tests: add test for evaluate output by data_type
Browse files Browse the repository at this point in the history
  • Loading branch information
Tveten committed Nov 24, 2024
1 parent d7d8bc9 commit cf5d568
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions skchange/tests/test_all_interval_evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from skchange.anomaly_scores import ANOMALY_SCORES
from skchange.change_scores import CHANGE_SCORES
from skchange.costs import COSTS
from skchange.datasets import generate_alternating_data

INTERVAL_EVALUATORS = COSTS + CHANGE_SCORES + ANOMALY_SCORES

Expand Down Expand Up @@ -43,6 +44,35 @@ def test_evaluator_evaluate(Evaluator):
assert len(results) == len(intervals)


@pytest.mark.parametrize("Evaluator", INTERVAL_EVALUATORS)
def test_evaluator_evaluate_by_data_type(Evaluator):
evaluator = Evaluator.create_test_instance()
n_segments = 1
seg_len = 50
p = 3
df = generate_alternating_data(
n_segments=n_segments,
mean=20,
segment_length=seg_len,
p=p,
random_state=15,
)

evaluator.fit(df)
interval1 = np.linspace(0, 10, evaluator.expected_interval_entries, dtype=int)
interval2 = np.linspace(10, 20, evaluator.expected_interval_entries, dtype=int)
intervals = np.array([interval1, interval2])

results = evaluator.evaluate(intervals)

if evaluator.data_type == "univariate":
assert results.shape == (2, p)
elif evaluator.data_type == "multivariate":
assert results.shape == (2, 1)
else:
raise ValueError("Invalid scitype:evaluator tag.")


@pytest.mark.parametrize("Evaluator", INTERVAL_EVALUATORS)
def test_evaluator_invalid_intervals(Evaluator):
evaluator = Evaluator.create_test_instance()
Expand Down

0 comments on commit cf5d568

Please sign in to comment.