Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add predict_sklearn method #535

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions bambi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,30 @@ def predict(self, idata, kind="mean", data=None, inplace=True, include_group_spe
else:
return idata

def predict_sklearn(self, idata, predict_variable, data=None):
"""
Produce point estimate predictions for each of the test data points in 'data'
(or the training data, if no test data is provided) which returns a numpy
array of the predicted means for each data point for the target variable
(i.e. the same return format as the sklearn 'predict' API)

Parameters
----------
idata: InferenceData
The ``InferenceData`` instance returned by ``.fit()``.
predict_variable: string
The name of the target variable fit by the model
data: pandas.DataFrame or None
An optional data frame with values for the predictors that are used to obtain
out-of-sample predictions. If omitted, the original dataset is used.

Returns
-------
ndarray
"""
self.predict(idata, kind="pps", data=data)
return idata.posterior_predictive[predict_variable].mean(("chain", "draw")).values

def graph(self, formatting="plain", name=None, figsize=None, dpi=300, fmt="png"):
"""
Produce a graphviz Digraph from a built Bambi model.
Expand Down
Loading