Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial MPcules Summary RESTer #758

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
24 changes: 7 additions & 17 deletions mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@ def _post_resource(
response = self.session.post(url, json=payload, verify=True, params=params)

if response.status_code == 200:

if self.monty_decode:
data = json.loads(response.text, cls=MontyDecoder)
else:
Expand Down Expand Up @@ -234,7 +233,6 @@ def _post_resource(
)

except RequestException as ex:

raise MPRestError(str(ex))

def _query_resource(
Expand Down Expand Up @@ -307,7 +305,6 @@ def _query_resource(
return data

except RequestException as ex:

raise MPRestError(str(ex))

def _submit_requests(
Expand Down Expand Up @@ -345,7 +342,6 @@ def _submit_requests(
# trying to evenly divide num_chunks by the total number of new
# criteria dicts.
if parallel_param is not None:

# Determine slice size accounting for character maximum in HTTP URL
# First get URl length without parallel param
url_string = ""
Expand All @@ -372,7 +368,6 @@ def _submit_requests(
]

if len(parallel_param_str_chunks) > 0:

params_min_chunk = min(parallel_param_str_chunks, key=lambda x: len(x.split("%2C")))

num_params_min_chunk = len(params_min_chunk.split("%2C"))
Expand Down Expand Up @@ -431,7 +426,6 @@ def _submit_requests(
initial_data_tuples = self._multi_thread(use_document_model, initial_params_list)

for data, subtotal, crit_ind in initial_data_tuples:

subtotals.append(subtotal)
sub_diff = subtotal - new_limits[crit_ind]
remaining_docs_avail[crit_ind] = sub_diff
Expand Down Expand Up @@ -475,7 +469,6 @@ def _submit_requests(

# Obtain missing initial data after rebalancing
if len(rebalance_params) > 0:

rebalance_data_tuples = self._multi_thread(use_document_model, rebalance_params)

for data, _, _ in rebalance_data_tuples:
Expand Down Expand Up @@ -611,12 +604,10 @@ def _multi_thread(
params_ind = 0

with ThreadPoolExecutor(max_workers=MAPIClientSettings().NUM_PARALLEL_REQUESTS) as executor:

# Get list of initial futures defined by max number of parallel requests
futures = set()

for params in itertools.islice(params_gen, MAPIClientSettings().NUM_PARALLEL_REQUESTS):

future = executor.submit(
self._submit_request_and_process,
use_document_model=use_document_model,
Expand All @@ -632,7 +623,6 @@ def _multi_thread(
finished, futures = wait(futures, return_when=FIRST_COMPLETED)

for future in finished:

data, subtotal = future.result()

if progress_bar is not None:
Expand All @@ -641,7 +631,6 @@ def _multi_thread(

# Populate more futures to replace finished
for params in itertools.islice(params_gen, len(finished)):

new_future = executor.submit(
self._submit_request_and_process,
use_document_model=use_document_model,
Expand Down Expand Up @@ -677,12 +666,17 @@ def _submit_request_and_process(
Tuple with data and total number of docs in matching the query in the database.
"""
try:
response = self.session.get(url=url, verify=verify, params=params, timeout=timeout, headers=self.headers)
response = self.session.get(
url=url,
verify=verify,
params=params,
timeout=timeout,
headers=self.headers,
)
except requests.exceptions.ConnectTimeout:
raise MPRestError(f"REST query timed out on URL {url}. Try again with a smaller request.")

if response.status_code == 200:

if self.monty_decode:
data = json.loads(response.text, cls=MontyDecoder)
else:
Expand All @@ -691,7 +685,6 @@ def _submit_request_and_process(
# other sub-urls may use different document models
# the client does not handle this in a particularly smart way currently
if self.document_model and use_document_model:

raw_doc_list = [self.document_model.parse_obj(d) for d in data["data"]] # type: ignore

if len(raw_doc_list) > 0:
Expand Down Expand Up @@ -727,7 +720,6 @@ def _submit_request_and_process(
)

def _generate_returned_model(self, doc):

set_fields = [field for field, _ in doc if field in doc.dict(exclude_unset=True)]
unset_fields = [field for field in doc.__fields__ if field not in set_fields]

Expand Down Expand Up @@ -840,7 +832,6 @@ def get_data_by_id(
try:
results = self._query_resource_data(criteria=criteria, fields=fields, suburl=document_id) # type: ignore
except MPRestError:

if self.primary_key == "material_id":
# see if the material_id has changed, perhaps a task_id was supplied
# this should likely be re-thought
Expand All @@ -857,7 +848,6 @@ def get_data_by_id(
docs = mpr.search(task_ids=[document_id], fields=["material_id"])

if len(docs) > 0:

new_document_id = docs[0].get("material_id", None)

if new_document_id is not None:
Expand Down
4 changes: 1 addition & 3 deletions mp_api/client/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ class MAPIClientSettings(BaseSettings):
description="Number of parallel requests to send.",
)

MAX_RETRIES: int = Field(
_MAX_RETRIES, description="Maximum number of retries for requests."
)
MAX_RETRIES: int = Field(_MAX_RETRIES, description="Maximum number of retries for requests.")

BACKOFF_FACTOR: float = Field(
0.1,
Expand Down
8 changes: 2 additions & 6 deletions mp_api/client/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ def api_sanitize(
"""

models = [
model
for model in get_flat_models_from_model(pydantic_model)
if issubclass(model, BaseModel)
model for model in get_flat_models_from_model(pydantic_model) if issubclass(model, BaseModel)
] # type: List[Type[BaseModel]]

fields_to_leave = fields_to_leave or []
Expand Down Expand Up @@ -100,9 +98,7 @@ def validate_monty(cls, v):
errors.append("@class")

if len(errors) > 0:
raise ValueError(
"Missing Monty seriailzation fields in dictionary: {errors}"
)
raise ValueError("Missing Monty seriailzation fields in dictionary: {errors}")

return v
else:
Expand Down
17 changes: 9 additions & 8 deletions mp_api/client/mprester.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class MPRester:
bonds: BondsRester
alloys: AlloysRester
absorption: AbsorptionRester
mpcules_summary: MPculesSummaryRester
_user_settings: UserSettingsRester
_general_store: GeneralStoreRester

Expand Down Expand Up @@ -133,7 +134,9 @@ def __init__(
self.endpoint = endpoint
self.headers = headers or {}
self.session = session or BaseRester._create_session(
api_key=self.api_key, include_user_agent=include_user_agent, headers=self.headers
api_key=self.api_key,
include_user_agent=include_user_agent,
headers=self.headers,
)
self.use_document_model = use_document_model
self.monty_decode = monty_decode
Expand Down Expand Up @@ -162,7 +165,6 @@ def __init__(
self.endpoint += "/"

for cls in BaseRester.__subclasses__():

rester = cls(
api_key=api_key,
endpoint=endpoint,
Expand Down Expand Up @@ -506,7 +508,6 @@ def get_entries(
try:
input_params = {"material_ids": validate_ids(chemsys_formula_mpids)}
except ValueError:

if any("-" in entry for entry in chemsys_formula_mpids):
input_params = {"chemsys": chemsys_formula_mpids}
else:
Expand Down Expand Up @@ -548,7 +549,6 @@ def get_entries(
)

if conventional_unit_cell:

entry_struct = Structure.from_dict(entry_dict["structure"])
s = SpacegroupAnalyzer(entry_struct).get_conventional_standard_structure()
site_ratio = len(s) / len(entry_struct)
Expand Down Expand Up @@ -604,7 +604,6 @@ def get_pourbaix_entries(
MaterialsProjectAqueousCompatibility,
MaterialsProjectCompatibility,
)
from pymatgen.entries.computed_entries import ComputedEntry

if solid_compat == "MaterialsProjectCompatibility":
solid_compat = MaterialsProjectCompatibility()
Expand Down Expand Up @@ -711,7 +710,9 @@ def get_ion_reference_data(self) -> List[Dict]:
compounds and aqueous species, Wiley, New York (1978)'}}
"""
return self.contribs.query_contributions(
query={"project": "ion_ref_data"}, fields=["identifier", "formula", "data"], paginate=True
query={"project": "ion_ref_data"},
fields=["identifier", "formula", "data"],
paginate=True,
).get("data")

def get_ion_reference_data_for_chemsys(self, chemsys: Union[str, List]) -> List[Dict]:
Expand Down Expand Up @@ -1122,9 +1123,9 @@ def get_download_info(self, material_ids, calc_types=None, file_patterns=None):

meta = {}
for doc in self.materials.search(
task_ids=material_ids, fields=["calc_types", "deprecated_tasks", "material_id"]
task_ids=material_ids,
fields=["calc_types", "deprecated_tasks", "material_id"],
):

for task_id, calc_type in doc.calc_types.items():
if calc_types and calc_type not in calc_types:
continue
Expand Down
1 change: 1 addition & 0 deletions mp_api/client/routes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .bonds import BondsRester
from .robocrys import RobocrysRester
from .absorption import AbsorptionRester
from mp_api.client.routes.mpcules.summary import MPculesSummaryRester

try:
from .alloys import AlloysRester
Expand Down
5 changes: 1 addition & 4 deletions mp_api/client/routes/_general_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@


class GeneralStoreRester(BaseRester[GeneralStoreDoc]): # pragma: no cover

suffix = "_general_store"
document_model = GeneralStoreDoc # type: ignore
primary_key = "submission_id"
Expand All @@ -24,9 +23,7 @@ def add_item(self, kind: str, markdown: str, meta: Dict): # pragma: no cover
Raises:
MPRestError
"""
return self._post_resource(
body=meta, params={"kind": kind, "markdown": markdown}
).get("data")
return self._post_resource(body=meta, params={"kind": kind, "markdown": markdown}).get("data")

def get_items(self, kind): # pragma: no cover
"""
Expand Down
5 changes: 1 addition & 4 deletions mp_api/client/routes/_user_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@


class UserSettingsRester(BaseRester[UserSettingsDoc]): # pragma: no cover

suffix = "_user_settings"
document_model = UserSettingsDoc # type: ignore
primary_key = "consumer_id"
Expand All @@ -21,9 +20,7 @@ def set_user_settings(self, consumer_id, settings): # pragma: no cover
Raises:
MPRestError
"""
return self._post_resource(
body=settings, params={"consumer_id": consumer_id}
).get("data")
return self._post_resource(body=settings, params={"consumer_id": consumer_id}).get("data")

def get_user_settings(self, consumer_id): # pragma: no cover
"""
Expand Down
1 change: 0 additions & 1 deletion mp_api/client/routes/alloys.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@


class AlloysRester(BaseRester[AlloyPairDoc]):

suffix = "alloys"
document_model = AlloyPairDoc # type: ignore
primary_key = "pair_id"
Expand Down
21 changes: 4 additions & 17 deletions mp_api/client/routes/bonds.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@


class BondsRester(BaseRester[BondingDoc]):

suffix = "bonds"
document_model = BondingDoc # type: ignore
primary_key = "material_id"
Expand Down Expand Up @@ -102,25 +101,13 @@ def search(
query_params.update({"coordination_envs": ",".join(coordination_envs)})

if coordination_envs_anonymous is not None:
query_params.update(
{"coordination_envs_anonymous": ",".join(coordination_envs_anonymous)}
)
query_params.update({"coordination_envs_anonymous": ",".join(coordination_envs_anonymous)})

if sort_fields:
query_params.update(
{"_sort_fields": ",".join([s.strip() for s in sort_fields])}
)
query_params.update({"_sort_fields": ",".join([s.strip() for s in sort_fields])})

query_params = {
entry: query_params[entry]
for entry in query_params
if query_params[entry] is not None
}
query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None}

return super()._search(
num_chunks=num_chunks,
chunk_size=chunk_size,
all_fields=all_fields,
fields=fields,
**query_params
num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, fields=fields, **query_params
)
11 changes: 5 additions & 6 deletions mp_api/client/routes/charge_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@


class ChargeDensityRester(BaseRester[ChgcarDataDoc]):

suffix = "charge_density"
primary_key = "fs_id"
document_model = ChgcarDataDoc # type: ignore
Expand Down Expand Up @@ -50,7 +49,11 @@ def download_for_task_ids(
return num_downloads

def search( # type: ignore
self, task_ids: Optional[List[str]] = None, num_chunks: Optional[int] = 1, chunk_size: int = 10, **kwargs
self,
task_ids: Optional[List[str]] = None,
num_chunks: Optional[int] = 1,
chunk_size: int = 10,
**kwargs,
) -> Union[List[ChgcarDataDoc], List[Dict]]: # type: ignore
"""
A search method to find what charge densities are available via this API.
Expand Down Expand Up @@ -80,13 +83,11 @@ def get_charge_density_from_file_id(self, fs_id: str):
url_doc = self.get_data_by_id(fs_id)

if url_doc:

# The check below is performed to see if the client is being
# used by our internal AWS deployment. If it is, we pull charge
# density data from a private S3 bucket. Else, we pull data
# from public MinIO buckets.
if environ.get("AWS_EXECUTION_ENV", None) == "AWS_ECS_FARGATE":

if self.boto_resource is None:
self.boto_resource = self._get_s3_resource(use_minio=False, unsigned=False)

Expand Down Expand Up @@ -118,7 +119,6 @@ def get_charge_density_from_file_id(self, fs_id: str):
return None

def _extract_s3_url_info(self, url_doc, use_minio: bool = True):

if use_minio:
url_list = url_doc.url.split("/")
bucket = url_list[3]
Expand All @@ -131,7 +131,6 @@ def _extract_s3_url_info(self, url_doc, use_minio: bool = True):
return (bucket, obj_prefix)

def _get_s3_resource(self, use_minio: bool = True, unsigned: bool = True):

resource = boto3.resource(
"s3",
endpoint_url="https://minio.materialsproject.org" if use_minio else None,
Expand Down
Loading