-
Notifications
You must be signed in to change notification settings - Fork 510
/
sample.py
84 lines (67 loc) · 2.68 KB
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import argparse
from queue import Queue
from threading import Thread
import torch
from PIL import Image
from transformers import AutoTokenizer, TextIteratorStreamer
from moondream.hf import LATEST_REVISION, Moondream, detect_device
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--image", type=str, required=True)
parser.add_argument("--prompt", type=str, required=False)
parser.add_argument("--caption", action="store_true")
parser.add_argument("--cpu", action="store_true")
args = parser.parse_args()
if args.cpu:
device = torch.device("cpu")
dtype = torch.float32
else:
device, dtype = detect_device()
if device != torch.device("cpu"):
print("Using device:", device)
print("If you run into issues, pass the `--cpu` flag to this script.")
print()
image_path = args.image
prompt = args.prompt
model_id = "vikhyatk/moondream2"
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=LATEST_REVISION)
moondream = Moondream.from_pretrained(
model_id,
revision=LATEST_REVISION,
torch_dtype=dtype,
).to(device=device)
moondream.eval()
image = Image.open(image_path)
if args.caption:
print(moondream.caption(images=[image], tokenizer=tokenizer)[0])
else:
image_embeds = moondream.encode_image(image)
if prompt is None:
chat_history = ""
while True:
question = input("> ")
result_queue = Queue()
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
# Separate direct arguments from keyword arguments
thread_args = (image_embeds, question, tokenizer, chat_history)
thread_kwargs = {"streamer": streamer, "result_queue": result_queue}
thread = Thread(
target=moondream.answer_question,
args=thread_args,
kwargs=thread_kwargs,
)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
if not new_text.endswith("<") and not new_text.endswith("END"):
print(buffer, end="", flush=True)
buffer = ""
print(buffer)
thread.join()
answer = result_queue.get()
chat_history += f"Question: {question}\n\nAnswer: {answer}\n\n"
else:
print(">", prompt)
answer = moondream.answer_question(image_embeds, prompt, tokenizer)
print(answer)