diff --git a/openprotein/api/align.py b/openprotein/api/align.py index ddb2eb0..3b05d04 100644 --- a/openprotein/api/align.py +++ b/openprotein/api/align.py @@ -1,5 +1,5 @@ from typing import Iterator, Optional, List, BinaryIO, Literal, Union -from pydantic import BaseModel, Field, validator, root_validator +from openprotein.pydantic import BaseModel, Field, validator, root_validator from enum import Enum from io import BytesIO import random diff --git a/openprotein/api/data.py b/openprotein/api/data.py index 01397f9..0ab9b45 100644 --- a/openprotein/api/data.py +++ b/openprotein/api/data.py @@ -1,6 +1,6 @@ import pandas as pd -import pydantic -from pydantic import BaseModel +import openprotein.pydantic as pydantic +from openprotein.pydantic import BaseModel from typing import Optional, List, Union from datetime import datetime from io import BytesIO @@ -468,7 +468,7 @@ def create( metadata.sequence_length = len(table["sequence"].values[0]) return AssayDataset(self.session, metadata) - def get(self, assay_id: str) -> AssayMetadata: + def get(self, assay_id: str, verbose: bool = False) -> AssayMetadata: """ Get an assay dataset by its ID. diff --git a/openprotein/api/design.py b/openprotein/api/design.py index 5d5650a..7951185 100644 --- a/openprotein/api/design.py +++ b/openprotein/api/design.py @@ -9,7 +9,7 @@ from openprotein.errors import APIError from openprotein.futures import FutureFactory, FutureBase -from pydantic import BaseModel, Field, validator +from openprotein.pydantic import BaseModel, Field, validator from datetime import datetime import re diff --git a/openprotein/api/embedding.py b/openprotein/api/embedding.py index dd45cc7..dbe9bd2 100644 --- a/openprotein/api/embedding.py +++ b/openprotein/api/embedding.py @@ -11,7 +11,7 @@ ) from openprotein.futures import FutureBase, FutureFactory -from pydantic import BaseModel, parse_obj_as +from openprotein.pydantic import BaseModel, parse_obj_as import numpy as np from typing import Optional, List, Union, Any import io @@ -247,7 +247,7 @@ def __init__( def get(self, verbose=False) -> List: return super().get(verbose=verbose) - + @property def sequences(self): if self._sequences is None: @@ -305,9 +305,7 @@ def embedding_model_post( """ endpoint = PATH_PREFIX + f"/models/{model_id}/embed" - sequences_unicode = [ - (s if isinstance(s, str) else s.decode()) for s in sequences - ] + sequences_unicode = [(s if isinstance(s, str) else s.decode()) for s in sequences] body = { "sequences": sequences_unicode, } @@ -345,9 +343,7 @@ def embedding_model_logits_post( """ endpoint = PATH_PREFIX + f"/models/{model_id}/logits" - sequences_unicode = [ - (s if isinstance(s, str) else s.decode()) for s in sequences - ] + sequences_unicode = [(s if isinstance(s, str) else s.decode()) for s in sequences] body = { "sequences": sequences_unicode, } @@ -385,9 +381,7 @@ def embedding_model_attn_post( """ endpoint = PATH_PREFIX + f"/models/{model_id}/attn" - sequences_unicode = [ - (s if isinstance(s, str) else s.decode()) for s in sequences - ] + sequences_unicode = [(s if isinstance(s, str) else s.decode()) for s in sequences] body = { "sequences": sequences_unicode, } @@ -500,9 +494,7 @@ def svd_embed_post(session: APISession, svd_id: str, sequences: List[bytes]) -> """ endpoint = PATH_PREFIX + f"/svd/{svd_id}/embed" - sequences_unicode = [ - (s if isinstance(s, str) else s.decode()) for s in sequences - ] + sequences_unicode = [(s if isinstance(s, str) else s.decode()) for s in sequences] body = { "sequences": sequences_unicode, } @@ -715,7 +707,7 @@ def get_job(self) -> Job: """Get job associated with this SVD model""" return job_get(self.session, self.id) - def get(self): + def get(self, verbose: bool = False): # overload for AsyncJobFuture return self @@ -963,7 +955,7 @@ def fit_svd( sequences: List[bytes], n_components: int = 1024, reduction: Optional[str] = None, - ) -> SVDModel: + ) -> SVDModel: # type: ignore """ Fit an SVD on the embedding results of this model. diff --git a/openprotein/api/fold.py b/openprotein/api/fold.py index c38ed0f..c8800b4 100644 --- a/openprotein/api/fold.py +++ b/openprotein/api/fold.py @@ -3,7 +3,7 @@ import openprotein.config as config from openprotein.api.embedding import ModelMetadata from openprotein.api.align import validate_msa, MSAFuture -import pydantic +import openprotein.pydantic as pydantic from typing import Optional, List, Union, Tuple from openprotein.futures import FutureBase, FutureFactory from abc import ABC, abstractmethod diff --git a/openprotein/api/jobs.py b/openprotein/api/jobs.py index 0430315..113d8dc 100644 --- a/openprotein/api/jobs.py +++ b/openprotein/api/jobs.py @@ -1,12 +1,12 @@ # Jobs and job centric flows -from typing import List, Union +from typing import List, Union, Optional import concurrent.futures import time import tqdm -import pydantic +import openprotein.pydantic as pydantic from openprotein.base import APISession import openprotein.config as config @@ -105,7 +105,7 @@ def list( more_recent_than=more_recent_than, ) - def get(self, job_id) -> Job: + def get(self, job_id: str, verbose: bool = False) -> Job: """get Job by ID""" return load_job(self.session, job_id) # return job_get(self.session, job_id) @@ -150,7 +150,7 @@ def done(self): def cancelled(self): return self.job.cancelled() - def get(self, verbose=False): + def get(self, verbose: bool = False): raise NotImplementedError() def wait_until_done( @@ -176,7 +176,7 @@ def wait_until_done( def wait( self, interval: int = config.POLLING_INTERVAL, - timeout: int = None, + timeout: Optional[int] = None, verbose: bool = False, ): """ @@ -195,7 +195,7 @@ def wait( self.session, interval=interval, timeout=timeout, verbose=verbose ) self.job = job - return self.get(verbose=verbose) + return self.get() class StreamingAsyncJobFuture(AsyncJobFuture): diff --git a/openprotein/api/poet.py b/openprotein/api/poet.py index 5a2685a..8beaf64 100644 --- a/openprotein/api/poet.py +++ b/openprotein/api/poet.py @@ -1,5 +1,5 @@ from typing import Iterator, Optional, List, Literal, Dict -from pydantic import BaseModel, validator +from openprotein.pydantic import BaseModel, validator from io import BytesIO import random import requests diff --git a/openprotein/api/predict.py b/openprotein/api/predict.py index be6e6fa..6475e3b 100644 --- a/openprotein/api/predict.py +++ b/openprotein/api/predict.py @@ -1,5 +1,5 @@ from typing import Optional, List, Union, Any, Dict, Literal -from pydantic import BaseModel +from openprotein.pydantic import BaseModel, root_validator from openprotein.base import APISession from openprotein.api.jobs import AsyncJobFuture @@ -16,6 +16,29 @@ class SequenceDataset(BaseModel): sequences: List[str] +class _Prediction(BaseModel): + """Prediction details.""" + + @root_validator(pre=True) + def extract_pred(cls, values): + p = values.pop("properties") + name = list(p.keys())[0] + ymu = p[name]["y_mu"] + yvar = p[name]["y_var"] + p["name"] = name + p["y_mu"] = ymu + p["y_var"] = yvar + + values.update(p) + return values + + model_id: str + model_name: str + y_mu: Optional[float] = None + y_var: Optional[float] = None + name: Optional[str] + + class Prediction(BaseModel): """Prediction details.""" @@ -35,6 +58,17 @@ class PredictJobBase(Job): class PredictJob(PredictJobBase): """Properties about predict job returned via API.""" + @root_validator(pre=True) + def extract_pred(cls, values): + # Extracting 'predictions' and 'sequences' from the input values + v = values.pop("result") + preds = [i["predictions"] for i in v] + seqs = [i["sequence"] for i in v] + values["result"] = [ + {"sequence": i, "predictions": p} for i, p in zip(seqs, preds) + ] + return values + class SequencePrediction(BaseModel): """Sequence prediction.""" @@ -42,7 +76,7 @@ class SequencePrediction(BaseModel): predictions: List[Prediction] = [] result: Optional[List[SequencePrediction]] = None - job_type: Literal[JobType.workflow_predict] = JobType.workflow_predict + job_type: str @register_job_type(JobType.worflow_predict_single_site) @@ -128,9 +162,9 @@ def _create_predict_job( def create_predict_job( session: APISession, sequences: SequenceDataset, - train_job: Any, + train_job: Optional[Any] = None, model_ids: Optional[List[str]] = None, -) -> PredictJob: +) -> FutureBase: """ Creates a predict job with a given set of sequences and a train job. @@ -167,8 +201,9 @@ def create_predict_job( model_ids = [model_ids] endpoint = "v1/workflow/predict" payload = {"sequences": sequences.sequences} + train_job_id = train_job.id if train_job is not None else None return _create_predict_job( - session, endpoint, payload, model_ids=model_ids, train_job_id=train_job.id + session, endpoint, payload, model_ids=model_ids, train_job_id=train_job_id ) @@ -177,7 +212,7 @@ def create_predict_single_site( sequence: SequenceData, train_job: Any, model_ids: Optional[List[str]] = None, -) -> PredictJob: +) -> FutureBase: """ Creates a predict job for single site mutants with a given sequence and a train job. @@ -318,6 +353,7 @@ class PredictFutureMixin: session: APISession job: PredictJob + id: Optional[str] = None def get_results( self, page_size: Optional[int] = None, page_offset: Optional[int] = None @@ -344,6 +380,7 @@ def get_results( HTTPError If the GET request does not succeed. """ + assert self.id is not None if "single_site" in self.job.job_type: return get_single_site_prediction_results( self.session, self.id, page_size, page_offset @@ -352,7 +389,7 @@ def get_results( return get_prediction_results(self.session, self.id, page_size, page_offset) -class PredictFuture(PredictFutureMixin, AsyncJobFuture, FutureBase): +class PredictFuture(PredictFutureMixin, AsyncJobFuture, FutureBase): # type: ignore """Future Job for manipulating results""" job_type = [JobType.workflow_predict, JobType.worflow_predict_single_site] @@ -372,24 +409,29 @@ def id(self): return self.job.job_id def _fmt_results(self, results): - properties = set(results[0].dict()["predictions"][0]["properties"].keys()) + properties = set( + list(i["properties"].keys())[0] for i in results[0].dict()["predictions"] + ) dict_results = {} for p in properties: dict_results[p] = {} for i, r in enumerate(results): s = r.sequence - props = r.predictions[0].properties[p] + props = [i.properties[p] for i in r.predictions if p in i.properties][0] dict_results[p][s] = {"mean": props["y_mu"], "variance": props["y_var"]} + dict_results return dict_results def _fmt_ssp_results(self, results): - properties = set(results[0].dict()["predictions"][0]["properties"].keys()) + properties = set( + list(i["properties"].keys())[0] for i in results[0].dict()["predictions"] + ) dict_results = {} for p in properties: dict_results[p] = {} for i, r in enumerate(results): - s = f"{r.position+1}{r.amino_acid}" - props = r.predictions[0].properties[p] + s = s = f"{r.position+1}{r.amino_acid}" + props = [i.properties[p] for i in r.predictions if p in i.properties][0] dict_results[p][s] = {"mean": props["y_mu"], "variance": props["y_var"]} return dict_results @@ -408,19 +450,21 @@ def get(self, verbose: bool = False) -> Dict: """ step = self.page_size - results = [] + results: List = [] num_returned = step offset = 0 while num_returned >= step: try: response = self.get_results(page_offset=offset, page_size=step) + assert isinstance(response.result, list) results += response.result num_returned = len(response.result) offset += num_returned except APIError as exc: if verbose: print(f"Failed to get results: {exc}") + if self.job.job_type == JobType.workflow_predict: return self._fmt_results(results) else: @@ -444,7 +488,7 @@ def __init__(self, session: APISession): def create_predict_job( self, sequences: List, - train_job: Any, + train_job: Optional[Any] = None, model_ids: Optional[List[str]] = None, ) -> PredictFuture: """ @@ -475,17 +519,18 @@ def create_predict_job( APIError If the backend refuses the job (due to sequence length or invalid inputs) """ - if train_job.assaymetadata is not None: - if train_job.assaymetadata.sequence_length is not None: - if any( - [ - train_job.assaymetadata.sequence_length != len(s) - for s in sequences - ] - ): - raise InvalidParameterError( - f"Predict sequences length {len(sequences[0])} != training assaydata ({train_job.assaymetadata.sequence_length})" - ) + if train_job is not None: + if train_job.assaymetadata is not None: + if train_job.assaymetadata.sequence_length is not None: + if any( + [ + train_job.assaymetadata.sequence_length != len(s) + for s in sequences + ] + ): + raise InvalidParameterError( + f"Predict sequences length {len(sequences[0])} != training assaydata ({train_job.assaymetadata.sequence_length})" + ) if not train_job.done(): print(f"WARNING: training job has status {train_job.status}") # raise InvalidParameterError( @@ -494,7 +539,7 @@ def create_predict_job( sequence_dataset = SequenceDataset(sequences=sequences) return create_predict_job( - self.session, sequence_dataset, train_job, model_ids=model_ids + self.session, sequence_dataset, train_job, model_ids=model_ids # type: ignore ) def create_predict_single_site( @@ -546,5 +591,5 @@ def create_predict_single_site( sequence_dataset = SequenceData(sequence=sequence) return create_predict_single_site( - self.session, sequence_dataset, train_job, model_ids=model_ids + self.session, sequence_dataset, train_job, model_ids=model_ids # type: ignore ) diff --git a/openprotein/api/train.py b/openprotein/api/train.py index a557811..0186006 100644 --- a/openprotein/api/train.py +++ b/openprotein/api/train.py @@ -1,7 +1,7 @@ from typing import Optional, List, Union -from pydantic import BaseModel +from openprotein.pydantic import BaseModel -import pydantic +import openprotein.pydantic as pydantic from openprotein.base import APISession from openprotein.api.jobs import AsyncJobFuture, Job from openprotein.futures import FutureFactory, FutureBase diff --git a/openprotein/futures.py b/openprotein/futures.py index 619f089..3cb39cd 100644 --- a/openprotein/futures.py +++ b/openprotein/futures.py @@ -1,14 +1,15 @@ # Store for Model and Future classes from openprotein.jobs import job_get, ResultsParser -from typing import Optional +from typing import Optional, Any class FutureBase: """Base class for all Future classes. - + This class needs to be directly inherited for class discovery.""" + # overridden by subclasses - job_type = None + job_type: Optional[Any] = None @classmethod def get_job_type(cls): @@ -23,7 +24,9 @@ class FutureFactory: """Factory class for creating Future instances based on job_type.""" @staticmethod - def create_future(session, job_id:Optional[str] = None, response:dict =None, **kwargs): + def create_future( + session, job_id: Optional[str] = None, response: Optional[dict] = None, **kwargs + ): """ Create and return an instance of the appropriate Future class based on the job type. @@ -36,22 +39,22 @@ def create_future(session, job_id:Optional[str] = None, response:dict =None, **k - An instance of the appropriate Future class. """ - # parse job + # parse job if job_id: job = job_get(session, job_id) else: - if 'job' not in kwargs: + if "job" not in kwargs: job = ResultsParser.parse_obj(response) else: job = kwargs.pop("job") - + # Dynamically discover all subclasses of FutureBase future_classes = FutureBase.__subclasses__() - kwargs = {k:v for k,v in kwargs.items() if v is not None} + kwargs = {k: v for k, v in kwargs.items() if v is not None} # Find the Future class that matches the job type for future_class in future_classes: if job.job_type in future_class.get_job_type(): - return future_class(session=session, job=job, **kwargs) + return future_class(session=session, job=job, **kwargs) # type: ignore raise ValueError(f"Unsupported job type: {job.job_type}") diff --git a/openprotein/jobs.py b/openprotein/jobs.py index b3ce71c..73fd6aa 100644 --- a/openprotein/jobs.py +++ b/openprotein/jobs.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import Optional, Literal import time -from pydantic import BaseModel, Field +from openprotein.pydantic import BaseModel, Field from openprotein.errors import TimeoutException from openprotein.base import APISession import openprotein.config as config @@ -13,7 +13,7 @@ class Job(BaseModel): status: JobStatus - job_id: str + job_id: Optional[str] # must be optional as predict can return None # new emb service get doesnt have job_type job_type: Optional[Literal[tuple(member.value for member in JobType.__members__.values())]] # type: ignore created_date: Optional[datetime] = None diff --git a/openprotein/pydantic.py b/openprotein/pydantic.py new file mode 100644 index 0000000..295e901 --- /dev/null +++ b/openprotein/pydantic.py @@ -0,0 +1,20 @@ +try: + from pydantic.v1 import ( + BaseModel, + Field, + ConfigDict, + validator, + root_validator, + parse_obj_as, + ) + import pydantic.v1 as pydantic +except ImportError: + from pydantic import ( + BaseModel, + Field, + ConfigDict, + validator, + root_validator, + parse_obj_as, + ) + import pydantic diff --git a/openprotein/schemas.py b/openprotein/schemas.py index ca19eaa..1af4968 100644 --- a/openprotein/schemas.py +++ b/openprotein/schemas.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, ConfigDict +from openprotein.pydantic import BaseModel, ConfigDict from enum import Enum