Skip to content

Commit

Permalink
still refining regress with multiple outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
jkitchin committed Jul 7, 2024
1 parent 7cf6c20 commit b29b48c
Showing 1 changed file with 16 additions and 9 deletions.
25 changes: 16 additions & 9 deletions pycse/PYCSE.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,16 @@ def regress(A, y, alpha=0.05, *args, **kwargs):
CI = sT * se

# bint is a little tricky, and depends on the shape of the output.
if len(y.shape) == 1:
bint = np.array([(b - CI, b + CI)])
else:
nrows, ncols = y.shape
bint = []
for i in range(ncols):
bint += [[b - CI, b + CI]]
bint = np.array(bint)
bint = np.array([(b - CI, b + CI)])
# if len(y.shape) == 1:
# bint = np.array([(b - CI, b + CI)])
# else:
# nrows, ncols = y.shape
# bint = []
# print(f'bint: {b.shape}, {CI.shape}')
# for i in range(ncols):
# bint += [[b[:, i] - CI[:, i], b[:, i] + CI[:, i]]]
# bint = np.array(bint)

return (b, bint.squeeze(), se)

Expand Down Expand Up @@ -193,7 +195,12 @@ def predict(X, y, pars, XX, alpha=0.05, ub=1e-5, ef=1.05):
pred_se = np.sqrt([_mse * np.diag(gprime @ I_fisher @ gprime.T) for _mse in mse]).T
# This happens if mse is a single number
except TypeError:
pred_se = np.sqrt(mse * np.diag(gprime @ I_fisher @ gprime.T)).T
# you need at least 1d to get a diagonal. This line is needed because
# there is a case where there is one prediction where this product leads
# to a scalar quantity and we need to upgrade it to 1d to avoid an
# error.
gig = np.atleast_1d(gprime @ I_fisher @ gprime.T)
pred_se = np.sqrt(mse * np.diag(gig)).T

tval = t.ppf(1.0 - alpha / 2.0, dof)

Expand Down

0 comments on commit b29b48c

Please sign in to comment.