diff --git a/sktree/stats/_might.py b/sktree/stats/_might.py index 0fb8b66f2..3c9d79421 100644 --- a/sktree/stats/_might.py +++ b/sktree/stats/_might.py @@ -151,7 +151,7 @@ def statistic( posterior_final[:, 0], posterior_final[:, 1], max_fpr=self.limit ) elif stat == "MI": - H_YX = np.mean(entropy(posterior_final[:, 1], base=np.exp(1), axis=1)) + H_YX = np.mean(entropy(posterior_final[:, 1], base=np.exp(1))) _, counts = np.unique(posterior_final[:, 0], return_counts=True) H_Y = entropy(counts, base=np.exp(1)) self.stat = max(H_Y - H_YX, 0)