diff --git a/mlflow/getml/autologging.py b/mlflow/getml/autologging.py index 88fdc6afec411..e6531d68563ca 100644 --- a/mlflow/getml/autologging.py +++ b/mlflow/getml/autologging.py @@ -154,7 +154,7 @@ def _extract_engine_system_metrics( step += 1 stop_event.wait(1) - def patched_fit_mlflow(original, self: getml.Pipeline, *args, **kwargs): + def patched_fit_mlflow(original, self: getml.Pipeline, *args, **kwargs) -> getml.pipeline.Pipeline: autologging_client = MlflowAutologgingQueueingClient() assert (active_run := mlflow.active_run()) run_id = active_run.info.run_id @@ -187,7 +187,7 @@ def patched_fit_mlflow(original, self: getml.Pipeline, *args, **kwargs): autologging_client.flush(synchronous=True) return fit_output - def patched_score_method(original, self: getml.Pipeline, *args, **kwargs): + def patched_score_method(original, self: getml.Pipeline, *args, **kwargs) -> getml.pipeline.Scores: target = self.data_model.population.roles.target[0] pop_df = args[0].population.to_pandas() @@ -202,9 +202,13 @@ def patched_score_method(original, self: getml.Pipeline, *args, **kwargs): model_type=["regressor" if self.is_regression else "classifier"][0], evaluators=["default"], ) - return original(self, *args, **kwargs) - def _log_pretraining_metadata(autologging_client, self: getml.Pipeline, run_id, *args): + + def _log_pretraining_metadata(autologging_client: MlflowAutologgingQueueingClient, + self: getml.Pipeline, + run_id: str, + *args + ) -> dict: pipeline_log_info = _extract_pipeline_informations(self) autologging_client.log_params(