From 23a8d55ba55937d97cef49436c258eae5f712d30 Mon Sep 17 00:00:00 2001 From: vik Date: Wed, 4 Dec 2024 19:49:57 -0800 Subject: [PATCH] script to test parity between cloud and local --- clients/python/moondream/__init__.py | 10 ++--- clients/python/scripts/test_cloud_parity.py | 49 +++++++++++++++++++++ 2 files changed, 54 insertions(+), 5 deletions(-) create mode 100644 clients/python/scripts/test_cloud_parity.py diff --git a/clients/python/moondream/__init__.py b/clients/python/moondream/__init__.py index 516315b..e34c5c3 100644 --- a/clients/python/moondream/__init__.py +++ b/clients/python/moondream/__init__.py @@ -5,10 +5,10 @@ def vl(*, model_path: Optional[str] = None, api_key: Optional[str] = None) -> VLM: - if not model_path and not api_key: - raise ValueError("Either model_path or api_key must be provided") - if api_key: - return CloudVL.from_api_key(api_key) + return CloudVL(api_key) + + if model_path: + return OnnxVL.from_path(model_path) - return OnnxVL.from_path(model_path) + raise ValueError("Either model_path or api_key must be provided.") diff --git a/clients/python/scripts/test_cloud_parity.py b/clients/python/scripts/test_cloud_parity.py new file mode 100644 index 0000000..370af92 --- /dev/null +++ b/clients/python/scripts/test_cloud_parity.py @@ -0,0 +1,49 @@ +import argparse +import os +import moondream as md + +from PIL import Image + +parser = argparse.ArgumentParser() +parser.add_argument("--model-path", type=str, required=True) +args = parser.parse_args() + +local = md.vl(model_path=args.model_path) +cloud = md.vl(api_key=os.environ["MOONDREAM_API_KEY"]) + +image_path = "../../assets/demo-1.jpg" +image = Image.open(image_path) + +print("# Captioning") +print("Local:", local.caption(image)) +print("Cloud:", cloud.caption(image)) + +print("# Querying") +question = "What is the character eating?" +print("Local:", local.query(image, question)) +print("Cloud:", cloud.query(image, question)) + +print("# Detecting") +object_to_detect = "burger" +print("Local:", local.detect(image, object_to_detect)) +print("Cloud:", cloud.detect(image, object_to_detect)) + +print("# Streaming Caption") +print("Local:") +for tok in local.caption(image, stream=True)["caption"]: + print(tok, end="", flush=True) +print() +print("Cloud:") +for tok in cloud.caption(image, stream=True)["caption"]: + print(tok, end="", flush=True) +print() + +print("# Streaming Query") +print("Local:") +for tok in local.query(image, question, stream=True)["answer"]: + print(tok, end="", flush=True) +print() +print("Cloud:") +for tok in cloud.query(image, question, stream=True)["answer"]: + print(tok, end="", flush=True) +print()