Skip to content

Commit

Permalink
Issue #12: enhanced Sklearn_Wine with UC support + added 'models:/' p…
Browse files Browse the repository at this point in the history
…rediction
  • Loading branch information
amesar committed Jul 3, 2023
1 parent 3cde840 commit e0a3834
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 6 deletions.
3 changes: 2 additions & 1 deletion databricks/notebooks/basic/Common.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,5 +236,6 @@ def show_mlflow_uris(msg):

def activate_unity_catalog():
mlflow.set_registry_uri("databricks-uc")
client = mlflow.MlflowClient()
show_mlflow_uris("After UC settings")
client = mlflow.MlflowClient()
return client
55 changes: 50 additions & 5 deletions databricks/notebooks/basic/Sklearn_Wine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Databricks notebook source
# MAGIC %md ## Sklearn Wine Quality MLflow model
# MAGIC * Trains and saves model as Sklearn flavor
# MAGIC * Predicts using Sklearn, PyFunc and UDF flavors
# MAGIC * Predicts using Sklearn, Pyfunc and UDF flavors
# MAGIC * Support Unity Catalog MLflow
# MAGIC
# MAGIC ### Widgets
# MAGIC * 01. Run name
Expand Down Expand Up @@ -97,7 +98,8 @@
# COMMAND ----------

if use_uc:
activate_unity_catalog()
client = activate_unity_catalog()
print("New client._registry_uri:",client._registry_uri)

# COMMAND ----------

Expand Down Expand Up @@ -192,14 +194,15 @@ def set_run_name_to_run_id(run):
if model_name:
if use_uc:
version = register_model_uc(run, model_name, model_alias)
print(f"Registered UC model '{model_name}' as version {version.version}")
else:
version = register_model(run,
model_name,
model_version_stage,
archive_existing_versions,
model_alias
)

print(f"Registered model '{model_name}' as version {version.version}")
rmse = np.sqrt(mean_squared_error(test_y, predictions))
r2 = r2_score(test_y, predictions)
print("Metrics:")
Expand All @@ -210,6 +213,7 @@ def set_run_name_to_run_id(run):

if shap:
mlflow.shap.log_explanation(model.predict, train_x)
print("version:", version.version)

# COMMAND ----------

Expand Down Expand Up @@ -253,7 +257,7 @@ def set_run_name_to_run_id(run):

# COMMAND ----------

# MAGIC %md ### Predict with `runs` URI
# MAGIC %md ### Predict with `runs:/` URI

# COMMAND ----------

Expand All @@ -277,7 +281,7 @@ def set_run_name_to_run_id(run):

# COMMAND ----------

# MAGIC %md #### Predict as PyFunc
# MAGIC %md #### Predict as Pyfunc

# COMMAND ----------

Expand All @@ -303,3 +307,44 @@ def set_run_name_to_run_id(run):
# COMMAND ----------

type(predictions)

# COMMAND ----------

# MAGIC %md ### Predict with `models:/` URI

# COMMAND ----------

model_name

# COMMAND ----------

if not model_name:
print("No registered model specified")
exit(0)

# COMMAND ----------

model_uri = f"models:/{model_name}/{version.version}"
model_uri


# COMMAND ----------

# MAGIC %md #### Predict as Pyfunc

# COMMAND ----------

model = mlflow.pyfunc.load_model(model_uri)
predictions = model.predict(data_to_predict)
display(pd.DataFrame(predictions,columns=[WineQuality.colPrediction]))

# COMMAND ----------

# MAGIC %md #### Predict as Spark UDF

# COMMAND ----------

df = spark.createDataFrame(data_to_predict)
udf = mlflow.pyfunc.spark_udf(spark, model_uri)
predictions = df.withColumn("prediction", udf(*df.columns)).select("prediction")
display(predictions)

0 comments on commit e0a3834

Please sign in to comment.