Skip to content

Commit

Permalink
feat(ingest/mlflow): Support onfigurable base_external_url
Browse files Browse the repository at this point in the history
  • Loading branch information
asikowitz committed Dec 18, 2024
1 parent 8c724db commit 8b14334
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 5 deletions.
35 changes: 30 additions & 5 deletions metadata-ingestion/src/datahub/ingestion/source/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,30 @@
class MLflowConfig(EnvConfigMixin):
tracking_uri: Optional[str] = Field(
default=None,
description="Tracking server URI. If not set, an MLflow default tracking_uri is used (local `mlruns/` directory or `MLFLOW_TRACKING_URI` environment variable)",
description=(
"Tracking server URI. If not set, an MLflow default tracking_uri is used"
" (local `mlruns/` directory or `MLFLOW_TRACKING_URI` environment variable)"
),
)
registry_uri: Optional[str] = Field(
default=None,
description="Registry server URI. If not set, an MLflow default registry_uri is used (value of tracking_uri or `MLFLOW_REGISTRY_URI` environment variable)",
description=(
"Registry server URI. If not set, an MLflow default registry_uri is used"
" (value of tracking_uri or `MLFLOW_REGISTRY_URI` environment variable)"
),
)
model_name_separator: str = Field(
default="_",
description="A string which separates model name from its version (e.g. model_1 or model-1)",
)
base_external_url: Optional[str] = Field(
default=None,
description=(
"Base URL to use when constructing external URLs to MLflow."
" If not set, tracking_uri is used if it's an HTTP URL."
" If neither is set, external URLs are not generated."
),
)


@dataclass
Expand Down Expand Up @@ -279,12 +293,23 @@ def _make_ml_model_urn(self, model_version: ModelVersion) -> str:
)
return urn

def _make_external_url(self, model_version: ModelVersion) -> Union[None, str]:
def _get_base_external_url_from_tracking_uri(self) -> Optional[str]:
if isinstance(
self.client.tracking_uri, str
) and self.client.tracking_uri.startswith("http"):
return self.client.tracking_uri
else:
return None

def _make_external_url(self, model_version: ModelVersion) -> Optional[str]:
"""
Generate URL for a Model Version to MLflow UI.
"""
base_uri = self.client.tracking_uri
if base_uri.startswith("http"):
base_uri = (
self.config.base_external_url
or self._get_base_external_url_from_tracking_uri()
)
if base_uri:
return f"{base_uri.rstrip('/')}/#/models/{model_version.name}/versions/{model_version.version}"
else:
return None
Expand Down
13 changes: 13 additions & 0 deletions metadata-ingestion/tests/unit/test_mlflow_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,16 @@ def test_make_external_link_remote(source, model_version):
url = source._make_external_url(model_version)

assert url == expected_url


def test_make_external_link_remote_via_config(source, model_version):
custom_base_url = "https://custom-server.org"
source.config.base_external_url = custom_base_url
source.client = MlflowClient(
tracking_uri="https://dummy-mlflow-tracking-server.org"
)
expected_url = f"{custom_base_url}/#/models/{model_version.name}/versions/{model_version.version}"

url = source._make_external_url(model_version)

assert url == expected_url

0 comments on commit 8b14334

Please sign in to comment.