diff --git a/mlflow_export_import/common/http_client.py b/mlflow_export_import/common/http_client.py index 2e90640..f1a6239 100644 --- a/mlflow_export_import/common/http_client.py +++ b/mlflow_export_import/common/http_client.py @@ -15,17 +15,18 @@ def __init__(self, api_name, host=None, token=None): self.api_uri = os.path.join(host,api_name) self.token = token - def _get(self, resource): + def _get(self, resource, params=None): """ Executes an HTTP GET call :param resource: Relative path name of resource such as cluster/list + :param params: Dict of query parameters """ uri = self._mk_uri(resource) - rsp = requests.get(uri, headers=self._mk_headers()) + rsp = requests.get(uri, headers=self._mk_headers(), params=params) self._check_response(rsp, uri) return rsp - def get(self, resource): - return json.loads(self._get(resource).text) + def get(self, resource, params=None): + return json.loads(self._get(resource, params).text) def post(self, resource, data): """ Executes an HTTP POST call diff --git a/mlflow_export_import/common/mlflow_utils.py b/mlflow_export_import/common/mlflow_utils.py index 4a49b92..9e16455 100644 --- a/mlflow_export_import/common/mlflow_utils.py +++ b/mlflow_export_import/common/mlflow_utils.py @@ -12,17 +12,17 @@ def dump_mlflow_info(): print(" DATABRICKS_TOKEN:", os.environ.get("DATABRICKS_TOKEN","")) def get_mlflow_host(): - ''' Returns the host (tracking URI) and token ''' + """ Returns the host (tracking URI) and token """ return get_mlflow_host_token()[0] def get_mlflow_host_token(): - ''' Returns the host (tracking URI) and token ''' - uri = os.environ.get('MLFLOW_TRACKING_URI',None) + """ Returns the host (tracking URI) and token """ + uri = os.environ.get("MLFLOW_TRACKING_URI",None) if uri is not None and uri != "databricks": return (uri,None) try: from mlflow_export_import.common import databricks_cli_utils - profile = os.environ.get('MLFLOW_PROFILE',None) + profile = os.environ.get("MLFLOW_PROFILE",None) ##host_token = databricks_cli_utils.get_host_token(profile) return databricks_cli_utils.get_host_token(profile) #except databricks_cli.utils.InvalidConfigurationError as e: @@ -30,21 +30,21 @@ def get_mlflow_host_token(): print("WARNING:",e) return (None,None) -def get_experiment(client, exp_id_or_name): - ''' Gets an experiment either by ID or name. ''' - exp = client.get_experiment_by_name(exp_id_or_name) +def get_experiment(mlflow_client, exp_id_or_name): + """ Gets an experiment either by ID or name. """ + exp = mlflow_client.get_experiment_by_name(exp_id_or_name) if exp is None: try: - exp = client.get_experiment(exp_id_or_name) + exp = mlflow_client.get_experiment(exp_id_or_name) except Exception: - raise Exception(f"Cannot find experiment ID or name '{exp_id_or_name}'. Client: {client}'") + raise Exception(f"Cannot find experiment ID or name '{exp_id_or_name}'. Client: {mlflow_client}'") return exp def set_experiment(dbx_client, exp_name): - ''' + """ Set experiment name. For Databricks, create the workspace directory if it doesn't exist. - ''' + """ from mlflow_export_import import utils if utils.importing_into_databricks(): exp_dir = os.path.dirname(exp_name) @@ -53,11 +53,11 @@ def set_experiment(dbx_client, exp_name): mlflow.set_experiment(exp_name) # BUG -def _get_experiment(client, exp_id_or_name): +def _get_experiment(mlflow_client, exp_id_or_name): try: - exp = client.get_experiment(exp_id_or_name) + exp = mlflow_client.get_experiment(exp_id_or_name) except Exception: - exp = client.get_experiment_by_name(exp_id_or_name) + exp = mlflow_client.get_experiment_by_name(exp_id_or_name) if exp is None: - raise Exception(f"Cannot find experiment ID or name '{exp_id_or_name}'. Client: {client}'") + raise Exception(f"Cannot find experiment ID or name '{exp_id_or_name}'. Client: {mlflow_client}'") return exp diff --git a/mlflow_export_import/model/export_model.py b/mlflow_export_import/model/export_model.py index c9310b2..6243c3f 100644 --- a/mlflow_export_import/model/export_model.py +++ b/mlflow_export_import/model/export_model.py @@ -19,7 +19,7 @@ def __init__(self, export_metadata_tags=False, notebook_formats=[], stages=None) def export_model(self, output_dir, model_name): path = os.path.join(output_dir,"model.json") - model = self.http_client.get(f"registered-models/get?name={model_name}") + model = self.http_client.get(f"registered-models/get", {"name": model_name}) model["registered_model"]["latest_versions"] = [] versions = self.mlflow_client.search_model_versions(f"name='{model_name}'") print(f"Found {len(versions)} versions for model {model_name}") diff --git a/mlflow_export_import/run/export_run.py b/mlflow_export_import/run/export_run.py index df1b404..acb662f 100644 --- a/mlflow_export_import/run/export_run.py +++ b/mlflow_export_import/run/export_run.py @@ -74,9 +74,14 @@ def export_notebook(self, output_dir, notebook, tags, fs): self.export_notebook_format(notebook_dir, notebook, format, format.lower(), notebook_name, revision_id) def export_notebook_format(self, notebook_dir, notebook, format, extension, notebook_name, revision_id): - resource = f'workspace/export?path={notebook}&direct_download=true&format={format}&revision={{"revision_timestamp":{revision_id}}}' + params = { + "path": notebook, + "direct_download": True, + "format": format + ## "revision": '{"revision_timestamp":{revision_id}}' # TODO: coming shortly + } try: - rsp = self.dbx_client._get(resource) + rsp = self.dbx_client._get("workspace/export", params) notebook_path = os.path.join(notebook_dir, f"{notebook_name}.{extension}") utils.write_file(notebook_path, rsp.content) except MlflowExportImportException as e: diff --git a/setup.py b/setup.py index 8240393..ae8d279 100644 --- a/setup.py +++ b/setup.py @@ -9,6 +9,9 @@ python_requires=">=3.7", packages=find_packages(), zip_safe=False, + entry_points={ + "mlflow.request_header_provider": "unused=mlflow_export_import.common.usr_agent_header:MlflowExportImportRequestHeaderProvider", + } install_requires=[ "mlflow>=1.15.0", "pytest==5.3.5",