Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Created generic prediction + annotation mechanism #571

Merged
merged 9 commits into from
Jan 7, 2025
116 changes: 98 additions & 18 deletions dagshub/data_engine/model/query_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,26 +495,11 @@
pre_hook: (optional, default: identity function) function that runs before datapoint is sent to the model
post_hook: (optional, default: identity function) function that converts mlflow model output
to the desired format
batch_size: (optional, default: 1) function that sets batch_size
batch_size: (optional, default: 1) number of datapoints to run inference on simultaneously
log_to_field: (optional, default: 'prediction') write prediction results to metadata logged in data engine.
If None, just returns predictions.
(in addition to logging to a field, iff that parameter is set)
"""

# to support depedency-free dataloading, `Batcher` is a barebones dataloader that sets up batched inference
class Batcher:
def __init__(self, dset, batch_size):
self.dset = dset
self.batch_size = batch_size

def __iter__(self):
self.curr_idx = 0
return self

def __next__(self):
self.curr_idx += self.batch_size
return [self.dset[idx] for idx in range(self.curr_idx - self.batch_size, self.curr_idx)]

if not host:
host = self.datasource.source.repoApi.host
prev_uri = mlflow.get_tracking_uri()
Expand All @@ -534,7 +519,7 @@
task = progress.add_task("Running inference...", total=len(dset))
with progress:
for idx, local_paths in enumerate(
Batcher(dset, batch_size) if batch_size != 1 else dset
_Batcher(dset, batch_size) if batch_size != 1 else dset
): # encapsulates dataset with batcher if necessary and iterates over it
for prediction, remote_path in zip(
post_hook(model.predict(pre_hook(local_paths))),
Expand All @@ -550,7 +535,15 @@
with self.datasource.metadata_context() as ctx:
for remote_path in predictions:
ctx.update_metadata(
remote_path, {log_to_field: json.dumps(predictions[remote_path]).encode("utf-8")}
remote_path,
{
log_to_field: json.dumps(predictions[remote_path][0]).encode("utf-8"),
f"{log_to_field}_score": (
None
if len(predictions[remote_path]) == 1
else json.dumps(predictions[remote_path][1]).encode("utf-8")
),
},
)
return predictions

Expand Down Expand Up @@ -856,6 +849,75 @@

return sess

def predict_with_callable(
self,
generic: Callable,
batch_size: int = 1,
jinensetpal marked this conversation as resolved.
Show resolved Hide resolved
log_to_field: Optional[str] = None,
) -> Optional[list]:
"""
jinensetpal marked this conversation as resolved.
Show resolved Hide resolved
Sends all the datapoints returned in this QueryResult as prediction targets for
a generic object.

Args:
generic: function that handles batched input and returns predictions in the form of (prediction, prediction_score: Optional[float] = None)

Check failure on line 863 in dagshub/data_engine/model/query_result.py

View workflow job for this annotation

GitHub Actions / Flake8

dagshub/data_engine/model/query_result.py#L863

Line too long (150 > 120 characters) (E501)
batch_size: (optional, default: 1) number of datapoints to run inference on simultaneously
log_to_field: (optional, default: 'prediction') write prediction results to metadata logged in data engine.
If None, just returns predictions.
(in addition to logging to a field, iff that parameter is set)
"""
dset = DagsHubDataset(self, tensorizers=[lambda x: x])

predictions = {}
progress = get_rich_progress(rich.progress.MofNCompleteColumn())
task = progress.add_task("Running inference...", total=len(dset))
with progress:
for idx, local_paths in enumerate(
_Batcher(dset, batch_size) if batch_size != 1 else dset
): # encapsulates dataset with batcher if necessary and iterates over it
for prediction, remote_path in zip(
generic(local_paths),
[result.path for result in self[idx * batch_size : (idx + 1) * batch_size]],
):
predictions[remote_path] = prediction
progress.update(task, advance=batch_size, refresh=True)

if log_to_field:
with self.datasource.metadata_context() as ctx:
for remote_path in predictions:
ctx.update_metadata(
remote_path,
{
log_to_field: json.dumps(predictions[remote_path][0]).encode("utf-8"),
f"{log_to_field}_score": (
None
if len(predictions[remote_path]) == 1
else json.dumps(predictions[remote_path][1]).encode("utf-8")
),
jinensetpal marked this conversation as resolved.
Show resolved Hide resolved
},
)
return predictions

def annotate_with_callable(self, generic, batch_size: int = 1, log_to_field: str = "annotation"):
"""
Sends all the datapoints returned in this QueryResult as prediction targets for
a generic object.

Args:
generic: function that handles batched input and returns predictions in the form of (prediction, prediction_score: Optional[float] = None)

Check failure on line 907 in dagshub/data_engine/model/query_result.py

View workflow job for this annotation

GitHub Actions / Flake8

dagshub/data_engine/model/query_result.py#L907

Line too long (150 > 120 characters) (E501)
batch_size: (optional, default: 1) number of datapoints to run inference on simultaneously
log_to_field: (optional, default: 'prediction') write prediction results to metadata logged in data engine.
"""
jinensetpal marked this conversation as resolved.
Show resolved Hide resolved
if log_to_field is None:
raise ValueError("`log_to_field` == None, there is nothing to do!")

self.predict_with_callable(
generic,
batch_size=batch_size,
log_to_field=log_to_field,
)
self.datasource.metadata_field(log_to_field).set_annotation().apply()

def annotate_with_mlflow_model(
self,
repo: str,
Expand All @@ -880,6 +942,9 @@
mlflow model output converts to labelstudio format
batch_size: (optional, default: 1) batched annotation size
"""
if log_to_field is None:
raise ValueError("`log_to_field` == None, there is nothing to do!")

jinensetpal marked this conversation as resolved.
Show resolved Hide resolved
self.predict_with_mlflow_model(
repo,
name,
Expand Down Expand Up @@ -961,3 +1026,18 @@
assert self.query_data_time is not None
artifact_name = self.datasource._get_mlflow_artifact_name("log", self.query_data_time)
return self.datasource._log_to_mlflow(artifact_name, run, self.query_data_time)


# to support depedency-free dataloading, `_Batcher` is a barebones dataloader that sets up batched inference
class _Batcher:
def __init__(self, dset, batch_size):
self.dset = dset
self.batch_size = batch_size

def __iter__(self):
self.curr_idx = 0
return self

def __next__(self):
self.curr_idx += self.batch_size
return [self.dset[idx] for idx in range(self.curr_idx - self.batch_size, self.curr_idx)]
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ pytest-mock==3.14.0
fiftyone==0.23.8
datasets==2.19.1
ultralytics==8.3.47
dagshub-annotation-converter>=0.1.0
jinensetpal marked this conversation as resolved.
Show resolved Hide resolved
Loading