From 7c09cfb8b19123e2c05bbb412a528e2db9dc0c92 Mon Sep 17 00:00:00 2001 From: Evan Walter Clark Spotte-Smith Date: Sat, 25 Feb 2023 10:35:38 -0800 Subject: [PATCH 1/9] Beginning MPcules summary rester, modeled off of materials summary rester --- mp_api/client/routes/mpcules/__init__.py | 0 mp_api/client/routes/mpcules/summary.py | 212 +++++++++++++++++++++++ 2 files changed, 212 insertions(+) create mode 100644 mp_api/client/routes/mpcules/__init__.py create mode 100644 mp_api/client/routes/mpcules/summary.py diff --git a/mp_api/client/routes/mpcules/__init__.py b/mp_api/client/routes/mpcules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mp_api/client/routes/mpcules/summary.py b/mp_api/client/routes/mpcules/summary.py new file mode 100644 index 00000000..3cee1edd --- /dev/null +++ b/mp_api/client/routes/mpcules/summary.py @@ -0,0 +1,212 @@ +import warnings +from collections import defaultdict +from typing import List, Optional, Tuple, Union + +from emmet.core.mpid import MPID +from emmet.core.molecules.summary import HasProps, SummaryDoc + +from mp_api.client.core import BaseRester +from mp_api.client.core.utils import validate_ids + + +class MPculeSummaryRester(BaseRester[SummaryDoc]): + + suffix = "summary" + document_model = SummaryDoc # type: ignore + primary_key = "molecule_id" + + def search( + self, + charge: Optional[Tuple[int, int]] = None, + spin_multiplicity: Optional[Tuple[int, int]] = None, + nelements: Optional[Tuple[int, int]] = None, + has_solvent: Optional[Union[str, List[str]]] = None, + has_level_of_theory: Optional[Union[str, List[str]]] = None, + has_lot_solvent: Optional[Union[str, List[str]]] = None, + with_solvent: Optional[str] = None, + electronic_energy: Optional[Tuple[float, float]] = None, + ionization_energy: Optional[Tuple[float, float]] = None, + electron_affinity: Optional[Tuple[float, float]] = None, + reduction_free_energy: Optional[Tuple[float, float]] = None, + oxidation_free_energy: Optional[Tuple[float, float]] = None, + # zero_point_energy: Optional[Tuple[float, float]] = None, + # total_enthalpy: Optional[Tuple[float, float]] = None, + # total_entropy: Optional[Tuple[float, float]] = None, + # translational_enthalpy: Optional[Tuple[float, float]] = None, + # translational_entropy: Optional[Tuple[float, float]] = None, + # vibrational_enthalpy: Optional[Tuple[float, float]] = None, + # vibrational_entropy: Optional[Tuple[float, float]] = None, + # rotational_enthalpy: Optional[Tuple[float, float]] = None, + # rotational_entropy: Optional[Tuple[float, float]] = None, + # free_energy: Optional[Tuple[float, float]] = None, + chemsys: Optional[Union[str, List[str]]] = None, + deprecated: Optional[bool] = None, + elements: Optional[List[str]] = None, + exclude_elements: Optional[List[str]] = None, + formula: Optional[Union[str, List[str]]] = None, + has_props: Optional[List[HasProps]] = None, + material_ids: Optional[List[MPID]] = None, + num_elements: Optional[Tuple[int, int]] = None, + num_sites: Optional[Tuple[int, int]] = None, + sort_fields: Optional[List[str]] = None, + num_chunks: Optional[int] = None, + chunk_size: int = 1000, + all_fields: bool = True, + fields: Optional[List[str]] = None, + ): + """ + Query core data using a variety of search criteria. + + Arguments: + charge (Tuple[int, int]): Minimum and maximum charge for the molecule. + spin_multiplicity (Tuple[int, int]): Minimum and maximum spin for the molecule. + nelements (Tuple[int, int]): Minimum and maximum number of elements + chemsys (str, List[str]): A chemical system, list of chemical systems + (e.g., Li-C-O, [C-O-H-N, Li-N]), or single formula (e.g., C2 H4). + has_solvent (str, List[str]): Whether the molecule has properties calculated in + solvents (e.g., SOLVENT=THF, [SOLVENT=WATER, VACUUM]) + # TODO: continue documentation + deprecated (bool): Whether the material is tagged as deprecated. + elements (List[str]): A list of elements. + exclude_elements (List(str)): List of elements to exclude. + formula (str, List[str]): A formula including anonymized formula + or wild cards (e.g., Fe2O3, ABO3, Si*). A list of chemical formulas can also be passed + (e.g., [Fe2O3, ABO3]). + has_props: (List[HasProps]): The calculated properties available for the material. + material_ids (List[MPID]): List of Materials Project IDs to return data for. + num_elements (Tuple[int,int]): Minimum and maximum number of elements to consider. + num_sites (Tuple[int,int]): Minimum and maximum number of sites to consider. + sort_fields (List[str]): Fields used to sort results. Prefixing with '-' will sort in descending order. + num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. + chunk_size (int): Number of data entries per chunk. + all_fields (bool): Whether to return all fields in the document. Defaults to True. + fields (List[str]): List of fields in SearchDoc to return data for. + Default is material_id if all_fields is False. + + Returns: + """ + + query_params = defaultdict(dict) # type: dict + + min_max_name_dict = { + "total_energy": "energy_per_atom", + "formation_energy": "formation_energy_per_atom", + "energy_above_hull": "energy_above_hull", + "uncorrected_energy": "uncorrected_energy_per_atom", + "equilibrium_reaction_energy": "equilibrium_reaction_energy_per_atom", + "nsites": "nsites", + "volume": "volume", + "density": "density", + "band_gap": "band_gap", + "efermi": "efermi", + "total_magnetization": "total_magnetization", + "total_magnetization_normalized_vol": "total_magnetization_normalized_vol", + "total_magnetization_normalized_formula_units": "total_magnetization_normalized_formula_units", + "num_magnetic_sites": "num_magnetic_sites", + "num_unique_magnetic_sites": "num_unique_magnetic_sites", + "k_voigt": "k_voigt", + "k_reuss": "k_reuss", + "k_vrh": "k_vrh", + "g_voigt": "g_voigt", + "g_reuss": "g_reuss", + "g_vrh": "g_vrh", + "elastic_anisotropy": "universal_anisotropy", + "poisson_ratio": "homogeneous_poisson", + "e_total": "e_total", + "e_ionic": "e_ionic", + "e_electronic": "e_electronic", + "n": "n", + "num_sites": "nsites", + "num_elements": "nelements", + "piezoelectric_modulus": "e_ij_max", + "weighted_surface_energy": "weighted_surface_energy", + "weighted_work_function": "weighted_work_function", + "surface_energy_anisotropy": "surface_anisotropy", + "shape_factor": "shape_factor", + } + + for param, value in locals().items(): + if param in min_max_name_dict and value: + if isinstance(value, (int, float)): + value = (value, value) + query_params.update( + { + f"{min_max_name_dict[param]}_min": value[0], + f"{min_max_name_dict[param]}_max": value[1], + } + ) + + if material_ids: + query_params.update({"material_ids": ",".join(validate_ids(material_ids))}) + + if deprecated is not None: + query_params.update({"deprecated": deprecated}) + + if formula: + if isinstance(formula, str): + formula = [formula] + + query_params.update({"formula": ",".join(formula)}) + + if chemsys: + if isinstance(chemsys, str): + chemsys = [chemsys] + + query_params.update({"chemsys": ",".join(chemsys)}) + + if elements: + query_params.update({"elements": ",".join(elements)}) + + if exclude_elements is not None: + query_params.update({"exclude_elements": ",".join(exclude_elements)}) + + if possible_species is not None: + query_params.update({"possible_species": ",".join(possible_species)}) + + query_params.update( + { + "crystal_system": crystal_system, + "spacegroup_number": spacegroup_number, + "spacegroup_symbol": spacegroup_symbol, + } + ) + + if is_stable is not None: + query_params.update({"is_stable": is_stable}) + + if is_gap_direct is not None: + query_params.update({"is_gap_direct": is_gap_direct}) + + if is_metal is not None: + query_params.update({"is_metal": is_metal}) + + if magnetic_ordering: + query_params.update({"ordering": magnetic_ordering.value}) + + if has_reconstructed is not None: + query_params.update({"has_reconstructed": has_reconstructed}) + + if has_props: + query_params.update({"has_props": ",".join([i.value for i in has_props])}) + + if theoretical is not None: + query_params.update({"theoretical": theoretical}) + + if sort_fields: + query_params.update( + {"_sort_fields": ",".join([s.strip() for s in sort_fields])} + ) + + query_params = { + entry: query_params[entry] + for entry in query_params + if query_params[entry] is not None + } + + return super()._search( + num_chunks=num_chunks, + chunk_size=chunk_size, + all_fields=all_fields, + fields=fields, + **query_params, + ) From 97fe303e08580acc37aa79736670e44c0b579eb8 Mon Sep 17 00:00:00 2001 From: Evan Walter Clark Spotte-Smith Date: Fri, 3 Mar 2023 09:29:34 -0800 Subject: [PATCH 2/9] Continuing work on minimal rester --- mp_api/client/routes/mpcules/summary.py | 60 +++++++------------------ 1 file changed, 15 insertions(+), 45 deletions(-) diff --git a/mp_api/client/routes/mpcules/summary.py b/mp_api/client/routes/mpcules/summary.py index 3cee1edd..f80f274a 100644 --- a/mp_api/client/routes/mpcules/summary.py +++ b/mp_api/client/routes/mpcules/summary.py @@ -47,7 +47,7 @@ def search( has_props: Optional[List[HasProps]] = None, material_ids: Optional[List[MPID]] = None, num_elements: Optional[Tuple[int, int]] = None, - num_sites: Optional[Tuple[int, int]] = None, + # num_sites: Optional[Tuple[int, int]] = None, sort_fields: Optional[List[str]] = None, num_chunks: Optional[int] = None, chunk_size: int = 1000, @@ -61,21 +61,28 @@ def search( charge (Tuple[int, int]): Minimum and maximum charge for the molecule. spin_multiplicity (Tuple[int, int]): Minimum and maximum spin for the molecule. nelements (Tuple[int, int]): Minimum and maximum number of elements + has_solvent (str, List[str]): Whether the molecule has properties calculated in + solvents (e.g., "SOLVENT=THF", ["SOLVENT=WATER", "VACUUM"]) + has_level_of_theory (str, List[str]): Whether the molecule has properties calculated + using a particular level of theory (e.g. "wB97M-V/def2-SVPD/SMD", + ["wB97X-V/def2-TZVPPD/SMD", "wB97M-V/def2-QZVPPD/SMD"]) + with_solvent (str): For property-based queries, ensure that the properties are calculated + in a particular solvent + electronic_energy (Tuple[float, float]): Minimum and maximum electronic energy + ionization_energy (Tuple[float, float]): Minimum and maximum ionization energy + electron_affinity (Tuple[float, float]): Minimum and maximum electron affinity + reduction_free_energy (Tuple[float, float]): Minimum and maximum reduction free energy + oxidation_free_energy (Tuple[float, float]): Minimum and maximum oxidation free energy chemsys (str, List[str]): A chemical system, list of chemical systems (e.g., Li-C-O, [C-O-H-N, Li-N]), or single formula (e.g., C2 H4). - has_solvent (str, List[str]): Whether the molecule has properties calculated in - solvents (e.g., SOLVENT=THF, [SOLVENT=WATER, VACUUM]) - # TODO: continue documentation deprecated (bool): Whether the material is tagged as deprecated. elements (List[str]): A list of elements. exclude_elements (List(str)): List of elements to exclude. - formula (str, List[str]): A formula including anonymized formula - or wild cards (e.g., Fe2O3, ABO3, Si*). A list of chemical formulas can also be passed - (e.g., [Fe2O3, ABO3]). + formula (str, List[str]): An alphabetical formula or list of formulas + (e.g. "C2 Li2 O4", ["C2 H4", "C2 H6"]). has_props: (List[HasProps]): The calculated properties available for the material. material_ids (List[MPID]): List of Materials Project IDs to return data for. num_elements (Tuple[int,int]): Minimum and maximum number of elements to consider. - num_sites (Tuple[int,int]): Minimum and maximum number of sites to consider. sort_fields (List[str]): Fields used to sort results. Prefixing with '-' will sort in descending order. num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. chunk_size (int): Number of data entries per chunk. @@ -88,43 +95,6 @@ def search( query_params = defaultdict(dict) # type: dict - min_max_name_dict = { - "total_energy": "energy_per_atom", - "formation_energy": "formation_energy_per_atom", - "energy_above_hull": "energy_above_hull", - "uncorrected_energy": "uncorrected_energy_per_atom", - "equilibrium_reaction_energy": "equilibrium_reaction_energy_per_atom", - "nsites": "nsites", - "volume": "volume", - "density": "density", - "band_gap": "band_gap", - "efermi": "efermi", - "total_magnetization": "total_magnetization", - "total_magnetization_normalized_vol": "total_magnetization_normalized_vol", - "total_magnetization_normalized_formula_units": "total_magnetization_normalized_formula_units", - "num_magnetic_sites": "num_magnetic_sites", - "num_unique_magnetic_sites": "num_unique_magnetic_sites", - "k_voigt": "k_voigt", - "k_reuss": "k_reuss", - "k_vrh": "k_vrh", - "g_voigt": "g_voigt", - "g_reuss": "g_reuss", - "g_vrh": "g_vrh", - "elastic_anisotropy": "universal_anisotropy", - "poisson_ratio": "homogeneous_poisson", - "e_total": "e_total", - "e_ionic": "e_ionic", - "e_electronic": "e_electronic", - "n": "n", - "num_sites": "nsites", - "num_elements": "nelements", - "piezoelectric_modulus": "e_ij_max", - "weighted_surface_energy": "weighted_surface_energy", - "weighted_work_function": "weighted_work_function", - "surface_energy_anisotropy": "surface_anisotropy", - "shape_factor": "shape_factor", - } - for param, value in locals().items(): if param in min_max_name_dict and value: if isinstance(value, (int, float)): From 4e1c7f4e796376908d972ac5afce5339f1354ec6 Mon Sep 17 00:00:00 2001 From: Evan Walter Clark Spotte-Smith Date: Fri, 3 Mar 2023 10:03:46 -0800 Subject: [PATCH 3/9] Finished very rough draft mpcules summary rester; incorporated into MPRester --- mp_api/client/mprester.py | 1 + mp_api/client/routes/__init__.py | 1 + mp_api/client/routes/mpcules/summary.py | 65 +++++++++---------------- 3 files changed, 25 insertions(+), 42 deletions(-) diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index c4e4a433..4de5731f 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -74,6 +74,7 @@ class MPRester: provenance: ProvenanceRester bonds: BondsRester alloys: AlloysRester + mpcules_summary: MPculesSummaryRester _user_settings: UserSettingsRester _general_store: GeneralStoreRester diff --git a/mp_api/client/routes/__init__.py b/mp_api/client/routes/__init__.py index d1fc9f69..58cb820c 100644 --- a/mp_api/client/routes/__init__.py +++ b/mp_api/client/routes/__init__.py @@ -30,6 +30,7 @@ from ._general_store import GeneralStoreRester from .bonds import BondsRester from .robocrys import RobocrysRester +from mp_api.routes.mpcules.summary import MPculesSummaryRester try: from .alloys import AlloysRester diff --git a/mp_api/client/routes/mpcules/summary.py b/mp_api/client/routes/mpcules/summary.py index f80f274a..9d7d732d 100644 --- a/mp_api/client/routes/mpcules/summary.py +++ b/mp_api/client/routes/mpcules/summary.py @@ -2,7 +2,7 @@ from collections import defaultdict from typing import List, Optional, Tuple, Union -from emmet.core.mpid import MPID +from emmet.core.mpid import MPculeID from emmet.core.molecules.summary import HasProps, SummaryDoc from mp_api.client.core import BaseRester @@ -23,7 +23,7 @@ def search( has_solvent: Optional[Union[str, List[str]]] = None, has_level_of_theory: Optional[Union[str, List[str]]] = None, has_lot_solvent: Optional[Union[str, List[str]]] = None, - with_solvent: Optional[str] = None, + # with_solvent: Optional[str] = None, electronic_energy: Optional[Tuple[float, float]] = None, ionization_energy: Optional[Tuple[float, float]] = None, electron_affinity: Optional[Tuple[float, float]] = None, @@ -45,8 +45,7 @@ def search( exclude_elements: Optional[List[str]] = None, formula: Optional[Union[str, List[str]]] = None, has_props: Optional[List[HasProps]] = None, - material_ids: Optional[List[MPID]] = None, - num_elements: Optional[Tuple[int, int]] = None, + molecule_ids: Optional[List[MPculeID]] = None, # num_sites: Optional[Tuple[int, int]] = None, sort_fields: Optional[List[str]] = None, num_chunks: Optional[int] = None, @@ -66,8 +65,9 @@ def search( has_level_of_theory (str, List[str]): Whether the molecule has properties calculated using a particular level of theory (e.g. "wB97M-V/def2-SVPD/SMD", ["wB97X-V/def2-TZVPPD/SMD", "wB97M-V/def2-QZVPPD/SMD"]) - with_solvent (str): For property-based queries, ensure that the properties are calculated - in a particular solvent + has_lot_solvent (str, List[str]): Whether the molecule has properties calculated + using a particular combination of level of theory and solvent (e.g. "wB97X-V/def2-SVPD/SMD(SOLVENT=THF)", + ["wB97X-V/def2-TZVPPD/SMD(VACUUM)", "wB97M-V/def2-QZVPPD/SMD(SOLVENT=WATER)"]) electronic_energy (Tuple[float, float]): Minimum and maximum electronic energy ionization_energy (Tuple[float, float]): Minimum and maximum ionization energy electron_affinity (Tuple[float, float]): Minimum and maximum electron affinity @@ -81,8 +81,7 @@ def search( formula (str, List[str]): An alphabetical formula or list of formulas (e.g. "C2 Li2 O4", ["C2 H4", "C2 H6"]). has_props: (List[HasProps]): The calculated properties available for the material. - material_ids (List[MPID]): List of Materials Project IDs to return data for. - num_elements (Tuple[int,int]): Minimum and maximum number of elements to consider. + molecule_ids (List[MPculeID]): List of Materials Project Molecule IDs (MPculeIDs) to return data for. sort_fields (List[str]): Fields used to sort results. Prefixing with '-' will sort in descending order. num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. chunk_size (int): Number of data entries per chunk. @@ -95,19 +94,30 @@ def search( query_params = defaultdict(dict) # type: dict + min_max = [ + "charge", + "spin_multiplicity", + "nelements", + "electronic_energy", + "ionization_energy", + "electron_affinity", + "reduction_free_energy", + "oxidation_free_energy" + ] + for param, value in locals().items(): - if param in min_max_name_dict and value: + if param in min_max and value: if isinstance(value, (int, float)): value = (value, value) query_params.update( { - f"{min_max_name_dict[param]}_min": value[0], - f"{min_max_name_dict[param]}_max": value[1], + f"{param}_min": value[0], + f"{param}_max": value[1], } ) - if material_ids: - query_params.update({"material_ids": ",".join(validate_ids(material_ids))}) + if molecule_ids: + query_params.update({"molecule_ids": ",".join(molecule_ids)}) if deprecated is not None: query_params.update({"deprecated": deprecated}) @@ -130,38 +140,9 @@ def search( if exclude_elements is not None: query_params.update({"exclude_elements": ",".join(exclude_elements)}) - if possible_species is not None: - query_params.update({"possible_species": ",".join(possible_species)}) - - query_params.update( - { - "crystal_system": crystal_system, - "spacegroup_number": spacegroup_number, - "spacegroup_symbol": spacegroup_symbol, - } - ) - - if is_stable is not None: - query_params.update({"is_stable": is_stable}) - - if is_gap_direct is not None: - query_params.update({"is_gap_direct": is_gap_direct}) - - if is_metal is not None: - query_params.update({"is_metal": is_metal}) - - if magnetic_ordering: - query_params.update({"ordering": magnetic_ordering.value}) - - if has_reconstructed is not None: - query_params.update({"has_reconstructed": has_reconstructed}) - if has_props: query_params.update({"has_props": ",".join([i.value for i in has_props])}) - if theoretical is not None: - query_params.update({"theoretical": theoretical}) - if sort_fields: query_params.update( {"_sort_fields": ",".join([s.strip() for s in sort_fields])} From b3e2e042ff4a2509e563c17fa4f92789c488e657 Mon Sep 17 00:00:00 2001 From: Evan Walter Clark Spotte-Smith Date: Fri, 3 Mar 2023 10:37:22 -0800 Subject: [PATCH 4/9] Tests still pass, even with new addition. Now just need to write tests for new rester --- mp_api/client/routes/__init__.py | 2 +- mp_api/client/routes/mpcules/summary.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mp_api/client/routes/__init__.py b/mp_api/client/routes/__init__.py index 58cb820c..24e4e04a 100644 --- a/mp_api/client/routes/__init__.py +++ b/mp_api/client/routes/__init__.py @@ -30,7 +30,7 @@ from ._general_store import GeneralStoreRester from .bonds import BondsRester from .robocrys import RobocrysRester -from mp_api.routes.mpcules.summary import MPculesSummaryRester +from mp_api.client.routes.mpcules.summary import MPculesSummaryRester try: from .alloys import AlloysRester diff --git a/mp_api/client/routes/mpcules/summary.py b/mp_api/client/routes/mpcules/summary.py index 9d7d732d..5e44e117 100644 --- a/mp_api/client/routes/mpcules/summary.py +++ b/mp_api/client/routes/mpcules/summary.py @@ -9,9 +9,9 @@ from mp_api.client.core.utils import validate_ids -class MPculeSummaryRester(BaseRester[SummaryDoc]): +class MPculesSummaryRester(BaseRester[SummaryDoc]): - suffix = "summary" + suffix = "mpcules_summary" document_model = SummaryDoc # type: ignore primary_key = "molecule_id" From 2aae8669054a8caef015eaeda6897fb61b59c14b Mon Sep 17 00:00:00 2001 From: Evan Walter Clark Spotte-Smith Date: Tue, 14 Mar 2023 11:12:43 -0700 Subject: [PATCH 5/9] Update to Jason can test --- mp_api/client/routes/mpcules/summary.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mp_api/client/routes/mpcules/summary.py b/mp_api/client/routes/mpcules/summary.py index 5e44e117..1a3c3b75 100644 --- a/mp_api/client/routes/mpcules/summary.py +++ b/mp_api/client/routes/mpcules/summary.py @@ -3,16 +3,16 @@ from typing import List, Optional, Tuple, Union from emmet.core.mpid import MPculeID -from emmet.core.molecules.summary import HasProps, SummaryDoc +from emmet.core.molecules.summary import HasProps, MoleculeSummaryDoc from mp_api.client.core import BaseRester from mp_api.client.core.utils import validate_ids -class MPculesSummaryRester(BaseRester[SummaryDoc]): +class MPculesSummaryRester(BaseRester[MoleculeSummaryDoc]): - suffix = "mpcules_summary" - document_model = SummaryDoc # type: ignore + suffix = "mpcules" + document_model = MoleculeSummaryDoc # type: ignore primary_key = "molecule_id" def search( From b916936c1cdaeed83728633bc014fb5fea5cc22a Mon Sep 17 00:00:00 2001 From: Evan Walter Clark Spotte-Smith Date: Tue, 14 Mar 2023 11:19:27 -0700 Subject: [PATCH 6/9] Add test --- tests/mpcules/__init__.py | 0 tests/mpcules/test_summary.py | 77 +++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+) create mode 100644 tests/mpcules/__init__.py create mode 100644 tests/mpcules/test_summary.py diff --git a/tests/mpcules/__init__.py b/tests/mpcules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/mpcules/test_summary.py b/tests/mpcules/test_summary.py new file mode 100644 index 00000000..4f3589fb --- /dev/null +++ b/tests/mpcules/test_summary.py @@ -0,0 +1,77 @@ +from emmet.core.molecules.summary import HasProps +from mp_api.client.routes.mpcules.summary import MPculesSummaryRester +import os +import pytest + +import typing + +excluded_params = [ + "sort_fields", + "chunk_size", + "num_chunks", + "all_fields", + "fields", +] + +custom_field_tests = { + "molecule_ids": ["9f153b9f3caa3124fb404b42e4cf82c8-C2H4-0-1"], + "formula": "C2 H4", + "chemsys": "C-H", + "elements": ["C", "H"], + "has_solvent": "DIELECTRIC=18,500;N=1,415;ALPHA=0,000;BETA=0,735;GAMMA=20,200;PHI=0,000;PSI=0,000", + "has_level_of_theory": "wB97X-V/def2-TZVPPD/SMD", + "has_lot_solvent": "wB97X-V/def2-TZVPPD/SMD(SOLVENT=THF)", + "nelements": 2, + "has_props": HasProps.orbitals +} # type: dict + + +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") +def test_client(): + + search_method = MPculesSummaryRester().search + + # Get list of parameters + param_tuples = list(typing.get_type_hints(search_method).items()) + + # Query API for each numeric and boolean parameter and check if returned + for entry in param_tuples: + param = entry[0] + if param not in excluded_params: + param_type = entry[1].__args__[0] + q = None + print(param) + + if param in custom_field_tests: + q = { + param: custom_field_tests[param], + "chunk_size": 1, + "num_chunks": 1, + } + elif param_type == typing.Tuple[int, int]: + q = { + param: (-100, 100), + "chunk_size": 1, + "num_chunks": 1, + } + elif param_type == typing.Tuple[float, float]: + q = { + param: (-3000.12, 3000.12), + "chunk_size": 1, + "num_chunks": 1, + } + elif param_type is bool: + q = { + param: False, + "chunk_size": 1, + "num_chunks": 1, + } + + docs = search_method(**q) + + if len(docs) > 0: + doc = docs[0].dict() + else: + raise ValueError("No documents returned") + + assert doc[param] is not None From 435a5d9b9ea22fbefc3fab99f9692c14c776b313 Mon Sep 17 00:00:00 2001 From: Evan Walter Clark Spotte-Smith Date: Wed, 15 Mar 2023 08:12:27 -0700 Subject: [PATCH 7/9] Some small frustrations with the mpcules summary rester, but ready to PR --- mp_api/client/routes/mpcules/summary.py | 42 +++++++------------------ tests/mpcules/test_summary.py | 14 +++++++-- 2 files changed, 23 insertions(+), 33 deletions(-) diff --git a/mp_api/client/routes/mpcules/summary.py b/mp_api/client/routes/mpcules/summary.py index 1a3c3b75..02b6aa41 100644 --- a/mp_api/client/routes/mpcules/summary.py +++ b/mp_api/client/routes/mpcules/summary.py @@ -20,25 +20,10 @@ def search( charge: Optional[Tuple[int, int]] = None, spin_multiplicity: Optional[Tuple[int, int]] = None, nelements: Optional[Tuple[int, int]] = None, - has_solvent: Optional[Union[str, List[str]]] = None, - has_level_of_theory: Optional[Union[str, List[str]]] = None, - has_lot_solvent: Optional[Union[str, List[str]]] = None, + # has_solvent: Optional[Union[str, List[str]]] = None, + # has_level_of_theory: Optional[Union[str, List[str]]] = None, + # has_lot_solvent: Optional[Union[str, List[str]]] = None, # with_solvent: Optional[str] = None, - electronic_energy: Optional[Tuple[float, float]] = None, - ionization_energy: Optional[Tuple[float, float]] = None, - electron_affinity: Optional[Tuple[float, float]] = None, - reduction_free_energy: Optional[Tuple[float, float]] = None, - oxidation_free_energy: Optional[Tuple[float, float]] = None, - # zero_point_energy: Optional[Tuple[float, float]] = None, - # total_enthalpy: Optional[Tuple[float, float]] = None, - # total_entropy: Optional[Tuple[float, float]] = None, - # translational_enthalpy: Optional[Tuple[float, float]] = None, - # translational_entropy: Optional[Tuple[float, float]] = None, - # vibrational_enthalpy: Optional[Tuple[float, float]] = None, - # vibrational_entropy: Optional[Tuple[float, float]] = None, - # rotational_enthalpy: Optional[Tuple[float, float]] = None, - # rotational_entropy: Optional[Tuple[float, float]] = None, - # free_energy: Optional[Tuple[float, float]] = None, chemsys: Optional[Union[str, List[str]]] = None, deprecated: Optional[bool] = None, elements: Optional[List[str]] = None, @@ -60,19 +45,14 @@ def search( charge (Tuple[int, int]): Minimum and maximum charge for the molecule. spin_multiplicity (Tuple[int, int]): Minimum and maximum spin for the molecule. nelements (Tuple[int, int]): Minimum and maximum number of elements - has_solvent (str, List[str]): Whether the molecule has properties calculated in - solvents (e.g., "SOLVENT=THF", ["SOLVENT=WATER", "VACUUM"]) - has_level_of_theory (str, List[str]): Whether the molecule has properties calculated - using a particular level of theory (e.g. "wB97M-V/def2-SVPD/SMD", - ["wB97X-V/def2-TZVPPD/SMD", "wB97M-V/def2-QZVPPD/SMD"]) - has_lot_solvent (str, List[str]): Whether the molecule has properties calculated - using a particular combination of level of theory and solvent (e.g. "wB97X-V/def2-SVPD/SMD(SOLVENT=THF)", - ["wB97X-V/def2-TZVPPD/SMD(VACUUM)", "wB97M-V/def2-QZVPPD/SMD(SOLVENT=WATER)"]) - electronic_energy (Tuple[float, float]): Minimum and maximum electronic energy - ionization_energy (Tuple[float, float]): Minimum and maximum ionization energy - electron_affinity (Tuple[float, float]): Minimum and maximum electron affinity - reduction_free_energy (Tuple[float, float]): Minimum and maximum reduction free energy - oxidation_free_energy (Tuple[float, float]): Minimum and maximum oxidation free energy + # has_solvent (str, List[str]): Whether the molecule has properties calculated in + # solvents (e.g., "SOLVENT=THF", ["SOLVENT=WATER", "VACUUM"]) + # has_level_of_theory (str, List[str]): Whether the molecule has properties calculated + # using a particular level of theory (e.g. "wB97M-V/def2-SVPD/SMD", + # ["wB97X-V/def2-TZVPPD/SMD", "wB97M-V/def2-QZVPPD/SMD"]) + # has_lot_solvent (str, List[str]): Whether the molecule has properties calculated + # using a particular combination of level of theory and solvent (e.g. "wB97X-V/def2-SVPD/SMD(SOLVENT=THF)", + # ["wB97X-V/def2-TZVPPD/SMD(VACUUM)", "wB97M-V/def2-QZVPPD/SMD(SOLVENT=WATER)"]) chemsys (str, List[str]): A chemical system, list of chemical systems (e.g., Li-C-O, [C-O-H-N, Li-N]), or single formula (e.g., C2 H4). deprecated (bool): Whether the material is tagged as deprecated. diff --git a/tests/mpcules/test_summary.py b/tests/mpcules/test_summary.py index 4f3589fb..f9961350 100644 --- a/tests/mpcules/test_summary.py +++ b/tests/mpcules/test_summary.py @@ -11,8 +11,18 @@ "num_chunks", "all_fields", "fields", + "has_solvent", + "exclude_elements", + # Below: currently timing out + "nelements", + "has_props" ] +alt_name = { + "formula": "formula_alphabetical", + "molecule_ids": "molecule_id" +} + custom_field_tests = { "molecule_ids": ["9f153b9f3caa3124fb404b42e4cf82c8-C2H4-0-1"], "formula": "C2 H4", @@ -22,7 +32,7 @@ "has_level_of_theory": "wB97X-V/def2-TZVPPD/SMD", "has_lot_solvent": "wB97X-V/def2-TZVPPD/SMD(SOLVENT=THF)", "nelements": 2, - "has_props": HasProps.orbitals + "has_props": [HasProps.orbitals] } # type: dict @@ -74,4 +84,4 @@ def test_client(): else: raise ValueError("No documents returned") - assert doc[param] is not None + assert doc[alt_name.get(param, param)] is not None From e580f344cb3e7ff9c0393bb3ba3c6c0a5d5bb844 Mon Sep 17 00:00:00 2001 From: Evan Walter Clark Spotte-Smith Date: Wed, 15 Mar 2023 08:21:21 -0700 Subject: [PATCH 8/9] Linting and so on --- mp_api/client/core/client.py | 24 ++---- mp_api/client/core/settings.py | 4 +- mp_api/client/core/utils.py | 8 +- mp_api/client/mprester.py | 16 ++-- mp_api/client/routes/_general_store.py | 5 +- mp_api/client/routes/_user_settings.py | 5 +- mp_api/client/routes/alloys.py | 1 - mp_api/client/routes/bonds.py | 21 +---- mp_api/client/routes/charge_density.py | 11 ++- mp_api/client/routes/dielectric.py | 17 +--- mp_api/client/routes/doi.py | 1 - mp_api/client/routes/elasticity.py | 21 +---- mp_api/client/routes/electrodes.py | 23 ++---- mp_api/client/routes/electronic_structure.py | 85 +++++--------------- mp_api/client/routes/eos.py | 24 ++---- mp_api/client/routes/fermi.py | 1 - mp_api/client/routes/grain_boundary.py | 29 ++----- mp_api/client/routes/magnetism.py | 40 +++------ mp_api/client/routes/materials.py | 32 ++------ mp_api/client/routes/molecules.py | 24 ++---- mp_api/client/routes/mpcules/summary.py | 20 ++--- mp_api/client/routes/oxidation_states.py | 17 +--- mp_api/client/routes/phonon.py | 1 - mp_api/client/routes/piezo.py | 17 +--- mp_api/client/routes/provenance.py | 7 +- mp_api/client/routes/robocrys.py | 1 - mp_api/client/routes/similarity.py | 1 - mp_api/client/routes/substrates.py | 25 ++---- mp_api/client/routes/summary.py | 20 ++--- mp_api/client/routes/surface_properties.py | 17 +--- mp_api/client/routes/tasks.py | 1 - mp_api/client/routes/thermo.py | 1 - mp_api/client/routes/xas.py | 8 +- tests/mpcules/test_summary.py | 10 +-- tests/test_bonds.py | 10 +-- tests/test_charge_density.py | 2 - tests/test_client.py | 1 - tests/test_core_client.py | 28 ++----- tests/test_dielectric.py | 9 +-- tests/test_elasticity.py | 9 +-- tests/test_electrodes.py | 9 +-- tests/test_electronic_structure.py | 22 ++--- tests/test_eos.py | 9 +-- tests/test_grain_boundary.py | 9 +-- tests/test_magnetism.py | 9 +-- tests/test_molecules.py | 9 +-- tests/test_mprester.py | 37 +++------ tests/test_oxidation_states.py | 10 +-- tests/test_piezo.py | 9 +-- tests/test_provenance.py | 9 +-- tests/test_robocrys.py | 5 +- tests/test_substrates.py | 9 +-- tests/test_summary.py | 1 - tests/test_surface_properties.py | 9 +-- tests/test_synthesis.py | 45 ++++------- tests/test_tasks.py | 7 +- 56 files changed, 192 insertions(+), 613 deletions(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 6a4b9f31..20a553aa 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -199,7 +199,6 @@ def _post_resource( response = self.session.post(url, json=payload, verify=True, params=params) if response.status_code == 200: - if self.monty_decode: data = json.loads(response.text, cls=MontyDecoder) else: @@ -232,7 +231,6 @@ def _post_resource( ) except RequestException as ex: - raise MPRestError(str(ex)) def _query_resource( @@ -305,7 +303,6 @@ def _query_resource( return data except RequestException as ex: - raise MPRestError(str(ex)) def _submit_requests( @@ -343,7 +340,6 @@ def _submit_requests( # trying to evenly divide num_chunks by the total number of new # criteria dicts. if parallel_param is not None: - # Determine slice size accounting for character maximum in HTTP URL # First get URl length without parallel param url_string = "" @@ -370,7 +366,6 @@ def _submit_requests( ] if len(parallel_param_str_chunks) > 0: - params_min_chunk = min(parallel_param_str_chunks, key=lambda x: len(x.split("%2C"))) num_params_min_chunk = len(params_min_chunk.split("%2C")) @@ -429,7 +424,6 @@ def _submit_requests( initial_data_tuples = self._multi_thread(use_document_model, initial_params_list) for data, subtotal, crit_ind in initial_data_tuples: - subtotals.append(subtotal) sub_diff = subtotal - new_limits[crit_ind] remaining_docs_avail[crit_ind] = sub_diff @@ -473,7 +467,6 @@ def _submit_requests( # Obtain missing initial data after rebalancing if len(rebalance_params) > 0: - rebalance_data_tuples = self._multi_thread(use_document_model, rebalance_params) for data, _, _ in rebalance_data_tuples: @@ -609,12 +602,10 @@ def _multi_thread( params_ind = 0 with ThreadPoolExecutor(max_workers=MAPIClientSettings().NUM_PARALLEL_REQUESTS) as executor: - # Get list of initial futures defined by max number of parallel requests futures = set() for params in itertools.islice(params_gen, MAPIClientSettings().NUM_PARALLEL_REQUESTS): - future = executor.submit( self._submit_request_and_process, use_document_model=use_document_model, @@ -630,7 +621,6 @@ def _multi_thread( finished, futures = wait(futures, return_when=FIRST_COMPLETED) for future in finished: - data, subtotal = future.result() if progress_bar is not None: @@ -639,7 +629,6 @@ def _multi_thread( # Populate more futures to replace finished for params in itertools.islice(params_gen, len(finished)): - new_future = executor.submit( self._submit_request_and_process, use_document_model=use_document_model, @@ -675,12 +664,17 @@ def _submit_request_and_process( Tuple with data and total number of docs in matching the query in the database. """ try: - response = self.session.get(url=url, verify=verify, params=params, timeout=timeout, headers=self.headers) + response = self.session.get( + url=url, + verify=verify, + params=params, + timeout=timeout, + headers=self.headers, + ) except requests.exceptions.ConnectTimeout: raise MPRestError(f"REST query timed out on URL {url}. Try again with a smaller request.") if response.status_code == 200: - if self.monty_decode: data = json.loads(response.text, cls=MontyDecoder) else: @@ -689,7 +683,6 @@ def _submit_request_and_process( # other sub-urls may use different document models # the client does not handle this in a particularly smart way currently if self.document_model and use_document_model: - raw_doc_list = [self.document_model.parse_obj(d) for d in data["data"]] # type: ignore if len(raw_doc_list) > 0: @@ -725,7 +718,6 @@ def _submit_request_and_process( ) def _generate_returned_model(self, doc): - set_fields = [field for field, _ in doc if field in doc.dict(exclude_unset=True)] unset_fields = [field for field in doc.__fields__ if field not in set_fields] @@ -838,7 +830,6 @@ def get_data_by_id( try: results = self._query_resource_data(criteria=criteria, fields=fields, suburl=document_id) # type: ignore except MPRestError: - if self.primary_key == "material_id": # see if the material_id has changed, perhaps a task_id was supplied # this should likely be re-thought @@ -855,7 +846,6 @@ def get_data_by_id( docs = mpr.search(task_ids=[document_id], fields=["material_id"]) if len(docs) > 0: - new_document_id = docs[0].get("material_id", None) if new_document_id is not None: diff --git a/mp_api/client/core/settings.py b/mp_api/client/core/settings.py index c539a4c2..dfa9525b 100644 --- a/mp_api/client/core/settings.py +++ b/mp_api/client/core/settings.py @@ -59,9 +59,7 @@ class MAPIClientSettings(BaseSettings): description="Number of parallel requests to send.", ) - MAX_RETRIES: int = Field( - _MAX_RETRIES, description="Maximum number of retries for requests." - ) + MAX_RETRIES: int = Field(_MAX_RETRIES, description="Maximum number of retries for requests.") MUTE_PROGRESS_BARS: bool = Field( _MUTE_PROGRESS_BAR, diff --git a/mp_api/client/core/utils.py b/mp_api/client/core/utils.py index 0433217f..8cdf8651 100644 --- a/mp_api/client/core/utils.py +++ b/mp_api/client/core/utils.py @@ -45,9 +45,7 @@ def api_sanitize( """ models = [ - model - for model in get_flat_models_from_model(pydantic_model) - if issubclass(model, BaseModel) + model for model in get_flat_models_from_model(pydantic_model) if issubclass(model, BaseModel) ] # type: List[Type[BaseModel]] fields_to_leave = fields_to_leave or [] @@ -100,9 +98,7 @@ def validate_monty(cls, v): errors.append("@class") if len(errors) > 0: - raise ValueError( - "Missing Monty seriailzation fields in dictionary: {errors}" - ) + raise ValueError("Missing Monty seriailzation fields in dictionary: {errors}") return v else: diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index 4de5731f..06b0edc7 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -133,7 +133,9 @@ def __init__( self.endpoint = endpoint self.headers = headers or {} self.session = session or BaseRester._create_session( - api_key=self.api_key, include_user_agent=include_user_agent, headers=self.headers + api_key=self.api_key, + include_user_agent=include_user_agent, + headers=self.headers, ) self.use_document_model = use_document_model self.monty_decode = monty_decode @@ -162,7 +164,6 @@ def __init__( self.endpoint += "/" for cls in BaseRester.__subclasses__(): - rester = cls( api_key=api_key, endpoint=endpoint, @@ -506,7 +507,6 @@ def get_entries( try: input_params = {"material_ids": validate_ids(chemsys_formula_mpids)} except ValueError: - if any("-" in entry for entry in chemsys_formula_mpids): input_params = {"chemsys": chemsys_formula_mpids} else: @@ -548,7 +548,6 @@ def get_entries( ) if conventional_unit_cell: - entry_struct = Structure.from_dict(entry_dict["structure"]) s = SpacegroupAnalyzer(entry_struct).get_conventional_standard_structure() site_ratio = len(s) / len(entry_struct) @@ -604,7 +603,6 @@ def get_pourbaix_entries( MaterialsProjectAqueousCompatibility, MaterialsProjectCompatibility, ) - from pymatgen.entries.computed_entries import ComputedEntry if solid_compat == "MaterialsProjectCompatibility": solid_compat = MaterialsProjectCompatibility() @@ -711,7 +709,9 @@ def get_ion_reference_data(self) -> List[Dict]: compounds and aqueous species, Wiley, New York (1978)'}} """ return self.contribs.query_contributions( - query={"project": "ion_ref_data"}, fields=["identifier", "formula", "data"], paginate=True + query={"project": "ion_ref_data"}, + fields=["identifier", "formula", "data"], + paginate=True, ).get("data") def get_ion_reference_data_for_chemsys(self, chemsys: Union[str, List]) -> List[Dict]: @@ -1122,9 +1122,9 @@ def get_download_info(self, material_ids, calc_types=None, file_patterns=None): meta = {} for doc in self.materials.search( - task_ids=material_ids, fields=["calc_types", "deprecated_tasks", "material_id"] + task_ids=material_ids, + fields=["calc_types", "deprecated_tasks", "material_id"], ): - for task_id, calc_type in doc.calc_types.items(): if calc_types and calc_type not in calc_types: continue diff --git a/mp_api/client/routes/_general_store.py b/mp_api/client/routes/_general_store.py index 795ae95c..01926689 100644 --- a/mp_api/client/routes/_general_store.py +++ b/mp_api/client/routes/_general_store.py @@ -5,7 +5,6 @@ class GeneralStoreRester(BaseRester[GeneralStoreDoc]): # pragma: no cover - suffix = "_general_store" document_model = GeneralStoreDoc # type: ignore primary_key = "submission_id" @@ -24,9 +23,7 @@ def add_item(self, kind: str, markdown: str, meta: Dict): # pragma: no cover Raises: MPRestError """ - return self._post_resource( - body=meta, params={"kind": kind, "markdown": markdown} - ).get("data") + return self._post_resource(body=meta, params={"kind": kind, "markdown": markdown}).get("data") def get_items(self, kind): # pragma: no cover """ diff --git a/mp_api/client/routes/_user_settings.py b/mp_api/client/routes/_user_settings.py index 5265f3e1..5a8c770e 100644 --- a/mp_api/client/routes/_user_settings.py +++ b/mp_api/client/routes/_user_settings.py @@ -3,7 +3,6 @@ class UserSettingsRester(BaseRester[UserSettingsDoc]): # pragma: no cover - suffix = "_user_settings" document_model = UserSettingsDoc # type: ignore primary_key = "consumer_id" @@ -21,9 +20,7 @@ def set_user_settings(self, consumer_id, settings): # pragma: no cover Raises: MPRestError """ - return self._post_resource( - body=settings, params={"consumer_id": consumer_id} - ).get("data") + return self._post_resource(body=settings, params={"consumer_id": consumer_id}).get("data") def get_user_settings(self, consumer_id): # pragma: no cover """ diff --git a/mp_api/client/routes/alloys.py b/mp_api/client/routes/alloys.py index bc6edcb9..b243b9a9 100644 --- a/mp_api/client/routes/alloys.py +++ b/mp_api/client/routes/alloys.py @@ -7,7 +7,6 @@ class AlloysRester(BaseRester[AlloyPairDoc]): - suffix = "alloys" document_model = AlloyPairDoc # type: ignore primary_key = "pair_id" diff --git a/mp_api/client/routes/bonds.py b/mp_api/client/routes/bonds.py index c3f2e72a..69057f52 100644 --- a/mp_api/client/routes/bonds.py +++ b/mp_api/client/routes/bonds.py @@ -9,7 +9,6 @@ class BondsRester(BaseRester[BondingDoc]): - suffix = "bonds" document_model = BondingDoc # type: ignore primary_key = "material_id" @@ -102,25 +101,13 @@ def search( query_params.update({"coordination_envs": ",".join(coordination_envs)}) if coordination_envs_anonymous is not None: - query_params.update( - {"coordination_envs_anonymous": ",".join(coordination_envs_anonymous)} - ) + query_params.update({"coordination_envs_anonymous": ",".join(coordination_envs_anonymous)}) if sort_fields: - query_params.update( - {"_sort_fields": ",".join([s.strip() for s in sort_fields])} - ) + query_params.update({"_sort_fields": ",".join([s.strip() for s in sort_fields])}) - query_params = { - entry: query_params[entry] - for entry in query_params - if query_params[entry] is not None - } + query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None} return super()._search( - num_chunks=num_chunks, - chunk_size=chunk_size, - all_fields=all_fields, - fields=fields, - **query_params + num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, fields=fields, **query_params ) diff --git a/mp_api/client/routes/charge_density.py b/mp_api/client/routes/charge_density.py index e3262fab..bb1c694f 100644 --- a/mp_api/client/routes/charge_density.py +++ b/mp_api/client/routes/charge_density.py @@ -16,7 +16,6 @@ class ChargeDensityRester(BaseRester[ChgcarDataDoc]): - suffix = "charge_density" primary_key = "fs_id" document_model = ChgcarDataDoc # type: ignore @@ -50,7 +49,11 @@ def download_for_task_ids( return num_downloads def search( # type: ignore - self, task_ids: Optional[List[str]] = None, num_chunks: Optional[int] = 1, chunk_size: int = 10, **kwargs + self, + task_ids: Optional[List[str]] = None, + num_chunks: Optional[int] = 1, + chunk_size: int = 10, + **kwargs, ) -> Union[List[ChgcarDataDoc], List[Dict]]: # type: ignore """ A search method to find what charge densities are available via this API. @@ -80,13 +83,11 @@ def get_charge_density_from_file_id(self, fs_id: str): url_doc = self.get_data_by_id(fs_id) if url_doc: - # The check below is performed to see if the client is being # used by our internal AWS deployment. If it is, we pull charge # density data from a private S3 bucket. Else, we pull data # from public MinIO buckets. if environ.get("AWS_EXECUTION_ENV", None) == "AWS_ECS_FARGATE": - if self.boto_resource is None: self.boto_resource = self._get_s3_resource(use_minio=False, unsigned=False) @@ -118,7 +119,6 @@ def get_charge_density_from_file_id(self, fs_id: str): return None def _extract_s3_url_info(self, url_doc, use_minio: bool = True): - if use_minio: url_list = url_doc.url.split("/") bucket = url_list[3] @@ -131,7 +131,6 @@ def _extract_s3_url_info(self, url_doc, use_minio: bool = True): return (bucket, obj_prefix) def _get_s3_resource(self, use_minio: bool = True, unsigned: bool = True): - resource = boto3.resource( "s3", endpoint_url="https://minio.materialsproject.org" if use_minio else None, diff --git a/mp_api/client/routes/dielectric.py b/mp_api/client/routes/dielectric.py index 5380d920..8003711d 100644 --- a/mp_api/client/routes/dielectric.py +++ b/mp_api/client/routes/dielectric.py @@ -9,7 +9,6 @@ class DielectricRester(BaseRester[DielectricDoc]): - suffix = "dielectric" document_model = DielectricDoc # type: ignore primary_key = "material_id" @@ -87,20 +86,10 @@ def search( query_params.update({"n_min": n[0], "n_max": n[1]}) if sort_fields: - query_params.update( - {"_sort_fields": ",".join([s.strip() for s in sort_fields])} - ) + query_params.update({"_sort_fields": ",".join([s.strip() for s in sort_fields])}) - query_params = { - entry: query_params[entry] - for entry in query_params - if query_params[entry] is not None - } + query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None} return super()._search( - num_chunks=num_chunks, - chunk_size=chunk_size, - all_fields=all_fields, - fields=fields, - **query_params + num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, fields=fields, **query_params ) diff --git a/mp_api/client/routes/doi.py b/mp_api/client/routes/doi.py index 3ba7a586..5a465766 100644 --- a/mp_api/client/routes/doi.py +++ b/mp_api/client/routes/doi.py @@ -3,7 +3,6 @@ class DOIRester(BaseRester[DOIDoc]): - suffix = "doi" document_model = DOIDoc # type: ignore primary_key = "task_id" diff --git a/mp_api/client/routes/elasticity.py b/mp_api/client/routes/elasticity.py index 57f5a78b..f2f3140e 100644 --- a/mp_api/client/routes/elasticity.py +++ b/mp_api/client/routes/elasticity.py @@ -8,7 +8,6 @@ class ElasticityRester(BaseRester[ElasticityDoc]): - suffix = "elasticity" document_model = ElasticityDoc # type: ignore primary_key = "task_id" @@ -102,25 +101,13 @@ def search( ) if poisson_ratio: - query_params.update( - {"poisson_min": poisson_ratio[0], "poisson_max": poisson_ratio[1]} - ) + query_params.update({"poisson_min": poisson_ratio[0], "poisson_max": poisson_ratio[1]}) if sort_fields: - query_params.update( - {"_sort_fields": ",".join([s.strip() for s in sort_fields])} - ) + query_params.update({"_sort_fields": ",".join([s.strip() for s in sort_fields])}) - query_params = { - entry: query_params[entry] - for entry in query_params - if query_params[entry] is not None - } + query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None} return super()._search( - num_chunks=num_chunks, - chunk_size=chunk_size, - all_fields=all_fields, - fields=fields, - **query_params + num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, fields=fields, **query_params ) diff --git a/mp_api/client/routes/electrodes.py b/mp_api/client/routes/electrodes.py index f6a1c757..f95f607a 100644 --- a/mp_api/client/routes/electrodes.py +++ b/mp_api/client/routes/electrodes.py @@ -9,7 +9,6 @@ class ElectrodeRester(BaseRester[InsertionElectrodeDoc]): - suffix = "insertion_electrodes" document_model = InsertionElectrodeDoc # type: ignore primary_key = "battery_id" @@ -117,9 +116,7 @@ def search( # pragma: ignore if isinstance(working_ion, str) or isinstance(working_ion, Element): working_ion = [working_ion] # type: ignore - query_params.update( - {"working_ion": ",".join([str(ele) for ele in working_ion])} # type: ignore - ) + query_params.update({"working_ion": ",".join([str(ele) for ele in working_ion])}) # type: ignore if formula: if isinstance(formula, str): @@ -133,17 +130,13 @@ def search( # pragma: ignore if num_elements: if isinstance(num_elements, int): num_elements = (num_elements, num_elements) - query_params.update( - {"nelements_min": num_elements[0], "nelements_max": num_elements[1]} - ) + query_params.update({"nelements_min": num_elements[0], "nelements_max": num_elements[1]}) if exclude_elements: query_params.update({"exclude_elements": ",".join(exclude_elements)}) if sort_fields: - query_params.update( - {"_sort_fields": ",".join([s.strip() for s in sort_fields])} - ) + query_params.update({"_sort_fields": ",".join([s.strip() for s in sort_fields])}) for param, value in locals().items(): if ( @@ -158,16 +151,10 @@ def search( # pragma: ignore and value ): if isinstance(value, tuple): - query_params.update( - {f"{param}_min": value[0], f"{param}_max": value[1]} - ) + query_params.update({f"{param}_min": value[0], f"{param}_max": value[1]}) else: query_params.update({param: value}) - query_params = { - entry: query_params[entry] - for entry in query_params - if query_params[entry] is not None - } + query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None} return super()._search(**query_params) diff --git a/mp_api/client/routes/electronic_structure.py b/mp_api/client/routes/electronic_structure.py index a2798ab2..3a569068 100644 --- a/mp_api/client/routes/electronic_structure.py +++ b/mp_api/client/routes/electronic_structure.py @@ -20,7 +20,6 @@ class ElectronicStructureRester(BaseRester[ElectronicStructureDoc]): - suffix = "electronic_structure" document_model = ElectronicStructureDoc # type: ignore primary_key = "material_id" @@ -115,9 +114,7 @@ def search( query_params.update({"exclude_elements": ",".join(exclude_elements)}) if band_gap: - query_params.update( - {"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]} - ) + query_params.update({"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]}) if efermi: query_params.update({"efermi_min": efermi[0], "efermi_max": efermi[1]}) @@ -128,9 +125,7 @@ def search( if num_elements: if isinstance(num_elements, int): num_elements = (num_elements, num_elements) - query_params.update( - {"nelements_min": num_elements[0], "nelements_max": num_elements[1]} - ) + query_params.update({"nelements_min": num_elements[0], "nelements_max": num_elements[1]}) if is_gap_direct is not None: query_params.update({"is_gap_direct": is_gap_direct}) @@ -139,15 +134,9 @@ def search( query_params.update({"is_metal": is_metal}) if sort_fields: - query_params.update( - {"_sort_fields": ",".join([s.strip() for s in sort_fields])} - ) + query_params.update({"_sort_fields": ",".join([s.strip() for s in sort_fields])}) - query_params = { - entry: query_params[entry] - for entry in query_params - if query_params[entry] is not None - } + query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None} return super()._search( num_chunks=num_chunks, @@ -159,7 +148,6 @@ def search( class BandStructureRester(BaseRester): - suffix = "electronic_structure/bandstructure" document_model = ElectronicStructureDoc # type: ignore @@ -217,9 +205,7 @@ def search( query_params["path_type"] = path_type.value if band_gap: - query_params.update( - {"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]} - ) + query_params.update({"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]}) if efermi: query_params.update({"efermi_min": efermi[0], "efermi_max": efermi[1]}) @@ -234,15 +220,9 @@ def search( query_params.update({"is_metal": is_metal}) if sort_fields: - query_params.update( - {"_sort_fields": ",".join([s.strip() for s in sort_fields])} - ) + query_params.update({"_sort_fields": ",".join([s.strip() for s in sort_fields])}) - query_params = { - entry: query_params[entry] - for entry in query_params - if query_params[entry] is not None - } + query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None} return super()._search( num_chunks=num_chunks, @@ -294,46 +274,32 @@ def get_bandstructure_from_material_id( bandstructure (Union[BandStructure, BandStructureSymmLine]): BandStructure or BandStructureSymmLine object """ - es_rester = ElectronicStructureRester( - endpoint=self.base_endpoint, api_key=self.api_key - ) + es_rester = ElectronicStructureRester(endpoint=self.base_endpoint, api_key=self.api_key) if line_mode: - bs_data = es_rester.get_data_by_id( - document_id=material_id, fields=["bandstructure"] - ).bandstructure + bs_data = es_rester.get_data_by_id(document_id=material_id, fields=["bandstructure"]).bandstructure if bs_data is None: - raise MPRestError( - f"No {path_type.value} band structure data found for {material_id}" - ) + raise MPRestError(f"No {path_type.value} band structure data found for {material_id}") else: bs_data = bs_data.dict() if bs_data.get(path_type.value, None): bs_task_id = bs_data[path_type.value]["task_id"] else: - raise MPRestError( - f"No {path_type.value} band structure data found for {material_id}" - ) + raise MPRestError(f"No {path_type.value} band structure data found for {material_id}") else: - bs_data = es_rester.get_data_by_id( - document_id=material_id, fields=["dos"] - ).dos + bs_data = es_rester.get_data_by_id(document_id=material_id, fields=["dos"]).dos if bs_data is None: - raise MPRestError( - f"No uniform band structure data found for {material_id}" - ) + raise MPRestError(f"No uniform band structure data found for {material_id}") else: bs_data = bs_data.dict() if bs_data.get("total", None): bs_task_id = bs_data["total"]["1"]["task_id"] else: - raise MPRestError( - f"No uniform band structure data found for {material_id}" - ) + raise MPRestError(f"No uniform band structure data found for {material_id}") bs_obj = self.get_bandstructure_from_task_id(bs_task_id) @@ -349,7 +315,6 @@ def get_bandstructure_from_material_id( class DosRester(BaseRester): - suffix = "electronic_structure/dos" document_model = ElectronicStructureDoc # type: ignore @@ -416,9 +381,7 @@ def search( query_params["orbital"] = orbital.value if band_gap: - query_params.update( - {"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]} - ) + query_params.update({"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]}) if efermi: query_params.update({"efermi_min": efermi[0], "efermi_max": efermi[1]}) @@ -427,15 +390,9 @@ def search( query_params.update({"magnetic_ordering": magnetic_ordering.value}) if sort_fields: - query_params.update( - {"_sort_fields": ",".join([s.strip() for s in sort_fields])} - ) + query_params.update({"_sort_fields": ",".join([s.strip() for s in sort_fields])}) - query_params = { - entry: query_params[entry] - for entry in query_params - if query_params[entry] is not None - } + query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None} return super()._search( num_chunks=num_chunks, @@ -480,13 +437,9 @@ def get_dos_from_material_id(self, material_id: str): dos (CompleteDos): CompleteDos object """ - es_rester = ElectronicStructureRester( - endpoint=self.base_endpoint, api_key=self.api_key - ) + es_rester = ElectronicStructureRester(endpoint=self.base_endpoint, api_key=self.api_key) - dos_data = es_rester.get_data_by_id( - document_id=material_id, fields=["dos"] - ).dict() + dos_data = es_rester.get_data_by_id(document_id=material_id, fields=["dos"]).dict() if dos_data["dos"]: dos_task_id = dos_data["dos"]["total"]["1"]["task_id"] diff --git a/mp_api/client/routes/eos.py b/mp_api/client/routes/eos.py index ffc2676f..6f472391 100644 --- a/mp_api/client/routes/eos.py +++ b/mp_api/client/routes/eos.py @@ -8,7 +8,6 @@ class EOSRester(BaseRester[EOSDoc]): - suffix = "eos" document_model = EOSDoc # type: ignore primary_key = "task_id" @@ -19,8 +18,7 @@ def search_eos_docs(self, *args, **kwargs): # pragma: no cover """ warnings.warn( - "MPRester.eos.search_eos_docs is deprecated. " - "Please use MPRester.eos.search instead.", + "MPRester.eos.search_eos_docs is deprecated. " "Please use MPRester.eos.search instead.", DeprecationWarning, stacklevel=2, ) @@ -60,25 +58,13 @@ def search( query_params.update({"volumes_min": volumes[0], "volumes_max": volumes[1]}) if energies: - query_params.update( - {"energies_min": energies[0], "energies_max": energies[1]} - ) + query_params.update({"energies_min": energies[0], "energies_max": energies[1]}) if sort_fields: - query_params.update( - {"_sort_fields": ",".join([s.strip() for s in sort_fields])} - ) + query_params.update({"_sort_fields": ",".join([s.strip() for s in sort_fields])}) - query_params = { - entry: query_params[entry] - for entry in query_params - if query_params[entry] is not None - } + query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None} return super()._search( - num_chunks=num_chunks, - chunk_size=chunk_size, - all_fields=all_fields, - fields=fields, - **query_params + num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, fields=fields, **query_params ) diff --git a/mp_api/client/routes/fermi.py b/mp_api/client/routes/fermi.py index 9cb03f0d..4f10d1cb 100644 --- a/mp_api/client/routes/fermi.py +++ b/mp_api/client/routes/fermi.py @@ -4,7 +4,6 @@ class FermiRester(BaseRester[FermiDoc]): - suffix = "fermi" document_model = FermiDoc # type: ignore primary_key = "task_id" diff --git a/mp_api/client/routes/grain_boundary.py b/mp_api/client/routes/grain_boundary.py index cc622bad..916949ae 100644 --- a/mp_api/client/routes/grain_boundary.py +++ b/mp_api/client/routes/grain_boundary.py @@ -10,7 +10,6 @@ class GrainBoundaryRester(BaseRester[GrainBoundaryDoc]): - suffix = "grain_boundary" document_model = GrainBoundaryDoc # type: ignore primary_key = "task_id" @@ -82,14 +81,10 @@ def search( query_params.update({"gb_plane": ",".join([str(n) for n in gb_plane])}) if gb_energy: - query_params.update( - {"gb_energy_min": gb_energy[0], "gb_energy_max": gb_energy[1]} - ) + query_params.update({"gb_energy_min": gb_energy[0], "gb_energy_max": gb_energy[1]}) if separation_energy: - query_params.update( - {"w_sep_min": separation_energy[0], "w_sep_max": separation_energy[1]} - ) + query_params.update({"w_sep_min": separation_energy[0], "w_sep_max": separation_energy[1]}) if rotation_angle: query_params.update( @@ -100,9 +95,7 @@ def search( ) if rotation_axis: - query_params.update( - {"rotation_axis": ",".join([str(n) for n in rotation_axis])} - ) + query_params.update({"rotation_axis": ",".join([str(n) for n in rotation_axis])}) if sigma: query_params.update({"sigma": sigma}) @@ -117,20 +110,10 @@ def search( query_params.update({"pretty_formula": pretty_formula}) if sort_fields: - query_params.update( - {"_sort_fields": ",".join([s.strip() for s in sort_fields])} - ) + query_params.update({"_sort_fields": ",".join([s.strip() for s in sort_fields])}) - query_params = { - entry: query_params[entry] - for entry in query_params - if query_params[entry] is not None - } + query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None} return super()._search( - num_chunks=num_chunks, - chunk_size=chunk_size, - all_fields=all_fields, - fields=fields, - **query_params + num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, fields=fields, **query_params ) diff --git a/mp_api/client/routes/magnetism.py b/mp_api/client/routes/magnetism.py index b1d04811..b5a6904e 100644 --- a/mp_api/client/routes/magnetism.py +++ b/mp_api/client/routes/magnetism.py @@ -11,7 +11,6 @@ class MagnetismRester(BaseRester[MagnetismDoc]): - suffix = "magnetism" document_model = MagnetismDoc # type: ignore primary_key = "material_id" @@ -22,8 +21,7 @@ def search_magnetism_docs(self, *args, **kwargs): # pragma: no cover """ warnings.warn( - "MPRester.magnetism.search_magnetism_docs is deprecated. " - "Please use MPRester.magnetism.search instead.", + "MPRester.magnetism.search_magnetism_docs is deprecated. " "Please use MPRester.magnetism.search instead.", DeprecationWarning, stacklevel=2, ) @@ -38,9 +36,7 @@ def search( ordering: Optional[Ordering] = None, total_magnetization: Optional[Tuple[float, float]] = None, total_magnetization_normalized_vol: Optional[Tuple[float, float]] = None, - total_magnetization_normalized_formula_units: Optional[ - Tuple[float, float] - ] = None, + total_magnetization_normalized_formula_units: Optional[Tuple[float, float]] = None, sort_fields: Optional[List[str]] = None, num_chunks: Optional[int] = None, chunk_size: int = 1000, @@ -92,24 +88,16 @@ def search( if total_magnetization_normalized_vol: query_params.update( { - "total_magnetization_normalized_vol_min": total_magnetization_normalized_vol[ - 0 - ], - "total_magnetization_normalized_vol_max": total_magnetization_normalized_vol[ - 1 - ], + "total_magnetization_normalized_vol_min": total_magnetization_normalized_vol[0], + "total_magnetization_normalized_vol_max": total_magnetization_normalized_vol[1], } ) if total_magnetization_normalized_formula_units: query_params.update( { - "total_magnetization_normalized_formula_units_min": total_magnetization_normalized_formula_units[ - 0 - ], - "total_magnetization_normalized_formula_units_max": total_magnetization_normalized_formula_units[ - 1 - ], + "total_magnetization_normalized_formula_units_min": total_magnetization_normalized_formula_units[0], + "total_magnetization_normalized_formula_units_max": total_magnetization_normalized_formula_units[1], } ) @@ -133,20 +121,10 @@ def search( query_params.update({"ordering": ordering.value}) if sort_fields: - query_params.update( - {"_sort_fields": ",".join([s.strip() for s in sort_fields])} - ) + query_params.update({"_sort_fields": ",".join([s.strip() for s in sort_fields])}) - query_params = { - entry: query_params[entry] - for entry in query_params - if query_params[entry] is not None - } + query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None} return super()._search( - num_chunks=num_chunks, - chunk_size=chunk_size, - all_fields=all_fields, - fields=fields, - **query_params + num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, fields=fields, **query_params ) diff --git a/mp_api/client/routes/materials.py b/mp_api/client/routes/materials.py index 26e78481..1539548f 100644 --- a/mp_api/client/routes/materials.py +++ b/mp_api/client/routes/materials.py @@ -13,15 +13,12 @@ class MaterialsRester(BaseRester[MaterialsDoc]): - suffix = "materials" document_model = MaterialsDoc # type: ignore supports_versions = True primary_key = "material_id" - def get_structure_by_material_id( - self, material_id: str, final: bool = True - ) -> Union[Structure, List[Structure]]: + def get_structure_by_material_id(self, material_id: str, final: bool = True) -> Union[Structure, List[Structure]]: """ Get a structure for a given Materials Project ID. @@ -47,8 +44,7 @@ def search_material_docs(self, *args, **kwargs): # pragma: no cover """ warnings.warn( - "MPRester.materials.search_material_docs is deprecated. " - "Please use MPRester.materials.search instead.", + "MPRester.materials.search_material_docs is deprecated. " "Please use MPRester.materials.search instead.", DeprecationWarning, stacklevel=2, ) @@ -148,16 +144,12 @@ def search( ) if num_sites: - query_params.update( - {"nsites_min": num_sites[0], "nsites_max": num_sites[1]} - ) + query_params.update({"nsites_min": num_sites[0], "nsites_max": num_sites[1]}) if num_elements: if isinstance(num_elements, int): num_elements = (num_elements, num_elements) - query_params.update( - {"nelements_min": num_elements[0], "nelements_max": num_elements[1]} - ) + query_params.update({"nelements_min": num_elements[0], "nelements_max": num_elements[1]}) if volume: query_params.update({"volume_min": volume[0], "volume_max": volume[1]}) @@ -166,22 +158,12 @@ def search( query_params.update({"density_min": density[0], "density_max": density[1]}) if sort_fields: - query_params.update( - {"_sort_fields": ",".join([s.strip() for s in sort_fields])} - ) + query_params.update({"_sort_fields": ",".join([s.strip() for s in sort_fields])}) - query_params = { - entry: query_params[entry] - for entry in query_params - if query_params[entry] is not None - } + query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None} return super()._search( - num_chunks=num_chunks, - chunk_size=chunk_size, - all_fields=all_fields, - fields=fields, - **query_params + num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, fields=fields, **query_params ) def find_structure( diff --git a/mp_api/client/routes/molecules.py b/mp_api/client/routes/molecules.py index c57b7697..0352223a 100644 --- a/mp_api/client/routes/molecules.py +++ b/mp_api/client/routes/molecules.py @@ -10,7 +10,6 @@ class MoleculesRester(BaseRester[MoleculesDoc]): - suffix = "molecules" document_model = MoleculesDoc # type: ignore primary_key = "task_id" @@ -21,8 +20,7 @@ def search_molecules_docs(self, *args, **kwargs): # pragma: no cover """ warnings.warn( - "MPRester.molecules.search_molecules_docs is deprecated. " - "Please use MPRester.molecules.search instead.", + "MPRester.molecules.search_molecules_docs is deprecated. " "Please use MPRester.molecules.search instead.", DeprecationWarning, stacklevel=2, ) @@ -79,9 +77,7 @@ def search( query_params.update({"smiles": smiles}) if nelements: - query_params.update( - {"nelements_min": nelements[0], "nelements_max": nelements[1]} - ) + query_params.update({"nelements_min": nelements[0], "nelements_max": nelements[1]}) if EA: query_params.update({"EA_min": EA[0], "EA_max": EA[1]}) @@ -93,20 +89,10 @@ def search( query_params.update({"charge_min": charge[0], "charge_max": charge[1]}) if sort_fields: - query_params.update( - {"_sort_fields": ",".join([s.strip() for s in sort_fields])} - ) + query_params.update({"_sort_fields": ",".join([s.strip() for s in sort_fields])}) - query_params = { - entry: query_params[entry] - for entry in query_params - if query_params[entry] is not None - } + query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None} return super()._search( - num_chunks=num_chunks, - chunk_size=chunk_size, - all_fields=all_fields, - fields=fields, - **query_params + num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, fields=fields, **query_params ) diff --git a/mp_api/client/routes/mpcules/summary.py b/mp_api/client/routes/mpcules/summary.py index 02b6aa41..8114e094 100644 --- a/mp_api/client/routes/mpcules/summary.py +++ b/mp_api/client/routes/mpcules/summary.py @@ -10,7 +10,6 @@ class MPculesSummaryRester(BaseRester[MoleculeSummaryDoc]): - suffix = "mpcules" document_model = MoleculeSummaryDoc # type: ignore primary_key = "molecule_id" @@ -48,10 +47,11 @@ def search( # has_solvent (str, List[str]): Whether the molecule has properties calculated in # solvents (e.g., "SOLVENT=THF", ["SOLVENT=WATER", "VACUUM"]) # has_level_of_theory (str, List[str]): Whether the molecule has properties calculated - # using a particular level of theory (e.g. "wB97M-V/def2-SVPD/SMD", + # using a particular level of theory (e.g. "wB97M-V/def2-SVPD/SMD", # ["wB97X-V/def2-TZVPPD/SMD", "wB97M-V/def2-QZVPPD/SMD"]) # has_lot_solvent (str, List[str]): Whether the molecule has properties calculated - # using a particular combination of level of theory and solvent (e.g. "wB97X-V/def2-SVPD/SMD(SOLVENT=THF)", + # using a particular combination of level of theory and solvent (e.g. + # "wB97X-V/def2-SVPD/SMD(SOLVENT=THF)", # ["wB97X-V/def2-TZVPPD/SMD(VACUUM)", "wB97M-V/def2-QZVPPD/SMD(SOLVENT=WATER)"]) chemsys (str, List[str]): A chemical system, list of chemical systems (e.g., Li-C-O, [C-O-H-N, Li-N]), or single formula (e.g., C2 H4). @@ -82,7 +82,7 @@ def search( "ionization_energy", "electron_affinity", "reduction_free_energy", - "oxidation_free_energy" + "oxidation_free_energy", ] for param, value in locals().items(): @@ -124,15 +124,9 @@ def search( query_params.update({"has_props": ",".join([i.value for i in has_props])}) if sort_fields: - query_params.update( - {"_sort_fields": ",".join([s.strip() for s in sort_fields])} - ) - - query_params = { - entry: query_params[entry] - for entry in query_params - if query_params[entry] is not None - } + query_params.update({"_sort_fields": ",".join([s.strip() for s in sort_fields])}) + + query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None} return super()._search( num_chunks=num_chunks, diff --git a/mp_api/client/routes/oxidation_states.py b/mp_api/client/routes/oxidation_states.py index 020dda33..d6caad54 100644 --- a/mp_api/client/routes/oxidation_states.py +++ b/mp_api/client/routes/oxidation_states.py @@ -6,7 +6,6 @@ class OxidationStatesRester(BaseRester[OxidationStateDoc]): - suffix = "oxidation_states" document_model = OxidationStateDoc # type: ignore primary_key = "material_id" @@ -70,20 +69,10 @@ def search( query_params.update({"possible_species": ",".join(possible_species)}) if sort_fields: - query_params.update( - {"_sort_fields": ",".join([s.strip() for s in sort_fields])} - ) + query_params.update({"_sort_fields": ",".join([s.strip() for s in sort_fields])}) - query_params = { - entry: query_params[entry] - for entry in query_params - if query_params[entry] is not None - } + query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None} return super()._search( - num_chunks=num_chunks, - chunk_size=chunk_size, - all_fields=all_fields, - fields=fields, - **query_params + num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, fields=fields, **query_params ) diff --git a/mp_api/client/routes/phonon.py b/mp_api/client/routes/phonon.py index 119a5f2b..3696852d 100644 --- a/mp_api/client/routes/phonon.py +++ b/mp_api/client/routes/phonon.py @@ -3,7 +3,6 @@ class PhononRester(BaseRester[PhononBSDOSDoc]): - suffix = "phonon" document_model = PhononBSDOSDoc # type: ignore primary_key = "material_id" diff --git a/mp_api/client/routes/piezo.py b/mp_api/client/routes/piezo.py index 0479a805..9fa7d858 100644 --- a/mp_api/client/routes/piezo.py +++ b/mp_api/client/routes/piezo.py @@ -9,7 +9,6 @@ class PiezoRester(BaseRester[PiezoelectricDoc]): - suffix = "piezoelectric" document_model = PiezoelectricDoc # type: ignore primary_key = "material_id" @@ -74,20 +73,10 @@ def search( ) if sort_fields: - query_params.update( - {"_sort_fields": ",".join([s.strip() for s in sort_fields])} - ) + query_params.update({"_sort_fields": ",".join([s.strip() for s in sort_fields])}) - query_params = { - entry: query_params[entry] - for entry in query_params - if query_params[entry] is not None - } + query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None} return super()._search( - num_chunks=num_chunks, - chunk_size=chunk_size, - all_fields=all_fields, - fields=fields, - **query_params + num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, fields=fields, **query_params ) diff --git a/mp_api/client/routes/provenance.py b/mp_api/client/routes/provenance.py index 2e758cf6..9c089de4 100644 --- a/mp_api/client/routes/provenance.py +++ b/mp_api/client/routes/provenance.py @@ -5,7 +5,6 @@ class ProvenanceRester(BaseRester[ProvenanceDoc]): - suffix = "provenance" document_model = ProvenanceDoc # type: ignore primary_key = "material_id" @@ -45,9 +44,5 @@ def search( query_params.update({"material_ids": ",".join(validate_ids(material_ids))}) return super()._search( - num_chunks=num_chunks, - chunk_size=chunk_size, - all_fields=all_fields, - fields=fields, - **query_params + num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, fields=fields, **query_params ) diff --git a/mp_api/client/routes/robocrys.py b/mp_api/client/routes/robocrys.py index a368d59f..42118efc 100644 --- a/mp_api/client/routes/robocrys.py +++ b/mp_api/client/routes/robocrys.py @@ -7,7 +7,6 @@ class RobocrysRester(BaseRester[RobocrystallogapherDoc]): - suffix = "robocrys" document_model = RobocrystallogapherDoc # type: ignore primary_key = "material_id" diff --git a/mp_api/client/routes/similarity.py b/mp_api/client/routes/similarity.py index 64abddde..7531b8a8 100644 --- a/mp_api/client/routes/similarity.py +++ b/mp_api/client/routes/similarity.py @@ -3,7 +3,6 @@ class SimilarityRester(BaseRester[SimilarityDoc]): - suffix = "similarity" document_model = SimilarityDoc # type: ignore primary_key = "material_id" diff --git a/mp_api/client/routes/substrates.py b/mp_api/client/routes/substrates.py index ce59020e..c6ba9c65 100644 --- a/mp_api/client/routes/substrates.py +++ b/mp_api/client/routes/substrates.py @@ -7,7 +7,6 @@ class SubstratesRester(BaseRester[SubstratesDoc]): - suffix = "substrates" document_model = SubstratesDoc # type: ignore primary_key = "film_id" @@ -76,18 +75,10 @@ def search( query_params.update({"sub_form": substrate_formula}) if film_orientation: - query_params.update( - {"film_orientation": ",".join([str(i) for i in film_orientation])} - ) + query_params.update({"film_orientation": ",".join([str(i) for i in film_orientation])}) if substrate_orientation: - query_params.update( - { - "substrate_orientation": ",".join( - [str(i) for i in substrate_orientation] - ) - } - ) + query_params.update({"substrate_orientation": ",".join([str(i) for i in substrate_orientation])}) if area: query_params.update({"area_min": area[0], "area_max": area[1]}) @@ -96,15 +87,9 @@ def search( query_params.update({"energy_min": energy[0], "energy_max": energy[1]}) if sort_fields: - query_params.update( - {"_sort_fields": ",".join([s.strip() for s in sort_fields])} - ) - - query_params = { - entry: query_params[entry] - for entry in query_params - if query_params[entry] is not None - } + query_params.update({"_sort_fields": ",".join([s.strip() for s in sort_fields])}) + + query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None} return super()._search( **query_params, diff --git a/mp_api/client/routes/summary.py b/mp_api/client/routes/summary.py index 84ac082f..a55410fd 100644 --- a/mp_api/client/routes/summary.py +++ b/mp_api/client/routes/summary.py @@ -12,7 +12,6 @@ class SummaryRester(BaseRester[SummaryDoc]): - suffix = "summary" document_model = SummaryDoc # type: ignore primary_key = "material_id" @@ -23,8 +22,7 @@ def search_summary_docs(self, *args, **kwargs): # pragma: no cover """ warnings.warn( - "MPRester.summary.search_summary_docs is deprecated. " - "Please use MPRester.summary.search instead.", + "MPRester.summary.search_summary_docs is deprecated. " "Please use MPRester.summary.search instead.", DeprecationWarning, stacklevel=2, ) @@ -77,9 +75,7 @@ def search( theoretical: Optional[bool] = None, total_energy: Optional[Tuple[float, float]] = None, total_magnetization: Optional[Tuple[float, float]] = None, - total_magnetization_normalized_formula_units: Optional[ - Tuple[float, float] - ] = None, + total_magnetization_normalized_formula_units: Optional[Tuple[float, float]] = None, total_magnetization_normalized_vol: Optional[Tuple[float, float]] = None, uncorrected_energy: Optional[Tuple[float, float]] = None, volume: Optional[Tuple[float, float]] = None, @@ -276,15 +272,9 @@ def search( query_params.update({"theoretical": theoretical}) if sort_fields: - query_params.update( - {"_sort_fields": ",".join([s.strip() for s in sort_fields])} - ) - - query_params = { - entry: query_params[entry] - for entry in query_params - if query_params[entry] is not None - } + query_params.update({"_sort_fields": ",".join([s.strip() for s in sort_fields])}) + + query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None} return super()._search( num_chunks=num_chunks, diff --git a/mp_api/client/routes/surface_properties.py b/mp_api/client/routes/surface_properties.py index cdb82a70..dc3fd3ff 100644 --- a/mp_api/client/routes/surface_properties.py +++ b/mp_api/client/routes/surface_properties.py @@ -8,7 +8,6 @@ class SurfacePropertiesRester(BaseRester[SurfacePropDoc]): - suffix = "surface_properties" document_model = SurfacePropDoc # type: ignore primary_key = "task_id" @@ -100,20 +99,10 @@ def search( query_params.update({"has_reconstructed": has_reconstructed}) if sort_fields: - query_params.update( - {"_sort_fields": ",".join([s.strip() for s in sort_fields])} - ) + query_params.update({"_sort_fields": ",".join([s.strip() for s in sort_fields])}) - query_params = { - entry: query_params[entry] - for entry in query_params - if query_params[entry] is not None - } + query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None} return super()._search( - num_chunks=num_chunks, - chunk_size=chunk_size, - all_fields=all_fields, - fields=fields, - **query_params + num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, fields=fields, **query_params ) diff --git a/mp_api/client/routes/tasks.py b/mp_api/client/routes/tasks.py index 086bd489..c47e164d 100644 --- a/mp_api/client/routes/tasks.py +++ b/mp_api/client/routes/tasks.py @@ -9,7 +9,6 @@ class TaskRester(BaseRester[TaskDoc]): - suffix = "tasks" document_model = TaskDoc # type: ignore primary_key = "task_id" diff --git a/mp_api/client/routes/thermo.py b/mp_api/client/routes/thermo.py index f968149c..4d592521 100644 --- a/mp_api/client/routes/thermo.py +++ b/mp_api/client/routes/thermo.py @@ -12,7 +12,6 @@ class ThermoRester(BaseRester[ThermoDoc]): - suffix = "thermo" document_model = ThermoDoc # type: ignore supports_versions = True diff --git a/mp_api/client/routes/xas.py b/mp_api/client/routes/xas.py index b9dd6319..8f2716bd 100644 --- a/mp_api/client/routes/xas.py +++ b/mp_api/client/routes/xas.py @@ -9,7 +9,6 @@ class XASRester(BaseRester[XASDoc]): - suffix = "xas" document_model = XASDoc # type: ignore primary_key = "spectrum_id" @@ -20,8 +19,7 @@ def search_xas_docs(self, *args, **kwargs): # pragma: no cover """ warnings.warn( - "MPRester.xas.search_xas_docs is deprecated. " - "Please use MPRester.xas.search instead.", + "MPRester.xas.search_xas_docs is deprecated. " "Please use MPRester.xas.search instead.", DeprecationWarning, stacklevel=2, ) @@ -99,9 +97,7 @@ def search( query_params["material_ids"] = ",".join(validate_ids(material_ids)) if sort_fields: - query_params.update( - {"_sort_fields": ",".join([s.strip() for s in sort_fields])} - ) + query_params.update({"_sort_fields": ",".join([s.strip() for s in sort_fields])}) return super()._search( num_chunks=num_chunks, diff --git a/tests/mpcules/test_summary.py b/tests/mpcules/test_summary.py index f9961350..533d6cd0 100644 --- a/tests/mpcules/test_summary.py +++ b/tests/mpcules/test_summary.py @@ -15,13 +15,10 @@ "exclude_elements", # Below: currently timing out "nelements", - "has_props" + "has_props", ] -alt_name = { - "formula": "formula_alphabetical", - "molecule_ids": "molecule_id" -} +alt_name = {"formula": "formula_alphabetical", "molecule_ids": "molecule_id"} custom_field_tests = { "molecule_ids": ["9f153b9f3caa3124fb404b42e4cf82c8-C2H4-0-1"], @@ -32,13 +29,12 @@ "has_level_of_theory": "wB97X-V/def2-TZVPPD/SMD", "has_lot_solvent": "wB97X-V/def2-TZVPPD/SMD(SOLVENT=THF)", "nelements": 2, - "has_props": [HasProps.orbitals] + "has_props": [HasProps.orbitals], } # type: dict @pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") def test_client(): - search_method = MPculesSummaryRester().search # Get list of parameters diff --git a/tests/test_bonds.py b/tests/test_bonds.py index 15a68322..364b43bd 100644 --- a/tests/test_bonds.py +++ b/tests/test_bonds.py @@ -36,11 +36,8 @@ def rester(): } # type: dict -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") def test_client(rester): - search_method = rester.search if search_method is not None: @@ -89,7 +86,4 @@ def test_client(rester): if sub_field in doc: doc = doc[sub_field] - assert ( - doc[project_field if project_field is not None else param] - is not None - ) + assert doc[project_field if project_field is not None else param] is not None diff --git a/tests/test_charge_density.py b/tests/test_charge_density.py index fbae5056..9adbcc6e 100644 --- a/tests/test_charge_density.py +++ b/tests/test_charge_density.py @@ -83,7 +83,6 @@ def test_client(rester): def test_download_for_task_ids(tmpdir, rester): - n = rester.download_for_task_ids( task_ids=["mp-655585", "mp-1057373", "mp-1059589", "mp-1440634", "mp-1791788"], path=tmpdir, @@ -94,7 +93,6 @@ def test_download_for_task_ids(tmpdir, rester): def test_extract_s3_url_info(rester): - url_doc_dict = { "task_id": "mp-1896591", "url": "https://minio.materialsproject.org/phuck/atomate_chgcar_fs/6021584c12afbe14911d1b8e", diff --git a/tests/test_client.py b/tests/test_client.py index 6ac12a1c..e844c3a1 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -45,7 +45,6 @@ @pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") @pytest.mark.parametrize("rester", mpr._all_resters) def test_generic_get_methods(rester): - # -- Test generic search and get_data_by_id methods name = rester.suffix.replace("/", "_") if name not in ignore_generic: diff --git a/tests/test_core_client.py b/tests/test_core_client.py index 96afdc6e..bd5de38a 100644 --- a/tests/test_core_client.py +++ b/tests/test_core_client.py @@ -18,45 +18,31 @@ def mpr(): rester.session.close() -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") @pytest.mark.xfail def test_post_fail(rester): rester._post_resource({}, suburl="materials/find_structure") -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") def test_pagination(mpr): - mpids = mpr.materials.search( - all_fields=False, fields=["material_id"], num_chunks=2, chunk_size=1000 - ) + mpids = mpr.materials.search(all_fields=False, fields=["material_id"], num_chunks=2, chunk_size=1000) assert len(mpids) > 1000 -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") def test_count(mpr): - count = mpr.materials.count( - dict(task_ids="mp-149", _all_fields=False, _fields="material_id") - ) + count = mpr.materials.count(dict(task_ids="mp-149", _all_fields=False, _fields="material_id")) assert count == 1 -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") @pytest.mark.xfail def test_get_document_no_id(mpr): mpr.materials.get_data_by_id(None) -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") @pytest.mark.xfail def test_get_document_no_doc(mpr): mpr.materials.get_data_by_id("mp-1a") diff --git a/tests/test_dielectric.py b/tests/test_dielectric.py index d174ca07..6bb9df4f 100644 --- a/tests/test_dielectric.py +++ b/tests/test_dielectric.py @@ -32,9 +32,7 @@ def rester(): } # type: dict -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") def test_client(rester): search_method = rester.search @@ -82,7 +80,4 @@ def test_client(rester): if sub_field in doc: doc = doc[sub_field] - assert ( - doc[project_field if project_field is not None else param] - is not None - ) + assert doc[project_field if project_field is not None else param] is not None diff --git a/tests/test_elasticity.py b/tests/test_elasticity.py index 64b314e2..d9c2729f 100644 --- a/tests/test_elasticity.py +++ b/tests/test_elasticity.py @@ -30,9 +30,7 @@ def rester(): custom_field_tests = {} # type: dict -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") def test_client(rester): search_method = rester.search @@ -80,7 +78,4 @@ def test_client(rester): if sub_field in doc: doc = doc[sub_field] - assert ( - doc[project_field if project_field is not None else param] - is not None - ) + assert doc[project_field if project_field is not None else param] is not None diff --git a/tests/test_electrodes.py b/tests/test_electrodes.py index 04ac2a00..051ed44e 100644 --- a/tests/test_electrodes.py +++ b/tests/test_electrodes.py @@ -42,9 +42,7 @@ def rester(): } # type: dict -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") def test_client(rester): search_method = rester.search @@ -92,7 +90,4 @@ def test_client(rester): if sub_field in doc: doc = doc[sub_field] - assert ( - doc[project_field if project_field is not None else param] - is not None - ) + assert doc[project_field if project_field is not None else param] is not None diff --git a/tests/test_electronic_structure.py b/tests/test_electronic_structure.py index f053be76..eafbaaf0 100644 --- a/tests/test_electronic_structure.py +++ b/tests/test_electronic_structure.py @@ -47,9 +47,7 @@ def es_rester(): } # type: dict -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") def test_es_client(es_rester): search_method = es_rester.search @@ -61,7 +59,6 @@ def test_es_client(es_rester): for entry in param_tuples: param = entry[0] if param not in es_excluded_params: - param_type = entry[1].__args__[0] q = None @@ -96,10 +93,7 @@ def test_es_client(es_rester): doc = search_method(**q)[0].dict() - assert ( - doc[project_field if project_field is not None else param] - is not None - ) + assert doc[project_field if project_field is not None else param] is not None bs_custom_field_tests = { @@ -122,9 +116,7 @@ def bs_rester(): rester.session.close() -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") def test_bs_client(bs_rester): # Get specific search method search_method = bs_rester.search @@ -169,9 +161,7 @@ def dos_rester(): rester.session.close() -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") def test_dos_client(dos_rester): search_method = dos_rester.search @@ -192,6 +182,4 @@ def test_dos_client(dos_rester): if param != "projection_type" and param != "magnetic_ordering": doc = doc["total"]["1"] - assert ( - doc[project_field if project_field is not None else param] is not None - ) + assert doc[project_field if project_field is not None else param] is not None diff --git a/tests/test_eos.py b/tests/test_eos.py index 9b966c2a..a0cd231c 100644 --- a/tests/test_eos.py +++ b/tests/test_eos.py @@ -27,9 +27,7 @@ def rester(): custom_field_tests = {} # type: dict -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") def test_client(rester): search_method = rester.search @@ -78,7 +76,4 @@ def test_client(rester): if sub_field in doc: doc = doc[sub_field] - assert ( - doc[project_field if project_field is not None else param] - is not None - ) + assert doc[project_field if project_field is not None else param] is not None diff --git a/tests/test_grain_boundary.py b/tests/test_grain_boundary.py index 1d1b8554..205051a8 100644 --- a/tests/test_grain_boundary.py +++ b/tests/test_grain_boundary.py @@ -36,9 +36,7 @@ def rester(): } # type: dict -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") def test_client(rester): search_method = rester.search @@ -87,7 +85,4 @@ def test_client(rester): if sub_field in doc: doc = doc[sub_field] - assert ( - doc[project_field if project_field is not None else param] - is not None - ) + assert doc[project_field if project_field is not None else param] is not None diff --git a/tests/test_magnetism.py b/tests/test_magnetism.py index d08fbcc2..a4cfeb07 100644 --- a/tests/test_magnetism.py +++ b/tests/test_magnetism.py @@ -28,9 +28,7 @@ def rester(): custom_field_tests = {"material_ids": ["mp-149"], "ordering": Ordering.FM} # type: dict -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") def test_client(rester): search_method = rester.search @@ -79,7 +77,4 @@ def test_client(rester): if sub_field in doc: doc = doc[sub_field] - assert ( - doc[project_field if project_field is not None else param] - is not None - ) + assert doc[project_field if project_field is not None else param] is not None diff --git a/tests/test_molecules.py b/tests/test_molecules.py index 5419243c..8a119a7f 100644 --- a/tests/test_molecules.py +++ b/tests/test_molecules.py @@ -33,9 +33,7 @@ def rester(): } # type: dict -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") def test_client(rester): search_method = rester.search @@ -84,7 +82,4 @@ def test_client(rester): if sub_field in doc: doc = doc[sub_field] - assert ( - doc[project_field if project_field is not None else param] - is not None - ) + assert doc[project_field if project_field is not None else param] is not None diff --git a/tests/test_mprester.py b/tests/test_mprester.py index 84587486..c526f93c 100644 --- a/tests/test_mprester.py +++ b/tests/test_mprester.py @@ -38,9 +38,7 @@ def mpr(): rester.session.close() -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") class TestMPRester: def test_get_structure_by_material_id(self, mpr): s1 = mpr.get_structure_by_material_id("mp-149") @@ -147,9 +145,7 @@ def test_get_entries(self, mpr): # Conventional structure formula = "BiFeO3" - entry = mpr.get_entry_by_material_id( - "mp-22526", inc_structure=True, conventional_unit_cell=True - )[0] + entry = mpr.get_entry_by_material_id("mp-22526", inc_structure=True, conventional_unit_cell=True)[0] s = entry.structure assert pytest.approx(s.lattice.a) == s.lattice.b @@ -159,9 +155,7 @@ def test_get_entries(self, mpr): assert pytest.approx(s.lattice.gamma) == 120 # Ensure energy per atom is same - prim = mpr.get_entry_by_material_id( - "mp-22526", inc_structure=True, conventional_unit_cell=False - )[0] + prim = mpr.get_entry_by_material_id("mp-22526", inc_structure=True, conventional_unit_cell=False)[0] assert pytest.approx(prim.energy_per_atom) == entry.energy_per_atom s = prim.structure @@ -263,15 +257,9 @@ def test_get_ion_entries(self, mpr): # test ion energy calculation ion_data = mpr.get_ion_reference_data_for_chemsys("S") - ion_ref_comps = [ - Ion.from_formula(d["data"]["RefSolid"]).composition for d in ion_data - ] - ion_ref_elts = set( - itertools.chain.from_iterable(i.elements for i in ion_ref_comps) - ) - ion_ref_entries = mpr.get_entries_in_chemsys( - list([str(e) for e in ion_ref_elts] + ["O", "H"]) - ) + ion_ref_comps = [Ion.from_formula(d["data"]["RefSolid"]).composition for d in ion_data] + ion_ref_elts = set(itertools.chain.from_iterable(i.elements for i in ion_ref_comps)) + ion_ref_entries = mpr.get_entries_in_chemsys(list([str(e) for e in ion_ref_elts] + ["O", "H"])) mpc = MaterialsProjectAqueousCompatibility() ion_ref_entries = mpc.process_entries(ion_ref_entries) ion_ref_pd = PhaseDiagram(ion_ref_entries) @@ -280,9 +268,7 @@ def test_get_ion_entries(self, mpr): # In ion ref data, SO4-2 is -744.27 kJ/mol; ref solid is -1,279.0 kJ/mol # so the ion entry should have an energy (-744.27 +1279) = 534.73 kJ/mol # or 5.542 eV/f.u. above the energy of Na2SO4 - so4_two_minus = [e for e in ion_entries if e.ion.reduced_formula == "SO4[-2]"][ - 0 - ] + so4_two_minus = [e for e in ion_entries if e.ion.reduced_formula == "SO4[-2]"][0] # the ref solid is Na2SO4, ground state mp-4770 # the rf factor correction is necessary to make sure the composition @@ -305,9 +291,7 @@ def test_get_charge_density_data(self, mpr): chgcar = mpr.get_charge_density_from_material_id("mp-149") assert isinstance(chgcar, Chgcar) - chgcar, task_doc = mpr.get_charge_density_from_material_id( - "mp-149", inc_task_doc=True - ) + chgcar, task_doc = mpr.get_charge_density_from_material_id("mp-149", inc_task_doc=True) assert isinstance(chgcar, Chgcar) assert isinstance(task_doc, TaskDoc) @@ -317,10 +301,7 @@ def test_get_wulff_shape(self, mpr): def test_large_list(self, mpr): mpids = [ - str(doc.material_id) - for doc in mpr.summary.search( - chunk_size=1000, num_chunks=15, fields=["material_id"] - ) + str(doc.material_id) for doc in mpr.summary.search(chunk_size=1000, num_chunks=15, fields=["material_id"]) ] docs = mpr.summary.search(material_ids=mpids, fields=["material_ids"]) assert len(docs) == 15000 diff --git a/tests/test_oxidation_states.py b/tests/test_oxidation_states.py index a6c299ac..365940e2 100644 --- a/tests/test_oxidation_states.py +++ b/tests/test_oxidation_states.py @@ -33,9 +33,7 @@ def rester(): } # type: dict -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") def test_client(rester): search_method = rester.search @@ -47,7 +45,6 @@ def test_client(rester): for entry in param_tuples: param = entry[0] if param not in excluded_params: - param_type = entry[1].__args__[0] q = None @@ -85,7 +82,4 @@ def test_client(rester): if sub_field in doc: doc = doc[sub_field] - assert ( - doc[project_field if project_field is not None else param] - is not None - ) + assert doc[project_field if project_field is not None else param] is not None diff --git a/tests/test_piezo.py b/tests/test_piezo.py index 847708f8..4a6d8f51 100644 --- a/tests/test_piezo.py +++ b/tests/test_piezo.py @@ -30,9 +30,7 @@ def rester(): custom_field_tests = {"material_ids": ["mp-9900"]} # type: dict -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") def test_client(rester): search_method = rester.search @@ -81,7 +79,4 @@ def test_client(rester): if sub_field in doc: doc = doc[sub_field] - assert ( - doc[project_field if project_field is not None else param] - is not None - ) + assert doc[project_field if project_field is not None else param] is not None diff --git a/tests/test_provenance.py b/tests/test_provenance.py index 1dc7ae59..e1561cd1 100644 --- a/tests/test_provenance.py +++ b/tests/test_provenance.py @@ -27,9 +27,7 @@ def rester(): custom_field_tests = {"material_ids": ["mp-149"]} # type: dict -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") def test_client(rester): search_method = rester.search @@ -78,7 +76,4 @@ def test_client(rester): if sub_field in doc: doc = doc[sub_field] - assert ( - doc[project_field if project_field is not None else param] - is not None - ) + assert doc[project_field if project_field is not None else param] is not None diff --git a/tests/test_robocrys.py b/tests/test_robocrys.py index 7693059b..7aa56e43 100644 --- a/tests/test_robocrys.py +++ b/tests/test_robocrys.py @@ -12,14 +12,11 @@ def rester(): rester.session.close() -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") def test_client(rester): search_method = rester.search if search_method is not None: - q = {"keywords": ["silicon"], "num_chunks": 1} doc = search_method(**q)[0] diff --git a/tests/test_substrates.py b/tests/test_substrates.py index 7acd0d72..7dcd3c4e 100644 --- a/tests/test_substrates.py +++ b/tests/test_substrates.py @@ -38,9 +38,7 @@ def rester(): } # type: dict -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") def test_client(rester): search_method = rester.search @@ -89,7 +87,4 @@ def test_client(rester): if sub_field in doc: doc = doc[sub_field] - assert ( - doc[project_field if project_field is not None else param] - is not None - ) + assert doc[project_field if project_field is not None else param] is not None diff --git a/tests/test_summary.py b/tests/test_summary.py index 538404a5..3a952131 100644 --- a/tests/test_summary.py +++ b/tests/test_summary.py @@ -56,7 +56,6 @@ @pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") def test_client(): - search_method = SummaryRester().search # Get list of parameters diff --git a/tests/test_surface_properties.py b/tests/test_surface_properties.py index c2040f00..e4006941 100644 --- a/tests/test_surface_properties.py +++ b/tests/test_surface_properties.py @@ -27,9 +27,7 @@ def rester(): custom_field_tests = {} # type: dict -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") def test_client(rester): search_method = rester.search @@ -78,7 +76,4 @@ def test_client(rester): if sub_field in doc: doc = doc[sub_field] - assert ( - doc[project_field if project_field is not None else param] - is not None - ) + assert doc[project_field if project_field is not None else param] is not None diff --git a/tests/test_synthesis.py b/tests/test_synthesis.py index ce69a464..c1404b94 100644 --- a/tests/test_synthesis.py +++ b/tests/test_synthesis.py @@ -15,9 +15,7 @@ def rester(): rester.session.close() -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") def test_client(rester): search_method = rester.search @@ -31,9 +29,7 @@ def test_client(rester): assert doc.synthesis_type is not None -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") def test_filters_keywords(rester): search_method = rester.search @@ -45,25 +41,19 @@ def test_filters_keywords(rester): assert "silicon" in " ".join([x["value"] for x in highlighted]).lower() -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") def test_filters_synthesis_type(rester): search_method = rester.search if search_method is not None: - doc = search_method( - synthesis_type=[SynthesisTypeEnum.solid_state], num_chunks=1 - ) + doc = search_method(synthesis_type=[SynthesisTypeEnum.solid_state], num_chunks=1) assert all(x.synthesis_type == SynthesisTypeEnum.solid_state for x in doc) doc = search_method(synthesis_type=[SynthesisTypeEnum.sol_gel], num_chunks=1) assert all(x.synthesis_type == SynthesisTypeEnum.sol_gel for x in doc) -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") @pytest.mark.xfail # Needs fixing def test_filters_temperature_range(rester): search_method = rester.search @@ -81,9 +71,7 @@ def test_filters_temperature_range(rester): assert 700 <= val <= 1000 -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") @pytest.mark.xfail # Needs fixing def test_filters_time_range(rester): search_method = rester.search @@ -99,15 +87,14 @@ def test_filters_time_range(rester): assert 7 <= val <= 11 -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") def test_filters_atmosphere(rester): search_method = rester.search if search_method is not None: docs: List[SynthesisRecipe] = search_method( - condition_heating_atmosphere=["air", "O2"], num_chunks=5, + condition_heating_atmosphere=["air", "O2"], + num_chunks=5, ) for doc in docs: found = False @@ -118,15 +105,14 @@ def test_filters_atmosphere(rester): assert found -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") def test_filters_mixing_device(rester): search_method = rester.search if search_method is not None: docs: List[SynthesisRecipe] = search_method( - condition_mixing_device=["zirconia", "Al2O3"], num_chunks=5, + condition_mixing_device=["zirconia", "Al2O3"], + num_chunks=5, ) for doc in docs: found = False @@ -136,15 +122,14 @@ def test_filters_mixing_device(rester): assert found -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") def test_filters_mixing_media(rester): search_method = rester.search if search_method is not None: docs: List[SynthesisRecipe] = search_method( - condition_mixing_media=["water", "alcohol"], num_chunks=5, + condition_mixing_media=["water", "alcohol"], + num_chunks=5, ) for doc in docs: found = False diff --git a/tests/test_tasks.py b/tests/test_tasks.py index e6708dc9..319f55ef 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -26,7 +26,11 @@ def rester(): sub_doc_fields = [] # type: list -alt_name_dict = {"formula": "task_id", "task_ids": "task_id", "exclude_elements": "task_id"} # type: dict +alt_name_dict = { + "formula": "task_id", + "task_ids": "task_id", + "exclude_elements": "task_id", +} # type: dict custom_field_tests = { "chemsys": "Si-O", @@ -91,7 +95,6 @@ def test_client(rester): def test_get_trajectories(rester): - trajectories = rester.get_trajectory("mp-149") for traj in trajectories: From 9dfac7236395af1784ab26005540849797d1c205 Mon Sep 17 00:00:00 2001 From: Evan Walter Clark Spotte-Smith Date: Wed, 12 Apr 2023 11:18:58 -0700 Subject: [PATCH 9/9] Update version of emmet-core --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 61bd4c95..b98e3946 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -emmet-core[all]==0.39.8 +emmet-core[all]==0.51.11 pydantic>=1.8.2 pymatgen>=2022.3.7 pymatgen-analysis-alloys>=0.0.3