Skip to content

Commit

Permalink
Issue #18: updates to notebook Batch_Score_Llama_2
Browse files Browse the repository at this point in the history
  • Loading branch information
amesar committed Oct 17, 2023
1 parent e797575 commit 2d688d3
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 8 deletions.
10 changes: 5 additions & 5 deletions databricks/notebooks/llama2/Batch_Score_Llama_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@

dbutils.widgets.text("1. Model", default_model_name)
dbutils.widgets.text("2. Version", "1")
dbutils.widgets.text("3. Input File or Table", "andre_m.default.llama2_in")
dbutils.widgets.text("4. Output Table", "andre_m.default.llama2_out")
dbutils.widgets.text("3. Input File or Table", "questions.csv")
dbutils.widgets.text("4. Output Table", "")
dbutils.widgets.dropdown("5. Write mode", "none", ["none", "append","overwrite"])

model_name = dbutils.widgets.get("1. Model")
Expand Down Expand Up @@ -72,7 +72,7 @@

# COMMAND ----------

df_questions = mk_df_from_file_or_table(input_file_or_table)
df_questions = load_data(input_file_or_table)
display(df_questions)

# COMMAND ----------
Expand All @@ -92,7 +92,7 @@

# MAGIC %md ##### Load model as Spark UDF
# MAGIC
# MAGIC This takes a minute or so.
# MAGIC This takes a minute or two.

# COMMAND ----------

Expand All @@ -102,7 +102,7 @@

# MAGIC %md ##### Call model with questions
# MAGIC
# MAGIC Takes about a minute per question.
# MAGIC Takes about 30-60 seconds per question.

# COMMAND ----------

Expand Down
6 changes: 3 additions & 3 deletions databricks/notebooks/llama2/Common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def mk_absolute_path(path):

from pyspark.sql.types import *

def mk_df_from_csv_file(path):
def load_from_path(path):
print(f"Reading from file '{path}'")
path = mk_absolute_path(path)
print(f"Reading from file '{path}'")
Expand All @@ -35,10 +35,10 @@ def mk_df_from_csv_file(path):

# COMMAND ----------

def mk_df_from_file_or_table(name):
def load_data(name):
toks = name.split(".")
if len(toks) == 3: # If unity catalog 3 component name
print(f"Reading from table '{name}'")
return spark.table(name)
else: # otherwise assume its a CSV file
return mk_df_from_csv_file(name)
return load_from_path(name)
3 changes: 3 additions & 0 deletions databricks/notebooks/llama2/_README.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,8 @@
# MAGIC * [Model_Serve_Llama_2]($Model_Serve_Llama_22) - Real-time scoring.
# MAGIC * [Common]($Common)
# MAGIC
# MAGIC ##### Github
# MAGIC * https://github.com/amesar/mlflow-examples/tree/master/databricks/notebooks/llama2
# MAGIC
# MAGIC
# MAGIC Last updated: 2023-10-16

0 comments on commit 2d688d3

Please sign in to comment.