Skip to content

Commit

Permalink
Merge pull request #346 from WMD-group/change_oxi_defaults
Browse files Browse the repository at this point in the history
Change default oxidation states from SMACT14 to ICSD24
  • Loading branch information
AntObi authored Dec 2, 2024
2 parents 423a0bf + 59bf75c commit 0728f46
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 45 deletions.
1 change: 1 addition & 0 deletions docs/tutorials/crystal_space.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
" max_atomic_num=103,\n",
" num_processes=8,\n",
" save_path=\"data/binary/df_binary_label.pkl\",\n",
" oxidation_states_set=\"smact14\"\n",
")"
]
},
Expand Down
6 changes: 3 additions & 3 deletions docs/tutorials/smact_validity_of_GNoMe.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions smact/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class Element:
Element.SSEPauling (float) : SSE based on regression fit with Pauling electronegativity
Element.oxidation_states (list) : Default list of allowed oxidation states for use in SMACT
Element.oxidation_states (list) : Default list of allowed oxidation states for use in SMACT. In >3.0, these are the ICSD24 set. In <3.0, these are the SMACT14 set.
Element.oxidation_states_smact14 (list): Original list of oxidation states that were manually compiled for SMACT in 2014 (default in SMACT < 3.0)
Expand Down Expand Up @@ -179,7 +179,7 @@ def __init__(self, symbol: str, oxi_states_custom_filepath: str | None = None):
("number", dataset["Z"]),
(
"oxidation_states",
data_loader.lookup_element_oxidation_states(symbol),
data_loader.lookup_element_oxidation_states_icsd24(symbol),
),
(
"oxidation_states_smact14",
Expand Down
16 changes: 7 additions & 9 deletions smact/screening.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,7 @@ def smact_filter(
threshold: int | None = 8,
stoichs: list[list[int]] | None = None,
species_unique: bool = True,
oxidation_states_set: str = "smact14",
comp_tuple: bool = False,
oxidation_states_set: str = "icsd24",
) -> list[tuple[str, int, int]] | list[tuple[str, int]]:
"""Function that applies the charge neutrality and electronegativity
tests in one go for simple application in external scripts that
Expand All @@ -348,12 +347,11 @@ def smact_filter(
Args:
----
els (tuple/list): A list of smact.Element objects
threshold (int): Threshold for stoichiometry limit, default = 8
els (tuple/list): A list of smact.Element objects.
threshold (int): Threshold for stoichiometry limit, default = 8.
stoichs (list[int]): A selection of valid stoichiometric ratios for each site.
species_unique (bool): Whether or not to consider elements in different oxidation states as unique in the results.
oxidation_states_set (string): A string to choose which set of oxidation states should be chosen. Options are 'smact14', 'icsd16',"icsd24", 'pymatgen_sp' and 'wiki' for the 2014 SMACT default, 2016 ICSD, 2024 ICSD, pymatgen structure predictor and Wikipedia (https://en.wikipedia.org/wiki/Template:List_of_oxidation_states_of_the_elements) oxidation states respectively. A filepath to an oxidation states text file can also be supplied as well.
comp_tuple (bool): Whether or not to return the results as a named tuple of elements and stoichiometries (True) or as a normal tuple of elements and stoichiometries (False).
Returns:
-------
Expand Down Expand Up @@ -438,7 +436,7 @@ def smact_validity(
composition: pymatgen.core.Composition | str,
use_pauling_test: bool = True,
include_alloys: bool = True,
oxidation_states_set: str | bytes | os.PathLike = "smact14",
oxidation_states_set: str = "icsd24",
) -> bool:
"""
Check if a composition is valid according to the SMACT rules.
Expand All @@ -454,7 +452,7 @@ def smact_validity(
composition (Union[pymatgen.core.Composition, str]): Composition/formula to check. This can be a pymatgen Composition object or a string.
use_pauling_test (bool): Whether to use the Pauling electronegativity test
include_alloys (bool): If True, compositions which only contain metal elements will be considered valid without further checks.
oxidation_states_set (Union[str, bytes, os.PathLike]): A string to choose which set of
oxidation_states_set (str): A string to choose which set of
oxidation states should be chosen for charge-balancing. Options are 'smact14', 'icsd14', 'icsd24',
'pymatgen_sp' and 'wiki' for the 2014 SMACT default, 2016 ICSD, 2024 ICSD, pymatgen structure predictor and Wikipedia
(https://en.wikipedia.org/wiki/Template:List_of_oxidation_states_of_the_elements) oxidation states respectively.
Expand Down Expand Up @@ -486,13 +484,13 @@ def smact_validity(
smact_elems = [e[1] for e in space.items()]
electronegs = [e.pauling_eneg for e in smact_elems]

if oxidation_states_set == "smact14" or oxidation_states_set is None:
if oxidation_states_set == "smact14":
ox_combos = [e.oxidation_states_smact14 for e in smact_elems]
elif oxidation_states_set == "icsd16":
ox_combos = [e.oxidation_states_icsd16 for e in smact_elems]
elif oxidation_states_set == "pymatgen_sp":
ox_combos = [e.oxidation_states_sp for e in smact_elems]
elif oxidation_states_set == "icsd24":
elif oxidation_states_set == "icsd24" or oxidation_states_set is None: # Default
ox_combos = [e.oxidation_states_icsd24 for e in smact_elems]
elif os.path.exists(oxidation_states_set):
ox_combos = [oxi_custom(e.symbol, oxidation_states_set) for e in smact_elems]
Expand Down
52 changes: 36 additions & 16 deletions smact/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def test_element_dictionary(self):
newlist = ["O", "Rb", "W"]
dictionary = smact.element_dictionary(newlist, TEST_OX_STATES)
self.assertEqual(dictionary["O"].crustal_abundance, 461000.0)
self.assertEqual(dictionary["Rb"].oxidation_states, [-1, 1])
self.assertEqual(dictionary["Rb"].oxidation_states_smact14, [-1, 1])
self.assertEqual(dictionary["Rb"].oxidation_states, [1])
self.assertEqual(dictionary["Rb"].oxidation_states_custom, [-1, 1])
self.assertEqual(dictionary["W"].name, "Tungsten")
self.assertTrue("Rn" in smact.element_dictionary())
Expand Down Expand Up @@ -335,31 +336,50 @@ def test_ml_rep_generator(self):
self.assertEqual(smact.screening.ml_rep_generator([Pb, O], [1, 2]), PbO2_ml)

def test_smact_filter(self):
oxidation_states_sets = ["smact14", "icsd24"]
oxidation_states_sets_results = {
"smact14": {
"thresh_2": [
(("Na", "Fe", "Cl"), (1, -1, -1), (2, 1, 1)),
(("Na", "Fe", "Cl"), (1, 1, -1), (1, 1, 2)),
]
},
"icsd24": {"thresh_2": [(("Na", "Fe", "Cl"), (1, 1, -1), (1, 1, 2))]},
}

Na, Fe, Cl = (smact.Element(label) for label in ("Na", "Fe", "Cl"))
result = smact.screening.smact_filter([Na, Fe, Cl], threshold=2)
self.assertEqual(
[(r[0], r[1], r[2]) for r in result],
[
(("Na", "Fe", "Cl"), (1, -1, -1), (2, 1, 1)),
(("Na", "Fe", "Cl"), (1, 1, -1), (1, 1, 2)),
],
)

for ox_state_set in oxidation_states_sets:
with self.subTest(ox_state_set=ox_state_set):
output = smact.screening.smact_filter([Na, Fe, Cl], threshold=2, oxidation_states_set=ox_state_set)
self.assertEqual(
[(r[0], r[1], r[2]) for r in output],
oxidation_states_sets_results[ox_state_set]["thresh_2"],
)

# Test that reading the oxidation states from a file produces the same results
self.assertEqual(
result,
smact.screening.smact_filter([Na, Fe, Cl], threshold=2, oxidation_states_set="smact14"),
smact.screening.smact_filter([Na, Fe, Cl], threshold=2, oxidation_states_set=TEST_OX_STATES),
)

self.assertEqual(
set(smact.screening.smact_filter([Na, Fe, Cl], threshold=2, species_unique=False)),
set(
smact.screening.smact_filter(
[Na, Fe, Cl], threshold=2, species_unique=False, oxidation_states_set="smact14"
)
),
{
(("Na", "Fe", "Cl"), (2, 1, 1)),
(("Na", "Fe", "Cl"), (1, 1, 2)),
},
)

self.assertEqual(len(smact.screening.smact_filter([Na, Fe, Cl], threshold=8)), 77)
self.assertEqual(
len(smact.screening.smact_filter([Na, Fe, Cl], threshold=8, oxidation_states_set="smact14")), 77
)

result = smact.screening.smact_filter([Na, Fe, Cl], stoichs=[[1], [1], [4]])
result = smact.screening.smact_filter([Na, Fe, Cl], stoichs=[[1], [1], [4]], oxidation_states_set="smact14")
self.assertEqual(
[(r[0], r[1], r[2]) for r in result],
[
Expand All @@ -368,7 +388,7 @@ def test_smact_filter(self):
)
stoichs = [list(range(1, 5)), list(range(1, 5)), list(range(1, 10))]
self.assertEqual(
len(smact.screening.smact_filter([Na, Fe, Cl], stoichs=stoichs)),
len(smact.screening.smact_filter([Na, Fe, Cl], stoichs=stoichs, oxidation_states_set="smact14")),
45,
)

Expand All @@ -380,8 +400,8 @@ def test_smact_validity(self):
# Test for single element
self.assertTrue(smact.screening.smact_validity("Al"))

# Test for MgB2 which is invalid for the default oxi states but valid for the icsd states
self.assertFalse(smact.screening.smact_validity("MgB2"))
# Test for MgB2 which is invalid for the smact14 oxi states but valid for the icsd states
self.assertFalse(smact.screening.smact_validity("MgB2", oxidation_states_set="smact14"))
self.assertTrue(smact.screening.smact_validity("MgB2", oxidation_states_set="icsd16"))
self.assertFalse(smact.screening.smact_validity("MgB2", oxidation_states_set="pymatgen_sp"))
self.assertTrue(smact.screening.smact_validity("MgB2", oxidation_states_set="wiki"))
Expand Down
38 changes: 24 additions & 14 deletions smact/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,20 +96,30 @@ def test_convert_formula(self):

def test_generate_composition_with_smact(self):
save_dir = "data/binary/df_binary_label.pkl"
smact_df = generate_composition_with_smact.generate_composition_with_smact(
num_elements=2,
max_stoich=1,
max_atomic_num=10,
save_path=save_dir,
)
self.assertIsInstance(smact_df, pd.DataFrame)
self.assertTrue(len(smact_df) > 0)

# Check if the data was saved to disk
self.assertTrue(os.path.exists(save_dir))

# Clean up
shutil.rmtree("data")
oxidation_states_sets = ["smact14", "icsd24"]
oxidation_states_sets_dict = {
"smact14": {"smact_allowed": 388},
"icsd24": {"smact_allowed": 342},
}
for ox_states in oxidation_states_sets:
with self.subTest(ox_states=ox_states):
smact_df = generate_composition_with_smact.generate_composition_with_smact(
num_elements=2,
max_stoich=3,
max_atomic_num=20,
save_path=save_dir,
oxidation_states_set=ox_states,
)
self.assertIsInstance(smact_df, pd.DataFrame)
self.assertTrue(len(smact_df) == 1330)
self.assertTrue(
smact_df["smact_allowed"].sum() == oxidation_states_sets_dict[ox_states]["smact_allowed"]
)
# Check if the data was saved to disk
self.assertTrue(os.path.exists(save_dir))

# Clean up
shutil.rmtree("data")

@pytest.mark.skipif(
sys.platform == "win32" or not (os.environ.get("MP_API_KEY") or SETTINGS.get("PMG_MAPI_KEY")),
Expand Down
7 changes: 6 additions & 1 deletion smact/utils/crystal_space/generate_composition_with_smact.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def generate_composition_with_smact(
max_atomic_num: int = 103,
num_processes: int | None = None,
save_path: str | None = None,
oxidation_states_set: str = "icsd24",
) -> pd.DataFrame:
"""
Generate all possible compositions of a given number of elements and
Expand All @@ -55,6 +56,7 @@ def generate_composition_with_smact(
max_atomic_num (int): the maximum atomic number. Defaults to 103.
num_processes (int): the number of processes to use. Defaults to None.
save_path (str): the path to save the results. Defaults to None.
oxidation_states_set (str): the oxidation states set to use. Options are "smact14", "icsd16", "icsd24", "pymatgen_sp" or a filepath to a custom oxidation states list. For reproducing the Faraday Discussions results, use "smact14".
Returns:
df (pd.DataFrame): A DataFrame of SMACT-generated compositions with boolean smact_allowed column.
Expand Down Expand Up @@ -104,7 +106,10 @@ def generate_composition_with_smact(
pool = multiprocessing.Pool(processes=multiprocessing.cpu_count() if num_processes is None else num_processes)
results = list(
tqdm(
pool.imap_unordered(partial(smact_filter, threshold=max_stoich), compounds_pauling),
pool.imap_unordered(
partial(smact_filter, threshold=max_stoich, oxidation_states_set=oxidation_states_set),
compounds_pauling,
),
total=len(compounds_pauling),
)
)
Expand Down

0 comments on commit 0728f46

Please sign in to comment.