Skip to content

Commit

Permalink
Merge pull request #9 from YAY-C/master
Browse files Browse the repository at this point in the history
fixed issue with goptions --> wraper
  • Loading branch information
briling authored Nov 22, 2022
2 parents 999d8e1 + cd99237 commit 8daaa2c
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
9 changes: 6 additions & 3 deletions qstack/regression/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,11 @@ def hyper_loop(sigma, eta):
print(s, e, mean, std, flush=True)
errors.append((mean, std, e, s))
return errors

kernel = get_kernel(akernel, [gkernel, gdict])
if gkernel == None:
gwrap = None
else:
gwrap = [gkernel, gdict]
kernel = get_kernel(akernel, gwrap)
if read_kernel is False:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=0)
else:
Expand Down Expand Up @@ -94,7 +97,7 @@ def main():
parser.add_argument('--y', type=str, dest='prop', required=True, help='path to the properties file')
parser.add_argument('--test', type=float, dest='test_size', default=defaults.test_size, help='test set fraction (default='+str(defaults.test_size)+')')
parser.add_argument('--akernel', type=str, dest='akernel', default=defaults.kernel, help='local kernel type (G for Gaussian, L for Laplacian, myL for Laplacian for open-shell systems) (default '+defaults.kernel+')')
parser.add_argument('--gkernel', type=str, dest='gkernel', default=defaults.gkernel, help='global kernel type (avg for average kernel, rem for REMatch kernel) (default '+defaults.gkernel+')')
parser.add_argument('--gkernel', type=str, dest='gkernel', default=defaults.gkernel, help='global kernel type (avg for average kernel, rem for REMatch kernel) (default )')
parser.add_argument('--gdict', nargs='*', action=ParseKwargs, dest='gdict', default=defaults.gdict, help='dictionary like input string to initialize global kernel parameters')
parser.add_argument('--splits', type=int, dest='splits', default=defaults.splits, help='k in k-fold cross validation (default='+str(defaults.n_rep)+')')
parser.add_argument('--print', type=int, dest='printlevel', default=0, help='printlevel')
Expand Down
3 changes: 2 additions & 1 deletion qstack/regression/kernel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __call__(self, parser, namespace, values, option_string=None):
sigma=32.0,
eta=1e-5,
kernel='L',
gkernel='avg',
gkernel=None,
gdict={'alpha':1.0, 'normalize':1},
test_size=0.2,
n_rep=5,
Expand Down Expand Up @@ -162,6 +162,7 @@ def get_global_K(X, Y, sigma, local_kernel, global_kernel, options):
print(f"Final global kernel has size : {K_global.shape}", flush=True)
return K_global


def my_laplacian_kernel(X, Y, gamma):
""" Compute Laplacian kernel between X and Y """
def cdist(X, Y):
Expand Down
2 changes: 1 addition & 1 deletion qstack/regression/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def regression(X, y, read_kernel=False, sigma=defaults.sigma, eta=defaults.eta,

maes_all = []
for size in train_size:
size_train = int(np.floor(len(y_train)*size))
size_train = int(np.floor(len(y_train)*size)) if size < 1.0 else size
maes = []
for rep in range(n_rep):
train_idx = np.random.choice(all_indices_train, size = size_train, replace=False)
Expand Down

0 comments on commit 8daaa2c

Please sign in to comment.