diff --git a/tests/test_trainer.py b/tests/test_trainer.py index a1a553747..cbcd28e26 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -713,7 +713,7 @@ def test_predict(self): trainer = get_regression_trainer(a=1.5, b=2.5, double_output=True, label_names=["labels"]) preds = trainer.predict(trainer.eval_dataset).predictions x = trainer.eval_dataset.x - self.assertTrue(len(preds), 2) + self.assertEqual(len(preds), 2) self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5)) self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))