Skip to content

Commit

Permalink
Issue #12: Factored out Unity Catalog stuff from Sklearn_Wine to Skle…
Browse files Browse the repository at this point in the history
…arn_Wine_UC notebook
  • Loading branch information
amesar committed Jul 16, 2023
1 parent 39bb8d4 commit 9134034
Show file tree
Hide file tree
Showing 4 changed files with 405 additions and 57 deletions.
22 changes: 11 additions & 11 deletions databricks/notebooks/basic/Common.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def register_model(run,
model = client.get_registered_model(model_name)
source = f"{run.info.artifact_uri}/{model_artifact}"
vr = client.create_model_version(model_name, source, run.info.run_id)
if model_version_stage:
if model_version_stage and model_version_stage != "None":
print(f"Transitioning model '{model_name}/{vr.version}' to stage '{model_version_stage}'")
client.transition_model_version_stage(model_name, vr.version, model_version_stage, archive_existing_versions=False)
if model_alias:
Expand All @@ -143,20 +143,20 @@ def register_model(run,
# COMMAND ----------

def register_model_uc(run,
model_name,
model_alias = None,
model_artifact = "model"
reg_model_name,
reg_model_alias = None,
run_model_artifact = "model"
):
""" Register mode with specified alias """
try:
model = client.create_registered_model(model_name)
model = client.create_registered_model(reg_model_name)
except RestException as e:
model = client.get_registered_model(model_name)
source = f"{run.info.artifact_uri}/{model_artifact}"
vr = client.create_model_version(model_name, source, run.info.run_id)
if model_alias:
print(f"Setting model '{model_name}/{vr.version}' alias to '{model_alias}'")
client.set_registered_model_alias(model_name, model_alias, vr.version)
model = client.get_registered_model(reg_model_name)
source = f"{run.info.artifact_uri}/{run_model_artifact}"
vr = client.create_model_version(reg_model_name, source, run.info.run_id)
if reg_model_alias:
print(f"Setting model '{reg_model_name}/{vr.version}' alias to '{reg_model_alias}'")
client.set_registered_model_alias(reg_model_name, reg_model_alias, vr.version)
return vr

# COMMAND ----------
Expand Down
86 changes: 41 additions & 45 deletions databricks/notebooks/basic/Sklearn_Wine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# MAGIC %md ## Sklearn Wine Quality MLflow model
# MAGIC * Trains and saves model as Sklearn flavor
# MAGIC * Predicts using Sklearn, Pyfunc and UDF flavors
# MAGIC * Support Unity Catalog MLflow
# MAGIC * See [Sklearn_Wine_UC]($Sklearn_Wine_UC) notebook for Unity Catalog version
# MAGIC
# MAGIC ### Widgets
# MAGIC * 01. Run name
Expand All @@ -19,18 +19,21 @@
# MAGIC * 12. Max depth
# MAGIC * 13. Unity Catalog
# MAGIC
# MAGIC ### Notes
# MAGIC #### Notes
# MAGIC
# MAGIC * Registered model:
# MAGIC * Sklearn_Wine_test
# MAGIC
# MAGIC * Experiment:
# MAGIC * /Users/[email protected]/experiments/sklearn_wine/Sklearn_Wine_ws
# MAGIC * /Users/[email protected]/experiments/best/Sklearn_Wine_repo
# MAGIC * Delta table: andre.wine_quality
# MAGIC
# MAGIC * UC
# MAGIC * Model: andre_catalog.ml_models.Sklearn_Wine_best
# MAGIC * Experiment: /Users/[email protected]/experiments/best/Sklearn_Wine_repo_uc
# MAGIC * Delta table: andre_catalog.ml_data.winequality_white
# MAGIC * Delta tables:
# MAGIC * andre.wine_quality
# MAGIC * andre_catalog.ml_data.winequality_white
# MAGIC * andre_catalog.ml_data.winequality_red
# MAGIC
# MAGIC Last udpated: 2023-07-07
# MAGIC Last udpated: 2023-07-16

# COMMAND ----------

Expand Down Expand Up @@ -196,19 +199,7 @@ def set_run_name_to_run_id(run):
log_data_input(run, log_input, data_source, train_x)

mlflow.sklearn.log_model(model, "model", signature=signature, input_example=test_x)
if model_name:
if use_uc:
version = register_model_uc(run, model_name, model_alias)
print(f"Registered UC model '{model_name}' as version {version.version}")
else:
version = register_model(run,
model_name,
model_version_stage,
archive_existing_versions,
model_alias
)
print(f"Registered model '{model_name}' as version {version.version}")
print("version:", version.version)

rmse = np.sqrt(mean_squared_error(test_y, predictions))
r2 = r2_score(test_y, predictions)
print("Metrics:")
Expand All @@ -227,19 +218,30 @@ def set_run_name_to_run_id(run):

# COMMAND ----------

# MAGIC %md ### Display UI links
# MAGIC %md ### Register model

# COMMAND ----------

display_run_uri(run.info.experiment_id, run_id)
if model_name:
version = register_model(run,
model_name,
model_version_stage,
archive_existing_versions,
model_alias
)
print(f"Registered model '{model_name}' as version {version.version}")

# COMMAND ----------

display_experiment_id_info(run.info.experiment_id)
# MAGIC %md ### Display UI links

# COMMAND ----------

display_run_uri(run.info.experiment_id, run_id)

# COMMAND ----------

model_name
display_experiment_id_info(run.info.experiment_id)

# COMMAND ----------

Expand Down Expand Up @@ -320,40 +322,34 @@ def set_run_name_to_run_id(run):

# COMMAND ----------

model_name

# COMMAND ----------

if not model_name:
print("No registered model specified. Exiting")
exit(0)

# COMMAND ----------

model_uri = f"models:/{model_name}/{version.version}"
model_uri

if model_name:
model_uri = f"models:/{model_name}/{version.version}"
print(model_uri)
else:
print("No registered model specified")

# COMMAND ----------

# MAGIC %md #### Predict as Pyfunc

# COMMAND ----------

model = mlflow.pyfunc.load_model(model_uri)
predictions = model.predict(data_to_predict)
display(pd.DataFrame(predictions,columns=[WineQuality.colPrediction]))
if model_name:
model = mlflow.pyfunc.load_model(model_uri)
predictions = model.predict(data_to_predict)
display(pd.DataFrame(predictions,columns=[WineQuality.colPrediction]))

# COMMAND ----------

# MAGIC %md #### Predict as Spark UDF

# COMMAND ----------

df = spark.createDataFrame(data_to_predict)
udf = mlflow.pyfunc.spark_udf(spark, model_uri)
predictions = df.withColumn("prediction", udf(*df.columns)).select("prediction")
display(predictions)
if model_name:
df = spark.createDataFrame(data_to_predict)
udf = mlflow.pyfunc.spark_udf(spark, model_uri)
predictions = df.withColumn("prediction", udf(*df.columns)).select("prediction")
display(predictions)

# COMMAND ----------

Expand Down
Loading

0 comments on commit 9134034

Please sign in to comment.