From e5cacc47c48e9283923c594526e8c4716b0f7b87 Mon Sep 17 00:00:00 2001 From: Derek Tapley Date: Thu, 3 Oct 2024 14:44:19 -0400 Subject: [PATCH] Issue mlflow#192: Added support for MLFlow authentication --- mlflow_export_import/client/http_client.py | 11 ++++++++--- mlflow_export_import/client/mlflow_auth_utils.py | 3 ++- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/mlflow_export_import/client/http_client.py b/mlflow_export_import/client/http_client.py index 475bb29d..3a11757e 100644 --- a/mlflow_export_import/client/http_client.py +++ b/mlflow_export_import/client/http_client.py @@ -1,6 +1,7 @@ from abc import abstractmethod, ABCMeta import os import json +from requests.auth import HTTPBasicAuth import requests import click from mlflow_export_import.common import MlflowExportImportException @@ -96,11 +97,14 @@ def __init__(self, api_name, host=None, token=None): self.host = host self.api_uri = os.path.join(host, api_name) self.token = token + user = os.environ.get("MLFLOW_TRACKING_USERNAME") + password = os.environ.get("MLFLOW_TRACKING_PASSWORD") + self.auth = HTTPBasicAuth(user, password) if user and password else None def _get(self, resource, params=None): uri = self._mk_uri(resource) - rsp = requests.get(uri, headers=self._mk_headers(), data=params, timeout=_TIMEOUT) + rsp = requests.get(uri, headers=self._mk_headers(), data=params, auth=self.auth, timeout=_TIMEOUT) return self._check_response(rsp, params) @@ -151,7 +155,7 @@ def patch(self, resource, data=None): def _delete(self, resource): uri = self._mk_uri(resource) - rsp = requests.delete(uri, headers=self._mk_headers(), timeout=_TIMEOUT) + rsp = requests.delete(uri, headers=self._mk_headers(), auth=self.auth, timeout=_TIMEOUT) return self._check_response(rsp) def delete(self, resource): @@ -163,7 +167,7 @@ def delete(self, resource): def _mutator(self, method, resource, data=None): uri = self._mk_uri(resource) - rsp = method(uri, headers=self._mk_headers(), data=data, timeout=_TIMEOUT) + rsp = method(uri, headers=self._mk_headers(), data=data, auth=self.auth, timeout=_TIMEOUT) return self._check_response(rsp) @@ -265,6 +269,7 @@ def __init__(self, host=None, token=None): type=str, required=False ) + def main(api, resource, method, params, data, output_file): print("Options:") for k,v in locals().items(): diff --git a/mlflow_export_import/client/mlflow_auth_utils.py b/mlflow_export_import/client/mlflow_auth_utils.py index 18e5b198..6a3c601b 100644 --- a/mlflow_export_import/client/mlflow_auth_utils.py +++ b/mlflow_export_import/client/mlflow_auth_utils.py @@ -1,6 +1,7 @@ from mlflow_export_import.client import databricks_cli_utils from mlflow_export_import.common import MlflowExportImportException from mlflow_export_import.common import utils +import os _logger = utils.getLogger(__name__) @@ -23,7 +24,7 @@ def get_mlflow_host_token(): if not uri.startswith("http"): _raise_exception(uri) else: - return (uri, None) + return (uri, os.environ.get("MLFLOW_TRACKING_TOKEN")) else: _raise_exception(uri)