From 5bbaad69cf81ad25f0e7a1ef2d45c0e53adcf974 Mon Sep 17 00:00:00 2001 From: Anthony Onwuli Date: Mon, 19 Aug 2024 15:24:05 +0100 Subject: [PATCH] Ruff fixes --- .pre-commit-config.yaml | 33 +-- pyproject.toml | 84 +++++- setup.py | 15 +- smact/__init__.py | 191 ++++++------- smact/benchmarking/pymatgen_benchmark.py | 3 + smact/benchmarking/smact_benchmark.py | 6 +- smact/benchmarking/utilities.py | 8 +- smact/builder.py | 25 +- smact/data_loader.py | 191 ++++++------- smact/distorter.py | 32 ++- smact/dopant_prediction/__init__.py | 2 + smact/dopant_prediction/doper.py | 121 ++++----- smact/lattice.py | 17 +- smact/lattice_parameters.py | 98 +++++-- smact/mainpage.py | 2 +- smact/oxidation_states.py | 91 +++---- smact/properties.py | 62 ++--- smact/screening.py | 252 +++++++++--------- smact/structure_prediction/__init__.py | 2 + smact/structure_prediction/database.py | 97 ++++--- smact/structure_prediction/mutation.py | 123 +++++---- smact/structure_prediction/prediction.py | 194 ++++++-------- .../probability_models.py | 34 ++- smact/structure_prediction/structure.py | 186 ++++++------- smact/structure_prediction/utilities.py | 38 ++- smact/tests/test_core.py | 118 ++------ smact/tests/test_doper.py | 28 +- smact/tests/test_structure.py | 141 +++------- smact/utils/band_gap_simple.py | 20 +- smact/utils/download_compounds_with_mp_api.py | 18 +- .../utils/generate_composition_with_smact.py | 50 ++-- smact/utils/plot_embedding.py | 13 +- utils/bandgap.py | 38 --- 33 files changed, 1142 insertions(+), 1191 deletions(-) delete mode 100755 utils/bandgap.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3ff89c33..94a0cd07 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,24 +1,11 @@ repos: - - repo: https://github.com/timothycrosley/isort - rev: "5.12.0" - hooks: - - id: isort - additional_dependencies: [toml] - args: ["--profile", "black", "--filter-files","--line-length=80"] - - repo: https://github.com/psf/black - rev: "23.1.0" - hooks: - - id: black-jupyter - args: [--line-length=80] - - repo: https://github.com/asottile/pyupgrade - rev: v3.3.1 - hooks: - - id: pyupgrade - args: [--py38-plus] - - repo: https://github.com/nbQA-dev/nbQA - rev: "1.6.1" - hooks: - - id: nbqa-pyupgrade - additional_dependencies: [pyupgrade==3.3.1] - args: [--py38-plus] - + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.1 + hooks: + # Run the linting tool + - id: ruff + types_or: [ python, pyi ] + args: [--fix] + # Run the formatter + - id: ruff-format + types_or: [ python, pyi ] \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 062dedf2..e1f546f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,5 +7,85 @@ build-backend = "setuptools.build_meta" version_variable = "setup.py:__version__" version_source = "tag" -[tool.black] -line-length = 79 \ No newline at end of file +[tool.ruff] +target-version = "py39" +line-length = 120 +force-exclude = true + +[tool.ruff.lint] +select = ["ALL"] +ignore = [ + # Rule families + "ANN", # flake8-annotations (not ready, require types for ALL args) + "ARG", # Check for unused function arguments + "BLE", # General catch of Exception + "C90", # Check for functions with a high McCabe complexity + "COM", # flake8-commas (conflict with line wrapper) + "CPY", # Missing copyright notice at top of file (need preview mode) + "EM", # Format nice error messages + "ERA", # Check for commented-out code + "FIX", # Check for FIXME, TODO and other developer notes + "FURB", # refurb (need preview mode, too many preview errors) + "G", # Validate logging format strings + "INP", # Ban PEP-420 implicit namespace packages + "N", # PEP8-naming (many var/arg names are intended) + "PTH", # Prefer pathlib over os.path + "SLF", # Access "private" class members + "T20", # Check for print/pprint + "TD", # TODO tags related + + # Single rules + "B023", # Function definition does not bind loop variable + "B028", # No explicit stacklevel keyword argument found + "B904", # Within an except clause, raise exceptions with ... + "C408", # unnecessary-collection-call + "D105", # Missing docstring in magic method + "D205", # One blank line required between summary line and description + "D212", # Multi-line docstring summary should start at the first line + "E501", # Line too long + "E722", # Do not use bare `except` TODO fix this + "FBT001", # Boolean-typed positional argument in function definition + "FBT002", # Boolean default positional argument in function + "ISC001", + "NPY201", # TODO: enable after migration to NumPy 2.0 + "PD901", # pandas-df-variable-name + "PERF203", # Use of try-except in for/while loop + "PERF401", # Replace "for" loops with list comprehension + "PLR0911", # Too many return statements + "PLR0912", # Too many branches + "PLR0913", # Too many arguments + "PLR0915", # Too many statements + "PLR2004", # Magic-value-comparison TODO fix these + "PLW2901", # Outer for loop variable overwritten by inner assignment target + "PLW0603", # Using the global statement to update `_el_ox_states_wiki` is discouraged TODO fix these + "PT009", # Use a regular `assert` instead of unittest-style `assertAlmostEqual` + "PT011", # `pytest.raises(ValueError)` is too broad, set the `match` parameter or use a more specific exception TODO fix these + "PT013", # Incorrect import of pytest + "RET505", # Unnecessary `else` after `return` statement + "S101", # Use of "assert" + "S110", # Log for try-except-pass + "S112", # Log for try-except-continue + "S301", # `pickle` and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue + "S311", # Use random module for cryptographic purposes + "S314", # Replace xml with defusedxml to avoid XML attacks + "S603", # Check source for use of "subprocess" call + "S607", # Start process with relative path + "S608", # Possible SQL injection vector through string-based query construction + "SIM105", # Use contextlib.suppress() instead of try-except-pass + "TRY002", # Create your own exception TODO fix these + "TRY003", # Avoid specifying long messages outside the exception class + "TRY300", # Check for return statements in try blocks + "TRY301", # Check for raise statements within try blocks + "E741", # Ambigous variable +] +exclude = ["docs/conf.py","smact/utils/*", "docs/*"] +pydocstyle.convention = "google" +isort.required-imports = ["from __future__ import annotations"] + +[tool.ruff.format] +docstring-code-format = true + + +[tool.ruff.lint.per-file-ignores] +"smact/utils/*" = ["D"] +"smact/tests/*" = ["D"] diff --git a/setup.py b/setup.py index 36ca0a73..4e164776 100644 --- a/setup.py +++ b/setup.py @@ -1,28 +1,29 @@ -#!/usr/bin/env python +"""Installation for SMACT.""" + +from __future__ import annotations __author__ = "Daniel W. Davies" __author_email__ = "d.w.davies@imperial.ac.uk" -__copyright__ = ( - "Copyright Daniel W. Davies, Adam J. Jackson, Keith T. Butler (2019)" -) +__copyright__ = "Copyright Daniel W. Davies, Adam J. Jackson, Keith T. Butler (2019)" __version__ = "2.6" __maintainer__ = "Anthony O. Onwuli" __maintaier_email__ = "anthony.onwuli16@imperial.ac.uk" __date__ = "July 10 2024" import os -import unittest -from setuptools import Extension, setup +from setuptools import setup module_dir = os.path.dirname(os.path.abspath(__file__)) +with open(os.path.join(module_dir, "README.md")) as f: + long_description = f.read() if __name__ == "__main__": setup( name="SMACT", version=__version__, description="Semiconducting Materials by Analogy and Chemical Theory", - long_description=open(os.path.join(module_dir, "README.md")).read(), + long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/WMD-group/SMACT", author=__author__, diff --git a/smact/__init__.py b/smact/__init__.py index e19dc51b..8e2676c7 100644 --- a/smact/__init__.py +++ b/smact/__init__.py @@ -1,26 +1,32 @@ """ -Semiconducting Materials from Analogy and Chemical Theory +Semiconducting Materials from Analogy and Chemical Theory. A collection of fast screening tools from elemental data """ +from __future__ import annotations + import itertools import warnings from math import gcd from operator import mul as multiply from os import path -from typing import Iterable, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING import pandas as pd module_directory = path.abspath(path.dirname(__file__)) data_directory = path.join(module_directory, "data") # get correct path for datafiles when called from another directory -from smact import data_loader +from smact import data_loader # noqa: E402 + +if TYPE_CHECKING: + from collections.abc import Iterable class Element: - """Collection of standard elemental properties for given element. + """ + Collection of standard elemental properties for given element. Data is drawn from "data/element.txt", part of the Open Babel package. @@ -29,6 +35,7 @@ class Element: "Species" class. Attributes: + ---------- Element.symbol (string) : Elemental symbol used to retrieve data Element.name (string) : Full name of element @@ -74,17 +81,18 @@ class Element: Element.HHI_r (float) : Hirfindahl-Hirschman Index for elemental reserves Raises: + ------ NameError: Element not found in element.txt Warning: Element not found in Eigenvalues.csv """ - def __init__( - self, symbol: str, oxi_states_custom_filepath: Optional[str] = None - ): - """Initialise Element class + def __init__(self, symbol: str, oxi_states_custom_filepath: str | None = None): + """ + Initialise Element class. Args: + ---- symbol (str): Chemical element symbol (e.g. 'Fe') oxi_states_custom_filepath (str): Path to custom oxidation states file @@ -92,16 +100,12 @@ def __init__( # Get the oxidation states from the custom file if it exists if oxi_states_custom_filepath: try: - self._oxidation_states_custom = ( - data_loader.lookup_element_oxidation_states_custom( - symbol, oxi_states_custom_filepath - ) + self._oxidation_states_custom = data_loader.lookup_element_oxidation_states_custom( + symbol, oxi_states_custom_filepath ) self.oxidation_states_custom = self._oxidation_states_custom except TypeError: - warnings.warn( - "Custom oxidation states file not found. Please check the file path." - ) + warnings.warn("Custom oxidation states file not found. Please check the file path.") self.oxidation_states_custom = None else: self.oxidation_states_custom = None @@ -109,36 +113,25 @@ def __init__( dataset = data_loader.lookup_element_data(self.symbol, copy=False) - if dataset == None: + if dataset is None: raise NameError(f"Elemental data for {symbol} not found.") # Set coordination-environment data from the Shannon-radius data. # As above, it is safe to use copy = False with this Get* function. - shannon_data = data_loader.lookup_element_shannon_radius_data( - symbol, copy=False - ) + shannon_data = data_loader.lookup_element_shannon_radius_data(symbol, copy=False) - if shannon_data != None: - coord_envs = [row["coordination"] for row in shannon_data] - else: - coord_envs = None + coord_envs = [row["coordination"] for row in shannon_data] if shannon_data is not None else None HHI_scores = data_loader.lookup_element_hhis(symbol) - if HHI_scores == None: + if HHI_scores is None: HHI_scores = (None, None) sse_data = data_loader.lookup_element_sse_data(symbol) - if sse_data: - sse = sse_data["SolidStateEnergy"] - else: - sse = None + sse = sse_data["SolidStateEnergy"] if sse_data else None sse_Pauling_data = data_loader.lookup_element_sse_pauling_data(symbol) - if sse_Pauling_data: - sse_Pauling = sse_Pauling_data["SolidStateEnergyPauling"] - else: - sse_Pauling = None + sse_Pauling = sse_Pauling_data["SolidStateEnergyPauling"] if sse_Pauling_data else None for attribute, value in ( ("coord_envs", coord_envs), @@ -181,7 +174,7 @@ def __init__( class Species(Element): """ - Class providing data for elements in a given chemical environment + Class providing data for elements in a given chemical environment. In addition to the standard properties from the periodic table (inherited from the Element class), Species objects use the @@ -192,6 +185,7 @@ class Species(Element): Baloch, A.A., Alqahtani, S.M., Mumtaz, F., Muqaibel, A.H., Rashkeev, S.N. and Alharbi, F.H., 2021. Extending Shannon's ionic radii database using machine learning. Physical Review Materials, 5(4), p.043804. Attributes: + ---------- Species.symbol: Elemental symbol used to retrieve data Species.name: Full name of element @@ -218,6 +212,7 @@ class Species(Element): Species.average_ionic_radius: An average ionic radius for the species. The average is taken over all coordination environments. Raises: + ------ NameError: Element not found in element.txt Warning: Element not found in Eigenvalues.csv @@ -230,6 +225,18 @@ def __init__( coordination: int = 4, radii_source: str = "shannon", ): + """ + Initialise Species class. + + Args: + ---- + symbol (str): Chemical element symbol (e.g. 'Fe') + oxidation (int): Oxidation state of species + coordination (int): Coordination number of species + radii_source (str): Source of shannon radii data. Choose 'shannon' for + the default shannon radii set or 'extended' for the machine-learnt shannon radii set + + """ Element.__init__(self, symbol) self.oxidation = oxidation @@ -240,29 +247,17 @@ def __init__( self.shannon_radius = None if radii_source == "shannon": - shannon_data = data_loader.lookup_element_shannon_radius_data( - symbol - ) + shannon_data = data_loader.lookup_element_shannon_radius_data(symbol) elif radii_source == "extended": - shannon_data = ( - data_loader.lookup_element_shannon_radius_data_extendedML( - symbol - ) - ) + shannon_data = data_loader.lookup_element_shannon_radius_data_extendedML(symbol) else: - print( - "Data source not recognised. Please select 'shannon' or 'extended'. " - ) + print("Data source not recognised. Please select 'shannon' or 'extended'. ") if shannon_data: for dataset in shannon_data: - if ( - dataset["charge"] == oxidation - and str(coordination) - == dataset["coordination"].split("_")[0] - ): + if dataset["charge"] == oxidation and str(coordination) == dataset["coordination"].split("_")[0]: self.shannon_radius = dataset["crystal_radius"] # Get ionic radius @@ -270,11 +265,7 @@ def __init__( if shannon_data: for dataset in shannon_data: - if ( - dataset["charge"] == oxidation - and str(coordination) - == dataset["coordination"].split("_")[0] - ): + if dataset["charge"] == oxidation and str(coordination) == dataset["coordination"].split("_")[0]: self.ionic_radius = dataset["ionic_radius"] # Get the average shannon and ionic radii @@ -286,9 +277,7 @@ def __init__( shannon_data_df = pd.DataFrame(shannon_data) # Get the rows corresponding to the oxidation state of the species - charge_rows = shannon_data_df.loc[ - shannon_data_df["charge"] == oxidation - ] + charge_rows = shannon_data_df.loc[shannon_data_df["charge"] == oxidation] # Get the mean self.average_shannon_radius = charge_rows["crystal_radius"].mean() @@ -307,13 +296,13 @@ def __init__( self.SSE_2015 = None -def ordered_elements(x: int, y: int) -> List[str]: +def ordered_elements(x: int, y: int) -> list[str]: """ Return a list of element symbols, ordered by proton number in the range x -> y Args: x,y : integers Returns: - list: Ordered list of element symbols + list: Ordered list of element symbols. """ with open(path.join(data_directory, "ordered_periodic.txt")) as f: data = f.readlines() @@ -330,45 +319,50 @@ def ordered_elements(x: int, y: int) -> List[str]: def element_dictionary( - elements: Optional[Iterable[str]] = None, - oxi_states_custom_filepath: Optional[str] = None, + elements: Iterable[str] | None = None, + oxi_states_custom_filepath: str | None = None, ): """ - Create a dictionary of initialised smact.Element objects + Create a dictionary of initialised smact.Element objects. Accessing an Element from a dict is significantly faster than repeadedly initialising them on-demand within nested loops. Args: + ---- elements (iterable of strings) : Elements to include. If None, use all elements up to 103. oxi_states_custom_filepath (str): Path to custom oxidation states file Returns: + ------- dict: Dictionary with element symbols as keys and smact.Element objects as data + """ - if elements == None: + if elements is None: elements = ordered_elements(1, 103) if oxi_states_custom_filepath: - return { - symbol: Element(symbol, oxi_states_custom_filepath) - for symbol in elements - } + return {symbol: Element(symbol, oxi_states_custom_filepath) for symbol in elements} else: return {symbol: Element(symbol) for symbol in elements} def are_eq(A: list, B: list, tolerance: float = 1e-4): - """Check two arrays for tolerance [1,2,3]==[1,2,3]; but [1,3,2]!=[1,2,3] + """ + Check two arrays for tolerance [1,2,3]==[1,2,3]; but [1,3,2]!=[1,2,3]. Args: - A, B (lists): 1-D list of values for approximate equality comparison - tolerance: numerical precision for equality condition + ---- + A (list): 1-D list of values for approximate equality comparison + B (list): 1-D list of values for approximate equality comparison + tolerance (float): numerical precision for equality condition Returns: + ------- boolean + """ are_eq = True if len(A) != len(B): @@ -382,54 +376,60 @@ def are_eq(A: list, B: list, tolerance: float = 1e-4): return are_eq -def lattices_are_same(lattice1, lattice2, tolerance=1e-4): - """Checks for the equivalence of two lattices +def lattices_are_same(lattice1, lattice2, tolerance: float = 1e-4): + """ + Checks for the equivalence of two lattices. Args: - lattice1,lattice2 : ASE crystal class + ---- + lattice1: ASE crystal class + lattice2: ASE crystal class + tolerance (float): numerical precision for equality condition + Returns: + ------- boolean + """ lattices_are_same = False i = 0 for site1 in lattice1: for site2 in lattice2: - if site1.symbol == site2.symbol: - if are_eq(site1.position, site2.position, tolerance=tolerance): - i += 1 + if site1.symbol == site2.symbol and are_eq(site1.position, site2.position, tolerance=tolerance): + i += 1 if i == len(lattice1): lattices_are_same = True return lattices_are_same def _gcd_recursive(*args: Iterable[int]): - """ - Get the greatest common denominator among any number of ints - """ + """Get the greatest common denominator among any number of ints.""" if len(args) == 2: return gcd(*args) else: return gcd(args[0], _gcd_recursive(*args[1:])) -def _isneutral(oxidations: Tuple[int, ...], stoichs: Tuple[int, ...]): +def _isneutral(oxidations: tuple[int, ...], stoichs: tuple[int, ...]): """ - Check if set of oxidation states is neutral in given stoichiometry + Check if set of oxidation states is neutral in given stoichiometry. Args: + ---- oxidations (tuple): Oxidation states of a set of oxidised elements stoichs (tuple): Stoichiometry values corresponding to `oxidations` + """ - return 0 == sum(map(multiply, oxidations, stoichs)) + return sum(map(multiply, oxidations, stoichs)) == 0 def neutral_ratios_iter( - oxidations: List[int], - stoichs: Union[bool, List[List[int]]] = False, - threshold: Optional[int] = 5, + oxidations: list[int], + stoichs: bool | list[list[int]] = False, + threshold: int | None = 5, ): """ - Iterator for charge-neutral stoichiometries + Iterator for charge-neutral stoichiometries. Given a list of oxidation states of arbitrary length, yield ratios in which these form a charge-neutral compound. Stoichiometries may be provided as a @@ -437,12 +437,15 @@ def neutral_ratios_iter( otherwise all unique ratios are tried up to a threshold coefficient. Args: + ---- oxidations : list of integers stoichs : stoichiometric ratios for each site (if provided) threshold : single threshold to go up to if stoichs are not provided Yields: + ------ tuple: ratio that gives neutrality + """ if not stoichs: stoichs = [list(range(1, threshold + 1))] * len(oxidations) @@ -458,12 +461,12 @@ def neutral_ratios_iter( def neutral_ratios( - oxidations: List[int], - stoichs: Union[bool, List[List[int]]] = False, + oxidations: list[int], + stoichs: bool | list[list[int]] = False, threshold=5, ): """ - Get a list of charge-neutral compounds + Get a list of charge-neutral compounds. Given a list of oxidation states of arbitrary length, yield ratios in which these form a charge-neutral compound. Stoichiometries may be provided as a @@ -475,6 +478,7 @@ def neutral_ratios( threshold. Args: + ---- oxidations (list of ints): Oxidation state of each site stoichs (list of positive ints): A selection of valid stoichiometric ratios for each site @@ -483,6 +487,7 @@ def neutral_ratios( to this value will be tried. Returns: + ------- (exists, allowed_ratios) (tuple): exists *bool*: @@ -491,13 +496,9 @@ def neutral_ratios( allowed_ratios *list of tuples*: Ratios of atoms in given oxidation states which yield a charge-neutral structure + """ - allowed_ratios = [ - x - for x in neutral_ratios_iter( - oxidations, stoichs=stoichs, threshold=threshold - ) - ] + allowed_ratios = list(neutral_ratios_iter(oxidations, stoichs=stoichs, threshold=threshold)) return (len(allowed_ratios) > 0, allowed_ratios) diff --git a/smact/benchmarking/pymatgen_benchmark.py b/smact/benchmarking/pymatgen_benchmark.py index 110983a4..8ce27db6 100644 --- a/smact/benchmarking/pymatgen_benchmark.py +++ b/smact/benchmarking/pymatgen_benchmark.py @@ -1,5 +1,7 @@ """Benchmarking functions for pymatgen.""" +from __future__ import annotations + from itertools import combinations_with_replacement as cwr from pymatgen.analysis.structure_prediction.substitution_probability import ( @@ -34,4 +36,5 @@ def __pair_corr(self): @timeit(delim=True, n=100) def probability_test_run(): + """Run all tests.""" ProbabilityBenchmarker().run_tests() diff --git a/smact/benchmarking/smact_benchmark.py b/smact/benchmarking/smact_benchmark.py index d60bece1..a9573e51 100644 --- a/smact/benchmarking/smact_benchmark.py +++ b/smact/benchmarking/smact_benchmark.py @@ -1,6 +1,9 @@ """SMACT benchmarking.""" -from ..structure_prediction.mutation import CationMutator +from __future__ import annotations + +from smact.structure_prediction.mutation import CationMutator + from .utilities import timeit @@ -26,4 +29,5 @@ def __pair_corr(self): @timeit(delim=True, n=100) def mutator_test_run(): + """Run benchmark tests for CationMutator.""" MutatorBenchmarker().run_tests() diff --git a/smact/benchmarking/utilities.py b/smact/benchmarking/utilities.py index 05ad2a8a..a52bd921 100644 --- a/smact/benchmarking/utilities.py +++ b/smact/benchmarking/utilities.py @@ -1,5 +1,7 @@ """Benchmarking utilities.""" +from __future__ import annotations + import functools import logging from statistics import mean @@ -25,9 +27,7 @@ def wrapper_timeit(*args, **kwargs): value = func(*args, **kwargs) times.append(time() - t0) - logging.info( - f"{func.__name__} -- Average over {n} repeats = {mean(times)}s" - ) + logging.info(f"{func.__name__} -- Average over {n} repeats = {mean(times)}s") if delim: logging.info("-" * DELIM_LENGTH) @@ -35,7 +35,7 @@ def wrapper_timeit(*args, **kwargs): return wrapper_timeit - fname = "benchmark.log" if not fname else fname + fname = fname if fname else "benchmark.log" logging.basicConfig(filename=fname) if _func is None: diff --git a/smact/builder.py b/smact/builder.py index d79695bc..0c99f03e 100644 --- a/smact/builder.py +++ b/smact/builder.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +"""A collection of functions for building certain lattice types.""" # Using the ase spacegroup module this can build the structure, from # the composition, as defined in the smact_lattice module. # TODO: @@ -6,27 +6,33 @@ # Add further types, Spinnel, Flourite, Delafossite .... # Implement Structure class, c.f. dev_docs. +from __future__ import annotations from ase.spacegroup import crystal from smact.lattice import Lattice, Site -def cubic_perovskite( - species, cell_par=[6, 6, 6, 90, 90, 90], repetitions=[1, 1, 1] -): +def cubic_perovskite(species, cell_par=None, repetitions=None): """ Build a perovskite cell using the crystal function in ASE. Args: + ---- species (str): Element symbols cell_par (list): Six floats/ints specifying 3 unit cell lengths and 3 unit cell angles. repetitions (list): Three floats specifying the expansion of the cell in x,y,z directions. + Returns: + ------- SMACT Lattice object of the unit cell, ASE crystal system of the unit cell. """ + if repetitions is None: + repetitions = [1, 1, 1] + if cell_par is None: + cell_par = [6, 6, 6, 90, 90, 90] system = crystal( (species), basis=[(0, 0, 0), (0.5, 0.5, 0.5), (0.5, 0.5, 0)], @@ -43,19 +49,26 @@ def cubic_perovskite( return Lattice(sites_list, oxidation_states), system -def wurtzite(species, cell_par=[2, 2, 6, 90, 90, 120], repetitions=[1, 1, 1]): +def wurtzite(species, cell_par=None, repetitions=None): """ Build a wurzite cell using the crystal function in ASE. Args: + ---- species (str): Element symbols cell_par (list): Six floats/ints specifying 3 unit cell lengths and 3 unit cell angles. repetitions (list): Three floats specifying the expansion of the cell in x,y,z directions. + Returns: + ------- SMACT Lattice object of the unit cell, ASE crystal system of the unit cell. """ + if repetitions is None: + repetitions = [1, 1, 1] + if cell_par is None: + cell_par = [2, 2, 6, 90, 90, 120] system = crystal( (species), basis=[(2.0 / 3.0, 1.0 / 3.0, 0), (2.0 / 3.0, 1.0 / 3.0, 5.0 / 8.0)], @@ -65,7 +78,7 @@ def wurtzite(species, cell_par=[2, 2, 6, 90, 90, 120], repetitions=[1, 1, 1]): ) sites_list = [] - oxidation_states = [[1], [2], [3], [4]] + [[-1], [-2], [-3], [-4]] + oxidation_states = [[1], [2], [3], [4], [-1], [-2], [-3], [-4]] for site in zip(system.get_scaled_positions(), oxidation_states): sites_list.append(Site(site[0], site[1])) diff --git a/smact/data_loader.py b/smact/data_loader.py index 79e0dd8a..f7250b3f 100644 --- a/smact/data_loader.py +++ b/smact/data_loader.py @@ -10,6 +10,8 @@ are used in the background and it is not necessary to use them directly. """ +from __future__ import annotations + import csv import os @@ -21,22 +23,24 @@ def set_warnings(enable=True): - """Set verbose warning messages on and off. + """ + Set verbose warning messages on and off. In order to see any of the warnings, this function needs to be called _before_ the first call to the smact.Element() constructor. Args: + ---- enable (bool) : print verbose warning messages. - """ + """ global _print_warnings _print_warnings = enable def _get_data_rows(filename): - """Generator for datafile entries by row""" + """Generator for datafile entries by row.""" with open(filename) as file: for line in file: line = line.strip() @@ -63,6 +67,7 @@ def lookup_element_oxidation_states(symbol, copy=True): most exhaustive list. Args: + ---- symbol (str) : the atomic symbol of the element to look up. copy (Optional(bool)): if True (default), return a copy of the oxidation-state list, rather than a reference to the cached @@ -70,23 +75,20 @@ def lookup_element_oxidation_states(symbol, copy=True): and where the list will not be modified! Returns: + ------- list: List of known oxidation states for the element. Returns None if oxidation states for the Element were not found in the external data. - """ + """ global _el_ox_states if _el_ox_states is None: _el_ox_states = {} - for items in _get_data_rows( - os.path.join(data_directory, "oxidation_states.txt") - ): - _el_ox_states[items[0]] = [ - int(oxidationState) for oxidationState in items[1:] - ] + for items in _get_data_rows(os.path.join(data_directory, "oxidation_states.txt")): + _el_ox_states[items[0]] = [int(oxidationState) for oxidationState in items[1:]] if symbol in _el_ox_states: if copy: @@ -94,15 +96,12 @@ def lookup_element_oxidation_states(symbol, copy=True): # deep copy. The elements of the lists are integers, which are # "value types" in Python. - return [oxidationState for oxidationState in _el_ox_states[symbol]] + return list(_el_ox_states[symbol]) else: return _el_ox_states[symbol] else: if _print_warnings: - print( - "WARNING: Oxidation states for element {} " - "not found.".format(symbol) - ) + print(f"WARNING: Oxidation states for element {symbol} " "not found.") return None @@ -116,6 +115,7 @@ def lookup_element_oxidation_states_icsd(symbol, copy=True): in the ICSD (and judged to be non-spurious). Args: + ---- symbol (str) : the atomic symbol of the element to look up. copy (Optional(bool)): if True (default), return a copy of the oxidation-state list, rather than a reference to the cached @@ -123,39 +123,31 @@ def lookup_element_oxidation_states_icsd(symbol, copy=True): and where the list will not be modified! Returns: + ------- list: List of known oxidation states for the element. Return None if oxidation states for the Element were not found in the external data. - """ + """ global _el_ox_states_icsd if _el_ox_states_icsd is None: _el_ox_states_icsd = {} - for items in _get_data_rows( - os.path.join(data_directory, "oxidation_states_icsd.txt") - ): - _el_ox_states_icsd[items[0]] = [ - int(oxidationState) for oxidationState in items[1:] - ] + for items in _get_data_rows(os.path.join(data_directory, "oxidation_states_icsd.txt")): + _el_ox_states_icsd[items[0]] = [int(oxidationState) for oxidationState in items[1:]] if symbol in _el_ox_states_icsd: if copy: # _el_ox_states_icsd stores lists -> if copy is set, make an implicit # deep copy. The elements of the lists are integers, which are # "value types" in Python. - return [ - oxidationState for oxidationState in _el_ox_states_icsd[symbol] - ] + return list(_el_ox_states_icsd[symbol]) else: return _el_ox_states_icsd[symbol] else: if _print_warnings: - print( - "WARNING: Oxidation states for element {}" - "not found.".format(symbol) - ) + print(f"WARNING: Oxidation states for element {symbol}" "not found.") return None @@ -169,6 +161,7 @@ def lookup_element_oxidation_states_sp(symbol, copy=True): are in the Pymatgen default lambda table for structure prediction. Args: + ---- symbol (str) : the atomic symbol of the element to look up. copy (Optional(bool)): if True (default), return a copy of the oxidation-state list, rather than a reference to the cached @@ -176,23 +169,20 @@ def lookup_element_oxidation_states_sp(symbol, copy=True): and where the list will not be modified! Returns: + ------- list: List of known oxidation states for the element. Return None if oxidation states for the Element were not found in the external data. - """ + """ global _el_ox_states_sp if _el_ox_states_sp is None: _el_ox_states_sp = {} - for items in _get_data_rows( - os.path.join(data_directory, "oxidation_states_SP.txt") - ): - _el_ox_states_sp[items[0]] = [ - int(oxidationState) for oxidationState in items[1:] - ] + for items in _get_data_rows(os.path.join(data_directory, "oxidation_states_SP.txt")): + _el_ox_states_sp[items[0]] = [int(oxidationState) for oxidationState in items[1:]] if symbol in _el_ox_states_sp: if copy: @@ -200,17 +190,12 @@ def lookup_element_oxidation_states_sp(symbol, copy=True): # deep copy. The elements of the lists are integers, which are # "value types" in Python. - return [ - oxidationState for oxidationState in _el_ox_states_sp[symbol] - ] + return list(_el_ox_states_sp[symbol]) else: return _el_ox_states_sp[symbol] else: if _print_warnings: - print( - "WARNING: Oxidation states for element {} " - "not found.".format(symbol) - ) + print(f"WARNING: Oxidation states for element {symbol} " "not found.") return None @@ -224,6 +209,7 @@ def lookup_element_oxidation_states_wiki(symbol, copy=True): are on Wikipedia (https://en.wikipedia.org/wiki/Template:List_of_oxidation_states_of_the_elements). Args: + ---- symbol (str) : the atomic symbol of the element to look up. copy (Optional(bool)): if True (default), return a copy of the oxidation-state list, rather than a reference to the cached @@ -231,23 +217,20 @@ def lookup_element_oxidation_states_wiki(symbol, copy=True): and where the list will not be modified! Returns: + ------- list: List of known oxidation states for the element. Return None if oxidation states for the Element were not found in the external data. - """ + """ global _el_ox_states_wiki if _el_ox_states_wiki is None: _el_ox_states_wiki = {} - for items in _get_data_rows( - os.path.join(data_directory, "oxidation_states_wiki.txt") - ): - _el_ox_states_wiki[items[0]] = [ - int(oxidationState) for oxidationState in items[1:] - ] + for items in _get_data_rows(os.path.join(data_directory, "oxidation_states_wiki.txt")): + _el_ox_states_wiki[items[0]] = [int(oxidationState) for oxidationState in items[1:]] if symbol in _el_ox_states_wiki: if copy: @@ -255,17 +238,12 @@ def lookup_element_oxidation_states_wiki(symbol, copy=True): # deep copy. The elements of the lists are integers, which are # "value types" in Python. - return [ - oxidationState for oxidationState in _el_ox_states_wiki[symbol] - ] + return list(_el_ox_states_wiki[symbol]) else: return _el_ox_states_wiki[symbol] else: if _print_warnings: - print( - "WARNING: Oxidation states for element {} " - "not found.".format(symbol) - ) + print(f"WARNING: Oxidation states for element {symbol} " "not found.") return None @@ -278,28 +256,30 @@ def lookup_element_oxidation_states_custom(symbol, filepath, copy=True): The oxidation states list is specified by the user in a text file. Args: + ---- symbol (str) : the atomic symbol of the element to look up. + filepath (str) : the path to the text file containing the + oxidation states data. copy (Optional(bool)): if True (default), return a copy of the oxidation-state list, rather than a reference to the cached data -- only use copy=False in performance-sensitive code and where the list will not be modified! Returns: + ------- list: List of known oxidation states for the element. Return None if oxidation states for the Element were not found in the external data. - """ + """ global _el_ox_states_custom if _el_ox_states_custom is None: _el_ox_states_custom = {} for items in _get_data_rows(filepath): - _el_ox_states_custom[items[0]] = [ - int(oxidationState) for oxidationState in items[1:] - ] + _el_ox_states_custom[items[0]] = [int(oxidationState) for oxidationState in items[1:]] if symbol in _el_ox_states_custom: if copy: @@ -307,18 +287,12 @@ def lookup_element_oxidation_states_custom(symbol, filepath, copy=True): # deep copy. The elements of the lists are integers, which are # "value types" in Python. - return [ - oxidationState - for oxidationState in _el_ox_states_custom[symbol] - ] + return list(_el_ox_states_custom[symbol]) else: return _el_ox_states_custom[symbol] else: if _print_warnings: - print( - "WARNING: Oxidation states for element {} " - "not found.".format(symbol) - ) + print(f"WARNING: Oxidation states for element {symbol} " "not found.") return None @@ -332,15 +306,17 @@ def lookup_element_hhis(symbol): Retrieve the HHI_R and HHI_p scores for an element. Args: + ---- symbol : the atomic symbol of the element to look up. Returns: + ------- tuple : (HHI_p, HHI_R) Return None if values for the elements were not found in the external data. - """ + """ global _element_hhis if _element_hhis is None: @@ -362,9 +338,7 @@ def lookup_element_hhis(symbol): return _element_hhis[symbol] else: if _print_warnings: - print( - "WARNING: HHI data for element " "{} not found.".format(symbol) - ) + print(f"WARNING: HHI data for element {symbol} not found.") return None @@ -374,7 +348,7 @@ def lookup_element_hhis(symbol): _element_data = None -def lookup_element_data(symbol, copy=True): +def lookup_element_data(symbol: str, copy: bool = True): """ Retrieve tabulated data for an element. @@ -384,16 +358,18 @@ def lookup_element_data(symbol, copy=True): constructed from the data table and cached before returning it. Args: + ---- symbol (str) : Atomic symbol for lookup - - copy (Optional(bool)) : if True (default), return a copy of the + copy (bool) : if True (default), return a copy of the data dictionary, rather than a reference to the cached object -- only used copy=False in performance-sensitive code and where you are certain the dictionary will not be modified! - Returns (dict) : Dictionary of data for given element, keyed by - column headings from data/element_data.txt. + Returns: + ------- + dict: Dictionary of data for given element, keyed by column headings from data/element_data.txt. + """ global _element_data if _element_data is None: @@ -412,9 +388,7 @@ def lookup_element_data(symbol, copy=True): "ion_pot", "dipol", ) - for items in _get_data_rows( - os.path.join(data_directory, "element_data.txt") - ): + for items in _get_data_rows(os.path.join(data_directory, "element_data.txt")): # First two columns are strings and should be left intact # Everything else is numerical and should be cast to a float # or, if not clearly a number, to None @@ -435,7 +409,7 @@ def lookup_element_data(symbol, copy=True): return _element_data[symbol] else: if _print_warnings: - print("WARNING: Elemental data for {}" " not found.".format(symbol)) + print(f"WARNING: Elemental data for {symbol} not found.") print(_element_data) return None @@ -453,6 +427,7 @@ def lookup_element_shannon_radius_data(symbol, copy=True): environments of an element. Args: + ---- symbol (str) : the atomic symbol of the element to look up. copy (Optional(bool)): if True (default), return a copy of the data @@ -461,6 +436,7 @@ def lookup_element_shannon_radius_data(symbol, copy=True): you are certain the dictionary will not be modified! Returns: + ------- list: Shannon radii datasets. @@ -479,8 +455,8 @@ def lookup_element_shannon_radius_data(symbol, copy=True): *float* comment *str* - """ + """ global _element_shannon_radii_data if _element_shannon_radii_data is None: @@ -525,10 +501,7 @@ def lookup_element_shannon_radius_data(symbol, copy=True): return _element_shannon_radii_data[symbol] else: if _print_warnings: - print( - "WARNING: Shannon-radius data for element {} not " - "found.".format(symbol) - ) + print(f"WARNING: Shannon-radius data for element {symbol} not " "found.") return None @@ -554,6 +527,7 @@ def lookup_element_shannon_radius_data_extendedML(symbol, copy=True): arXiv preprint arXiv:2101.00269. Args: + ---- symbol (str) : the atomic symbol of the element to look up. copy (Optional(bool)): if True (default), return a copy of the data @@ -562,6 +536,7 @@ def lookup_element_shannon_radius_data_extendedML(symbol, copy=True): you are certain the dictionary will not be modified! Returns: + ------- list: Extended Shannon radii datasets. @@ -578,16 +553,14 @@ def lookup_element_shannon_radius_data_extendedML(symbol, copy=True): *float* comment *str* - """ + """ global _element_shannon_radii_data_extendedML if _element_shannon_radii_data_extendedML is None: _element_shannon_radii_data_extendedML = {} - with open( - os.path.join(data_directory, "shannon_radii_ML_extended.csv") - ) as file: + with open(os.path.join(data_directory, "shannon_radii_ML_extended.csv")) as file: reader = csv.reader(file) # Skip the first row (headers). @@ -621,18 +594,12 @@ def lookup_element_shannon_radius_data_extendedML(symbol, copy=True): # function on each element. # The dictionary values are all Python "value types", so # nothing further is required to make a deep copy. - return [ - item.copy() - for item in _element_shannon_radii_data_extendedML[symbol] - ] + return [item.copy() for item in _element_shannon_radii_data_extendedML[symbol]] else: return _element_shannon_radii_data_extendedML[symbol] else: if _print_warnings: - print( - "WARNING: Extended Shannon-radius data for element {} not " - "found.".format(symbol) - ) + print(f"WARNING: Extended Shannon-radius data for element {symbol} not " "found.") return None @@ -650,9 +617,11 @@ def lookup_element_sse_data(symbol): DOI: 10.1021/ja204670s Args: + ---- symbol : the atomic symbol of the element to look up. Returns: + ------- list : SSE datasets for the element, or None if the element was not found among the external data. @@ -670,8 +639,8 @@ def lookup_element_sse_data(symbol): *str* SolidStateRenormalisationEnergy *float* - """ + """ global _element_ssedata if _element_ssedata is None: @@ -696,10 +665,7 @@ def lookup_element_sse_data(symbol): return _element_ssedata[symbol] else: if _print_warnings: - print( - "WARNING: Solid-state energy data for element {} not" - " found.".format(symbol) - ) + print(f"WARNING: Solid-state energy data for element {symbol} not" " found.") return None @@ -719,6 +685,7 @@ def lookup_element_sse2015_data(symbol, copy=True): pp138-144, DOI: 10.1016/j.jssc.2015.07.037. Args: + ---- symbol : the atomic symbol of the element to look up. copy: if True (default), return a copy of the data dictionary, rather than a reference to a cached object -- only use @@ -726,6 +693,7 @@ def lookup_element_sse2015_data(symbol, copy=True): certain the dictionary will not be modified! Returns: + ------- list : SSE datasets for the element, or None if the element was not found among the external data. @@ -735,8 +703,8 @@ def lookup_element_sse2015_data(symbol, copy=True): *int* SolidStateEnergy2015 *float* SSE2015 - """ + """ global _element_sse2015_data if _element_sse2015_data is None: @@ -768,10 +736,7 @@ def lookup_element_sse2015_data(symbol, copy=True): return _element_sse2015_data[symbol] else: if _print_warnings: - print( - "WARNING: Solid-state energy (revised 2015) data for " - "element {} not found.".format(symbol) - ) + print(f"WARNING: Solid-state energy (revised 2015) data for element {symbol} not found.") return None @@ -783,7 +748,8 @@ def lookup_element_sse2015_data(symbol, copy=True): def lookup_element_sse_pauling_data(symbol): - """Retrieve Pauling SSE data + """ + Retrieve Pauling SSE data. Retrieve the solid-state energy (SSEPauling) data for an element from the regression fit when SSE2015 is plotted against Pauling @@ -791,13 +757,14 @@ def lookup_element_sse_pauling_data(symbol): pp138-144, DOI: 10.1016/j.jssc.2015.07.037 Args: + ---- symbol (str) : the atomic symbol of the element to look up. Returns: A dictionary containing the SSE2015 dataset for the element, or None if the element was not found among the external data. - """ + """ global _element_ssepauling_data if _element_ssepauling_data is None: @@ -818,7 +785,7 @@ def lookup_element_sse_pauling_data(symbol): print( "WARNING: Solid-state energy data from Pauling " " electronegativity regression fit for " - " element {} not found.".format(symbol) + f" element {symbol} not found." ) return None diff --git a/smact/distorter.py b/smact/distorter.py index a5caf05d..1a170d13 100644 --- a/smact/distorter.py +++ b/smact/distorter.py @@ -9,6 +9,8 @@ for equivalence. """ +from __future__ import annotations + import copy import smact @@ -29,28 +31,33 @@ def get_sg(lattice): Get the space-group of the system. Args: + ---- lattice: the ASE crystal class Returns: sg (int): integer number of the spacegroup + """ spacegroup = spglib.get_spacegroup(lattice, symprec=1e-5) space_split = spacegroup.split() spg_num = space_split[1].replace("(", "").replace(")", "") - sg = Spacegroup(int(spg_num)) - return sg + return Spacegroup(int(spg_num)) def get_inequivalent_sites(sub_lattice, lattice): - """Given a sub lattice, returns symmetry unique sites for substitutions. + """ + Given a sub lattice, returns symmetry unique sites for substitutions. Args: + ---- sub_lattice (list of lists): array containing Cartesian coordinates of the sub-lattice of interest lattice (ASE crystal): the total lattice Returns: + ------- List of sites + """ sg = get_sg(lattice) inequivalent_sites = [] @@ -59,30 +66,34 @@ def get_inequivalent_sites(sub_lattice, lattice): # Check against the existing members of the list of inequivalent sites if len(inequivalent_sites) > 0: for inequiv_site in inequivalent_sites: - if smact.are_eq(site, inequiv_site) == True: + if smact.are_eq(site, inequiv_site) is True: new_site = False # Check against symmetry related members of the list of inequivalent sites equiv_inequiv_sites, _ = sg.equivalent_sites(inequiv_site) for equiv_inequiv_site in equiv_inequiv_sites: - if smact.are_eq(site, equiv_inequiv_site) == True: + if smact.are_eq(site, equiv_inequiv_site) is True: new_site = False - if new_site == True: + if new_site is True: inequivalent_sites.append(site) return inequivalent_sites def make_substitution(lattice, site, new_species): - """Change atomic species on lattice site to new_species. + """ + Change atomic species on lattice site to new_species. Args: + ---- lattice (ASE crystal): Input lattice site (list): Cartesian coordinates of the substitution site new_species (str): New species Returns: + ------- lattice + """ i = 0 # NBNBNBNB It is necessary to use deepcopy for objects, otherwise changes applied to a clone @@ -97,17 +108,20 @@ def make_substitution(lattice, site, new_species): def build_sub_lattice(lattice, symbol): - """Generate a sub-lattice of the lattice based on equivalent atomic species. + """ + Generate a sub-lattice of the lattice based on equivalent atomic species. Args: + ---- lattice (ASE crystal class): Input lattice symbol (string): Symbol of species identifying sub-lattice Returns: + ------- list of lists: sub_lattice: Cartesian coordinates of the sub-lattice of symbol - """ + """ sub_lattice = [] i = 0 atomic_labels = lattice.get_chemical_symbols() diff --git a/smact/dopant_prediction/__init__.py b/smact/dopant_prediction/__init__.py index 9f21268f..e4323a9c 100644 --- a/smact/dopant_prediction/__init__.py +++ b/smact/dopant_prediction/__init__.py @@ -1,5 +1,7 @@ """Minimalist dopant prediction tools for materials design.""" +from __future__ import annotations + import logging __author__ = "Chloe (Jiwoo) Lee (이지우)" diff --git a/smact/dopant_prediction/doper.py b/smact/dopant_prediction/doper.py index edd32a60..94d0ed07 100644 --- a/smact/dopant_prediction/doper.py +++ b/smact/dopant_prediction/doper.py @@ -1,6 +1,12 @@ +"""The dopant prediction module facilitates high-throughput prediction of p-type and n-type dopants +in multi-component solids. The search and ranking process is based on electronic filters +(e.g. accessible oxidation states) and chemical filters (e.g. difference in ionic radius). +""" + +from __future__ import annotations + import os from itertools import groupby -from typing import List, Optional, Tuple, Type import numpy as np from pymatgen.util import plotting @@ -19,20 +25,21 @@ class Doper: """ A class to search for n & p type dopants - Methods: get_dopants, plot_dopants + Methods: get_dopants, plot_dopants. """ def __init__( self, - original_species: Tuple[str, ...], - filepath: Optional[str] = None, - embedding: Optional[str] = None, + original_species: tuple[str, ...], + filepath: str | None = None, + embedding: str | None = None, use_probability: bool = True, ): """ - Initialise the `Doper` class with a tuple of species + Initialise the `Doper` class with a tuple of species. Args: + ---- original_species: See :class:`~.Doper`. filepath (str): Path to a JSON file containing lambda table data. embedding (str): Name of the species embedding to use. Currently only 'skipspecies' is supported. @@ -44,15 +51,11 @@ def __init__( # filepath and embedding are mutually exclusive # check if both are provided if filepath and embedding: - raise ValueError( - "Only one of filepath or embedding should be provided" - ) + raise ValueError("Only one of filepath or embedding should be provided") if embedding and embedding != "skipspecies": raise ValueError(f"Embedding {embedding} is not supported") if embedding: - self.cation_mutator = mutation.CationMutator.from_json( - SKIPSSPECIES_COSINE_SIM_PATH - ) + self.cation_mutator = mutation.CationMutator.from_json(SKIPSSPECIES_COSINE_SIM_PATH) elif filepath: self.cation_mutator = mutation.CationMutator.from_json(filepath) else: @@ -60,18 +63,14 @@ def __init__( self.cation_mutator = mutation.CationMutator.from_json(filepath) self.possible_species = list(self.cation_mutator.specs) self.lambda_threshold = self.cation_mutator.alpha("X", "Y") - self.threshold = ( - 1 - / self.cation_mutator.Z - * np.exp(self.cation_mutator.alpha("X", "Y")) - ) + self.threshold = 1 / self.cation_mutator.Z * np.exp(self.cation_mutator.alpha("X", "Y")) self.use_probability = use_probability self.results = None def _get_selectivity( self, - data_list: List[smact.Element], - cations: List[smact.Element], + data_list: list[smact.Element], + cations: list[smact.Element], sub, ): data = data_list.copy() @@ -83,9 +82,7 @@ def _get_selectivity( sum_prob = sub_prob for cation in cations: if cation != original_specie: - sum_prob += self._calculate_species_sim_prob( - cation, selected_site - ) + sum_prob += self._calculate_species_sim_prob(cation, selected_site) selectivity = sub_prob / sum_prob selectivity = round(selectivity, 2) @@ -99,26 +96,27 @@ def _merge_dicts(self, keys, dopants_list, groupby_list): merged_values = dict() merged_values["sorted"] = dopants for key, value in group.items(): - merged_values[key] = sorted( - value, key=lambda x: x[2], reverse=True - ) + merged_values[key] = sorted(value, key=lambda x: x[2], reverse=True) merged_dict[k] = merged_values return merged_dict def _get_dopants( self, - specie_ions: List[str], + specie_ions: list[str], ion_type: str, ): """ Get possible dopants for a given list of elements and dopants. Args: + ---- specie_ions (List[str]): List of original species (anions or cations) as strings. ion_type (str): Identify which species to check. Returns: + ------- List[str]: List of possible dopants. + """ poss_n_type = set() poss_p_type = set() @@ -145,31 +143,38 @@ def _calculate_species_sim_prob(self, species1, species2): Calculate the similarity/probability between two species. Args: + ---- species1 (str): The first species. species2 (str): The second species. Returns: + ------- float: The similarity between the two species. + """ if self.use_probability: return self.cation_mutator.sub_prob(species1, species2) else: return self.cation_mutator.get_lambda(species1, species2) - def get_dopants( - self, num_dopants: int = 5, get_selectivity=True, group_by_charge=True - ) -> dict: + def get_dopants(self, num_dopants: int = 5, get_selectivity=True, group_by_charge=True) -> dict: """ + Get the top num_dopants dopant suggestions for n- and p-type dopants. + Args: + ---- num_dopants (int): The number of suggestions to return for n- and p-type dopants. get_selectivity (bool): Whether to calculate the selectivity of the dopants. group_by_charge (bool): Whether to group the dopants by charge. + Returns: + ------- (dict): Dopant suggestions, given as a dictionary with keys - "n-type cation substitutions", "p-type cation substitutions", "n-type anion substitutions", "p-type anion substitutions". + "n-type cation substitutions", "p-type cation substitutions", "n-type anion substitutions", "p-type anion substitutions". Examples: - >>> test = Doper(('Ti4+','O2-')) + -------- + >>> test = Doper(("Ti4+", "O2-")) >>> print(test.get_dopants(num_dopants=2)) {'n-type anion substitutions': {'-1': [['F1-', 'O2-', 0.01508116810515677, 1.0], ['Cl1-','O2-', 0.004737202729901607, 1.0]], @@ -247,7 +252,6 @@ def get_dopants( 1.0]]}} """ - cations, anions = [], [] for ion in self.original_species: @@ -334,32 +338,22 @@ def get_dopants( sub = "cation" if i > 1: sub = "anion" - dopants_lists[i] = self._get_selectivity( - dopants_lists[i], cations, sub - ) + dopants_lists[i] = self._get_selectivity(dopants_lists[i], cations, sub) # if groupby - groupby_lists = [ - dict() - ] * 4 # create list of empty dict length of 4 (n-cat, p-cat, n-an, p-an) + groupby_lists = [dict()] * 4 # create list of empty dict length of 4 (n-cat, p-cat, n-an, p-an) # in case group_by_charge = False if group_by_charge: for i, dl in enumerate(dopants_lists): # groupby first element charge dl = sorted(dl, key=lambda x: utilities.parse_spec(x[0])[1]) - grouped_data = groupby( - dl, key=lambda x: utilities.parse_spec(x[0])[1] - ) - grouped_top_data = { - str(k): list(g)[:num_dopants] for k, g in grouped_data - } + grouped_data = groupby(dl, key=lambda x: utilities.parse_spec(x[0])[1]) + grouped_top_data = {str(k): list(g)[:num_dopants] for k, g in grouped_data} groupby_lists[i] = grouped_top_data del grouped_data # select top n elements - dopants_lists = [ - dopants_list[:num_dopants] for dopants_list in dopants_lists - ] + dopants_lists = [dopants_list[:num_dopants] for dopants_list in dopants_lists] keys = [ "n-type cation substitutions", @@ -376,27 +370,24 @@ def get_dopants( def plot_dopants(self, cmap: str = "YlOrRd") -> None: """ Plot the dopant suggestions using the periodic table heatmap. + Args: + ---- cmap (str): The colormap to use for the heatmap. + Returns: + ------- None + """ - assert ( - self.results - ), "Dopants are not calculated. Run get_dopants first." + assert self.results, "Dopants are not calculated. Run get_dopants first." - for dopant_type, dopants in self.results.items(): + for dopants in self.results.values(): # due to selectivity option if self.len_list == 3: - dict_results = { - utilities.parse_spec(x)[0]: y - for x, _, y in dopants.get("sorted") - } + dict_results = {utilities.parse_spec(x)[0]: y for x, _, y in dopants.get("sorted")} else: - dict_results = { - utilities.parse_spec(x)[0]: y - for x, _, y, _ in dopants.get("sorted") - } + dict_results = {utilities.parse_spec(x)[0]: y for x, _, y, _ in dopants.get("sorted")} plotting.periodic_table_heatmap( elemental_data=dict_results, cmap=cmap, @@ -411,6 +402,14 @@ def _format_number(self, num_str): @property def to_table(self): + """ + Print the dopant suggestions in a tabular format. + + Returns: + ------- + None + + """ if not self.results: print("No data available") return @@ -423,9 +422,7 @@ def to_table(self): for k, v in dopants.items(): kind = k if k == "sorted" else self._format_number(k) print("\033[96m" + str(kind) + "\033[0m") - enumerated_data = [ - [i + 1] + sublist for i, sublist in enumerate(v) - ] + enumerated_data = [[i + 1, *sublist] for i, sublist in enumerate(v)] print( tabulate( enumerated_data, diff --git a/smact/lattice.py b/smact/lattice.py index f8f33d7d..a4e5b861 100755 --- a/smact/lattice.py +++ b/smact/lattice.py @@ -1,12 +1,12 @@ #!/usr/bin/env python +"""This module defines the Lattice and Site classes for crystal structures.""" -import numpy as np - -import smact +from __future__ import annotations class Lattice: - """A unique set of Sites. + """ + A unique set of Sites. Lattice objects define a general crystal structure, with a space group and a collection of Site objects. These Site objects have their own fractional @@ -17,6 +17,7 @@ class Lattice: Environment. Attributes: + ---------- basis_sites: A list of Site objects [SiteA, SiteB, SiteC, ...] comprising the basis sites in Cartesian coordinates @@ -27,11 +28,13 @@ class Lattice: Structurbericht identity, if applicable (e.g. 'B1') Methods: + ------- lattice_vector_calc(): """ def __init__(self, sites, space_group=1, strukturbericht=False): + """Initialize the Lattice object.""" self.sites = sites self.space_group = space_group self.strukturbericht = strukturbericht @@ -44,11 +47,15 @@ class Site: The Site object is primarily used within Lattice objects. Attributes: + ---------- position: A list of fractional coordinates [x,y,z] oxidation_states: A list of possible oxidation states e.g. [-1,0,1] """ - def __init__(self, position, oxidation_states=[0]): + def __init__(self, position, oxidation_states=None): + """Initialize the Site object.""" + if oxidation_states is None: + oxidation_states = [0] self.position = position self.oxidation_states = oxidation_states diff --git a/smact/lattice_parameters.py b/smact/lattice_parameters.py index e4be7bf0..067b7456 100644 --- a/smact/lattice_parameters.py +++ b/smact/lattice_parameters.py @@ -1,23 +1,28 @@ -#!/usr/bin/env python """ This module can be used to calculate roughly the lattice parameters of a lattice type, based on the radii of the species on each site. """ + +from __future__ import annotations + import numpy as np def cubic_perovskite(shannon_radius): # Cubic Pervoskite - """The lattice parameters of the cubic perovskite structure. + """ + The lattice parameters of the cubic perovskite structure. Args: + ---- shannon_radius (list) : The radii of the a,b,c ions Returns: + ------- (tuple): float values of lattics constants and angles (a, b, c, alpha, beta, gamma) - """ + """ limiting_factors = [2 * sum(shannon_radius[1:])] a = max(limiting_factors) b = a @@ -30,15 +35,19 @@ def cubic_perovskite(shannon_radius): # Cubic Pervoskite def wurtzite(shannon_radius): - """The lattice parameters of the wurtzite structure. + """ + The lattice parameters of the wurtzite structure. Args: + ---- shannon_radius (list) : The radii of the a,b ions Returns: + ------- (tuple): float values of lattics constants and angles (a, b, c, alpha, beta, gamma) + """ shannon_radius.sort(reverse=True) # Geometry assumes atom A is larger # "Ideal" wurtzite structure @@ -60,24 +69,26 @@ def wurtzite(shannon_radius): # 0.817 is sin(109.6/2) a = 2 * 0.817 * (shannon_radius[0] + shannon_radius[1]) b = a - c = (shannon_radius[0] + shannon_radius[1]) * ( - 2 + 2 * 0.335 - ) # 0.335 is sin(109.6-90) + c = (shannon_radius[0] + shannon_radius[1]) * (2 + 2 * 0.335) # 0.335 is sin(109.6-90) # inner_space = a * (6**0.5) - (4*shannon_radius[0]) return a, b, c, alpha, beta, gamma # A1# def fcc(covalent_radius): - """The lattice parameters of the A1. + """ + The lattice parameters of the A1. - Args: - shannon_radius (list) : The radii of the a ions + Args: + ---- + covalent_radius (list) : The radii of the a ions Returns: + ------- (tuple): float values of lattics constants and angles (a, b, c, alpha, beta, gamma) + """ a = 2 * 2**0.5 * covalent_radius b = 2 * 2**0.5 * covalent_radius @@ -90,15 +101,19 @@ def fcc(covalent_radius): # A2# def bcc(covalent_radius): - """The lattice parameters of the A2. + """ + The lattice parameters of the A2. Args: - shannon_radius (list) : The radii of the a ions + ---- + covalent_radius (list) : The radii of the a ions Returns: + ------- (tuple): float values of lattics constants and angles (a, b, c, alpha, beta, gamma) + """ a = 4 * covalent_radius / np.sqrt(3) b = a @@ -111,15 +126,19 @@ def bcc(covalent_radius): # A3# def hcp(covalent_radius): - """The lattice parameters of the hcp. + """ + The lattice parameters of the hcp. Args: - shannon_radius (list) : The radii of the a ions + ---- + covalent_radius (list) : The radii of the a ions Returns: + ------- (tuple): float values of lattics constants and angles (a, b, c, alpha, beta, gamma) + """ a = 2 * covalent_radius b = a @@ -132,15 +151,19 @@ def hcp(covalent_radius): # A4# def diamond(covalent_radius): - """The lattice parameters of the diamond. + """ + The lattice parameters of the diamond. Args: - shannon_radius (list) : The radii of the a ions + ---- + covalent_radius (list) : The radii of the a ions Returns: + ------- (tuple): float values of lattics constants and angles (a, b, c, alpha, beta, gamma) + """ a = 8 * covalent_radius / np.sqrt(3) b = a @@ -153,15 +176,19 @@ def diamond(covalent_radius): # A5# def bct(covalent_radius): - """The lattice parameters of the bct. + """ + The lattice parameters of the bct. Args: - shannon_radius (list) : The radii of the a ions + ---- + covalent_radius (list) : The radii of the a ions Returns: + ------- (tuple): float values of lattics constants and angles (a, b, c, alpha, beta, gamma) + """ a = 3.86 * covalent_radius b = a @@ -174,15 +201,19 @@ def bct(covalent_radius): # B1 def rocksalt(shannon_radius): - """The lattice parameters of rocksalt. + """ + The lattice parameters of rocksalt. Args: + ---- shannon_radius (list) : The radii of the a,b ions Returns: + ------- (tuple): float values of lattics constants and angles (a, b, c, alpha, beta, gamma) + """ limiting_factors = [ 2 * 2**0.2 * shannon_radius[0], @@ -200,15 +231,19 @@ def rocksalt(shannon_radius): # B2 def b2(shannon_radius): - """The lattice parameters of b2. + """ + The lattice parameters of b2. Args: + ---- shannon_radius (list) : The radii of the a,b ions Returns: + ------- (tuple): float values of lattics constants and angles (a, b, c, alpha, beta, gamma) + """ limiting_factors = [ 2 * (shannon_radius[0] + shannon_radius[0]) / np.sqrt(3), @@ -226,14 +261,19 @@ def b2(shannon_radius): # B3 def zincblende(shannon_radius): - """The lattice parameters of Zinc Blende. - Args: + """ + The lattice parameters of Zinc Blende. + + Args: + ---- shannon_radius (list) : The radii of the a,b ions Returns: + ------- (tuple): float values of lattics constants and angles (a, b, c, alpha, beta, gamma) + """ limiting_factors = [ 2 * (max(shannon_radius) * np.sqrt(2)), @@ -255,15 +295,19 @@ def zincblende(shannon_radius): # B10 def b10(shannon_radius): # Litharge - """The lattice parameters of Litharge + """ + The lattice parameters of Litharge. Args: + ---- shannon_radius (list) : The radii of the a,b ions Returns: + ------- (tuple): float values of lattics constants and angles (a, b, c, alpha, beta, gamma) + """ limiting_factors = [ 4 * (max(shannon_radius)) / np.sqrt(2), @@ -279,15 +323,19 @@ def b10(shannon_radius): # Litharge def stuffed_wurtzite(shannon_radii): - """The stuffed wurtzite structure (e.g. LiGaGe) space group P63/mc. + """ + The stuffed wurtzite structure (e.g. LiGaGe) space group P63/mc. Args: - shannon_radius (list) : The radii of the a,b,c ions + ---- + shannon_radii (list) : The radii of the a,b,c ions Returns: + ------- (tuple): float values of lattics constants and angles (a, b, c, alpha, beta, gamma) + """ rac = shannon_radii[2] + shannon_radii[1] x = rac * np.sin(np.radians(19.5)) diff --git a/smact/mainpage.py b/smact/mainpage.py index a2711070..ea7d9286 100644 --- a/smact/mainpage.py +++ b/smact/mainpage.py @@ -1,5 +1,5 @@ r""" -\mainpage +\mainpage. This is the mainpage of the smact package diff --git a/smact/oxidation_states.py b/smact/oxidation_states.py index 300b28ef..7656a365 100644 --- a/smact/oxidation_states.py +++ b/smact/oxidation_states.py @@ -6,16 +6,17 @@ Prediction - DOI: 10.1039/C8FD00032H. """ +from __future__ import annotations + import json from os import path -from typing import Dict, Optional, Tuple from numpy import mean from pymatgen.core import Structure from pymatgen.core.periodic_table import Species as pmgSpecies from pymatgen.core.periodic_table import get_el_sp -from smact import Element, Species, data_directory +from smact import Species, data_directory class Oxidation_state_probability_finder: @@ -24,21 +25,19 @@ class Oxidation_state_probability_finder: to compute the likelihood of metal species existing in solids in the presence of certain anions. """ - def __init__( - self, probability_table: Optional[Dict[Tuple[str, str], float]] = None - ): + def __init__(self, probability_table: dict[tuple[str, str], float] | None = None): """ + Initialise the oxidation state probability finder. + Args: + ---- probability_table (dict): Lookup table to get probabilities for anion-cation pairs. Must be of the format {(anion,cation): probability, ...} e.g. {('F-1', 'Li1'): 1.0,...}. If none, the default table is loaded from the data directory. + """ - if probability_table == None: - with open( - path.join( - data_directory, "oxidation_state_probability_table.json" - ) - ) as f: + if probability_table is None: + with open(path.join(data_directory, "oxidation_state_probability_table.json")) as f: probability_data = json.load(f) # Put data into the required format probability_table = {} @@ -47,8 +46,8 @@ def __init__( self._probability_table = probability_table # Define set of species for which we have data - included_anions = {i[0] for i in self._probability_table.keys()} - included_cations = {i[1] for i in self._probability_table.keys()} + included_anions = {i[0] for i in self._probability_table} + included_cations = {i[1] for i in self._probability_table} included_species = list(included_anions) + list(included_cations) self._included_species = included_species @@ -60,10 +59,12 @@ def _generate_lookup_key(self, species1: Species, species2: Species): Internal function to generate keys to lookup table. Args: + ---- species1 (smact.Species): Species species2 (smact.Species): Species Returns: + ------- table_key (tuple): For looking up probability in the form (an_key, cat_key). """ @@ -82,82 +83,71 @@ def _generate_lookup_key(self, species1: Species, species2: Species): an_key = "".join([anion.symbol, str(int(anion.oxidation))]) # Check that both the species are included in the probability table - if not all( - elem in self._included_species for elem in [an_key, cat_key] - ): - raise NameError( - f"One or both of [{cat_key}, {an_key}] are not in the probability table." - ) + if not all(elem in self._included_species for elem in [an_key, cat_key]): + raise NameError(f"One or both of [{cat_key}, {an_key}] are not in the probability table.") - table_key = (an_key, cat_key) - return table_key + return (an_key, cat_key) def pair_probability(self, species1: Species, species2: Species) -> float: - """ + r""" Get the anion-cation oxidation state probability for a provided pair of smact Species. i.e. :math:`P_{SA}=\\frac{N_{SX}}{N_{MX}}` in the original paper (DOI:10.1039/C8FD00032H). Args: + ---- species1 (smact.Species): Cation or anion species species2 (smact.Species): Cation or anion species Returns: + ------- prob (float): Species-anion probability """ # Generate lookup table key and use it to look up probability probability_table_key = self._generate_lookup_key(species1, species2) - prob = self._probability_table[probability_table_key] - return prob + return self._probability_table[probability_table_key] def get_included_species(self): - """ - Returns a list of species for which there exists data in the probability table used. - """ + """Returns a list of species for which there exists data in the probability table used.""" return self._included_species - def compound_probability( - self, structure: Structure, ignore_stoichiometry: bool = True - ) -> float: + def compound_probability(self, structure: Structure, ignore_stoichiometry: bool = True) -> float: """ calculate overall probability for structure or composition. Args: + ---- structure (pymatgen.Structure): Compound for which the probability score will be generated. Can also be a list of pymatgen or SMACT Species. ignore_stoichiometry (bool): Whether to weight probabilities by stoichiometry. Defaults to false as decribed in the original paper. Returns: + ------- compound_prob (float): Compound probability - """ + """ # Convert input to list of SMACT Species - if type(structure) == list: + if isinstance(structure, list): if all(isinstance(i, Species) for i in structure): pass elif all(isinstance(i, pmgSpecies) for i in structure): structure = [Species(i.symbol, i.oxi_state) for i in structure] else: - raise TypeError( - "Input requires a list of SMACT or Pymatgen species." - ) - elif type(structure) == Structure: + raise TypeError("Input requires a list of SMACT or Pymatgen species.") + elif isinstance(structure, Structure): species = structure.species if not all(isinstance(i, pmgSpecies) for i in species): raise TypeError("Structure must have oxidation states.") - else: - structure = [ - Species( - get_el_sp(i.species_string).symbol, - get_el_sp(i.species_string).oxi_state, - ) - for i in structure - ] + structure = [ + Species( + get_el_sp(i.species_string).symbol, + get_el_sp(i.species_string).oxi_state, + ) + for i in structure + ] else: - raise TypeError( - "Input requires a list of SMACT or Pymatgen Species or a Structure." - ) + raise TypeError("Input requires a list of SMACT or Pymatgen Species or a Structure.") # Put most electonegative element last in list by sorting by electroneg structure.sort(key=lambda x: x.pauling_eneg) @@ -172,8 +162,5 @@ def compound_probability( species_pairs = list(set(species_pairs)) # Do the maths - pair_probs = [ - self.pair_probability(pair[0], pair[1]) for pair in species_pairs - ] - compound_prob = mean(pair_probs) - return compound_prob + pair_probs = [self.pair_probability(pair[0], pair[1]) for pair in species_pairs] + return mean(pair_probs) diff --git a/smact/properties.py b/smact/properties.py index 55e33700..2dead3e8 100644 --- a/smact/properties.py +++ b/smact/properties.py @@ -1,35 +1,38 @@ -from typing import List, Optional, Union +"""A collection of tools for estimating physical properties based on chemical composition.""" + +from __future__ import annotations import numpy as np import smact -def eneg_mulliken(element: Union[smact.Element, str]) -> float: - """Get Mulliken electronegativity from the IE and EA. +def eneg_mulliken(element: smact.Element | str) -> float: + """ + Get Mulliken electronegativity from the IE and EA. Arguments: - symbol (smact.Element or str): Element object or symbol + --------- + element (smact.Element or str): Element object or symbol Returns: + ------- mulliken (float): Mulliken electronegativity """ - if type(element) == str: + if isinstance(element, str): element = smact.Element(element) - elif type(element) != smact.Element: - raise Exception(f"Unexpected type: {type(element)}") + elif not isinstance(element, smact.Element): + raise TypeError(f"Unexpected type: {type(element)}") - mulliken = (element.ionpot + element.e_affinity) / 2.0 - - return mulliken + return (element.ionpot + element.e_affinity) / 2.0 def band_gap_Harrison( anion: str, cation: str, verbose: bool = False, - distance: Optional[Union[float, str]] = None, + distance: float | str | None = None, ) -> float: """ Estimates the band gap from elemental data. @@ -39,19 +42,19 @@ def band_gap_Harrison( Solids: The Physics of the Chemical Bond". Args: - Anion (str): Element symbol of the dominant anion in the system - - Cation (str): Element symbol of the the dominant cation in the system - Distance (float or str): Nuclear separation between anion and cation + ---- + anion (str): Element symbol of the dominant anion in the system + cation (str): Element symbol of the the dominant cation in the system + distance (float or str): Nuclear separation between anion and cation i.e. sum of ionic radii verbose (bool) : An optional True/False flag. If True, additional - information is printed to the standard output. [Defult: False] + information is printed to the standard output. [Default: False] - Returns : + Returns: + ------- Band_gap (float): Band gap in eV """ - # Set constants hbarsq_over_m = 7.62 @@ -85,11 +88,12 @@ def band_gap_Harrison( def compound_electroneg( verbose: bool = False, - elements: List[Union[str, smact.Element]] = None, - stoichs: List[Union[int, float]] = None, + elements: list[str | smact.Element] | None = None, + stoichs: list[int | float] | None = None, source: str = "Mulliken", ) -> float: - """Estimate electronegativity of compound from elemental data. + """ + Estimate electronegativity of compound from elemental data. Uses Mulliken electronegativity by default, which uses elemental ionisation potentials and electron affinities. Alternatively, can @@ -102,6 +106,7 @@ def compound_electroneg( X_Cu2S = (X_Cu * X_Cu * C_S)^(1/3) Args: + ---- elements (list) : Elements given as standard elemental symbols. stoichs (list) : Stoichiometries, given as integers or floats. verbose (bool) : An optional True/False flag. If True, additional information @@ -111,17 +116,16 @@ def compound_electroneg( rescaled to a Mulliken-like scale. Returns: + ------- Electronegativity (float) : Estimated electronegativity (no units). """ - if type(elements[0]) == str: + if isinstance(elements[0], str): elementlist = [smact.Element(i) for i in elements] - elif type(elements[0]) == smact.Element: + elif isinstance(elements[0], smact.Element): elementlist = elements else: - raise Exception( - "Please supply a list of element symbols or SMACT Element objects" - ) + raise TypeError("Please supply a list of element symbols or SMACT Element objects") stoichslist = stoichs # Convert stoichslist from string to float @@ -135,9 +139,7 @@ def compound_electroneg( elif source == "Pauling": elementlist = [(2.86 * el.pauling_eneg) for el in elementlist] else: - raise Exception( - f"Electronegativity type '{source}'", "is not recognised" - ) + raise Exception(f"Electronegativity type '{source}'", "is not recognised") # Print optional list of element electronegativities. # This may be a useful sanity check in case of a suspicious result. @@ -146,7 +148,7 @@ def compound_electroneg( # Raise each electronegativity to its appropriate power # to account for stoichiometry. - for i in range(0, len(elementlist)): + for i in range(len(elementlist)): elementlist[i] = [elementlist[i] ** stoichslist[i]] # Calculate geometric mean (n-th root of product) diff --git a/smact/screening.py b/smact/screening.py index d3e45945..eb979907 100644 --- a/smact/screening.py +++ b/smact/screening.py @@ -1,12 +1,16 @@ +"""A collection of tools for estimating physical properties +based on chemical composition. +""" + +from __future__ import annotations + import itertools import os import warnings -from collections import namedtuple from itertools import combinations -from typing import List, Optional, Tuple, Union +from typing import TYPE_CHECKING import numpy as np -import pymatgen from pymatgen.core import Composition import smact @@ -15,29 +19,26 @@ lookup_element_oxidation_states_custom as oxi_custom, ) -# Use named tuple to improve readability of smact_filter outputs -_allowed_compositions = namedtuple( - "Composition", ["element_symbols", "oxidation_states", "stoichiometries"] -) -_allowed_compositions_nonunique = namedtuple( - "Composition", ["element_symbols", "stoichiometries"] -) +if TYPE_CHECKING: + import pymatgen def pauling_test( - oxidation_states: List[int], - electronegativities: List[float], - symbols: List[str] = [], + oxidation_states: list[int], + electronegativities: list[float], + symbols: list[str] | None = None, repeat_anions: bool = True, repeat_cations: bool = True, threshold: float = 0.0, ): - """Check if a combination of ions makes chemical sense, + """ + Check if a combination of ions makes chemical sense, (i.e. positive ions should be of lower electronegativity). Args: - ox (list): oxidation states of elements in the compound - paul (list): the corresponding Pauling electronegativities + ---- + oxidation_states (list): oxidation states of elements in the compound + electronegativities (list): the corresponding Pauling electronegativities of the elements in the compound symbols (list) : chemical symbols of each site threshold (float): a tolerance for the allowed deviation from @@ -47,39 +48,37 @@ def pauling_test( repeat_cations : as above, but for cations Returns: + ------- bool: True if anions are more electronegative than cations, otherwise False - """ + """ + if symbols is None: + symbols = [] if repeat_anions and repeat_cations and threshold == 0.0: return eneg_states_test(oxidation_states, electronegativities) elif repeat_anions and repeat_cations: - return eneg_states_test_threshold( - oxidation_states, electronegativities, threshold=threshold - ) + return eneg_states_test_threshold(oxidation_states, electronegativities, threshold=threshold) - else: - if _no_repeats( - oxidation_states, - symbols, - repeat_anions=repeat_anions, - repeat_cations=repeat_cations, - ): - if threshold == 0.0: - return eneg_states_test(oxidation_states, electronegativities) - else: - return eneg_states_test_threshold( - oxidation_states, electronegativities, threshold=threshold - ) + elif _no_repeats( + oxidation_states, + symbols, + repeat_anions=repeat_anions, + repeat_cations=repeat_cations, + ): + if threshold == 0.0: + return eneg_states_test(oxidation_states, electronegativities) else: - return False + return eneg_states_test_threshold(oxidation_states, electronegativities, threshold=threshold) + else: + return False def _no_repeats( - oxidation_states: List[int], - symbols: List[str], + oxidation_states: list[int], + symbols: list[str], repeat_anions: bool = False, repeat_cations: bool = False, ): @@ -87,6 +86,7 @@ def _no_repeats( Check if any anion or cation appears twice. Args: + ---- oxidation_states (list): oxidation states of species symbols (list): chemical symbols corresponding to oxidation states @@ -95,7 +95,9 @@ def _no_repeats( repeat_cations (bool): if True, cations may be repeated (e.g. Cu in +1 and +2 states) - Returns: bool + Returns: + ------- + bool: True if no anion or cation is repeated, False otherwise """ if repeat_anions is False and repeat_cations is False: return len(symbols) == len(set(symbols)) @@ -106,28 +108,30 @@ def _no_repeats( cations.append(symbol) else: anions.append(symbol) - if not repeat_anions and len(anions) != len(set(anions)): - return False - elif not repeat_cations and len(cations) != len(set(cations)): - return False - else: - return True + return not ( + not repeat_anions + and len(anions) != len(set(anions)) + or not repeat_cations + and len(cations) != len(set(cations)) + ) def pauling_test_old( - ox: List[int], - paul: List[float], - symbols: List[str], + ox: list[int], + paul: list[float], + symbols: list[str], repeat_anions: bool = True, repeat_cations: bool = True, threshold: float = 0.0, ): - """Check if a combination of ions makes chemical sense, + """ + Check if a combination of ions makes chemical sense, (i.e. positive ions should be of lower Pauling electronegativity). This function should give the same results as pauling_test but is not optimised for speed. Args: + ---- ox (list): oxidation states of the compound paul (list): the corresponding Pauling electronegativities of the elements in the compound @@ -139,6 +143,7 @@ def pauling_test_old( repeat_cations : as above, but for cations. Returns: + ------- (bool): True if anions are more electronegative than cations, otherwise False @@ -175,15 +180,12 @@ def pauling_test_old( return False if max(positive) == min(negative): return False - if max(positive) - min(negative) > threshold: - return False - else: - return True + return not max(positive) - min(negative) > threshold -def eneg_states_test(ox_states: List[int], enegs: List[float]): +def eneg_states_test(ox_states: list[int], enegs: list[float]): """ - Internal function for checking electronegativity criterion + Internal function for checking electronegativity criterion. This implementation is fast as it 'short-circuits' as soon as it finds an invalid combination. However it may be that in some cases @@ -191,33 +193,36 @@ def eneg_states_test(ox_states: List[int], enegs: List[float]): this method and eneg_states_test_alternate. Args: + ---- ox_states (list): oxidation states corresponding to species in compound enegs (list): Electronegativities corresponding to species in compound Returns: + ------- bool : True if anions are more electronegative than cations, otherwise False """ - for (ox1, eneg1), (ox2, eneg2) in combinations( - list(zip(ox_states, enegs)), 2 - ): - if eneg1 is None or eneg2 is None: - return False - elif (ox1 > 0) and (ox2 < 0) and (eneg1 >= eneg2): - return False - elif (ox1 < 0) and (ox2 > 0) and (eneg1 <= eneg2): + for (ox1, eneg1), (ox2, eneg2) in combinations(list(zip(ox_states, enegs)), 2): + if ( + eneg1 is None + or eneg2 is None + or (ox1 > 0) + and (ox2 < 0) + and (eneg1 >= eneg2) + or (ox1 < 0) + and (ox2 > 0) + and (eneg1 <= eneg2) + ): return False - else: - return True + return True -def eneg_states_test_threshold( - ox_states: List[int], enegs: List[float], threshold: Optional[float] = 0 -): - """Internal function for checking electronegativity criterion +def eneg_states_test_threshold(ox_states: list[int], enegs: list[float], threshold: float | None = 0): + """ + Internal function for checking electronegativity criterion. This implementation is fast as it 'short-circuits' as soon as it finds an invalid combination. However it may be that in some cases @@ -228,6 +233,7 @@ def eneg_states_test_threshold( relaxed somewhat. Args: + ---- ox_states (list): oxidation states corresponding to species in compound enegs (list): Electronegativities corresponding to species in @@ -236,34 +242,40 @@ def eneg_states_test_threshold( the Pauling criterion Returns: + ------- bool : True if anions are more electronegative than cations, otherwise False """ - for (ox1, eneg1), (ox2, eneg2) in combinations( - list(zip(ox_states, enegs)), 2 - ): - if (ox1 > 0) and (ox2 < 0) and ((eneg1 - eneg2) > threshold): - return False - elif (ox1 < 0) and (ox2 > 0) and (eneg2 - eneg1) > threshold: + for (ox1, eneg1), (ox2, eneg2) in combinations(list(zip(ox_states, enegs)), 2): + if ( + (ox1 > 0) + and (ox2 < 0) + and ((eneg1 - eneg2) > threshold) + or (ox1 < 0) + and (ox2 > 0) + and (eneg2 - eneg1) > threshold + ): return False - else: - return True + return True -def eneg_states_test_alternate(ox_states: List[int], enegs: List[float]): - """Internal function for checking electronegativity criterion +def eneg_states_test_alternate(ox_states: list[int], enegs: list[float]): + """ + Internal function for checking electronegativity criterion. This implementation appears to be slightly slower than eneg_states_test, but further testing is needed. Args: + ---- ox_states (list): oxidation states corresponding to species in compound enegs (list): Electronegativities corresponding to species in compound Returns: + ------- bool : True if anions are more electronegative than cations, otherwise False @@ -278,10 +290,11 @@ def eneg_states_test_alternate(ox_states: List[int], enegs: List[float]): def ml_rep_generator( - composition: Union[List[Element], List[str]], - stoichs: Optional[List[int]] = None, + composition: list[Element] | list[str], + stoichs: list[int] | None = None, ): - """Function to take a composition of Elements and return a + """ + Function to take a composition of Elements and return a list of values between 0 and 1 that describes the composition, useful for machine learning. @@ -293,49 +306,52 @@ def ml_rep_generator( Inspired by the representation used by Legrain et al. DOI: 10.1021/acs.chemmater.7b00789 Args: + ---- composition (list): Element objects in composition OR symbols of elements in composition stoichs (list): Corresponding stoichiometries in the composition Returns: + ------- norm (list): List of floats representing the composition that sum to one """ - if stoichs == None: + if stoichs is None: stoichs = [1 for i, el in enumerate(composition)] ML_rep = [0 for i in range(1, 103)] - if type(composition[0]) == Element: + if isinstance(composition[0], Element): for element, stoich in zip(composition, stoichs): ML_rep[int(element.number) - 1] += stoich else: for element, stoich in zip(composition, stoichs): ML_rep[int(Element(element).number) - 1] += stoich - norm = [float(i) / sum(ML_rep) for i in ML_rep] - return norm + return [float(i) / sum(ML_rep) for i in ML_rep] def smact_filter( - els: Union[Tuple[Element], List[Element]], - threshold: Optional[int] = 8, - stoichs: Optional[List[List[int]]] = None, + els: tuple[Element] | list[Element], + threshold: int | None = 8, + stoichs: list[list[int]] | None = None, species_unique: bool = True, oxidation_states_set: str = "default", - comp_tuple: bool = False, -) -> Union[List[Tuple[str, int, int]], List[Tuple[str, int]]]: - """Function that applies the charge neutrality and electronegativity +) -> list[tuple[str, int, int]] | list[tuple[str, int]]: + """ + Function that applies the charge neutrality and electronegativity tests in one go for simple application in external scripts that wish to apply the general 'smact test'. Args: + ---- els (tuple/list): A list of smact.Element objects threshold (int): Threshold for stoichiometry limit, default = 8 stoichs (list[int]): A selection of valid stoichiometric ratios for each site. species_unique (bool): Whether or not to consider elements in different oxidation states as unique in the results. oxidation_states_set (string): A string to choose which set of oxidation states should be chosen. Options are 'default', 'icsd', 'pymatgen' and 'wiki' for the default, icsd, pymatgen structure predictor and Wikipedia (https://en.wikipedia.org/wiki/Template:List_of_oxidation_states_of_the_elements) oxidation states respectively. A filepath to an oxidation states text file can also be supplied as well. - comp_tuple (bool): Whether or not to return the results as a named tuple of elements and stoichiometries (True) or as a normal tuple of elements and stoichiometries (False). + Returns: + ------- allowed_comps (list): Allowed compositions for that chemical system in the form [(elements), (oxidation states), (ratios)] if species_unique=True or in the form [(elements), (ratios)] if species_unique=False. @@ -343,8 +359,8 @@ def smact_filter( Example usage: >>> from smact.screening import smact_filter >>> from smact import Element - >>> els = (Element('Cs'), Element('Pb'), Element('I')) - >>> comps = smact_filter(els, threshold =5 ) + >>> els = (Element("Cs"), Element("Pb"), Element("I")) + >>> comps = smact_filter(els, threshold=5) >>> for comp in comps: >>> print(comp) [('Cs', 'Pb', 'I'), (1, -4, -1), (5, 1, 1)] @@ -357,14 +373,13 @@ def smact_filter( Example (using stoichs): >>> from smact.screening import smact_filter >>> from smact import Element - >>> comps = smact_filter(els, stoichs = [[1],[1],[3]], comp_tuple=True ) + >>> comps = smact_filter(els, stoichs=[[1], [1], [3]]) >>> for comp in comps: >>> print(comp) - Composition(element_symbols=('Cs', 'Pb', 'I'), oxidation_states=(1, 2, -1), stoichiometries=(1, 1, 3)) + [('Cs', 'Pb', 'I'), (1, 2, -1), (1, 1, 3)] """ - compositions = [] # Get symbols and electronegativities @@ -396,47 +411,36 @@ def smact_filter( for ox_states in itertools.product(*ox_combos): # Test for charge balance - cn_e, cn_r = neutral_ratios( - ox_states, stoichs=stoichs, threshold=threshold - ) + cn_e, cn_r = neutral_ratios(ox_states, stoichs=stoichs, threshold=threshold) # Electronegativity test if cn_e: electroneg_OK = pauling_test(ox_states, electronegs) if electroneg_OK: for ratio in cn_r: - compositions.append( - _allowed_compositions(symbols, ox_states, ratio) - if comp_tuple - else (symbols, ox_states, ratio) - ) + compositions.append((symbols, ox_states, ratio)) # Return list depending on whether we are interested in unique species combinations # or just unique element combinations. if species_unique: return compositions else: - if comp_tuple: - compositions = [ - _allowed_compositions_nonunique(i[0], i[2]) - for i in compositions - ] - else: - compositions = [(i[0], i[2]) for i in compositions] - compositions = list(set(compositions)) - return compositions + compositions = [(i[0], i[2]) for i in compositions] + return list(set(compositions)) def smact_validity( - composition: Union[pymatgen.core.Composition, str], + composition: pymatgen.core.Composition | str, use_pauling_test: bool = True, include_alloys: bool = True, - oxidation_states_set: Union[str, bytes, os.PathLike] = "default", + oxidation_states_set: str | bytes | os.PathLike = "default", ) -> bool: - """Check if a composition is valid according to the SMACT rules. + """ + Check if a composition is valid according to the SMACT rules. Composition is considered valid if it passes the charge neutrality test and the Pauling electronegativity test. Args: + ---- composition (Union[pymatgen.core.Composition, str]): Composition/formula to check. This can be a pymatgen Composition object or a string. use_pauling_test (bool): Whether to use the Pauling electronegativity test include_alloys (bool): If True, compositions which only contain metal elements will be considered valid without further checks. @@ -447,6 +451,7 @@ def smact_validity( A filepath to an oxidation states text file can also be supplied. Returns: + ------- bool: True if the composition is valid, False otherwise """ @@ -478,9 +483,7 @@ def smact_validity( elif oxidation_states_set == "pymatgen": ox_combos = [e.oxidation_states_sp for e in smact_elems] elif os.path.exists(oxidation_states_set): - ox_combos = [ - oxi_custom(e.symbol, oxidation_states_set) for e in smact_elems - ] + ox_combos = [oxi_custom(e.symbol, oxidation_states_set) for e in smact_elems] elif oxidation_states_set == "wiki": warnings.warn( "This set of oxidation states is sourced from Wikipedia. The results from using this set could be questionable and should not be used unless you know what you are doing and have inspected the oxidation states.", @@ -500,9 +503,7 @@ def smact_validity( for ox_states in itertools.product(*ox_combos): stoichs = [(c,) for c in count] # Test for charge balance - cn_e, cn_r = smact.neutral_ratios( - ox_states, stoichs=stoichs, threshold=threshold - ) + cn_e, cn_r = smact.neutral_ratios(ox_states, stoichs=stoichs, threshold=threshold) # Electronegativity test if cn_e: if use_pauling_test: @@ -515,10 +516,7 @@ def smact_validity( electroneg_OK = True if electroneg_OK: for ratio in cn_r: - compositions.append(tuple([elem_symbols, ox_states, ratio])) + compositions.append((elem_symbols, ox_states, ratio)) compositions = [(i[0], i[2]) for i in compositions] compositions = list(set(compositions)) - if len(compositions) > 0: - return True - else: - return False + return len(compositions) > 0 diff --git a/smact/structure_prediction/__init__.py b/smact/structure_prediction/__init__.py index e3298a35..dd2016e7 100644 --- a/smact/structure_prediction/__init__.py +++ b/smact/structure_prediction/__init__.py @@ -1,5 +1,7 @@ """Minimalist ionic compound prediction tools for materials design.""" +from __future__ import annotations + import logging __author__ = "Alexander Moriarty" diff --git a/smact/structure_prediction/database.py b/smact/structure_prediction/database.py index 42f0d764..f29271f3 100644 --- a/smact/structure_prediction/database.py +++ b/smact/structure_prediction/database.py @@ -1,7 +1,8 @@ """Tools for database interfacing for high throughput IO.""" +from __future__ import annotations + import itertools -from multiprocessing import Pool from operator import itemgetter try: @@ -12,24 +13,30 @@ pathos_available = False import sqlite3 -from typing import Dict, Generator, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING -import pymatgen from pymatgen.ext.matproj import MPRester from . import logger from .structure import SmactStructure from .utilities import convert_next_gen_mprest_data, get_sign +if TYPE_CHECKING: + from collections.abc import Sequence + + import pymatgen + class StructureDB: - """SQLite Structure Database interface. + """ + SQLite Structure Database interface. Acts as a context manager for database interfacing and wraps several useful SQLite commands within methods. Attributes: + ---------- db: The database name. conn: The database connection. Only open when used as a context manager. @@ -37,9 +44,10 @@ class StructureDB: when class implemented as context manager. Examples: + -------- Connecting to a database in memory: - >>> DB = StructureDB(':memory:') + >>> DB = StructureDB(":memory:") >>> with DB as c: ... _ = c.execute("CREATE TABLE test (id, val)") ... c.execute("SELECT * FROM test").fetchall() @@ -52,9 +60,11 @@ class StructureDB: """ def __init__(self, db: str): - """Set database name. + """ + Set database name. Args: + ---- db (str): The name of the database. Can also be ':memory:' to connect to a database in RAM. @@ -62,9 +72,11 @@ def __init__(self, db: str): self.db = db def __enter__(self) -> sqlite3.Cursor: - """Initialize database connection. + """ + Initialize database connection. Returns: + ------- An SQLite cursor for interfacing with the database. """ @@ -74,7 +86,8 @@ def __enter__(self) -> sqlite3.Cursor: return self.cur def __exit__(self, exc_type, *args): - """Close database connection. + """ + Close database connection. Commits all changes before closing. Alternatively, rolls back any changes if an exception @@ -91,19 +104,20 @@ def __exit__(self, exc_type, *args): def add_mp_icsd( self, table: str, - mp_data: Optional[ - List[Dict[str, Union[pymatgen.core.Structure, str]]] - ] = None, - mp_api_key: Optional[str] = None, + mp_data: list[dict[str, pymatgen.core.Structure | str]] | None = None, + mp_api_key: str | None = None, ) -> int: - """Add a table populated with Materials Project-hosted ICSD structures. + """ + Add a table populated with Materials Project-hosted ICSD structures. Note: + ---- This is very computationally expensive for large datasets and will not likely run on a laptop. If possible, download a pre-constructed database. Args: + ---- table (str): The name of the table to add. mp_data: The Materials Project data to parse. If this is None, data will be downloaded. Downloading data needs `mp_api_key` to be set. @@ -111,6 +125,7 @@ def add_mp_icsd( is None. Returns: + ------- The number of structs added. """ @@ -122,9 +137,7 @@ def add_mp_icsd( properties=["structure", "material_id"], ) except NotImplementedError: - docs = m.summary.search( - theoretical=False, fields=["structure", "material_id"] - ) + docs = m.summary.search(theoretical=False, fields=["structure", "material_id"]) data = [convert_next_gen_mprest_data(doc) for doc in docs] else: data = mp_data @@ -140,9 +153,11 @@ def add_mp_icsd( return self.add_structs(parse_iter, table, commit_after_each=True) def add_table(self, table: str): - """Add a table to the database. + """ + Add a table to the database. Args: + ---- table: The name of the table to add """ @@ -153,9 +168,11 @@ def add_table(self, table: str): ) def add_struct(self, struct: SmactStructure, table: str): - """Add a SmactStructure to a table. + """ + Add a SmactStructure to a table. Args: + ---- struct: The :class:`~.SmactStructure` to add. table: The name of the table to add the structure to. @@ -169,11 +186,13 @@ def add_structs( self, structs: Sequence[SmactStructure], table: str, - commit_after_each: Optional[bool] = False, + commit_after_each: bool | None = False, ) -> int: - """Add several SmactStructures to a table. + """ + Add several SmactStructures to a table. Args: + ---- structs: Iterable of :class:`~.SmactStructure` s to add to table. table: The name of the table to add the structs to. commit_after_each (bool, optional): Whether to commit the addition @@ -184,6 +203,7 @@ def add_structs( Defaults to False. Returns: + ------- The number of structures added. """ @@ -202,15 +222,18 @@ def add_structs( return num - def get_structs(self, composition: str, table: str) -> List[SmactStructure]: - """Get SmactStructures for a given composition. + def get_structs(self, composition: str, table: str) -> list[SmactStructure]: + """ + Get SmactStructures for a given composition. Args: + ---- composition: The composition to search for. See :meth:`SmactStructure.composition`. table: The name of the table in which to search. Returns: + ------- A list of :class:`~.SmactStructure` s. """ @@ -224,16 +247,19 @@ def get_structs(self, composition: str, table: str) -> List[SmactStructure]: def get_with_species( self, - species: List[Tuple[str, int]], + species: list[tuple[str, int]], table: str, - ) -> List[SmactStructure]: - """Get SmactStructures containing given species. + ) -> list[SmactStructure]: + """ + Get SmactStructures containing given species. Args: + ---- species: A list of species as tuples, in (element, charge) format. table: The name of the table from which to get the species. Returns: + ------- A list of :class:`SmactStructure` s in the table that contain the species. """ @@ -244,11 +270,7 @@ def get_with_species( species.sort(key=itemgetter(0)) # Generate a list of [element1, charge1, sign1, element2, ...] - vals = list( - itertools.chain.from_iterable( - [x[0], abs(x[1]), get_sign(x[1])] for x in species - ) - ) + vals = list(itertools.chain.from_iterable([x[0], abs(x[1]), get_sign(x[1])] for x in species)) glob_form = glob.format(*vals) @@ -263,12 +285,14 @@ def get_with_species( def parse_mprest( - data: Dict[str, Union[pymatgen.core.Structure, str]], + data: dict[str, pymatgen.core.Structure | str], determine_oxi: str = "BV", ) -> SmactStructure: - """Parse MPRester query data to generate structures. + """ + Parse MPRester query data to generate structures. Args: + ---- data: A dictionary containing the keys 'structure' and 'material_id', with the associated values. determine_oxi (str): The method to determine the assignments oxidation states in the structure. @@ -276,6 +300,7 @@ def parse_mprest( ICSD statistics or trial both sequentially, respectively. Returns: + ------- An oxidation-state-decorated :class:`SmactStructure`. """ @@ -285,11 +310,7 @@ def parse_mprest( data = convert_next_gen_mprest_data(data) try: - return SmactStructure.from_py_struct( - data["structure"], determine_oxi="BV" - ) + return SmactStructure.from_py_struct(data["structure"], determine_oxi="BV") except: # Couldn't decorate with oxidation states - logger.warn( - f"Couldn't decorate {data['material_id']} with oxidation states." - ) + logger.warn(f"Couldn't decorate {data['material_id']} with oxidation states.") diff --git a/smact/structure_prediction/mutation.py b/smact/structure_prediction/mutation.py index 41667506..c17b8d5c 100644 --- a/smact/structure_prediction/mutation.py +++ b/smact/structure_prediction/mutation.py @@ -1,23 +1,29 @@ """Tools for handling ion mutation.""" +from __future__ import annotations + import itertools import json import os -import re from copy import deepcopy from operator import itemgetter -from typing import Callable, Generator, Optional, Tuple +from typing import TYPE_CHECKING, Callable import numpy as np import pandas as pd import pymatgen.analysis.structure_prediction as pymatgen_sp -from .structure import SmactStructure from .utilities import parse_spec +if TYPE_CHECKING: + from collections.abc import Generator + + from .structure import SmactStructure + class CationMutator: - """Handles cation mutation of SmactStructures based on substitution probability. + """ + Handles cation mutation of SmactStructures based on substitution probability. Based on the algorithm presented in: Hautier, G., Fischer, C., Ehrlacher, V., Jain, A., and Ceder, G. (2011) @@ -30,11 +36,13 @@ class CationMutator: def __init__( self, lambda_df: pd.DataFrame, - alpha: Optional[Callable[[str, str], float]] = (lambda s1, s2: -5.0), + alpha: Callable[[str, str], float] | None = (lambda s1, s2: -5.0), ): - """Assign attributes and get lambda table. + """ + Assign attributes and get lambda table. Args: + ---- lambda_df: A pandas DataFrame, with column and index labels as species strings and lambda values as entries. alpha: A function to call to fill in missing lambda values. @@ -47,11 +55,7 @@ def __init__( """ self.lambda_tab = lambda_df - self.specs = set( - itertools.chain.from_iterable( - set(getattr(self.lambda_tab, x)) for x in ["columns", "index"] - ) - ) + self.specs = set(itertools.chain.from_iterable(set(getattr(self.lambda_tab, x)) for x in ["columns", "index"])) self.alpha = alpha @@ -62,12 +66,14 @@ def __init__( @staticmethod def from_json( - lambda_json: Optional[str] = None, - alpha: Optional[Callable[[str, str], float]] = (lambda s1, s2: -5.0), + lambda_json: str | None = None, + alpha: Callable[[str, str], float] | None = (lambda s1, s2: -5.0), ): - """Create a CationMutator instance from a DataFrame. + """ + Create a CationMutator instance from a DataFrame. Args: + ---- lambda_json (str, optional): JSON-style representation of the lambda table. This is a list of entries, containing pairs and their associated lambda values. @@ -77,6 +83,7 @@ def from_json( alpha: See :meth:`__init__`. Returns: + ------- A :class:`CationMutator` instance. """ @@ -98,12 +105,13 @@ def from_json( lambda_dat = [tuple(x) for x in lambda_dat] lambda_df = pd.DataFrame(lambda_dat) - lambda_df = lambda_df.pivot(index=0, columns=1, values=2) + lambda_df = lambda_df.pivot_table(index=0, columns=1, values=2) return CationMutator(lambda_df, alpha) def _populate_lambda(self): - """Populate lambda table. + """ + Populate lambda table. Ensures no values are NaN and performs alpha calculations, such that an entry exists for every possible species @@ -116,18 +124,18 @@ def _populate_lambda(self): def add_alpha(s1, s2): """Add an alpha value to the lambda table at both coordinates.""" a = self.alpha(s1, s2) - self.lambda_tab.at[s1, s2] = a - self.lambda_tab.at[s2, s1] = a + self.lambda_tab.loc[s1, s2] = a + self.lambda_tab.loc[s2, s1] = a def mirror_lambda(s1, s2): """Mirror the lambda value at (s2, s1) into (s1, s2).""" - self.lambda_tab.at[s1, s2] = self.lambda_tab.at[s2, s1] + self.lambda_tab.loc[s1, s2] = self.lambda_tab.loc[s2, s1] for s1, s2 in pairs: try: - if np.isnan(self.lambda_tab.at[s1, s2]): + if np.isnan(self.lambda_tab.loc[s1, s2]): try: - if not np.isnan(self.lambda_tab.at[s2, s1]): + if not np.isnan(self.lambda_tab.loc[s2, s1]): mirror_lambda(s1, s2) else: add_alpha(s1, s2) @@ -137,7 +145,7 @@ def mirror_lambda(s1, s2): mirror_lambda(s2, s1) except KeyError: try: - if np.isnan(self.lambda_tab.at[s2, s1]): + if np.isnan(self.lambda_tab.loc[s2, s1]): add_alpha(s1, s2) else: mirror_lambda(s1, s2) @@ -149,31 +157,37 @@ def mirror_lambda(s1, s2): self.lambda_tab = self.lambda_tab[idx] def get_lambda(self, s1: str, s2: str) -> float: - """Get lambda values corresponding to a pair of species. + """ + Get lambda values corresponding to a pair of species. Args: + ---- s1 (str): One of the species. s2 (str): The other species. Returns: + ------- lambda (float): The lambda value, if it exists in the table. Otherwise, the alpha value for the two species. """ if {s1, s2} <= self.specs: - return self.lambda_tab.at[s1, s2] + return self.lambda_tab.loc[s1, s2] return self.alpha(s1, s2) def get_lambdas(self, species: str) -> pd.Series: - """Get all the lambda values associated with a species. + """ + Get all the lambda values associated with a species. Args: + ---- species (str): The species for which to get the lambda values. Returns: + ------- A pandas Series, with index-labelled lambda entries. """ @@ -188,7 +202,8 @@ def _mutate_structure( init_species: str, final_species: str, ) -> SmactStructure: - """Mutate a species within a SmactStructure. + """ + Mutate a species within a SmactStructure. Replaces all instances of the species within the structure. Every site occupied by the species @@ -199,16 +214,20 @@ def _mutate_structure( Requires the species to have the same charge. Note: + ---- Creates a deepcopy of the supplied structure, such that the original instance is not modified. Args: + ---- + structure (SmactStructure): The structure to mutate. init_species (str): The species within the structure to mutate. final_species (str): The species to replace the initial species with. Returns: + ------- A :class:`.~SmactStructure`, with the species mutated. @@ -237,12 +256,8 @@ def _mutate_structure( # Replace sites struct_buff.sites[final_species] = struct_buff.sites.pop(init_species) # And sort - species_strs = struct_buff._format_style("{ele}{charge}{sign}").split( - " " - ) - struct_buff.sites = { - spec: struct_buff.sites[spec] for spec in species_strs - } + species_strs = struct_buff._format_style("{ele}{charge}{sign}").split(" ") + struct_buff.sites = {spec: struct_buff.sites[spec] for spec in species_strs} return struct_buff @@ -252,12 +267,15 @@ def _nary_mutate_structure( init_species: list, final_species: list, ) -> SmactStructure: - """Perform a n-ary mutation of a SmactStructure (n>1). + """ + Perform a n-ary mutation of a SmactStructure (n>1). Replaces all instances of a group of species within the structure. Stoichiometry is maintained. Charge neutrality is preserved, but the species pair do not need the same charge. Args: + ---- + structure (SmactStructure): The structure to mutate. init_species (list): A list of species within the structure to mutate. final_species (list): The list of species to replace the initial species with @@ -268,9 +286,7 @@ def _nary_mutate_structure( struct_buff = deepcopy(structure) init_spec_tup_list = [parse_spec(i) for i in init_species] struct_spec_tups = list(map(itemgetter(0, 1), struct_buff.species)) - spec_loc = [ - struct_spec_tups.index(init_spec_tup_list[i]) for i in range(n) - ] + spec_loc = [struct_spec_tups.index(init_spec_tup_list[i]) for i in range(n)] final_spec_tup_list = [parse_spec(i) for i in final_species] @@ -291,17 +307,11 @@ def _nary_mutate_structure( # Replace sites for i in range(n): - struct_buff.sites[final_species[i]] = struct_buff.sites.pop( - init_species[i] - ) + struct_buff.sites[final_species[i]] = struct_buff.sites.pop(init_species[i]) # And sort - species_strs = struct_buff._format_style("{ele}{charge}{sign}").split( - " " - ) - struct_buff.sites = { - spec: struct_buff.sites[spec] for spec in species_strs - } + species_strs = struct_buff._format_style("{ele}{charge}{sign}").split(" ") + struct_buff.sites = {spec: struct_buff.sites[spec] for spec in species_strs} return struct_buff @@ -310,7 +320,8 @@ def sub_prob(self, s1: str, s2: str) -> float: return np.exp(self.get_lambda(s1, s2)) / self.Z def sub_probs(self, s1: str) -> pd.Series: - """Determine the substitution probabilities of a species with others. + """ + Determine the substitution probabilities of a species with others. Determines the probability of substitution of the species with every species in the lambda table. @@ -359,9 +370,7 @@ def same_spec_probs(self) -> pd.Series: def same_spec_cond_probs(self) -> pd.Series: """Calculate the same species conditional substiution probabilities.""" - return np.exp(self.lambda_tab.to_numpy().diagonal()) / np.exp( - self.lambda_tab - ).sum(axis=0) + return np.exp(self.lambda_tab.to_numpy().diagonal()) / np.exp(self.lambda_tab).sum(axis=0) def pair_corr(self, s1: str, s2: str) -> float: """Determine the pair correlation of two ionic species.""" @@ -372,12 +381,11 @@ def pair_corr(self, s1: str, s2: str) -> float: def cond_sub_prob(self, s1: str, s2: str) -> float: """Calculate the probability of substitution of one species with another.""" - return ( - np.exp(self.get_lambda(s1, s2)) / np.exp(self.get_lambdas(s2)).sum() - ) + return np.exp(self.get_lambda(s1, s2)) / np.exp(self.get_lambdas(s2)).sum() def cond_sub_probs(self, s1: str) -> pd.Series: - """Calculate the probabilities of substitution of a given species. + """ + Calculate the probabilities of substitution of a given species. Calculates probabilities of substitution of given species with all others in the lambda table. @@ -391,16 +399,19 @@ def cond_sub_probs(self, s1: str) -> pd.Series: def unary_substitute( self, structure: SmactStructure, - thresh: Optional[float] = 1e-5, - ) -> Generator[Tuple[SmactStructure, float], None, None]: - """Find all structures with 1 substitution with probability above a threshold. + thresh: float | None = 1e-5, + ) -> Generator[tuple[SmactStructure, float], None, None]: + """ + Find all structures with 1 substitution with probability above a threshold. Args: + ---- structure: A :class:`SmactStructure` instance from which to generate compounds. thresh (float): The probability threshold; discard all substitutions that have a probability to generate a naturally-occuring compound less than this. Yields: + ------ Tuples of (:class:`SmactStructure`, probability). """ diff --git a/smact/structure_prediction/prediction.py b/smact/structure_prediction/prediction.py index 5fedfd28..db7c8822 100644 --- a/smact/structure_prediction/prediction.py +++ b/smact/structure_prediction/prediction.py @@ -1,6 +1,8 @@ -"""Structure prediction implementation. +""" +Structure prediction implementation. Todo: +---- * Test with a fully populated database. * Implement n-ary substitution probabilities; at the moment, only zero- and single-species @@ -8,19 +10,26 @@ """ +from __future__ import annotations + import itertools -from typing import Generator, List, Optional, Tuple +from typing import TYPE_CHECKING import numpy as np -from .database import StructureDB -from .mutation import CationMutator -from .structure import SmactStructure from .utilities import parse_spec, unparse_spec +if TYPE_CHECKING: + from collections.abc import Generator + + from .database import StructureDB + from .mutation import CationMutator + from .structure import SmactStructure + class StructurePredictor: - """Provides structure prediction functionality. + """ + Provides structure prediction functionality. Implements a statistically-based model for determining likely structures of a given composition, based on a @@ -35,12 +44,12 @@ class StructurePredictor: """ - def __init__( - self, mutator: CationMutator, struct_db: StructureDB, table: str - ): - """Initialize class. + def __init__(self, mutator: CationMutator, struct_db: StructureDB, table: str): + """ + Initialize class. Args: + ---- mutator: A :class:`CationMutator` for probability calculations. struct_db: A :class:`StructureDB` from which to read strucutures to attempt to mutate. @@ -53,13 +62,15 @@ def __init__( def predict_structs( self, - species: List[Tuple[str, int]], - thresh: Optional[float] = 1e-3, - include_same: Optional[bool] = True, - ) -> Generator[Tuple[SmactStructure, float, SmactStructure], None, None]: - """Predict structures for a combination of species. + species: list[tuple[str, int]], + thresh: float | None = 1e-3, + include_same: bool | None = True, + ) -> Generator[tuple[SmactStructure, float, SmactStructure], None, None]: + """ + Predict structures for a combination of species. Args: + ---- species: A list of (element, charge). The constituent species of the target compound. thresh: The probability threshold, below which to discard @@ -69,6 +80,7 @@ def predict_structs( same species. Defaults to True. Yields: + ------ Potential structures, as tuples of (structure, probability, parent). """ @@ -82,9 +94,9 @@ def predict_structs( sub_spec = itertools.combinations(species, len(species) - 1) sub_spec = list(map(list, sub_spec)) - potential_unary_parents: List[List[SmactStructure]] = list( + potential_unary_parents: list[list[SmactStructure]] = [ self.db.get_with_species(specs, self.table) for specs in sub_spec - ) + ] for spec_idx, parents in enumerate(potential_unary_parents): # Get missing ion @@ -110,20 +122,9 @@ def predict_structs( # Determine probability # Get species to be substituted # Ensure only 1 species is obtained - if ( - len( - set(parent.get_spec_strs()) - - set(map(unparse_spec, species)) - - {diff_spec_str} - ) - > 1 - ): + if len(set(parent.get_spec_strs()) - set(map(unparse_spec, species)) - {diff_spec_str}) > 1: continue - (alt_spec,) = ( - set(parent.get_spec_strs()) - - set(map(unparse_spec, species)) - - {diff_spec_str} - ) + (alt_spec,) = set(parent.get_spec_strs()) - set(map(unparse_spec, species)) - {diff_spec_str} if parse_spec(alt_spec)[1] != diff_spec[1]: # Different charge @@ -137,30 +138,28 @@ def predict_structs( if p > thresh: try: - mutated = self.cm._mutate_structure( - parent, alt_spec, diff_spec_str - ) + self.cm._mutate_structure(parent, alt_spec, diff_spec_str) except ValueError: # Poorly decorated continue yield ( - self.cm._mutate_structure( - parent, alt_spec, diff_spec_str - ), + self.cm._mutate_structure(parent, alt_spec, diff_spec_str), p, parent, ) def nary_predict_structs( self, - species: List[Tuple[str, int]], - n_ary: Optional[int] = 2, - thresh: Optional[float] = 1e-3, - include_same: Optional[bool] = True, - ) -> Generator[Tuple[SmactStructure, float, SmactStructure], None, None]: - """Predicts structures for a combination of species. + species: list[tuple[str, int]], + n_ary: int | None = 2, + thresh: float | None = 1e-3, + include_same: bool | None = True, + ) -> Generator[tuple[SmactStructure, float, SmactStructure], None, None]: + """ + Predicts structures for a combination of species. Args: + ---- species: A list of (element, charge). The constituent species of the target compound. thresh: The probability threshold, below which to discard predictions. @@ -169,9 +168,10 @@ def nary_predict_structs( i.e. structures containing all the same species. Yields: + ------ Potential structures, as tuples of (structure, probability, parent). - """ + """ if include_same: for identical in self.db.get_with_species(species, self.table): yield (identical, 1.0, identical) @@ -182,9 +182,9 @@ def nary_predict_structs( sub_species = itertools.combinations(species, len(species) - n_ary) sub_species = list(map(list, sub_species)) - potential_nary_parents: List[List[SmactStructure]] = list( + potential_nary_parents: list[list[SmactStructure]] = [ self.db.get_with_species(specs, self.table) for specs in sub_species - ) + ] for spec_idx, parents in enumerate(potential_nary_parents): # Get missing ions @@ -197,76 +197,58 @@ def nary_predict_structs( diff_sub_probs = [self.cm.cond_sub_probs(i) for i in diff_spec_str] - for parent in parents: - # print("testing parent") - # Filter out any structures with identical species - if n_ary == 1: - if parent.has_species(diff_species[0]): - continue - elif n_ary == 2: - if parent.has_species(diff_species[0]) and parent.has_species( - diff_species[1] - ): - continue - elif n_ary == 3: - if ( + for parent in parents: + # print("testing parent") + # Filter out any structures with identical species + if n_ary == 1: + if parent.has_species(diff_species[0]): + continue + elif n_ary == 2: + if parent.has_species(diff_species[0]) and parent.has_species(diff_species[1]): + continue + elif n_ary == 3 and ( parent.has_species(diff_species[0]) and parent.has_species(diff_species[1]) and parent.has_species(diff_species[2]) ): continue - # Ensure parent has as many species as target - if len(parent.species) != len(species): - continue + # Ensure parent has as many species as target + if len(parent.species) != len(species): + continue - # Determine probability - # Get species to be substituted - # Ensure n species are obtained - - if ( - len( - set(parent.get_spec_strs()) - - set(map(unparse_spec, species)) - - set(diff_species) - ) - != n_ary - ): - continue - alt_spec = list( - set(parent.get_spec_strs()) - - set(map(unparse_spec, species)) - - set(diff_species) - ) - - # Need to consider p(A,X)p(B,Y) and p(A,Y)p(B,X) - # if utilities.parse_spec(alt_spec_1)[1] != diff_species_1[1] and utilities.parse_spec(alt_spec_2)[1] != diff_species_2[1] : - # Different charge - # continue + # Determine probability + # Get species to be substituted + # Ensure n species are obtained - try: - p = [] - for i in range(n_ary): - p.append(diff_sub_probs[i].loc[alt_spec[i]]) - except: - # Not in the Series - continue + if len(set(parent.get_spec_strs()) - set(map(unparse_spec, species)) - set(diff_species)) != n_ary: + continue + alt_spec = list(set(parent.get_spec_strs()) - set(map(unparse_spec, species)) - set(diff_species)) - p = np.prod(p) + # Need to consider p(A,X)p(B,Y) and p(A,Y)p(B,X) + # if utilities.parse_spec(alt_spec_1)[1] != diff_species_1[1] and utilities.parse_spec(alt_spec_2)[1] != diff_species_2[1] : + # Different charge + # continue - if p > thresh: try: - mutated = self.cm._nary_mutate_structure( - parent, alt_spec, diff_spec_str - ) - - except ValueError: - # Poorly decorated + p = [] + for i in range(n_ary): + p.append(diff_sub_probs[i].loc[alt_spec[i]]) + except: + # Not in the Series continue - yield ( - self.cm._nary_mutate_structure( - parent, alt_spec, diff_spec_str - ), - p, - parent, - ) + + p = np.prod(p) + + if p > thresh: + try: + self.cm._nary_mutate_structure(parent, alt_spec, diff_spec_str) + + except ValueError: + # Poorly decorated + continue + yield ( + self.cm._nary_mutate_structure(parent, alt_spec, diff_spec_str), + p, + parent, + ) diff --git a/smact/structure_prediction/probability_models.py b/smact/structure_prediction/probability_models.py index 4b6a8bb2..64900f5c 100644 --- a/smact/structure_prediction/probability_models.py +++ b/smact/structure_prediction/probability_models.py @@ -1,4 +1,5 @@ -"""Probability models for species substitution. +""" +Probability models for species substitution. Implements base class :class:`SubstitutionModel`, which can be extended to allow for development of new @@ -6,6 +7,7 @@ :class:`RadiusModel`, is also implemented. Todo: +---- * Allow for parallelism in lambda table calculations by implementing a `sub_probs` abstractmethod that :meth:`SubstitutionModel.gen_lambda` uses, @@ -13,10 +15,11 @@ """ +from __future__ import annotations + import abc import os from itertools import combinations_with_replacement -from typing import Dict, List, Optional import pandas as pd @@ -30,24 +33,30 @@ class SubstitutionModel(abc.ABC): @abc.abstractmethod def sub_prob(self, s1: str, s2: str) -> float: - """Calculate the probability of substituting species s1 for s2. + """ + Calculate the probability of substituting species s1 for s2. Args: + ---- s1: The species being substituted. s2: The species substituting. Returns: + ------- The probability of substitution. """ - def gen_lambda(self, species: List[str]) -> pd.DataFrame: - """Generate a lambda table for a list of species. + def gen_lambda(self, species: list[str]) -> pd.DataFrame: + """ + Generate a lambda table for a list of species. Args: + ---- species: A list of species strings. Returns: + ------- A pivot table-style DataFrame containing lambda values for every possible species pair. @@ -62,14 +71,15 @@ def gen_lambda(self, species: List[str]) -> pd.DataFrame: lambda_tab.append((s2, s1, prob)) df = pd.DataFrame(lambda_tab) - return df.pivot(index=0, columns=1, values=2) + return df.pivot_table(index=0, columns=1, values=2) class RadiusModel(SubstitutionModel): """Substitution probability model based on Shannon radii.""" def __init__(self): - r"""Parse Shannon radii data file. + r""" + Parse Shannon radii data file. Also calculates "spring constant", _k_, based on maximum difference in Shannon radii: @@ -82,13 +92,11 @@ def __init__(self): self.shannon_data = pd.read_csv(shannon_file, index_col=0) - self.k = ( - self.shannon_data["ionic_radius"].max() - - self.shannon_data["ionic_radius"].min() - ) ** -2 + self.k = (self.shannon_data["ionic_radius"].max() - self.shannon_data["ionic_radius"].min()) ** -2 def sub_prob(self, s1, s2): - r"""Calculate the probability of substituting species s1 for s2. + r""" + Calculate the probability of substituting species s1 for s2. Based on the difference in Shannon radii, the probability is assumed to be: @@ -97,10 +105,12 @@ def sub_prob(self, s1, s2): p = 1 - k \Delta r^2. Args: + ---- s1: The species being substituted. s2: The species substituting. Returns: + ------- The probability of substitution. """ diff --git a/smact/structure_prediction/structure.py b/smact/structure_prediction/structure.py index 0f072b35..d409dac2 100644 --- a/smact/structure_prediction/structure.py +++ b/smact/structure_prediction/structure.py @@ -1,12 +1,12 @@ """Minimalist structure representation for comprehensible manipulation.""" -import logging +from __future__ import annotations + import re from collections import defaultdict from functools import reduce from math import gcd from operator import itemgetter -from typing import Dict, List, Optional, Tuple, Union import numpy as np import pymatgen @@ -18,12 +18,12 @@ import smact -from . import logger from .utilities import convert_next_gen_mprest_data, get_sign class SmactStructure: - """SMACT implementation inspired by pymatgen Structure class. + """ + SMACT implementation inspired by pymatgen Structure class. Handles basic structural and compositional information for a compound. Includes a lossless POSCAR-style specification for storing structures, @@ -31,6 +31,7 @@ class SmactStructure: from the `Materials Project `_. Attributes: + ---------- species: A list of tuples describing the composition of the structure, stored as (element, oxidation, stoichiometry). The list is sorted alphabetically based on element symbol, and identical elements @@ -40,7 +41,7 @@ class SmactStructure: representation of the species and coords is a list of position vectors, given as lists of length 3. For example: - >>> s = SmactStructure.from_file('tests/files/NaCl.txt') + >>> s = SmactStructure.from_file("tests/files/NaCl.txt") >>> s.sites {'Cl1-': [[2.323624165, 1.643050405, 4.02463512]], 'Na1+': [[0.0, 0.0, 0.0]]} @@ -50,15 +51,17 @@ class SmactStructure: def __init__( self, - species: List[Union[Tuple[str, int, int], Tuple[smact.Species, int]]], + species: list[tuple[str, int, int] | tuple[smact.Species, int]], lattice_mat: np.ndarray, - sites: Dict[str, List[List[float]]], - lattice_param: Optional[float] = 1.0, - sanitise_species: Optional[bool] = True, + sites: dict[str, list[list[float]]], + lattice_param: float | None = 1.0, + sanitise_species: bool | None = True, ): - """Initialize structure with constituent species. + """ + Initialize structure with constituent species. Args: + ---- species: See :class:`~.SmactStructure`. May be supplied as either a list of (element, oxidation, stoichiometry) or (:class:`~smact.Species`, stoichiometry). lattice_mat: See :class:`~.SmactStructure`. @@ -69,20 +72,17 @@ def __init__( :meth:`~.from_mp`. """ - self.species = ( - self._sanitise_species(species) if sanitise_species else species - ) + self.species = self._sanitise_species(species) if sanitise_species else species self.lattice_mat = lattice_mat - self.sites = { - spec: sites[spec] for spec in self.get_spec_strs() - } # Sort sites + self.sites = {spec: sites[spec] for spec in self.get_spec_strs()} # Sort sites self.lattice_param = lattice_param def __repr__(self): - """Represent the structure as a POSCAR. + """ + Represent the structure as a POSCAR. Alias for :meth:`~.as_poscar`. @@ -90,12 +90,14 @@ def __repr__(self): return self.as_poscar() def __eq__(self, other): - """Determine equality of SmactStructures based on their attributes. + """ + Determine equality of SmactStructures based on their attributes. :attr:`~.species`, :attr:`~.lattice_mat`, :attr:`~.lattice_param` and :attr:`~.sites` must all be equal for the comparison to be True. Note: + ---- For the SmactStructures to be equal their attributes must be *identical*. For example, it is insufficient that the two structures have the same space group or the same species; @@ -116,18 +118,22 @@ def __eq__(self, other): @staticmethod def _sanitise_species( - species: List[Union[Tuple[str, int, int], Tuple[smact.Species, int]]], - ) -> List[Tuple[str, int, int]]: - """Sanitise and format a list of species. + species: list[tuple[str, int, int] | tuple[smact.Species, int]], + ) -> list[tuple[str, int, int]]: + """ + Sanitise and format a list of species. Args: + ---- species: See :meth:`~.__init__`. Returns: + ------- sanit_species: Sanity-checked species in the format of a list of (element, oxidation, stoichiometry). Raises: + ------ TypeError: species contains the wrong types. ValueError: species is either empty or contains tuples of incorrect length. @@ -138,9 +144,7 @@ def _sanitise_species( if len(species) == 0: raise ValueError("`species` cannot be empty.") if not isinstance(species[0], tuple): - raise TypeError( - f"`species` must be a list of tuples, got list of {type(species[0])}." - ) + raise TypeError(f"`species` must be a list of tuples, got list of {type(species[0])}.") species_error = ( "`species` list of tuples must contain either " @@ -155,13 +159,9 @@ def _sanitise_species( species.sort(key=itemgetter(0)) sanit_species = species - elif isinstance( - species[0][0], smact.Species - ): # Species class variation of instantiation + elif isinstance(species[0][0], smact.Species): # Species class variation of instantiation species.sort(key=lambda x: (x[0].symbol, -x[0].oxidation)) - sanit_species = [ - (x[0].symbol, x[0].oxidation, x[1]) for x in species - ] + sanit_species = [(x[0].symbol, x[0].oxidation, x[1]) for x in species] else: raise TypeError(species_error) @@ -171,13 +171,16 @@ def _sanitise_species( @staticmethod def __parse_py_sites( structure: pymatgen.core.Structure, - ) -> Tuple[Dict[str, List[List[float]]], List[Tuple[str, int, int]]]: - """Parse the sites of a pymatgen Structure. + ) -> tuple[dict[str, list[list[float]]], list[tuple[str, int, int]]]: + """ + Parse the sites of a pymatgen Structure. Args: + ---- structure: A :class:`pymatgen.core.Structure` instance. Returns: + ------- sites (dict): In which a key is a species string and its corresponding value is a list of the coordinates that species occupies in the supercell. The coordinates @@ -188,9 +191,7 @@ def __parse_py_sites( """ if not isinstance(structure, pymatgen.core.Structure): - raise TypeError( - "structure must be a pymatgen.core.Structure instance." - ) + raise TypeError("structure must be a pymatgen.core.Structure instance.") sites = defaultdict(list) for site in structure.sites: @@ -230,25 +231,24 @@ def __parse_py_sites( return sites, species @staticmethod - def from_py_struct( - structure: pymatgen.core.Structure, determine_oxi: str = "BV" - ): - """Create a SmactStructure from a pymatgen Structure object. + def from_py_struct(structure: pymatgen.core.Structure, determine_oxi: str = "BV"): + """ + Create a SmactStructure from a pymatgen Structure object. Args: + ---- structure: A pymatgen Structure. determine_oxi (str): The method to determine the assignments oxidation states in the structure. Options are 'BV', 'comp_ICSD','both' for determining the oxidation states by bond valence, ICSD statistics or trial both sequentially, respectively. Returns: + ------- :class:`~.SmactStructure` """ if not isinstance(structure, pymatgen.core.Structure): - raise TypeError( - "Structure must be a pymatgen.core.Structure instance." - ) + raise TypeError("Structure must be a pymatgen.core.Structure instance.") if determine_oxi == "BV": bva = BVAnalyzer() @@ -256,9 +256,7 @@ def from_py_struct( elif determine_oxi == "comp_ICSD": comp = structure.composition - oxi_transform = OxidationStateDecorationTransformation( - comp.oxi_state_guesses()[0] - ) + oxi_transform = OxidationStateDecorationTransformation(comp.oxi_state_guesses()[0]) struct = oxi_transform.apply_transformation(structure) print("Charge assigned based on ICSD statistics") @@ -269,9 +267,7 @@ def from_py_struct( print("Oxidation states assigned using bond valence") except ValueError: comp = structure.composition - oxi_transform = OxidationStateDecorationTransformation( - comp.oxi_state_guesses()[0] - ) + oxi_transform = OxidationStateDecorationTransformation(comp.oxi_state_guesses()[0]) struct = oxi_transform.apply_transformation(structure) print("Oxidation states assigned based on ICSD statistics") elif determine_oxi == "predecorated": @@ -298,13 +294,15 @@ def from_py_struct( @staticmethod def from_mp( - species: List[Union[Tuple[str, int, int], Tuple[smact.Species, int]]], + species: list[tuple[str, int, int] | tuple[smact.Species, int]], api_key: str, determine_oxi: str = "BV", ): - """Create a SmactStructure using the first Materials Project entry for a composition. + """ + Create a SmactStructure using the first Materials Project entry for a composition. Args: + ---- species: See :meth:`~.__init__`. determine_oxi (str): The method to determine the assignments oxidation states in the structure. Options are 'BV', 'comp_ICSD','both' for determining the oxidation states by bond valence, @@ -312,6 +310,7 @@ def from_mp( api_key: A www.materialsproject.org API key. Returns: + ------- :class:`~.SmactStructure` """ @@ -333,25 +332,20 @@ def from_mp( if len(structs) == 0: raise ValueError( - "Could not find composition in Materials Project Database, " - "please supply a structure." + "Could not find composition in Materials Project Database, " "please supply a structure." ) # Default to first found structure struct = structs[0]["structure"] - if 0 not in ( - spec[1] for spec in sanit_species - ): # If everything's charged + if 0 not in (spec[1] for spec in sanit_species): # If everything's charged if determine_oxi == "BV": bva = BVAnalyzer() struct = bva.get_oxi_state_decorated_structure(struct) elif determine_oxi == "comp_ICSD": comp = struct.composition - oxi_transform = OxidationStateDecorationTransformation( - comp.oxi_state_guesses()[0] - ) + oxi_transform = OxidationStateDecorationTransformation(comp.oxi_state_guesses()[0]) struct = oxi_transform.apply_transformation(struct) print("Charge assigned based on ICSD statistics") @@ -362,9 +356,7 @@ def from_mp( print("Oxidation states assigned using bond valence") except ValueError: comp = struct.composition - oxi_transform = OxidationStateDecorationTransformation( - comp.oxi_state_guesses()[0] - ) + oxi_transform = OxidationStateDecorationTransformation(comp.oxi_state_guesses()[0]) struct = oxi_transform.apply_transformation(struct) print("Oxidation states assigned based on ICSD statistics") else: @@ -387,13 +379,16 @@ def from_mp( @staticmethod def from_file(fname: str): - """Create SmactStructure from a POSCAR file. + """ + Create SmactStructure from a POSCAR file. Args: + ---- fname: The name of the POSCAR file. See :meth:`~.as_poscar` for format specification. Returns: + ------- :class:`~.SmactStructure` """ @@ -402,13 +397,16 @@ def from_file(fname: str): @staticmethod def from_poscar(poscar: str): - """Create SmactStructure from a POSCAR string. + """ + Create SmactStructure from a POSCAR string. Args: + ---- poscar: A SMACT-formatted POSCAR string. See :meth:`~.as_poscar` for format specification. Returns: + ------- :class:`~.SmactStructure` """ @@ -435,9 +433,7 @@ def from_poscar(poscar: str): lattice_param = float(lines[1]) - lattice = np.array( - [[float(point) for point in line.split(" ")] for line in lines[2:5]] - ) + lattice = np.array([[float(point) for point in line.split(" ")] for line in lines[2:5]]) sites = defaultdict(list) for line in lines[8:]: @@ -457,15 +453,17 @@ def from_poscar(poscar: str): def _format_style( self, template: str, - delim: Optional[str] = " ", - include_ground: Optional[bool] = False, + delim: str | None = " ", + include_ground: bool | None = False, ) -> str: - """Format a given template string with the composition. + """ + Format a given template string with the composition. Formats a python template string with species information, with each species separated by a given delimiter. Args: + ---- template: Template string to format, using python's curly brackets notation. Supported keywords are `ele` for the elemental symbol, `stoic` for the @@ -477,12 +475,14 @@ def _format_style( of neutral species. Returns: + ------- String of templates formatted for each species, separated by `delim`. Examples: - >>> s = SmactStructure.from_file('tests/files/CaTiO3.txt') - >>> template = '{stoic}x{ele}{charge}{sign}' + -------- + >>> s = SmactStructure.from_file("tests/files/CaTiO3.txt") + >>> template = "{stoic}x{ele}{charge}{sign}" >>> print(s._format_style(template)) 1xCa2+ 3xO2- 1xTi4+ @@ -509,17 +509,21 @@ def _format_style( ) @staticmethod - def _get_ele_stoics(species: List[Tuple[str, int, int]]) -> Dict[str, int]: - """Get the number of each element type in the compound, irrespective of oxidation state. + def _get_ele_stoics(species: list[tuple[str, int, int]]) -> dict[str, int]: + """ + Get the number of each element type in the compound, irrespective of oxidation state. Args: + ---- species: See :meth:`~.__init__`. Returns: + ------- eles: Dictionary of {element: stoichiometry}. Examples: - >>> species = [('Fe', 2, 1), ('Fe', 3, 2), ('O', -2, 4)] + -------- + >>> species = [("Fe", 2, 1), ("Fe", 3, 2), ("O", -2, 4)] >>> print(SmactStructure._get_ele_stoics(species)) {'Fe': 3, 'O': 4} @@ -530,18 +534,21 @@ def _get_ele_stoics(species: List[Tuple[str, int, int]]) -> Dict[str, int]: return dict(eles) - def has_species(self, species: Tuple[str, int]) -> bool: + def has_species(self, species: tuple[str, int]) -> bool: """Determine whether a given species is in the structure.""" return species in map(itemgetter(0, 1), self.species) - def get_spec_strs(self) -> List[str]: - """Get string representations of the constituent species. + def get_spec_strs(self) -> list[str]: + """ + Get string representations of the constituent species. Returns: + ------- A list of strings, formatted as '{element}{charge}{sign}'. Examples: - >>> s = SmactStructure.from_file('tests/files/CaTiO3.txt') + -------- + >>> s = SmactStructure.from_file("tests/files/CaTiO3.txt") >>> s.get_spec_strs() ['Ca2+', 'O2-', 'Ti4+'] @@ -549,17 +556,20 @@ def get_spec_strs(self) -> List[str]: return self._format_style("{ele}{charge}{sign}").split(" ") def composition(self) -> str: - """Generate a key that describes the composition. + """ + Generate a key that describes the composition. Key format is '{element}_{stoichiometry}_{charge}{sign}' with no delimiter, *sans brackets*. Species are ordered as stored within the structure, see :class:`~.SmactStructure`. Returns: + ------- Key describing constituent species. Examples: - >>> s = SmactStructure.from_file('tests/files/CaTiO3.txt') + -------- + >>> s = SmactStructure.from_file("tests/files/CaTiO3.txt") >>> print(s.composition()) Ca_1_2+O_3_2-Ti_1_4+ @@ -568,7 +578,8 @@ def composition(self) -> str: return self._format_style(comp_style, delim="", include_ground=True) def as_poscar(self) -> str: - """Represent the structure as a POSCAR file compatible with VASP5. + """ + Represent the structure as a POSCAR file compatible with VASP5. The POSCAR format adopted is as follows: @@ -587,6 +598,7 @@ def as_poscar(self) -> str: For examples of this format, see the text files under tests/files. Returns: + ------- str: POSCAR-style representation of the structure. """ @@ -594,21 +606,13 @@ def as_poscar(self) -> str: poscar += f"{self.lattice_param}\n" - poscar += ( - "\n".join( - " ".join(map(str, vec)) for vec in self.lattice_mat.tolist() - ) - + "\n" - ) + poscar += "\n".join(" ".join(map(str, vec)) for vec in self.lattice_mat.tolist()) + "\n" spec_count = {spec: len(coords) for spec, coords in self.sites.items()} poscar += self._format_style("{ele}") + "\n" - poscar += ( - " ".join(str(spec_count[spec]) for spec in self.get_spec_strs()) - + "\n" - ) + poscar += " ".join(str(spec_count[spec]) for spec in self.get_spec_strs()) + "\n" poscar += "Cartesian\n" for spec, coords in self.sites.items(): diff --git a/smact/structure_prediction/utilities.py b/smact/structure_prediction/utilities.py index 6c265d79..f7f68ae7 100644 --- a/smact/structure_prediction/utilities.py +++ b/smact/structure_prediction/utilities.py @@ -1,24 +1,31 @@ """Miscellaneous tools for data parsing.""" -import re -from typing import Dict, Optional, Tuple, Union +from __future__ import annotations -import pymatgen +import re +from typing import TYPE_CHECKING from . import logger +if TYPE_CHECKING: + import pymatgen -def parse_spec(species: str) -> Tuple[str, int]: - """Parse a species string into its element and charge. + +def parse_spec(species: str) -> tuple[str, int]: + """ + Parse a species string into its element and charge. Args: + ---- species (str): String representation of a species in the format {element}{absolute_charge}{sign}. Returns: + ------- A tuple of (element, signed_charge). Examples: + -------- >>> parse_spec("Fe2+") ('Fe', 2) >>> parse_spec("O2-") @@ -36,18 +43,22 @@ def parse_spec(species: str) -> Tuple[str, int]: return ele, charge -def unparse_spec(species: Tuple[str, int]) -> str: - """Unparse a species into a string representation. +def unparse_spec(species: tuple[str, int]) -> str: + """ + Unparse a species into a string representation. The analogue of :func:`parse_spec`. Args: - A tuple of (element, signed_charge). + ---- + species (tuple): A tuple of (element, signed_charge). Returns: + ------- String of {element}{absolute_charge}{sign}. Examples: + -------- >>> unparse_spec(("Fe", 2)) 'Fe2+' >>> unparse_spec(("O", -2)) @@ -58,12 +69,15 @@ def unparse_spec(species: Tuple[str, int]) -> str: def get_sign(charge: int) -> str: - """Get string representation of a number's sign. + """ + Get string representation of a number's sign. Args: + ---- charge (int): The number whose sign to derive. Returns: + ------- sign (str): either '+', '-', or '' for neutral. """ @@ -77,10 +91,12 @@ def get_sign(charge: int) -> str: def convert_next_gen_mprest_data( doc, -) -> Dict[str, Union[pymatgen.core.Structure, Optional[str]]]: - """Converts the `MPDataDoc` object returned by a next-gen MP query to a dictionary +) -> dict[str, pymatgen.core.Structure | str | None]: + """ + Converts the `MPDataDoc` object returned by a next-gen MP query to a dictionary. Args: + ---- doc (MPDataDoc): A MPDataDoc object (based on a pydantic model) with fields 'structure' and 'material_id' Returns: A dictionary containing at least the keys 'structure' and diff --git a/smact/tests/test_core.py b/smact/tests/test_core.py index a30b4f28..0d46f8ec 100755 --- a/smact/tests/test_core.py +++ b/smact/tests/test_core.py @@ -1,4 +1,5 @@ #!/usr/bin/env python +from __future__ import annotations import os import unittest @@ -34,9 +35,7 @@ def test_Element_class_Pt(self): self.assertEqual(Pt.dipol, 44.00) def test_ordered_elements(self): - self.assertEqual( - smact.ordered_elements(65, 68), ["Tb", "Dy", "Ho", "Er"] - ) + self.assertEqual(smact.ordered_elements(65, 68), ["Tb", "Dy", "Ho", "Er"]) self.assertEqual(smact.ordered_elements(52, 52), ["Te"]) def test_element_dictionary(self): @@ -49,11 +48,7 @@ def test_element_dictionary(self): self.assertTrue("Rn" in smact.element_dictionary()) def test_are_eq(self): - self.assertTrue( - smact.are_eq( - [1.00, 2.00, 3.00], [1.001, 1.999, 3.00], tolerance=1e-2 - ) - ) + self.assertTrue(smact.are_eq([1.00, 2.00, 3.00], [1.001, 1.999, 3.00], tolerance=1e-2)) self.assertFalse(smact.are_eq([1.00, 2.00, 3.00], [1.001, 1.999, 3.00])) def test_gcd_recursive(self): @@ -75,9 +70,7 @@ def test_neutral_ratios(self): def test_compound_eneg_brass(self): self.assertAlmostEqual( - compound_electroneg( - elements=["Cu", "Zn"], stoichs=[0.5, 0.5], source="Pauling" - ), + compound_electroneg(elements=["Cu", "Zn"], stoichs=[0.5, 0.5], source="Pauling"), 5.0638963259, ) @@ -104,11 +97,7 @@ def test_pauling_test(self): (Sn.pauling_eneg, S.pauling_eneg), ) ) - self.assertFalse( - smact.screening.pauling_test( - (-2, +2), (Sn.pauling_eneg, S.pauling_eneg) - ) - ) + self.assertFalse(smact.screening.pauling_test((-2, +2), (Sn.pauling_eneg, S.pauling_eneg))) self.assertFalse( smact.screening.pauling_test( (-2, -2, +2), @@ -194,40 +183,24 @@ def test_pauling_test_old(self): def test_eneg_states_test(self): Na, Fe, Cl = (smact.Element(label) for label in ("Na", "Fe", "Cl")) self.assertTrue( - smact.screening.eneg_states_test( - [1, 3, -1], [Na.pauling_eneg, Fe.pauling_eneg, Cl.pauling_eneg] - ) + smact.screening.eneg_states_test([1, 3, -1], [Na.pauling_eneg, Fe.pauling_eneg, Cl.pauling_eneg]) ) self.assertFalse( - smact.screening.eneg_states_test( - [-1, 3, 1], [Na.pauling_eneg, Fe.pauling_eneg, Cl.pauling_eneg] - ) + smact.screening.eneg_states_test([-1, 3, 1], [Na.pauling_eneg, Fe.pauling_eneg, Cl.pauling_eneg]) ) def test_eneg_states_test_alternate(self): Na, Fe, Cl = (smact.Element(label) for label in ("Na", "Fe", "Cl")) self.assertTrue( - smact.screening.eneg_states_test_alternate( - [1, 3, -1], [Na.pauling_eneg, Fe.pauling_eneg, Cl.pauling_eneg] - ) + smact.screening.eneg_states_test_alternate([1, 3, -1], [Na.pauling_eneg, Fe.pauling_eneg, Cl.pauling_eneg]) ) self.assertFalse( - smact.screening.eneg_states_test_alternate( - [-1, 3, 1], [Na.pauling_eneg, Fe.pauling_eneg, Cl.pauling_eneg] - ) + smact.screening.eneg_states_test_alternate([-1, 3, 1], [Na.pauling_eneg, Fe.pauling_eneg, Cl.pauling_eneg]) ) def test_eneg_states_test_threshold(self): - self.assertFalse( - smact.screening.eneg_states_test_threshold( - [1, -1], [1.83, 1.82], threshold=0 - ) - ) - self.assertTrue( - smact.screening.eneg_states_test_threshold( - [1, -1], [1.83, 1.82], threshold=0.1 - ) - ) + self.assertFalse(smact.screening.eneg_states_test_threshold([1, -1], [1.83, 1.82], threshold=0)) + self.assertTrue(smact.screening.eneg_states_test_threshold([1, -1], [1.83, 1.82], threshold=0.1)) def test_ml_rep_generator(self): Pb, O = (smact.Element(label) for label in ("Pb", "O")) @@ -335,12 +308,8 @@ def test_ml_rep_generator(self): 0.0, 0.0, ] - self.assertEqual( - smact.screening.ml_rep_generator(["Pb", "O"], [1, 2]), PbO2_ml - ) - self.assertEqual( - smact.screening.ml_rep_generator([Pb, O], [1, 2]), PbO2_ml - ) + self.assertEqual(smact.screening.ml_rep_generator(["Pb", "O"], [1, 2]), PbO2_ml) + self.assertEqual(smact.screening.ml_rep_generator([Pb, O], [1, 2]), PbO2_ml) def test_smact_filter(self): Na, Fe, Cl = (smact.Element(label) for label in ("Na", "Fe", "Cl")) @@ -354,39 +323,20 @@ def test_smact_filter(self): ) self.assertEqual( result, - smact.screening.smact_filter( - [Na, Fe, Cl], threshold=2, oxidation_states_set=TEST_OX_STATES - ), - ) - result_comp_tuple = smact.screening.smact_filter( - [Na, Fe, Cl], threshold=2, comp_tuple=True - ) - self.assertTupleEqual( - result_comp_tuple[0].element_symbols, ("Na", "Fe", "Cl") - ) - self.assertTupleEqual(result_comp_tuple[0].stoichiometries, (2, 1, 1)) - self.assertTupleEqual( - result_comp_tuple[0].oxidation_states, (1, -1, -1) + smact.screening.smact_filter([Na, Fe, Cl], threshold=2, oxidation_states_set=TEST_OX_STATES), ) + self.assertEqual( - set( - smact.screening.smact_filter( - [Na, Fe, Cl], threshold=2, species_unique=False - ) - ), + set(smact.screening.smact_filter([Na, Fe, Cl], threshold=2, species_unique=False)), { (("Na", "Fe", "Cl"), (2, 1, 1)), (("Na", "Fe", "Cl"), (1, 1, 2)), }, ) - self.assertEqual( - len(smact.screening.smact_filter([Na, Fe, Cl], threshold=8)), 77 - ) + self.assertEqual(len(smact.screening.smact_filter([Na, Fe, Cl], threshold=8)), 77) - result = smact.screening.smact_filter( - [Na, Fe, Cl], stoichs=[[1], [1], [4]] - ) + result = smact.screening.smact_filter([Na, Fe, Cl], stoichs=[[1], [1], [4]]) self.assertEqual( [(r[0], r[1], r[2]) for r in result], [ @@ -402,33 +352,17 @@ def test_smact_filter(self): def test_smact_validity(self): self.assertTrue(smact.screening.smact_validity("NaCl")) self.assertTrue(smact.screening.smact_validity("Na10Cl10")) - self.assertFalse( - smact.screening.smact_validity("Al3Li", include_alloys=False) - ) - self.assertTrue( - smact.screening.smact_validity("Al3Li", include_alloys=True) - ) + self.assertFalse(smact.screening.smact_validity("Al3Li", include_alloys=False)) + self.assertTrue(smact.screening.smact_validity("Al3Li", include_alloys=True)) # Test for single element self.assertTrue(smact.screening.smact_validity("Al")) # Test for MgB2 which is invalid for the default oxi states but valid for the icsd states self.assertFalse(smact.screening.smact_validity("MgB2")) - self.assertTrue( - smact.screening.smact_validity("MgB2", oxidation_states_set="icsd") - ) - self.assertFalse( - smact.screening.smact_validity( - "MgB2", oxidation_states_set="pymatgen" - ) - ) - self.assertTrue( - smact.screening.smact_validity("MgB2", oxidation_states_set="wiki") - ) - self.assertFalse( - smact.screening.smact_validity( - "MgB2", oxidation_states_set=TEST_OX_STATES - ) - ) + self.assertTrue(smact.screening.smact_validity("MgB2", oxidation_states_set="icsd")) + self.assertFalse(smact.screening.smact_validity("MgB2", oxidation_states_set="pymatgen")) + self.assertTrue(smact.screening.smact_validity("MgB2", oxidation_states_set="wiki")) + self.assertFalse(smact.screening.smact_validity("MgB2", oxidation_states_set=TEST_OX_STATES)) # ---------------- Lattice ---------------- def test_Lattice_class(self): @@ -440,9 +374,7 @@ def test_Lattice_class(self): # ---------- Lattice parameters ----------- def test_lattice_parameters(self): - perovskite = smact.lattice_parameters.cubic_perovskite( - [1.81, 1.33, 1.82] - ) + perovskite = smact.lattice_parameters.cubic_perovskite([1.81, 1.33, 1.82]) wurtz = smact.lattice_parameters.wurtzite([1.81, 1.33]) self.assertAlmostEqual(perovskite[0], 6.3) self.assertAlmostEqual(perovskite[1], 6.3) diff --git a/smact/tests/test_doper.py b/smact/tests/test_doper.py index 79fd7446..12556980 100644 --- a/smact/tests/test_doper.py +++ b/smact/tests/test_doper.py @@ -1,9 +1,12 @@ +from __future__ import annotations + import os import unittest -import smact +import pytest + from smact.dopant_prediction import doper -from smact.structure_prediction import mutation, utilities +from smact.structure_prediction import utilities files_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "files") TEST_LAMBDA_JSON = os.path.join(files_dir, "test_lambda_tab.json") @@ -11,16 +14,11 @@ class DopantPredictionTest(unittest.TestCase): def test_dopant_prediction(self): - num_dopants = 10 test_specie = ("Cu+", "Ga3+", "S2-") test = doper.Doper(test_specie) - cation_max_charge = max( - test_specie, key=lambda x: utilities.parse_spec(x)[1] - ) - anion_min_charge = min( - test_specie, key=lambda x: utilities.parse_spec(x)[1] - ) + cation_max_charge = max(test_specie, key=lambda x: utilities.parse_spec(x)[1]) + anion_min_charge = min(test_specie, key=lambda x: utilities.parse_spec(x)[1]) _, cat_charge = utilities.parse_spec(cation_max_charge) _, an_charge = utilities.parse_spec(anion_min_charge) @@ -48,17 +46,13 @@ def test_dopant_prediction(self): def test_dopant_prediction_skipspecies(self): test_specie = ("Cu+", "Ga3+", "S2-") - with self.assertRaises(ValueError): - doper.Doper( - test_specie, filepath=TEST_LAMBDA_JSON, embedding="skipspecies" - ) + with pytest.raises(ValueError): + doper.Doper(test_specie, filepath=TEST_LAMBDA_JSON, embedding="skipspecies") - with self.assertRaises(ValueError): + with pytest.raises(ValueError): doper.Doper(test_specie, embedding="skip", use_probability=False) - test = doper.Doper( - test_specie, embedding="skipspecies", use_probability=False - ) + test = doper.Doper(test_specie, embedding="skipspecies", use_probability=False) result = test.get_dopants() n_sub_list_cat = result.get("n-type cation substitutions").get("sorted") diff --git a/smact/tests/test_structure.py b/smact/tests/test_structure.py index d92253bc..c021c9ef 100644 --- a/smact/tests/test_structure.py +++ b/smact/tests/test_structure.py @@ -1,4 +1,6 @@ -"""Test structure prediction module""" +"""Test structure prediction module.""" + +from __future__ import annotations import itertools import json @@ -9,10 +11,12 @@ from contextlib import contextmanager from operator import itemgetter from random import sample +from typing import ClassVar import numpy as np import pandas as pd import pymatgen +import pytest from pandas.testing import assert_frame_equal, assert_series_equal from pymatgen.analysis.structure_prediction.substitution_probability import ( SubstitutionProbability, @@ -61,16 +65,15 @@ def ignore_warnings(logger: logging.Logger) -> int: class StructureTest(unittest.TestCase): """`SmactStructure` testing.""" - TEST_SPECIES = { + TEST_SPECIES: ClassVar = { "CaTiO3": [("Ca", 2, 1), ("Ti", 4, 1), ("O", -2, 3)], "NaCl": [("Na", 1, 1), ("Cl", -1, 1)], "Fe": [("Fe", 0, 1)], } - def assertStructAlmostEqual( - self, s1: SmactStructure, s2: SmactStructure, places: int = 7 - ): - """Assert that two SmactStructures are almost equal. + def assertStructAlmostEqual(self, s1: SmactStructure, s2: SmactStructure, places: int = 7): + """ + Assert that two SmactStructures are almost equal. Almost equality dependent on how many decimal places the site coordinates are equal to. @@ -88,7 +91,7 @@ def assertStructAlmostEqual( def test_as_poscar(self): """Test POSCAR generation.""" - for comp in self.TEST_SPECIES.keys(): + for comp in self.TEST_SPECIES: with self.subTest(comp=comp): comp_file = os.path.join(files_dir, f"{comp}.txt") with open(comp_file) as f: @@ -142,9 +145,7 @@ def test_from_py_struct_icsd(self): py_structure = pymatgen.core.Structure.from_dict(d) with ignore_warnings(smact.structure_prediction.logger): - s1 = SmactStructure.from_py_struct( - py_structure, determine_oxi="comp_ICSD" - ) + s1 = SmactStructure.from_py_struct(py_structure, determine_oxi="comp_ICSD") s2 = SmactStructure.from_file(os.path.join(files_dir, "CaTiO3.txt")) @@ -152,11 +153,7 @@ def test_from_py_struct_icsd(self): def test_has_species(self): """Test determining whether a species is in a `SmactStructure`.""" - s1 = SmactStructure( - *self._gen_empty_structure( - [("Ba", 2, 2), ("O", -2, 1), ("F", -1, 2)] - ) - ) + s1 = SmactStructure(*self._gen_empty_structure([("Ba", 2, 2), ("O", -2, 1), ("F", -1, 2)])) self.assertTrue(s1.has_species(("Ba", 2))) self.assertFalse(s1.has_species(("Ba", 3))) @@ -164,16 +161,8 @@ def test_has_species(self): def test_smactStruc_comp_key(self): """Test generation of a composition key for `SmactStructure`s.""" - s1 = SmactStructure( - *self._gen_empty_structure( - [("Ba", 2, 2), ("O", -2, 1), ("F", -1, 2)] - ) - ) - s2 = SmactStructure( - *self._gen_empty_structure( - [("Fe", 2, 1), ("Fe", 3, 2), ("O", -2, 4)] - ) - ) + s1 = SmactStructure(*self._gen_empty_structure([("Ba", 2, 2), ("O", -2, 1), ("F", -1, 2)])) + s2 = SmactStructure(*self._gen_empty_structure([("Fe", 2, 1), ("Fe", 3, 2), ("O", -2, 4)])) Ba = Species("Ba", 2) O = Species("O", -2) @@ -181,12 +170,8 @@ def test_smactStruc_comp_key(self): Fe2 = Species("Fe", 2) Fe3 = Species("Fe", 3) - s3 = SmactStructure( - *self._gen_empty_structure([(Ba, 2), (O, 1), (F, 2)]) - ) - s4 = SmactStructure( - *self._gen_empty_structure([(Fe2, 1), (Fe3, 2), (O, 4)]) - ) + s3 = SmactStructure(*self._gen_empty_structure([(Ba, 2), (O, 1), (F, 2)])) + s4 = SmactStructure(*self._gen_empty_structure([(Fe2, 1), (Fe3, 2), (O, 4)])) Ba_2OF_2 = "Ba_2_2+F_2_1-O_1_2-" Fe_3O_4 = "Fe_2_3+Fe_1_2+O_4_2-" @@ -206,9 +191,7 @@ def test_smactStruc_from_file(self): def test_equality(self): """Test equality determination of `SmactStructure`.""" - struct_files = [ - os.path.join(files_dir, f"{x}.txt") for x in ["CaTiO3", "NaCl"] - ] + struct_files = [os.path.join(files_dir, f"{x}.txt") for x in ["CaTiO3", "NaCl"]] CaTiO3 = SmactStructure.from_file(struct_files[0]) NaCl = SmactStructure.from_file(struct_files[1]) @@ -223,28 +206,16 @@ def test_equality(self): def test_ele_stoics(self): """Test acquiring element stoichiometries.""" - s1 = SmactStructure( - *self._gen_empty_structure( - [("Fe", 2, 1), ("Fe", 3, 2), ("O", -2, 4)] - ) - ) + s1 = SmactStructure(*self._gen_empty_structure([("Fe", 2, 1), ("Fe", 3, 2), ("O", -2, 4)])) s1_stoics = {"Fe": 3, "O": 4} - s2 = SmactStructure( - *self._gen_empty_structure( - [("Ba", 2, 2), ("O", -2, 1), ("F", -1, 2)] - ) - ) + s2 = SmactStructure(*self._gen_empty_structure([("Ba", 2, 2), ("O", -2, 1), ("F", -1, 2)])) s2_stoics = {"Ba": 2, "O": 1, "F": 2} for test, expected in [(s1, s1_stoics), (s2, s2_stoics)]: with self.subTest(species=test.species): - self.assertEqual( - SmactStructure._get_ele_stoics(test.species), expected - ) + self.assertEqual(SmactStructure._get_ele_stoics(test.species), expected) - @unittest.skipUnless( - os.environ.get("MPI_KEY"), "requires MPI key to be set." - ) + @unittest.skipUnless(os.environ.get("MPI_KEY"), "requires MPI key to be set.") def test_from_mp(self): """Test downloading structures from materialsproject.org.""" # TODO Needs ensuring that the structure query gets the same @@ -294,15 +265,11 @@ def test_db_interface(self): self.fail(e) with self.subTest(msg="Getting structure from table."): - struct_list = self.db.get_structs( - struct.composition(), self.TEST_TABLE - ) + struct_list = self.db.get_structs(struct.composition(), self.TEST_TABLE) self.assertEqual(len(struct_list), 1) self.assertEqual(struct_list[0], struct) - struct_files = [ - os.path.join(files_dir, f"{x}.txt") for x in ["NaCl", "Fe"] - ] + struct_files = [os.path.join(files_dir, f"{x}.txt") for x in ["NaCl", "Fe"]] structs = [SmactStructure.from_file(fname) for fname in struct_files] with self.subTest(msg="Adding multiple structures to table."): @@ -333,21 +300,16 @@ def test_db_interface(self): [struct], ] - for spec, expected in zip( - test_with_species_args, test_with_species_exp - ): + for spec, expected in zip(test_with_species_args, test_with_species_exp): with self.subTest(msg=f"Retrieving species with {spec}"): - self.assertEqual( - self.db.get_with_species(spec, self.TEST_TABLE), expected - ) + self.assertEqual(self.db.get_with_species(spec, self.TEST_TABLE), expected) mp_strucs = [ pymatgen.core.Structure.from_file(os.path.join(files_dir, f)) for f in ["CaTiO3.json", "NaCl.json", "Fe3O4.json"] ] mp_data = [ - {"material_id": mpid, "structure": s} - for mpid, s in zip(["mp-4019", "mp-22862", "mp-19306"], mp_strucs) + {"material_id": mpid, "structure": s} for mpid, s in zip(["mp-4019", "mp-22862", "mp-19306"], mp_strucs) ] with self.subTest(msg="Testing adding downloaded MP structures."): @@ -364,15 +326,11 @@ def setUpClass(cls): cls.test_struct = SmactStructure.from_file(TEST_POSCAR) cls.test_mutator = CationMutator.from_json(lambda_json=TEST_LAMBDA_JSON) - cls.test_pymatgen_mutator = CationMutator.from_json( - lambda_json=None, alpha=lambda x, y: -5 - ) + cls.test_pymatgen_mutator = CationMutator.from_json(lambda_json=None, alpha=lambda x, y: -5) # 5 random test species -> 5! test pairs cls.test_species = sample(list(cls.test_pymatgen_mutator.specs), 5) - cls.test_pairs = list( - itertools.combinations_with_replacement(cls.test_species, 2) - ) + cls.test_pairs = list(itertools.combinations_with_replacement(cls.test_species, 2)) cls.pymatgen_sp = SubstitutionProbability(lambda_table=None, alpha=-5) @@ -400,16 +358,11 @@ def test_partition_func_Z(self): def test_pymatgen_lambda_import(self): """Test importing pymatgen lambda table.""" - self.assertIsInstance( - self.test_pymatgen_mutator.lambda_tab, pd.DataFrame - ) + self.assertIsInstance(self.test_pymatgen_mutator.lambda_tab, pd.DataFrame) def test_lambda_interface(self): """Test getting lambda values.""" - test_cases = [ - itertools.permutations(x) - for x in [("A", "B"), ("A", "C"), ("B", "C")] - ] + test_cases = [itertools.permutations(x) for x in [("A", "B"), ("A", "C"), ("B", "C")]] expected = [0.5, -5.0, 0.3] @@ -417,9 +370,7 @@ def test_lambda_interface(self): for spec_comb in test_case: s1, s2 = spec_comb with self.subTest(s1=s1, s2=s2): - self.assertEqual( - self.test_mutator.get_lambda(s1, s2), expectation - ) + self.assertEqual(self.test_mutator.get_lambda(s1, s2), expectation) def test_ion_mutation(self): """Test mutating an ion of a SmactStructure.""" @@ -430,17 +381,14 @@ def test_ion_mutation(self): BaTiO3 = SmactStructure.from_file(ba_file) with self.subTest(s1="CaTiO3", s2="BaTiO3"): - mutation = self.test_mutator._mutate_structure( - CaTiO3, "Ca2+", "Ba2+" - ) + mutation = self.test_mutator._mutate_structure(CaTiO3, "Ca2+", "Ba2+") self.assertEqual(mutation, BaTiO3) na_file = os.path.join(files_dir, "NaCl.txt") NaCl = SmactStructure.from_file(na_file) - with self.subTest(s1="Na1+Cl1-", s2="Na2+Cl1-"): - with self.assertRaises(ValueError): - self.test_mutator._mutate_structure(NaCl, "Na1+", "Na2+") + with self.subTest(s1="Na1+Cl1-", s2="Na2+Cl1-"), pytest.raises(ValueError): + self.test_mutator._mutate_structure(NaCl, "Na1+", "Na2+") # TODO Confirm functionality with more complex substitutions @@ -459,15 +407,10 @@ def test_cond_sub_probs(self): with self.subTest(s=s1): cond_sub_probs_test = self.test_mutator.cond_sub_probs(s1) - vals = [ - (s1, s2, self.test_mutator.cond_sub_prob(s1, s2)) - for s2 in ["A", "B", "C"] - ] + vals = [(s1, s2, self.test_mutator.cond_sub_prob(s1, s2)) for s2 in ["A", "B", "C"]] test_df = pd.DataFrame(vals) - test_df: pd.DataFrame = test_df.pivot( - index=0, columns=1, values=2 - ) + test_df: pd.DataFrame = test_df.pivot_table(index=0, columns=1, values=2) # Slice to convert to series assert_series_equal(cond_sub_probs_test, test_df.iloc[0]) @@ -515,7 +458,7 @@ def test_complete_cond_probs(self): ] cond_probs = pd.DataFrame(vals) - cond_probs = cond_probs.pivot(index=0, columns=1, values=2) + cond_probs = cond_probs.pivot_table(index=0, columns=1, values=2) assert_frame_equal(self.test_mutator.complete_cond_probs(), cond_probs) @@ -533,7 +476,7 @@ def test_complete_sub_probs(self): ] sub_probs = pd.DataFrame(vals) - sub_probs = sub_probs.pivot(index=0, columns=1, values=2) + sub_probs = sub_probs.pivot_table(index=0, columns=1, values=2) assert_frame_equal(self.test_mutator.complete_sub_probs(), sub_probs) @@ -551,7 +494,7 @@ def test_complete_pair_corrs(self): ] pair_corrs = pd.DataFrame(vals) - pair_corrs = pair_corrs.pivot(index=0, columns=1, values=2) + pair_corrs = pair_corrs.pivot_table(index=0, columns=1, values=2) assert_frame_equal(self.test_mutator.complete_pair_corrs(), pair_corrs) @@ -584,11 +527,7 @@ def test_prediction(self): with self.subTest(msg="Acquiring predictions"): try: - predictions = list( - sp.predict_structs( - test_specs, thresh=0.02, include_same=False - ) - ) + predictions = list(sp.predict_structs(test_specs, thresh=0.02, include_same=False)) except Exception as e: self.fail(e) diff --git a/smact/utils/band_gap_simple.py b/smact/utils/band_gap_simple.py index 28921865..606edca7 100644 --- a/smact/utils/band_gap_simple.py +++ b/smact/utils/band_gap_simple.py @@ -21,21 +21,11 @@ if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser( - description="Compound band gap estimates from elemental data." - ) - parser.add_argument( - "-a", "--anion", type=str, help="Element symbol for anion." - ) - parser.add_argument( - "-c", "--cation", type=str, help="Element symbol for cation." - ) - parser.add_argument( - "-d", "--distance", type=float, help="Internuclear separation." - ) - parser.add_argument( - "-v", "--verbose", action="store_true", help="More Verbose output." - ) + parser = argparse.ArgumentParser(description="Compound band gap estimates from elemental data.") + parser.add_argument("-a", "--anion", type=str, help="Element symbol for anion.") + parser.add_argument("-c", "--cation", type=str, help="Element symbol for cation.") + parser.add_argument("-d", "--distance", type=float, help="Internuclear separation.") + parser.add_argument("-v", "--verbose", action="store_true", help="More Verbose output.") args = parser.parse_args() print( diff --git a/smact/utils/download_compounds_with_mp_api.py b/smact/utils/download_compounds_with_mp_api.py index 777042b0..e6d29b13 100644 --- a/smact/utils/download_compounds_with_mp_api.py +++ b/smact/utils/download_compounds_with_mp_api.py @@ -4,6 +4,7 @@ import time from collections import defaultdict from pathlib import Path +from typing import Optional from mp_api.client import MPRester from pymatgen.core.composition import Composition @@ -11,7 +12,7 @@ def download_mp_data( - mp_api_key: str = None, + mp_api_key: Optional[str] = None, num_elements: int = 2, max_stoich: int = 8, save_dir: str = "data/binary/mp_api", @@ -24,6 +25,7 @@ def download_mp_data( The data is saved to a specified directory. Args: + ---- mp_api_key (str, optional): the API key for Materials Project. num_elements (int, optional): the number of elements in each compound to consider. Defaults to 2. @@ -35,13 +37,13 @@ def download_mp_data( Defaults to 1. Returns: + ------- None + """ # check if MP_API_KEY is set if mp_api_key is None: - raise ValueError( - "Please set your MP_API_KEY in the environment variable." - ) + raise ValueError("Please set your MP_API_KEY in the environment variable.") # set save directory save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) @@ -49,9 +51,7 @@ def download_mp_data( # make a list for all possible combinartions of formula anonymous symbols = string.ascii_uppercase formula_anonymous_list = [] - for stoichs in itertools.combinations_with_replacement( - range(1, max_stoich + 1), num_elements - ): + for stoichs in itertools.combinations_with_replacement(range(1, max_stoich + 1), num_elements): formula_dict = {symbols[i]: stoich for i, stoich in enumerate(stoichs)} formula_anonymous_list.append(Composition(formula_dict).reduced_formula) formula_anonymous_list = sorted(set(formula_anonymous_list)) @@ -89,7 +89,5 @@ def download_mp_data( if (energy_above_hull) < e_hull_dict[formula_pretty]: e_hull_dict[formula_pretty] = energy_above_hull - json.dump( - doc.dict(), open(save_dir / f"{formula_pretty}.json", "w") - ) + json.dump(doc.dict(), open(save_dir / f"{formula_pretty}.json", "w")) time.sleep(request_interval) diff --git a/smact/utils/generate_composition_with_smact.py b/smact/utils/generate_composition_with_smact.py index 55449d2b..a75e21d6 100644 --- a/smact/utils/generate_composition_with_smact.py +++ b/smact/utils/generate_composition_with_smact.py @@ -3,6 +3,7 @@ import warnings from functools import partial from pathlib import Path +from typing import Optional import pandas as pd from pymatgen.core.composition import Composition @@ -17,10 +18,8 @@ def convert_formula(combinations, num_elements, max_stoich): symbols = [element.symbol for element in combinations] local_compounds = [] - for counts in itertools.product( - range(1, max_stoich + 1), repeat=num_elements - ): - formula_dict = {symbol: count for symbol, count in zip(symbols, counts)} + for counts in itertools.product(range(1, max_stoich + 1), repeat=num_elements): + formula_dict = dict(zip(symbols, counts)) formula = Composition(formula_dict).reduced_formula local_compounds.append(formula) return local_compounds @@ -30,13 +29,15 @@ def generate_composition_with_smact( num_elements: int = 2, max_stoich: int = 8, max_atomic_num: int = 103, - num_processes: int = None, - save_path: str = None, + num_processes: Optional[int] = None, + save_path: Optional[str] = None, ): - """Generate all possible compositions of a given number of elements and + """ + Generate all possible compositions of a given number of elements and filter them with SMACT. Args: + ---- num_elements: the number of elements in a compound. Defaults to 2. max_stoich: the maximum stoichiometric coefficient. Defaults to 8. max_atomic_num: the maximum atomic number. Defaults to 103. @@ -44,26 +45,21 @@ def generate_composition_with_smact( save_path: the path to save the results. Defaults to None. Returns: + ------- _description_ - """ + """ # 1. generate all possible combinations of elements print("#1. Generating all possible combinations of elements...") - elements = [ - Element(element) for element in ordered_elements(1, max_atomic_num) - ] + elements = [Element(element) for element in ordered_elements(1, max_atomic_num)] combinations = list(itertools.combinations(elements, num_elements)) print(f"Number of generated combinations: {len(list(combinations))}") # 2. generate all possible stoichiometric combinations print("#2. Generating all possible stoichiometric combinations...") - pool = multiprocessing.Pool( - processes=multiprocessing.cpu_count() - if num_processes is None - else num_processes - ) + pool = multiprocessing.Pool(processes=multiprocessing.cpu_count() if num_processes is None else num_processes) compounds = list( tqdm( pool.imap_unordered( @@ -90,24 +86,14 @@ def generate_composition_with_smact( # 3. filter compounds with smact print("#3. Filtering compounds with SMACT...") elements_pauling = [ - Element(element) - for element in ordered_elements(1, max_atomic_num) - if Element(element).pauling_eneg is not None + Element(element) for element in ordered_elements(1, max_atomic_num) if Element(element).pauling_eneg is not None ] # omit elements without Pauling electronegativity (e.g., He, Ne, Ar, ...) - compounds_pauling = list( - itertools.combinations(elements_pauling, num_elements) - ) + compounds_pauling = list(itertools.combinations(elements_pauling, num_elements)) - pool = multiprocessing.Pool( - processes=multiprocessing.cpu_count() - if num_processes is None - else num_processes - ) + pool = multiprocessing.Pool(processes=multiprocessing.cpu_count() if num_processes is None else num_processes) results = list( tqdm( - pool.imap_unordered( - partial(smact_filter, threshold=max_stoich), compounds_pauling - ), + pool.imap_unordered(partial(smact_filter, threshold=max_stoich), compounds_pauling), total=len(compounds_pauling), ) ) @@ -122,9 +108,7 @@ def generate_composition_with_smact( for result in results: for res in result: symbols_stoich = zip(res[0], res[2]) - composition_dict = { - symbol: stoich for symbol, stoich in symbols_stoich - } + composition_dict = dict(symbols_stoich) smact_allowed.append(Composition(composition_dict).reduced_formula) smact_allowed = list(set(smact_allowed)) print(f"Number of compounds allowed by SMACT: {len(smact_allowed)}") diff --git a/smact/utils/plot_embedding.py b/smact/utils/plot_embedding.py index 8b76f5ef..4a13a4ae 100644 --- a/smact/utils/plot_embedding.py +++ b/smact/utils/plot_embedding.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import List import pandas as pd import plotly.graph_objects as go @@ -60,8 +59,8 @@ def update_layout( def plot_reducers_embeddings( df_label: pd.DataFrame, - reducers: List[str], - embedding_names: List[str], + reducers: list[str], + embedding_names: list[str], embedding_dir: Path, save_path: Path, symbol: str = "circle", @@ -70,11 +69,7 @@ def plot_reducers_embeddings( fig = make_subplots( rows=6, cols=3, - subplot_titles=[ - f"{reducer} - {embedding_name}" - for embedding_name in embedding_names - for reducer in reducers - ], + subplot_titles=[f"{reducer} - {embedding_name}" for embedding_name in embedding_names for reducer in reducers], vertical_spacing=0.02, horizontal_spacing=0.02, ) @@ -121,7 +116,7 @@ def plot_reducers_embeddings( ) # add legend - for label, _ in legend_colors.items(): + for label in legend_colors: fig.add_trace( go.Scatter( x=[None], diff --git a/utils/bandgap.py b/utils/bandgap.py deleted file mode 100755 index e1ddb85c..00000000 --- a/utils/bandgap.py +++ /dev/null @@ -1,38 +0,0 @@ -#!/usr/bin/env python - -from smact.properties import band_gap_Harrison - -# python band_gap_simple.py --h -# usage: band_gap_simple.py [-h] [-a ANION] [-c CATION] [-d DISTANCE] [-v] - -# python band_gap_simple.py -c Mg -a Cl -d 2.38 -# 3.8944137939094166 - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser( - description="Compound band gap estimates from elemental data." - ) - parser.add_argument( - "-a", "--anion", type=str, help="Element symbol for anion." - ) - parser.add_argument( - "-c", "--cation", type=str, help="Element symbol for cation." - ) - parser.add_argument( - "-d", "--distance", type=float, help="Internuclear separation." - ) - parser.add_argument( - "-v", "--verbose", action="store_true", help="More Verbose output." - ) - args = parser.parse_args() - - print( - band_gap_Harrison( - verbose=args.verbose, - anion=args.anion, - cation=args.cation, - distance=args.distance, - ) - )