Skip to content

Commit

Permalink
Smurching patch 1 (#2)
Browse files Browse the repository at this point in the history
* mlflow_utils.py: minor cosmetic

* http_client.py: added params arg to get() method

* http_client.py: added params arg to get() method

* Update setup.py

Co-authored-by: amesar <[email protected]>
  • Loading branch information
smurching and amesar authored Jan 4, 2022
1 parent 3f83427 commit b4a9d3a
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 22 deletions.
9 changes: 5 additions & 4 deletions mlflow_export_import/common/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,18 @@ def __init__(self, api_name, host=None, token=None):
self.api_uri = os.path.join(host,api_name)
self.token = token

def _get(self, resource):
def _get(self, resource, params=None):
""" Executes an HTTP GET call
:param resource: Relative path name of resource such as cluster/list
:param params: Dict of query parameters
"""
uri = self._mk_uri(resource)
rsp = requests.get(uri, headers=self._mk_headers())
rsp = requests.get(uri, headers=self._mk_headers(), params=params)
self._check_response(rsp, uri)
return rsp

def get(self, resource):
return json.loads(self._get(resource).text)
def get(self, resource, params=None):
return json.loads(self._get(resource, params).text)

def post(self, resource, data):
""" Executes an HTTP POST call
Expand Down
30 changes: 15 additions & 15 deletions mlflow_export_import/common/mlflow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,39 @@ def dump_mlflow_info():
print(" DATABRICKS_TOKEN:", os.environ.get("DATABRICKS_TOKEN",""))

def get_mlflow_host():
''' Returns the host (tracking URI) and token '''
""" Returns the host (tracking URI) and token """
return get_mlflow_host_token()[0]

def get_mlflow_host_token():
''' Returns the host (tracking URI) and token '''
uri = os.environ.get('MLFLOW_TRACKING_URI',None)
""" Returns the host (tracking URI) and token """
uri = os.environ.get("MLFLOW_TRACKING_URI",None)
if uri is not None and uri != "databricks":
return (uri,None)
try:
from mlflow_export_import.common import databricks_cli_utils
profile = os.environ.get('MLFLOW_PROFILE',None)
profile = os.environ.get("MLFLOW_PROFILE",None)
##host_token = databricks_cli_utils.get_host_token(profile)
return databricks_cli_utils.get_host_token(profile)
#except databricks_cli.utils.InvalidConfigurationError as e:
except Exception as e: # TODO: make more specific
print("WARNING:",e)
return (None,None)

def get_experiment(client, exp_id_or_name):
''' Gets an experiment either by ID or name. '''
exp = client.get_experiment_by_name(exp_id_or_name)
def get_experiment(mlflow_client, exp_id_or_name):
""" Gets an experiment either by ID or name. """
exp = mlflow_client.get_experiment_by_name(exp_id_or_name)
if exp is None:
try:
exp = client.get_experiment(exp_id_or_name)
exp = mlflow_client.get_experiment(exp_id_or_name)
except Exception:
raise Exception(f"Cannot find experiment ID or name '{exp_id_or_name}'. Client: {client}'")
raise Exception(f"Cannot find experiment ID or name '{exp_id_or_name}'. Client: {mlflow_client}'")
return exp

def set_experiment(dbx_client, exp_name):
'''
"""
Set experiment name.
For Databricks, create the workspace directory if it doesn't exist.
'''
"""
from mlflow_export_import import utils
if utils.importing_into_databricks():
exp_dir = os.path.dirname(exp_name)
Expand All @@ -53,11 +53,11 @@ def set_experiment(dbx_client, exp_name):
mlflow.set_experiment(exp_name)

# BUG
def _get_experiment(client, exp_id_or_name):
def _get_experiment(mlflow_client, exp_id_or_name):
try:
exp = client.get_experiment(exp_id_or_name)
exp = mlflow_client.get_experiment(exp_id_or_name)
except Exception:
exp = client.get_experiment_by_name(exp_id_or_name)
exp = mlflow_client.get_experiment_by_name(exp_id_or_name)
if exp is None:
raise Exception(f"Cannot find experiment ID or name '{exp_id_or_name}'. Client: {client}'")
raise Exception(f"Cannot find experiment ID or name '{exp_id_or_name}'. Client: {mlflow_client}'")
return exp
2 changes: 1 addition & 1 deletion mlflow_export_import/model/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, export_metadata_tags=False, notebook_formats=[], stages=None)

def export_model(self, output_dir, model_name):
path = os.path.join(output_dir,"model.json")
model = self.http_client.get(f"registered-models/get?name={model_name}")
model = self.http_client.get(f"registered-models/get", {"name": model_name})
model["registered_model"]["latest_versions"] = []
versions = self.mlflow_client.search_model_versions(f"name='{model_name}'")
print(f"Found {len(versions)} versions for model {model_name}")
Expand Down
9 changes: 7 additions & 2 deletions mlflow_export_import/run/export_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,14 @@ def export_notebook(self, output_dir, notebook, tags, fs):
self.export_notebook_format(notebook_dir, notebook, format, format.lower(), notebook_name, revision_id)

def export_notebook_format(self, notebook_dir, notebook, format, extension, notebook_name, revision_id):
resource = f'workspace/export?path={notebook}&direct_download=true&format={format}&revision={{"revision_timestamp":{revision_id}}}'
params = {
"path": notebook,
"direct_download": True,
"format": format
## "revision": '{"revision_timestamp":{revision_id}}' # TODO: coming shortly
}
try:
rsp = self.dbx_client._get(resource)
rsp = self.dbx_client._get("workspace/export", params)
notebook_path = os.path.join(notebook_dir, f"{notebook_name}.{extension}")
utils.write_file(notebook_path, rsp.content)
except MlflowExportImportException as e:
Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
python_requires=">=3.7",
packages=find_packages(),
zip_safe=False,
entry_points={
"mlflow.request_header_provider": "unused=mlflow_export_import.common.usr_agent_header:MlflowExportImportRequestHeaderProvider",
}
install_requires=[
"mlflow>=1.15.0",
"pytest==5.3.5",
Expand Down

0 comments on commit b4a9d3a

Please sign in to comment.