Skip to content

Commit

Permalink
add files
Browse files Browse the repository at this point in the history
  • Loading branch information
blink1073 committed Dec 17, 2024
1 parent c04f6ed commit 3004313
Show file tree
Hide file tree
Showing 2 changed files with 461 additions and 0 deletions.
204 changes: 204 additions & 0 deletions libs/langchain-mongodb/langchain_mongodb/indexes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
from typing import Any, Dict, List, Optional, Sequence

from langchain_core.indexing.base import RecordManager


def _get_pymongo_client(mongodb_url: str, **kwargs: Any) -> Any:
"""Get MongoClient for sync operations from the mongodb_url,
otherwise raise error."""
from pymongo import MongoClient
try:
client = MongoClient(mongodb_url, **kwargs)
except ValueError as e:
raise ImportError(
f"MongoClient string provided is not in proper format. " f"Got error: {e} "
) from None
return client


def _get_motor_client(mongodb_url: str, **kwargs: Any) -> Any:
"""Get AsyncIOMotorClient for async operations from the mongodb_url,
otherwise raise error."""
from motor.motor_asyncio import AsyncIOMotorClient
try:
client = AsyncIOMotorClient(mongodb_url, **kwargs)
except ValueError as e:
raise ImportError(
f"AsyncIOMotorClient string provided is not in proper format. "
f"Got error: {e} "
) from None
return client


class MongoDBRecordManager(RecordManager):
"""A MongoDB-based implementation of the record manager."""

def __init__(
self,
*,
connection_string: str,
db_name: str,
collection_name: str,
) -> None:
"""Initialize the MongoDBRecordManager.
Args:
connection_string: A valid MongoDB connection URI.
db_name: The name of the database to use.
collection_name: The name of the collection to use.
"""
super().__init__(namespace=".".join([db_name, collection_name]))
self.sync_client = _get_pymongo_client(connection_string)
self.sync_db = self.sync_client[db_name]
self.sync_collection = self.sync_db[collection_name]
self.async_client = _get_motor_client(connection_string)
self.async_db = self.async_client[db_name]
self.async_collection = self.async_db[collection_name]

def create_schema(self) -> None:
"""Create the database schema for the document manager."""
pass

async def acreate_schema(self) -> None:
"""Create the database schema for the document manager."""
pass

def update(
self,
keys: Sequence[str],
*,
group_ids: Optional[Sequence[Optional[str]]] = None,
time_at_least: Optional[float] = None,
) -> None:
"""Upsert documents into the MongoDB collection."""
if group_ids is None:
group_ids = [None] * len(keys)

if len(keys) != len(group_ids):
raise ValueError("Number of keys does not match number of group_ids")

for key, group_id in zip(keys, group_ids):
self.sync_collection.find_one_and_update(
{"namespace": self.namespace, "key": key},
{"$set": {"group_id": group_id, "updated_at": self.get_time()}},
upsert=True,
)

async def aupdate(
self,
keys: Sequence[str],
*,
group_ids: Optional[Sequence[Optional[str]]] = None,
time_at_least: Optional[float] = None,
) -> None:
"""Asynchronously upsert documents into the MongoDB collection."""
if group_ids is None:
group_ids = [None] * len(keys)

if len(keys) != len(group_ids):
raise ValueError("Number of keys does not match number of group_ids")

update_time = await self.aget_time()
if time_at_least and update_time < time_at_least:
raise ValueError("Server time is behind the expected time_at_least")

for key, group_id in zip(keys, group_ids):
await self.async_collection.find_one_and_update(
{"namespace": self.namespace, "key": key},
{"$set": {"group_id": group_id, "updated_at": update_time}},
upsert=True,
)

def get_time(self) -> float:
"""Get the current server time as a timestamp."""
server_info = self.sync_db.command("hostInfo")
local_time = server_info["system"]["currentTime"]
timestamp = local_time.timestamp()
return timestamp

async def aget_time(self) -> float:
"""Asynchronously get the current server time as a timestamp."""
host_info = await self.async_collection.database.command("hostInfo")
local_time = host_info["system"]["currentTime"]
return local_time.timestamp()

def exists(self, keys: Sequence[str]) -> List[bool]:
"""Check if the given keys exist in the MongoDB collection."""
existing_keys = {
doc["key"]
for doc in self.sync_collection.find(
{"namespace": self.namespace, "key": {"$in": keys}}, {"key": 1}
)
}
return [key in existing_keys for key in keys]

async def aexists(self, keys: Sequence[str]) -> List[bool]:
"""Asynchronously check if the given keys exist in the MongoDB collection."""
cursor = self.async_collection.find(
{"namespace": self.namespace, "key": {"$in": keys}}, {"key": 1}
)
existing_keys = {doc["key"] async for doc in cursor}
return [key in existing_keys for key in keys]

def list_keys(
self,
*,
before: Optional[float] = None,
after: Optional[float] = None,
group_ids: Optional[Sequence[str]] = None,
limit: Optional[int] = None,
) -> List[str]:
"""List documents in the MongoDB collection based on the provided date range."""
query: Dict[str, Any] = {"namespace": self.namespace}
if before:
query["updated_at"] = {"$lt": before}
if after:
query["updated_at"] = {"$gt": after}
if group_ids:
query["group_id"] = {"$in": group_ids}

cursor = (
self.sync_collection.find(query, {"key": 1}).limit(limit)
if limit
else self.sync_collection.find(query, {"key": 1})
)
return [doc["key"] for doc in cursor]

async def alist_keys(
self,
*,
before: Optional[float] = None,
after: Optional[float] = None,
group_ids: Optional[Sequence[str]] = None,
limit: Optional[int] = None,
) -> List[str]:
"""
Asynchronously list documents in the MongoDB collection
based on the provided date range.
"""
query: Dict[str, Any] = {"namespace": self.namespace}
if before:
query["updated_at"] = {"$lt": before}
if after:
query["updated_at"] = {"$gt": after}
if group_ids:
query["group_id"] = {"$in": group_ids}

cursor = (
self.async_collection.find(query, {"key": 1}).limit(limit)
if limit
else self.async_collection.find(query, {"key": 1})
)
return [doc["key"] async for doc in cursor]

def delete_keys(self, keys: Sequence[str]) -> None:
"""Delete documents from the MongoDB collection."""
self.sync_collection.delete_many(
{"namespace": self.namespace, "key": {"$in": keys}}
)

async def adelete_keys(self, keys: Sequence[str]) -> None:
"""Asynchronously delete documents from the MongoDB collection."""
await self.async_collection.delete_many(
{"namespace": self.namespace, "key": {"$in": keys}}
)
Loading

0 comments on commit 3004313

Please sign in to comment.