Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Created generic prediction + annotation mechanism #571

Merged
merged 9 commits into from
Jan 7, 2025
128 changes: 87 additions & 41 deletions dagshub/data_engine/model/query_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@

logger = logging.getLogger(__name__)

CustomPredictor = Callable[
[
List[str],
],
List[Tuple[Any, Optional[float]]],
]


class VisualizeError(Exception):
""":meta private:"""
Expand Down Expand Up @@ -570,21 +577,6 @@ def predict_with_mlflow_model(
Default batch size is 1, but it is still being sent as a list for consistency.
log_to_field: If set, writes prediction results to this metadata field in the datasource.
"""

# to support depedency-free dataloading, `Batcher` is a barebones dataloader that sets up batched inference
class Batcher:
def __init__(self, dset, batch_size):
self.dset = dset
self.batch_size = batch_size

def __iter__(self):
self.curr_idx = 0
return self

def __next__(self):
self.curr_idx += self.batch_size
return [self.dset[idx] for idx in range(self.curr_idx - self.batch_size, self.curr_idx)]

if not host:
host = self.datasource.source.repoApi.host

Expand All @@ -609,32 +601,7 @@ def __next__(self):
if "torch" in loader_module:
model.predict = model.__call__

dset = DagsHubDataset(self, tensorizers=[lambda x: x])

predictions = {}
progress = get_rich_progress(rich.progress.MofNCompleteColumn())
task = progress.add_task("Running inference...", total=len(dset))
with progress:
for idx, local_paths in enumerate(
Batcher(dset, batch_size) if batch_size != 1 else dset
): # encapsulates dataset with batcher if necessary and iterates over it
for prediction, remote_path in zip(
post_hook(model.predict(pre_hook(local_paths))),
[result.path for result in self[idx * batch_size : (idx + 1) * batch_size]],
):
predictions[remote_path] = {
"data": {"image": multi_urljoin(self.datasource.source.root_raw_path, remote_path)},
"annotations": [prediction],
}
progress.update(task, advance=batch_size, refresh=True)

if log_to_field:
with self.datasource.metadata_context() as ctx:
for remote_path in predictions:
ctx.update_metadata(
remote_path, {log_to_field: json.dumps(predictions[remote_path]).encode("utf-8")}
)
return predictions
return self.generate_predictions(lambda x: post_hook(model.predict(pre_hook(x))), batch_size, log_to_field)

def get_annotations(self, **kwargs) -> "QueryResult":
"""
Expand Down Expand Up @@ -876,6 +843,14 @@ def to_voxel51_dataset(self, **kwargs) -> "fo.Dataset":
ds.merge_samples(samples)
return ds

@staticmethod
def _get_predict_dict(predictions, remote_path, log_to_field):
res = {log_to_field: json.dumps(predictions[remote_path][0]).encode("utf-8")}
if len(predictions[remote_path]) == 2:
res[f"{log_to_field}_score"] = predictions[remote_path][1]

return res

def _check_downloaded_dataset_size(self):
download_size_prompt_threshold = 100 * (2**20) # 100 Megabytes
dp_size = self._calculate_datapoint_size()
Expand Down Expand Up @@ -938,6 +913,62 @@ def visualize(self, visualizer: Literal["dagshub", "fiftyone"] = "dagshub", **kw

return sess

def generate_predictions(
self,
predict_fn: CustomPredictor,
batch_size: int = 1,
log_to_field: Optional[str] = None,
) -> Dict[str, Tuple[str, Optional[float]]]:
"""
Sends all the datapoints returned in this QueryResult as prediction targets for
a generic object.

Args:
predict_fn: function that handles batched input and returns predictions with an optional prediction score.
batch_size: (optional, default: 1) number of datapoints to run inference on simultaneously
log_to_field: (optional, default: 'prediction') write prediction results to metadata logged in data engine.
If None, just returns predictions.
(in addition to logging to a field, iff that parameter is set)
"""
dset = DagsHubDataset(self, tensorizers=[lambda x: x])

predictions = {}
progress = get_rich_progress(rich.progress.MofNCompleteColumn())
task = progress.add_task("Running inference...", total=len(dset))
with progress:
for idx, local_paths in enumerate(
_Batcher(dset, batch_size) if batch_size != 1 else dset
): # encapsulates dataset with batcher if necessary and iterates over it
for prediction, remote_path in zip(
predict_fn(local_paths),
[result.path for result in self[idx * batch_size : (idx + 1) * batch_size]],
):
predictions[remote_path] = prediction
progress.update(task, advance=batch_size, refresh=True)

if log_to_field:
with self.datasource.metadata_context() as ctx:
for remote_path in predictions:
ctx.update_metadata(remote_path, self._get_predict_dict(predictions, remote_path, log_to_field))
return predictions

def generate_annotations(self, predict_fn: CustomPredictor, batch_size: int = 1, log_to_field: str = "annotation"):
"""
Sends all the datapoints returned in this QueryResult as prediction targets for
a generic object.

Args:
predict_fn: function that handles batched input and returns predictions with an optional prediction score.
batch_size: (optional, default: 1) number of datapoints to run inference on simultaneously.
log_to_field: (optional, default: 'prediction') write prediction results to metadata logged in data engine.
"""
self.generate_predictions(
predict_fn,
batch_size=batch_size,
log_to_field=log_to_field,
)
self.datasource.metadata_field(log_to_field).set_annotation().apply()

def annotate(
self,
open_project: bool = True,
Expand Down Expand Up @@ -1007,3 +1038,18 @@ def log_to_mlflow(self, run: Optional["mlflow.entities.Run"] = None) -> "mlflow.
assert self.query_data_time is not None
artifact_name = self.datasource._get_mlflow_artifact_name("log", self.query_data_time)
return self.datasource._log_to_mlflow(artifact_name, run, self.query_data_time)


# to support depedency-free dataloading, `_Batcher` is a barebones dataloader that sets up batched inference
class _Batcher:
def __init__(self, dset, batch_size):
self.dset = dset
self.batch_size = batch_size

def __iter__(self):
self.curr_idx = 0
return self

def __next__(self):
self.curr_idx += self.batch_size
return [self.dset[idx] for idx in range(self.curr_idx - self.batch_size, self.curr_idx)]
Loading