Skip to content

Commit

Permalink
add type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
jan-meyer-1986 committed Nov 22, 2024
1 parent e0fc867 commit 4797edc
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions mlflow/getml/autologging.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def _extract_engine_system_metrics(
step += 1
stop_event.wait(1)

def patched_fit_mlflow(original, self: getml.Pipeline, *args, **kwargs):
def patched_fit_mlflow(original, self: getml.Pipeline, *args, **kwargs) -> getml.pipeline.Pipeline:

Check failure on line 157 in mlflow/getml/autologging.py

View workflow job for this annotation

GitHub Actions / lint

Line too long (103 > 100). See https://docs.astral.sh/ruff/rules/E501 for how to fix this error.
autologging_client = MlflowAutologgingQueueingClient()
assert (active_run := mlflow.active_run())
run_id = active_run.info.run_id
Expand Down Expand Up @@ -187,7 +187,7 @@ def patched_fit_mlflow(original, self: getml.Pipeline, *args, **kwargs):
autologging_client.flush(synchronous=True)
return fit_output

def patched_score_method(original, self: getml.Pipeline, *args, **kwargs):
def patched_score_method(original, self: getml.Pipeline, *args, **kwargs) -> getml.pipeline.Scores:

Check failure on line 190 in mlflow/getml/autologging.py

View workflow job for this annotation

GitHub Actions / lint

Line too long (103 > 100). See https://docs.astral.sh/ruff/rules/E501 for how to fix this error.

target = self.data_model.population.roles.target[0]
pop_df = args[0].population.to_pandas()
Expand All @@ -202,9 +202,13 @@ def patched_score_method(original, self: getml.Pipeline, *args, **kwargs):
model_type=["regressor" if self.is_regression else "classifier"][0],
evaluators=["default"],
)

return original(self, *args, **kwargs)
def _log_pretraining_metadata(autologging_client, self: getml.Pipeline, run_id, *args):

def _log_pretraining_metadata(autologging_client: MlflowAutologgingQueueingClient,
self: getml.Pipeline,
run_id: str,
*args
) -> dict:

pipeline_log_info = _extract_pipeline_informations(self)
autologging_client.log_params(
Expand Down

0 comments on commit 4797edc

Please sign in to comment.