From 2d688d324265897d26d6dc4377ae2c1da8342f0f Mon Sep 17 00:00:00 2001 From: amesar Date: Tue, 17 Oct 2023 02:53:53 +0000 Subject: [PATCH] Issue #18: updates to notebook Batch_Score_Llama_2 --- databricks/notebooks/llama2/Batch_Score_Llama_2.py | 10 +++++----- databricks/notebooks/llama2/Common.py | 6 +++--- databricks/notebooks/llama2/_README.py | 3 +++ 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/databricks/notebooks/llama2/Batch_Score_Llama_2.py b/databricks/notebooks/llama2/Batch_Score_Llama_2.py index 32fe99f..21a3d67 100644 --- a/databricks/notebooks/llama2/Batch_Score_Llama_2.py +++ b/databricks/notebooks/llama2/Batch_Score_Llama_2.py @@ -40,8 +40,8 @@ dbutils.widgets.text("1. Model", default_model_name) dbutils.widgets.text("2. Version", "1") -dbutils.widgets.text("3. Input File or Table", "andre_m.default.llama2_in") -dbutils.widgets.text("4. Output Table", "andre_m.default.llama2_out") +dbutils.widgets.text("3. Input File or Table", "questions.csv") +dbutils.widgets.text("4. Output Table", "") dbutils.widgets.dropdown("5. Write mode", "none", ["none", "append","overwrite"]) model_name = dbutils.widgets.get("1. Model") @@ -72,7 +72,7 @@ # COMMAND ---------- -df_questions = mk_df_from_file_or_table(input_file_or_table) +df_questions = load_data(input_file_or_table) display(df_questions) # COMMAND ---------- @@ -92,7 +92,7 @@ # MAGIC %md ##### Load model as Spark UDF # MAGIC -# MAGIC This takes a minute or so. +# MAGIC This takes a minute or two. # COMMAND ---------- @@ -102,7 +102,7 @@ # MAGIC %md ##### Call model with questions # MAGIC -# MAGIC Takes about a minute per question. +# MAGIC Takes about 30-60 seconds per question. # COMMAND ---------- diff --git a/databricks/notebooks/llama2/Common.py b/databricks/notebooks/llama2/Common.py index cfffbaa..22a15de 100644 --- a/databricks/notebooks/llama2/Common.py +++ b/databricks/notebooks/llama2/Common.py @@ -23,7 +23,7 @@ def mk_absolute_path(path): from pyspark.sql.types import * -def mk_df_from_csv_file(path): +def load_from_path(path): print(f"Reading from file '{path}'") path = mk_absolute_path(path) print(f"Reading from file '{path}'") @@ -35,10 +35,10 @@ def mk_df_from_csv_file(path): # COMMAND ---------- -def mk_df_from_file_or_table(name): +def load_data(name): toks = name.split(".") if len(toks) == 3: # If unity catalog 3 component name print(f"Reading from table '{name}'") return spark.table(name) else: # otherwise assume its a CSV file - return mk_df_from_csv_file(name) + return load_from_path(name) diff --git a/databricks/notebooks/llama2/_README.py b/databricks/notebooks/llama2/_README.py index 540e712..5225cab 100644 --- a/databricks/notebooks/llama2/_README.py +++ b/databricks/notebooks/llama2/_README.py @@ -9,5 +9,8 @@ # MAGIC * [Model_Serve_Llama_2]($Model_Serve_Llama_22) - Real-time scoring. # MAGIC * [Common]($Common) # MAGIC +# MAGIC ##### Github +# MAGIC * https://github.com/amesar/mlflow-examples/tree/master/databricks/notebooks/llama2 +# MAGIC # MAGIC # MAGIC Last updated: 2023-10-16