Skip to content

Commit

Permalink
add closed-shell rep test and fixed test_ecp
Browse files Browse the repository at this point in the history
  • Loading branch information
YAY-C committed Dec 6, 2024
1 parent e3a951d commit 8a4afbb
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
Binary file added tests/data/H2O_spahm_b.npy
Binary file not shown.
18 changes: 16 additions & 2 deletions tests/test_spahm_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,19 @@ def test_water():
for Xa, Xa_true in zip(X, X_true):
assert(np.linalg.norm(Xa-Xa_true) < 1e-8) # evaluating representation diff as norm (threshold = 1e-8)

def test_water_closed():
path = os.path.dirname(os.path.realpath(__file__))
xyz_in = path+'/data/H2O.xyz'
mols = utils.load_mols([xyz_in], [None], [0], 'minao')
dms = utils.mols_guess(mols, [xyz_in], 'LB', spin=[None])
X = bond.get_repr(mols, [xyz_in], 'LB', spin=[None], with_symbols=False, same_basis=False)
true_file = path+'/data/H2O_spahm_b.npy'
X_true = np.load(true_file)
print(X_true.shape)
assert(X_true.shape == X.shape)
for Xa, Xa_true in zip(X, X_true):
assert(np.linalg.norm(Xa-Xa_true) < 1e-8) # evaluating representation diff as norm (threshold = 1e-8)

def test_water_O_only():
path = os.path.dirname(os.path.realpath(__file__))
xyz_in = path+'/data/H2O.xyz'
Expand Down Expand Up @@ -47,8 +60,8 @@ def test_water_same_basis():
def test_ecp():
path = os.path.dirname(os.path.realpath(__file__))
xyz_in = path+'/data/I2.xyz'
mols = utils.load_mols([xyz_in], [0], [None], 'minao', ecp='def2-svp')
dms = utils.mols_guess(mols, [xyz_in], 'LB', spin=[None])
mols = utils.load_mols([xyz_in], [0], [0], 'minao', ecp='def2-svp')
dms = utils.mols_guess(mols, [xyz_in], 'LB', spin=[0])
X = bond.bond(mols, dms, same_basis=True)
X = np.squeeze(X) #contains a single elements but has shape (1,Nfeat)
X = np.hstack(X) # merging alpha-beta components for spin unrestricted representation #TODO: should be included into function not in main
Expand All @@ -75,4 +88,5 @@ def test_from_list():
if __name__ == '__main__':
test_water()
test_from_list()
test_water_closed()

0 comments on commit 8a4afbb

Please sign in to comment.