Skip to content

Commit

Permalink
script to test parity between cloud and local
Browse files Browse the repository at this point in the history
  • Loading branch information
vikhyat committed Dec 5, 2024
1 parent b3d0a78 commit 23a8d55
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 5 deletions.
10 changes: 5 additions & 5 deletions clients/python/moondream/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@


def vl(*, model_path: Optional[str] = None, api_key: Optional[str] = None) -> VLM:
if not model_path and not api_key:
raise ValueError("Either model_path or api_key must be provided")

if api_key:
return CloudVL.from_api_key(api_key)
return CloudVL(api_key)

if model_path:
return OnnxVL.from_path(model_path)

return OnnxVL.from_path(model_path)
raise ValueError("Either model_path or api_key must be provided.")
49 changes: 49 additions & 0 deletions clients/python/scripts/test_cloud_parity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import argparse
import os
import moondream as md

from PIL import Image

parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, required=True)
args = parser.parse_args()

local = md.vl(model_path=args.model_path)
cloud = md.vl(api_key=os.environ["MOONDREAM_API_KEY"])

image_path = "../../assets/demo-1.jpg"
image = Image.open(image_path)

print("# Captioning")
print("Local:", local.caption(image))
print("Cloud:", cloud.caption(image))

print("# Querying")
question = "What is the character eating?"
print("Local:", local.query(image, question))
print("Cloud:", cloud.query(image, question))

print("# Detecting")
object_to_detect = "burger"
print("Local:", local.detect(image, object_to_detect))
print("Cloud:", cloud.detect(image, object_to_detect))

print("# Streaming Caption")
print("Local:")
for tok in local.caption(image, stream=True)["caption"]:
print(tok, end="", flush=True)
print()
print("Cloud:")
for tok in cloud.caption(image, stream=True)["caption"]:
print(tok, end="", flush=True)
print()

print("# Streaming Query")
print("Local:")
for tok in local.query(image, question, stream=True)["answer"]:
print(tok, end="", flush=True)
print()
print("Cloud:")
for tok in cloud.query(image, question, stream=True)["answer"]:
print(tok, end="", flush=True)
print()

0 comments on commit 23a8d55

Please sign in to comment.