Skip to content

Commit

Permalink
collect: fix Pool type annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
avivrosenberg committed Jan 1, 2024
1 parent 5795d4c commit 9d0755d
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions src/pp5/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pathlib import Path
from datetime import datetime
from dataclasses import dataclass
from multiprocessing.pool import AsyncResult
from multiprocessing.pool import Pool, AsyncResult

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -183,7 +183,7 @@ def collect(self) -> dict:
LOGGER.info(f"Collection metadata:\n" f"{collection_meta_formatted}")
return self._collection_meta

def _finalize_collection(self, pool):
def _finalize_collection(self, pool: Pool):
LOGGER.info(f"Finalizing collection for {self.id}...")
if self.out_dir is None:
return
Expand Down Expand Up @@ -438,7 +438,7 @@ def _collection_functions(self):
"Write dataset": self._write_dataset,
}

def _collect_precs(self, pool: mp.Pool):
def _collect_precs(self, pool: Pool):
meta = {}
pdb_ids = self.query.execute()
n_structs = len(pdb_ids)
Expand Down Expand Up @@ -481,7 +481,7 @@ def _collect_precs(self, pool: mp.Pool):
)
return meta

def _filter_collected(self, pool: mp.Pool) -> dict:
def _filter_collected(self, pool: Pool) -> dict:
"""
Filters collected structures according to conditions on their metadata.
"""
Expand Down Expand Up @@ -534,14 +534,14 @@ def _update_rejected_counts(filter_name: str, idx: pd.Series):
"n_collected_filtered": len(df_filtered),
}

def _filter_metadata(self, pool: mp.Pool, df_all: pd.DataFrame) -> pd.Series:
def _filter_metadata(self, pool: Pool, df_all: pd.DataFrame) -> pd.Series:
# Even though we query by resolution, the metadata resolution is different
# than what we can query on. Metadata shows resolution after refinement,
# while the query is using data collection resolution.
idx_filter = df_all[COL_RESOLUTION].astype(float) <= self.resolution
return idx_filter

def _filter_redundant_unps(self, pool: mp.Pool, df_all: pd.DataFrame) -> pd.Series:
def _filter_redundant_unps(self, pool: Pool, df_all: pd.DataFrame) -> pd.Series:

if self.seq_similarity_thresh == 1.0:
LOGGER.info("Skipping sequence similarity filter...")
Expand Down Expand Up @@ -620,7 +620,7 @@ def _filter_redundant_unps(self, pool: mp.Pool, df_all: pd.DataFrame) -> pd.Seri
filtered_idx = df_all[COL_UNP_ID].isin(filtered_unp_ids)
return filtered_idx

def _write_dataset(self, pool: mp.Pool) -> dict:
def _write_dataset(self, pool: Pool) -> dict:
all_csvs = tuple(self.prec_csv_out_dir.glob("*.csv"))
n_pdb_ids = len(all_csvs)
LOGGER.info(f"Creating dataset file from {n_pdb_ids} precs...")
Expand Down Expand Up @@ -812,7 +812,7 @@ def _collection_functions(self):
"Collect pgroups": self._collect_all_pgroups,
}

def _collect_all_structures(self, pool: mp.Pool):
def _collect_all_structures(self, pool: Pool):
meta = {}

if self._all_file:
Expand Down Expand Up @@ -864,7 +864,7 @@ def _collect_all_structures(self, pool: mp.Pool):
meta["n_all_structures"] = len(self._df_all)
return meta

def _collect_all_refs(self, pool: mp.Pool):
def _collect_all_refs(self, pool: Pool):
meta = {}

if self._ref_file:
Expand Down Expand Up @@ -906,7 +906,7 @@ def _collect_all_refs(self, pool: mp.Pool):
self._out_filepaths.append(filepath)
return meta

def _collect_all_pgroups(self, pool: mp.Pool):
def _collect_all_pgroups(self, pool: Pool):
meta = {}

# Initialize a local BLAST DB.
Expand Down

0 comments on commit 9d0755d

Please sign in to comment.