-
Notifications
You must be signed in to change notification settings - Fork 1
/
test_cohere_wiki.py
56 lines (50 loc) · 1.3 KB
/
test_cohere_wiki.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from os import environ
import cohere
import httpx
from datasets import load_dataset
from tqdm import tqdm
namespace = "cohere_wiki"
dim = 768
vocab = 30522
client = httpx.Client(base_url="http://127.0.0.1:8000")
resp = client.post(
"/api/namespace",
json={"name": namespace, "vector_dim": dim, "sparse_vector_dim": vocab},
)
resp.raise_for_status()
docs = load_dataset(
"Cohere/wikipedia-22-12-simple-embeddings", split="train", streaming=True
)
count = 0
for doc in tqdm(docs):
resp = client.post(
"/api/doc",
json={
"namespace": namespace,
"text": doc["text"],
"doc_id": doc["id"],
"vector": doc["emb"],
"title": doc["title"],
},
)
if resp.is_error:
print(f"Error adding doc: ({resp.status_code}) {resp.text}")
continue
count += 1
print(f"Added {count} docs")
query = "the cat is on the mat"
co = cohere.Client(api_key=environ["COHERE_TOKEN"])
emb_resp = co.embed([query], model="multilingual-22-12")
resp = client.post(
"/api/query",
json={
"namespace": namespace,
"query": query,
"vector": emb_resp.embeddings[0],
},
)
resp.raise_for_status()
for doc in resp.json():
print(f"[{doc['id']}] {doc['title']}")
print(doc["text"])
print("=" * 80)