From 9d0755d79b2f6ecb94d56a587820a75251b404c5 Mon Sep 17 00:00:00 2001 From: Aviv Rosenberg Date: Mon, 1 Jan 2024 09:38:33 +0200 Subject: [PATCH] collect: fix Pool type annotation --- src/pp5/collect.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/pp5/collect.py b/src/pp5/collect.py index 11e4c62..55cc8fd 100644 --- a/src/pp5/collect.py +++ b/src/pp5/collect.py @@ -14,7 +14,7 @@ from pathlib import Path from datetime import datetime from dataclasses import dataclass -from multiprocessing.pool import AsyncResult +from multiprocessing.pool import Pool, AsyncResult import numpy as np import pandas as pd @@ -183,7 +183,7 @@ def collect(self) -> dict: LOGGER.info(f"Collection metadata:\n" f"{collection_meta_formatted}") return self._collection_meta - def _finalize_collection(self, pool): + def _finalize_collection(self, pool: Pool): LOGGER.info(f"Finalizing collection for {self.id}...") if self.out_dir is None: return @@ -438,7 +438,7 @@ def _collection_functions(self): "Write dataset": self._write_dataset, } - def _collect_precs(self, pool: mp.Pool): + def _collect_precs(self, pool: Pool): meta = {} pdb_ids = self.query.execute() n_structs = len(pdb_ids) @@ -481,7 +481,7 @@ def _collect_precs(self, pool: mp.Pool): ) return meta - def _filter_collected(self, pool: mp.Pool) -> dict: + def _filter_collected(self, pool: Pool) -> dict: """ Filters collected structures according to conditions on their metadata. """ @@ -534,14 +534,14 @@ def _update_rejected_counts(filter_name: str, idx: pd.Series): "n_collected_filtered": len(df_filtered), } - def _filter_metadata(self, pool: mp.Pool, df_all: pd.DataFrame) -> pd.Series: + def _filter_metadata(self, pool: Pool, df_all: pd.DataFrame) -> pd.Series: # Even though we query by resolution, the metadata resolution is different # than what we can query on. Metadata shows resolution after refinement, # while the query is using data collection resolution. idx_filter = df_all[COL_RESOLUTION].astype(float) <= self.resolution return idx_filter - def _filter_redundant_unps(self, pool: mp.Pool, df_all: pd.DataFrame) -> pd.Series: + def _filter_redundant_unps(self, pool: Pool, df_all: pd.DataFrame) -> pd.Series: if self.seq_similarity_thresh == 1.0: LOGGER.info("Skipping sequence similarity filter...") @@ -620,7 +620,7 @@ def _filter_redundant_unps(self, pool: mp.Pool, df_all: pd.DataFrame) -> pd.Seri filtered_idx = df_all[COL_UNP_ID].isin(filtered_unp_ids) return filtered_idx - def _write_dataset(self, pool: mp.Pool) -> dict: + def _write_dataset(self, pool: Pool) -> dict: all_csvs = tuple(self.prec_csv_out_dir.glob("*.csv")) n_pdb_ids = len(all_csvs) LOGGER.info(f"Creating dataset file from {n_pdb_ids} precs...") @@ -812,7 +812,7 @@ def _collection_functions(self): "Collect pgroups": self._collect_all_pgroups, } - def _collect_all_structures(self, pool: mp.Pool): + def _collect_all_structures(self, pool: Pool): meta = {} if self._all_file: @@ -864,7 +864,7 @@ def _collect_all_structures(self, pool: mp.Pool): meta["n_all_structures"] = len(self._df_all) return meta - def _collect_all_refs(self, pool: mp.Pool): + def _collect_all_refs(self, pool: Pool): meta = {} if self._ref_file: @@ -906,7 +906,7 @@ def _collect_all_refs(self, pool: mp.Pool): self._out_filepaths.append(filepath) return meta - def _collect_all_pgroups(self, pool: mp.Pool): + def _collect_all_pgroups(self, pool: Pool): meta = {} # Initialize a local BLAST DB.