Skip to content

Commit

Permalink
Issue #18: Fixed Model_Serve_Llama_2 request format
Browse files Browse the repository at this point in the history
  • Loading branch information
amesar committed Dec 11, 2023
1 parent ed9fcf5 commit 4cdf6a0
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 35 deletions.
16 changes: 7 additions & 9 deletions databricks/notebooks/llama2/Batch_Score_Llama_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
# MAGIC %md ### Batch Score LLama 2 model
# MAGIC
# MAGIC * Simple Llama2 model batch scoring example.
# MAGIC * Load the new Marketplace LLama 2 7b model from Unity Catalog (UC) registry and asks some questions.
# MAGIC * Load the new Marketplace `LLama 2 7b` model from Unity Catalog (UC) registry and asks some questions.
# MAGIC * Questions can be from a file or a table.
# MAGIC * The table has a one string column called `question`.
# MAGIC * The input file is a one column CSV file with no header.
# MAGIC * You can optionally write the answers to an output table.
# MAGIC * All table names are 3 part UC names such `andre_m.data.llama2_answers`.
# MAGIC * Runs on `e2-dogfood` workspace. Mileage may vary on other workspaces.
# MAGIC * Cluster instance type: for `llama_2_7b_chat_hf` instance `g4dn.xlarge` is OK.
# MAGIC * All table names are 3 part UC names such `andre_m.ml_data.llama2_answers`.
# MAGIC * Cluster instance type: for `llama_2_7b_chat_hf`, instance `g4dn.xlarge` is OK.
# MAGIC
# MAGIC ##### Widgets
# MAGIC * `1. Model` - Model name.
Expand All @@ -20,7 +19,7 @@
# MAGIC * `5. Write mode` - Write mode for output table. If "none", will not write to the table.
# MAGIC
# MAGIC
# MAGIC ##### Last updated: 2023-12-10
# MAGIC ##### Last updated: _2023-12-10_

# COMMAND ----------

Expand All @@ -34,7 +33,6 @@

import mlflow
mlflow.set_registry_uri("databricks-uc")
mlflow.__version__

# COMMAND ----------

Expand Down Expand Up @@ -68,7 +66,7 @@

# COMMAND ----------

# MAGIC %md ##### Load input questions from either a file or table.
# MAGIC %md ##### Load input questions from either a file or table

# COMMAND ----------

Expand All @@ -92,7 +90,7 @@

# MAGIC %md ##### Load model as Spark UDF
# MAGIC
# MAGIC This takes a minute or two for `llama_2_7b_chat_hf` model.
# MAGIC This may take a few minutes to load the `llama_2_7b_chat_hf` model.

# COMMAND ----------

Expand All @@ -111,7 +109,7 @@

# COMMAND ----------

# MAGIC %md #### Write to table
# MAGIC %md #### Write results to table

# COMMAND ----------

Expand Down
18 changes: 16 additions & 2 deletions databricks/notebooks/llama2/Common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Databricks notebook source
default_model_name = "marketplace_staging_llama_2_models.models.llama_2_7b_chat_hf" # e2_dogfood
#default_model_name = "databricks_llama_2_models.models.llama_2_7b_chat_hf" # e2_demo_west
default_model_name = "databricks_llama_2_models.models.llama_2_7b_chat_hf" # per Marketplace notebook
#default_model_name = "marketplace_staging_llama_2_models.models.llama_2_7b_chat_hf" # e2_dogfood

print("default_model_name:", default_model_name)

Expand Down Expand Up @@ -55,3 +55,17 @@ def load_data(name):
return spark.table(name)
else: # otherwise assume its a CSV file
return load_from_path(name)

# COMMAND ----------

def dump(dct, title=None, sort_keys=None, indent=2):
if title:
print(f"{title}:")
print(json.dumps(dct, sort_keys=sort_keys, indent=indent))

# COMMAND ----------

import mlflow
import os
print("MLflow versions:", mlflow.__version__)
print("DBR version ", os.environ.get("DATABRICKS_RUNTIME_VERSION"))
63 changes: 44 additions & 19 deletions databricks/notebooks/llama2/Model_Serve_Llama_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# MAGIC #### Docs
# MAGIC * https://docs.databricks.com/api/workspace/servingendpoints
# MAGIC * https://docs.databricks.com/en/machine-learning/model-serving/create-manage-serving-endpoints.html#gpu
# MAGIC * [Send scoring requests to serving endpoints](https://docs.databricks.com/en/machine-learning/model-serving/score-model-serving-endpoints.html)
# MAGIC
# MAGIC #### Widget values
# MAGIC ##### _Workload type_
Expand All @@ -26,12 +27,7 @@
# MAGIC * Medium
# MAGIC * Large
# MAGIC
# MAGIC ##### Last updated: 2023-11-05

# COMMAND ----------

import os
print("DBR: ", os.environ.get("DATABRICKS_RUNTIME_VERSION"))
# MAGIC ##### Last updated: _2023-12-10_

# COMMAND ----------

Expand All @@ -52,18 +48,21 @@
dbutils.widgets.text("3. Endpoint", "llama2_simple")
dbutils.widgets.text("4. Workload type", "GPU_MEDIUM")
dbutils.widgets.text("5. Workload size", "Small")
dbutils.widgets.text("6. Max tokens", "128")

model_name = dbutils.widgets.get("1. Model")
version = dbutils.widgets.get("2. Version")
endpoint_name = dbutils.widgets.get("3. Endpoint")
workload_type = dbutils.widgets.get("4. Workload type")
workload_size = dbutils.widgets.get("5. Workload size")
max_tokens = dbutils.widgets.get("6. Max tokens")

print("model:", model_name)
print("version:", version)
print("endpoint_name:", endpoint_name)
print("workload_type:", workload_type)
print("workload_size:", workload_size)
print("max_tokens:", max_tokens)

# COMMAND ----------

Expand Down Expand Up @@ -114,10 +113,11 @@
# COMMAND ----------

# MAGIC %md #### Wait until endpoint is in READY state
# MAGIC * This can take up to 10 minutes.

# COMMAND ----------

model_serving_client.wait_until(endpoint_name, max=60, sleep_time=10)
model_serving_client.wait_until(endpoint_name, max=120, sleep_time=10)

# COMMAND ----------

Expand All @@ -129,18 +129,42 @@

# COMMAND ----------

# MAGIC %md #### Make questions
# MAGIC %md #### Create the questions
# MAGIC * Several different input formats are supported:
# MAGIC * input
# MAGIC * instances
# MAGIC * dataframe_records
# MAGIC * dataframe_split
# MAGIC
# MAGIC See documentaion [Send scoring requests to serving endpoints](https://docs.databricks.com/en/machine-learning/model-serving/score-model-serving-endpoints.html).

# COMMAND ----------

import pandas as pd
import json

def mk_questions(questions):
questions = [ [q] for q in questions ]
pdf = pd.DataFrame(questions, columns = ["question"])
ds_dict = {"dataframe_split": pdf.to_dict(orient="split")}
return json.dumps(ds_dict, allow_nan=True)
# COMMAND ----------

def as_dataframe_records(questions):
return {
"dataframe_records": [ { "prompt": q } for q in questions],
"params": {
"temperature": 0.5,
"max_tokens": max_tokens
}
}

# COMMAND ----------

def as_inputs(questions):
return {
"inputs": {
"prompt": questions,
},
"params": {
"temperature": 0.5,
"max_tokens": max_tokens
}
}

# COMMAND ----------

Expand All @@ -151,12 +175,13 @@ def mk_questions(questions):
"What is the western most town in the world?"
]

questions = mk_questions(questions)
questions
request = as_inputs(questions)
#request = as_dataframe_records(questions)
dump(request)

# COMMAND ----------

# MAGIC %md #### Call Model Server
# MAGIC %md #### Call Model serving endpoint

# COMMAND ----------

Expand All @@ -168,5 +193,5 @@ def mk_questions(questions):
import requests

headers = { "Authorization": f"Bearer {token}", "Content-Type": "application/json" }
rsp = requests.post(endpoint_uri, headers=headers, data=questions, timeout=15)
rsp.status_code, rsp.text
response = requests.post(endpoint_uri, headers=headers, json=request, timeout=15)
dump(response.json())
11 changes: 6 additions & 5 deletions databricks/notebooks/llama2/_README.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# Databricks notebook source
# MAGIC %md ### Take Llama 2 model for a spin
# MAGIC
# MAGIC * Simple example for [Databricks Marketplace LLama2](https://marketplace.databricks.com/details/46527194-66f5-4ea1-9619-d7ec6be6a2fa/Databricks_Llama-2-Models) model.
# MAGIC * Based on `e2-dogfood` workspace.
# MAGIC * Simple example for Llama 2 models available at [Databricks Marketplace Llama2](https://marketplace.databricks.com/details/46527194-66f5-4ea1-9619-d7ec6be6a2fa/Databricks_Llama-2-Models).
# MAGIC * See [llama_2_marketplace_listing_example](https://marketplace.databricks.com/details/46527194-66f5-4ea1-9619-d7ec6be6a2fa/Databricks_Llama-2-Models) Marketplace example notebook.
# MAGIC * Demonstrates how to do both real-time and batch model inference.
# MAGIC * Cluster instance type:
# MAGIC * For batch use `g4dn.xlarge`.
# MAGIC * For batch, use `g4dn.xlarge`.
# MAGIC * For model serving, use GPU_MEDIUM for `Workload type` and Small for `Workload size`.
# MAGIC
# MAGIC ##### Notebooks
# MAGIC * [Batch_Score_Llama_2]($Batch_Score_Llama_2) - Batch scoring.
# MAGIC * [Model_Serve_Llama_2]($Model_Serve_Llama_2) - Realtime scoring - model serving endpoint.
# MAGIC * [Batch_Score_Llama_2]($Batch_Score_Llama_2) - Batch scoring with Spark UDF.
# MAGIC * [Model_Serve_Llama_2]($Model_Serve_Llama_2) - Real-time scoring with model serving endpoint.
# MAGIC * [Common]($Common)
# MAGIC
# MAGIC ##### Github
Expand Down

0 comments on commit 4cdf6a0

Please sign in to comment.