-
Notifications
You must be signed in to change notification settings - Fork 0
/
chroma_db_component.py
116 lines (99 loc) · 4.38 KB
/
chroma_db_component.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import chromadb
import logging
logger = logging.getLogger(__name__)
class ChromaDBComponent:
def __init__(self, persist_directory="./chroma_db"):
self.persist_directory = persist_directory
self.client = None
self.collection = None
self.EMBEDDING_DIMENSION = 768
self._initialize_client()
self._create_collection()
def _initialize_client(self):
try:
self.client = chromadb.PersistentClient(path=self.persist_directory)
logger.info("ChromaDB client initialized successfully")
except Exception as e:
self._log_and_raise_error("Error initializing ChromaDB client", e)
def _create_collection(self, collection_name="document_collection"):
try:
self.collection = self.client.get_or_create_collection(
name=collection_name,
metadata={"hnsw:space": "cosine"},
embedding_function=None
)
logger.info(f"Collection '{collection_name}' created or retrieved successfully")
collection_info = self.collection.get()
if collection_info['embeddings']:
dim = len(collection_info['embeddings'][0])
logger.info(f"Collection dimensionality: {dim}")
if dim != self.EMBEDDING_DIMENSION:
raise ValueError(f"Collection dimensionality {dim} does not match expected {self.EMBEDDING_DIMENSION}")
else:
logger.info("Collection is empty")
except Exception as e:
self._log_and_raise_error("Error creating/retrieving collection", e)
def add_documents(self, ids, embeddings, metadatas, documents):
try:
logger.info(f"Adding {len(documents)} documents, {len(embeddings)} embeddings, and {len(metadatas)} metadata entries")
if len(documents) != len(embeddings) or len(documents) != len(metadatas) or len(ids) != len(documents):
raise ValueError("Mismatch in the number of documents, embeddings, metadatas, or ids")
self.collection.add(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents
)
logger.info(f"Successfully added {len(ids)} documents to the collection")
except Exception as e:
logger.error(f"Error adding documents to collection: {str(e)}", exc_info=True)
raise
def list_all_documents(self):
try:
return self.collection.get()
except Exception as e:
logger.error(f"Error listing all documents: {str(e)}")
raise
def query(self, query_embedding, n_results=5):
try:
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=n_results
)
logger.info(f"Query executed successfully, returned {len(results['ids'][0])} results")
return results
except Exception as e:
logger.error(f"Error querying collection: {str(e)}")
raise
def get_collection_stats(self):
try:
return {
"name": self.collection.name,
"count": self.collection.count()
}
except Exception as e:
logger.error(f"Error getting collection stats: {str(e)}")
raise
def list_all_documents(self):
try:
all_docs = self.collection.get()
logger.info(f"Retrieved {len(all_docs['ids'])} documents from the collection")
return all_docs
except Exception as e:
logger.error(f"Error listing documents: {str(e)}")
raise
def similarity_search(self, query_embedding, n_results=5):
try:
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=n_results
)
if 'ids' not in results or 'documents' not in results:
logger.error("Unexpected structure in results from ChromaDB")
return []
logger.info(f"Results: {results}")
logger.info(f"Similarity search executed successfully, returned {len(results['ids'][0])} results")
return results
except Exception as e:
logger.error(f"Error with similarity search: {str(e)}")
raise