Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Created generic prediction + annotation mechanism #571

Merged
merged 9 commits into from
Jan 7, 2025
163 changes: 120 additions & 43 deletions dagshub/data_engine/model/query_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import field, dataclass
from os import PathLike
from pathlib import Path
from typing import TYPE_CHECKING, List, Dict, Any, Optional, Union, Tuple, Literal, Callable
from typing import TYPE_CHECKING, List, Dict, Any, Optional, Union, Tuple, Literal, Callable, Protocol
import json
import os
import os.path
Expand Down Expand Up @@ -59,6 +59,10 @@
logger = logging.getLogger(__name__)


class CustomPredictor(Protocol):
def __call__(self, local_paths: List[str]) -> List[Tuple[Any, Optional[float]]]: ...

jinensetpal marked this conversation as resolved.
Show resolved Hide resolved

class VisualizeError(Exception):
""":meta private:"""

Expand Down Expand Up @@ -559,7 +563,6 @@

Args:
repo: repository to extract the model from
name: name of the model in the repository's MLflow registry.
host: address of the DagsHub instance with the repo to load the model from.
Set it if the model is hosted on a different DagsHub instance than the datasource.
version: version of the model in the mlflow registry.
Expand All @@ -570,21 +573,6 @@
Default batch size is 1, but it is still being sent as a list for consistency.
log_to_field: If set, writes prediction results to this metadata field in the datasource.
"""

# to support depedency-free dataloading, `Batcher` is a barebones dataloader that sets up batched inference
class Batcher:
def __init__(self, dset, batch_size):
self.dset = dset
self.batch_size = batch_size

def __iter__(self):
self.curr_idx = 0
return self

def __next__(self):
self.curr_idx += self.batch_size
return [self.dset[idx] for idx in range(self.curr_idx - self.batch_size, self.curr_idx)]

if not host:
host = self.datasource.source.repoApi.host

Expand All @@ -609,32 +597,7 @@
if "torch" in loader_module:
model.predict = model.__call__

dset = DagsHubDataset(self, tensorizers=[lambda x: x])

predictions = {}
progress = get_rich_progress(rich.progress.MofNCompleteColumn())
task = progress.add_task("Running inference...", total=len(dset))
with progress:
for idx, local_paths in enumerate(
Batcher(dset, batch_size) if batch_size != 1 else dset
): # encapsulates dataset with batcher if necessary and iterates over it
for prediction, remote_path in zip(
post_hook(model.predict(pre_hook(local_paths))),
[result.path for result in self[idx * batch_size : (idx + 1) * batch_size]],
):
predictions[remote_path] = {
"data": {"image": multi_urljoin(self.datasource.source.root_raw_path, remote_path)},
"annotations": [prediction],
}
progress.update(task, advance=batch_size, refresh=True)

if log_to_field:
with self.datasource.metadata_context() as ctx:
for remote_path in predictions:
ctx.update_metadata(
remote_path, {log_to_field: json.dumps(predictions[remote_path]).encode("utf-8")}
)
return predictions
return self.generate_predictions(lambda x: post_hook(model.predict(pre_hook(x))), batch_size, log_to_field)

def get_annotations(self, **kwargs) -> "QueryResult":
"""
Expand Down Expand Up @@ -876,6 +839,13 @@
ds.merge_samples(samples)
return ds

@staticmethod
def _get_predict_dict(predictions, remote_path, log_to_field):
return {
log_to_field: json.dumps(predictions[remote_path][0]).encode("utf-8"),
f"{log_to_field}_score": (0.0 if len(predictions[remote_path]) == 1 else predictions[remote_path][1]),
jinensetpal marked this conversation as resolved.
Show resolved Hide resolved
}

def _check_downloaded_dataset_size(self):
download_size_prompt_threshold = 100 * (2**20) # 100 Megabytes
dp_size = self._calculate_datapoint_size()
Expand Down Expand Up @@ -938,6 +908,98 @@

return sess

def generate_predictions(
self,
predict_fn: CustomPredictor,
batch_size: int = 1,
log_to_field: Optional[str] = None,
) -> Dict[str, Tuple[str, Optional[float]]]:
"""
Sends all the datapoints returned in this QueryResult as prediction targets for
a generic object.

Args:
predict_fn: function that handles batched input and returns predictions with an optional prediction score.
batch_size: (optional, default: 1) number of datapoints to run inference on simultaneously
log_to_field: (optional, default: 'prediction') write prediction results to metadata logged in data engine.
If None, just returns predictions.
(in addition to logging to a field, iff that parameter is set)
"""
dset = DagsHubDataset(self, tensorizers=[lambda x: x])

predictions = {}
progress = get_rich_progress(rich.progress.MofNCompleteColumn())
task = progress.add_task("Running inference...", total=len(dset))
with progress:
for idx, local_paths in enumerate(
_Batcher(dset, batch_size) if batch_size != 1 else dset
): # encapsulates dataset with batcher if necessary and iterates over it
for prediction, remote_path in zip(
predict_fn(local_paths),
[result.path for result in self[idx * batch_size : (idx + 1) * batch_size]],
):
predictions[remote_path] = prediction
progress.update(task, advance=batch_size, refresh=True)

if log_to_field:
with self.datasource.metadata_context() as ctx:
for remote_path in predictions:
ctx.update_metadata(remote_path, self._get_predict_dict(predictions, remote_path, log_to_field))
return predictions

def generate_annotations(self, predict_fn: CustomPredictor, batch_size: int = 1, log_to_field: str = "annotation"):
"""
Sends all the datapoints returned in this QueryResult as prediction targets for
a generic object.

Args:
predict_fn: function that handles batched input and returns predictions with an optional prediction score.
batch_size: (optional, default: 1) number of datapoints to run inference on simultaneously.
log_to_field: (optional, default: 'prediction') write prediction results to metadata logged in data engine.
"""
self.generate_predictions(
predict_fn,
batch_size=batch_size,
log_to_field=log_to_field,
)
self.datasource.metadata_field(log_to_field).set_annotation().apply()

def annotate_with_mlflow_model(

Check failure on line 967 in dagshub/data_engine/model/query_result.py

View workflow job for this annotation

GitHub Actions / Flake8

dagshub/data_engine/model/query_result.py#L967

Redefinition of unused 'annotate_with_mlflow_model' from line 485 (F811)
self,
repo: str,
name: str,
post_hook: Callable = lambda x: x,
jinensetpal marked this conversation as resolved.
Show resolved Hide resolved
pre_hook: Callable = lambda x: x,
host: Optional[str] = None,
version: str = "latest",
batch_size: int = 1,
log_to_field: str = "annotation",
) -> Optional[str]:
"""
Sends all the datapoints returned in this QueryResult to an MLFlow model which automatically labels datapoints.

Args:
repo: repository to extract the model from
name: name of the model in the mlflow registry
version: (optional, default: 'latest') version of the model in the mlflow registry
pre_hook: (optional, default: identity function) function that runs
before the datapoint is sent to the model
post_hook: (optional, default: identity function) function that converts
mlflow model output converts to labelstudio format
batch_size: (optional, default: 1) batched annotation size
"""
self.predict_with_mlflow_model(
repo,
name,
host=host,
version=version,
pre_hook=pre_hook,
post_hook=post_hook,
batch_size=batch_size,
log_to_field=log_to_field,
)
self.datasource.metadata_field(log_to_field).set_annotation().apply()

def annotate(
self,
open_project: bool = True,
Expand Down Expand Up @@ -1007,3 +1069,18 @@
assert self.query_data_time is not None
artifact_name = self.datasource._get_mlflow_artifact_name("log", self.query_data_time)
return self.datasource._log_to_mlflow(artifact_name, run, self.query_data_time)


# to support depedency-free dataloading, `_Batcher` is a barebones dataloader that sets up batched inference
class _Batcher:
def __init__(self, dset, batch_size):
self.dset = dset
self.batch_size = batch_size

def __iter__(self):
self.curr_idx = 0
return self

def __next__(self):
self.curr_idx += self.batch_size
return [self.dset[idx] for idx in range(self.curr_idx - self.batch_size, self.curr_idx)]
Loading