Skip to content

Commit

Permalink
Adding a couple of unit tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
drewoldag committed Jan 2, 2025
1 parent 2e9b92b commit 3c13ddc
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
16 changes: 16 additions & 0 deletions tests/laiss_resspect_classifier/test_laiss_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from sklearn.ensemble import RandomForestClassifier
from laiss_resspect_classifier.laiss_classifier import LaissRandomForest

def test_laiss_classifier():
"""Basic attribute check"""

laiss_rf = LaissRandomForest()
assert laiss_rf.n_estimators == 100
assert isinstance(laiss_rf.classifier, RandomForestClassifier)

def test_classifier_with_kwargs():
"""Check that kwargs are passed to the classifier"""

laiss_rf = LaissRandomForest(n_estimators=200, max_depth=13)
assert laiss_rf.n_estimators == 200
assert laiss_rf.classifier.max_depth == 13
12 changes: 12 additions & 0 deletions tests/laiss_resspect_classifier/test_laiss_feature_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from laiss_resspect_classifier.laiss_feature_extractor import LaissFeatureExtractor

def test_laiss_feature_extractor():
"""Basic attribute check"""

feature_extractor = LaissFeatureExtractor()
assert feature_extractor.id_column == "ztf_object_id"
assert feature_extractor.label_column == "ideal_label"
assert feature_extractor.other_feature_names == []

assert hasattr(feature_extractor, "fit")
assert hasattr(feature_extractor, "fit_all")

0 comments on commit 3c13ddc

Please sign in to comment.