Skip to content

Commit

Permalink
better error handling and docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
jinensetpal committed Dec 28, 2024
1 parent 40a4c05 commit efa6263
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions dagshub/data_engine/model/query_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,11 @@ def predict_with_mlflow_model(
remote_path,
{
log_to_field: json.dumps(predictions[remote_path][0]).encode("utf-8"),
f"{log_to_field}_score": json.dumps(predictions[remote_path][1]).encode("utf-8"),
f"{log_to_field}_score": (
None
if len(predictions[remote_path]) == 1
else json.dumps(predictions[remote_path][1]).encode("utf-8")
),
},
)
return predictions
Expand Down Expand Up @@ -856,7 +860,7 @@ def predict_with_callable(
a generic object.
Args:
generic: function that returns predictions in the form of (prediction, prediction_score: Optional[float] = None)
generic: function that handles batched input and returns predictions in the form of (prediction, prediction_score: Optional[float] = None)

Check failure on line 863 in dagshub/data_engine/model/query_result.py

View workflow job for this annotation

GitHub Actions / Flake8

dagshub/data_engine/model/query_result.py#L863

Line too long (150 > 120 characters) (E501)
batch_size: (optional, default: 1) number of datapoints to run inference on simultaneously
log_to_field: (optional, default: 'prediction') write prediction results to metadata logged in data engine.
If None, just returns predictions.
Expand Down Expand Up @@ -885,7 +889,11 @@ def predict_with_callable(
remote_path,
{
log_to_field: json.dumps(predictions[remote_path][0]).encode("utf-8"),
f"{log_to_field}_score": json.dumps(predictions[remote_path][1]).encode("utf-8"),
f"{log_to_field}_score": (
None
if len(predictions[remote_path]) == 1
else json.dumps(predictions[remote_path][1]).encode("utf-8")
),
},
)
return predictions
Expand All @@ -896,7 +904,7 @@ def annotate_with_callable(self, generic, batch_size: int = 1, log_to_field: str
a generic object.
Args:
generic: function that returns (annotation, prediction_score)
generic: function that handles batched input and returns predictions in the form of (prediction, prediction_score: Optional[float] = None)

Check failure on line 907 in dagshub/data_engine/model/query_result.py

View workflow job for this annotation

GitHub Actions / Flake8

dagshub/data_engine/model/query_result.py#L907

Line too long (150 > 120 characters) (E501)
batch_size: (optional, default: 1) number of datapoints to run inference on simultaneously
log_to_field: (optional, default: 'prediction') write prediction results to metadata logged in data engine.
"""
Expand Down

0 comments on commit efa6263

Please sign in to comment.