From f1e3cf1af87d3214989b9d5b979f12e159d865e8 Mon Sep 17 00:00:00 2001 From: Justin Ramsey Date: Fri, 13 Dec 2024 14:31:38 +0800 Subject: [PATCH] Initial commit --- .gitignore | 5 + LICENSE | 21 ++ README.md | 153 ++++++++++++++ langgraph_checkpoint_dynamodb/__init__.py | 0 langgraph_checkpoint_dynamodb/saver.py | 247 ++++++++++++++++++++++ langgraph_checkpoint_dynamodb/write.py | 80 +++++++ requirements.txt | 47 ++++ setup.py | 23 ++ 8 files changed, 576 insertions(+) create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 README.md create mode 100644 langgraph_checkpoint_dynamodb/__init__.py create mode 100644 langgraph_checkpoint_dynamodb/saver.py create mode 100644 langgraph_checkpoint_dynamodb/write.py create mode 100644 requirements.txt create mode 100644 setup.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3dba330 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +**/build +**/.venv +**/dist +**/langgraph_checkpoint_dynamodb.egg-info +**/.idea diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..339fd03 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 researchwiseai + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..9dd655c --- /dev/null +++ b/README.md @@ -0,0 +1,153 @@ +# langgraph-checkpoint-dynamodb + +Implementation of a LangGraph CheckpointSaver that uses a AWS's DynamoDB + +## Inspiration + +Based on: https://github.com/researchwiseai/langgraphjs-checkpoint-dynamodb + +## Required DynamoDB Tables + +To be able to use this checkpointer, two DynamoDB table's are needed, one to store +checkpoints and the other to store writes. Below are some examples of how you +can create the required tables. + +### Terraform + +```hcl +# Variables for table names +variable "checkpoints_table_name" { + type = string +} + +variable "writes_table_name" { + type = string +} + +# Checkpoints Table +resource "aws_dynamodb_table" "checkpoints_table" { + name = var.checkpoints_table_name + billing_mode = "PAY_PER_REQUEST" + + hash_key = "thread_id" + range_key = "checkpoint_id" + + attribute { + name = "thread_id" + type = "S" + } + + attribute { + name = "checkpoint_id" + type = "S" + } +} + +# Writes Table +resource "aws_dynamodb_table" "writes_table" { + name = var.writes_table_name + billing_mode = "PAY_PER_REQUEST" + + hash_key = "thread_id_checkpoint_id_checkpoint_ns" + range_key = "task_id_idx" + + attribute { + name = "thread_id_checkpoint_id_checkpoint_ns" + type = "S" + } + + attribute { + name = "task_id_idx" + type = "S" + } +} +``` + +### AWS CDK + +```python +from aws_cdk import ( + Stack, + aws_dynamodb as dynamodb, +) +from constructs import Construct + +class DynamoDbStack(Stack): + def __init__(self, scope: Construct, id: str, **kwargs): + super().__init__(scope, id, **kwargs) + + checkpoints_table_name = 'YourCheckpointsTableName' + writes_table_name = 'YourWritesTableName' + + # Checkpoints Table + dynamodb.Table( + self, + 'CheckpointsTable', + table_name=checkpoints_table_name, + billing_mode=dynamodb.BillingMode.PAY_PER_REQUEST, + partition_key=dynamodb.Attribute( + name='thread_id', + type=dynamodb.AttributeType.STRING, + ), + sort_key=dynamodb.Attribute( + name='checkpoint_id', + type=dynamodb.AttributeType.STRING, + ), + ) + + # Writes Table + dynamodb.Table( + self, + 'WritesTable', + table_name=writes_table_name, + billing_mode=dynamodb.BillingMode.PAY_PER_REQUEST, + partition_key=dynamodb.Attribute( + name='thread_id_checkpoint_id_checkpoint_ns', + type=dynamodb.AttributeType.STRING, + ), + sort_key=dynamodb.Attribute( + name='task_id_idx', + type=dynamodb.AttributeType.STRING, + ), + ) +``` + +## Using the Checkpoint Saver + +### Default + +To use the DynamoDB checkpoint saver, you only need to specify the names of +the checkpoints and writes tables. In this scenario the DynamoDB client will +be instantiated with the default configuration, great for running on AWS Lambda. + +```python +from langgraph_checkpoint_dynamodb import DynamoDBSaver +... +checkpoints_table_name = 'YourCheckpointsTableName' +writes_table_name = 'YourWritesTableName' + +memory = DynamoDBSaver( + checkpoints_table_name=checkpoints_table_name, + writes_table_name=writes_table_name, +) + +graph = workflow.compile(checkpointer=memory) +``` + +### Providing Client Configuration + +If you need to provide custom configuration to the DynamoDB client, you can +pass in an object with the configuration options. Below is an example of how +you can provide custom configuration. + +```python +memory = DynamoDBSaver( + checkpoints_table_name=checkpoints_table_name, + writes_table_name=writes_table_name, + client_config={ + 'region': 'us-west-2', + 'accessKeyId': 'your-access-key-id', + 'secretAccessKey': 'your-secret-access-key', + } +) +``` diff --git a/langgraph_checkpoint_dynamodb/__init__.py b/langgraph_checkpoint_dynamodb/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/langgraph_checkpoint_dynamodb/saver.py b/langgraph_checkpoint_dynamodb/saver.py new file mode 100644 index 0000000..97fb95e --- /dev/null +++ b/langgraph_checkpoint_dynamodb/saver.py @@ -0,0 +1,247 @@ +import boto3 +from boto3.dynamodb.conditions import Key, Attr +from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple + +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.base import BaseCheckpointSaver, CheckpointTuple, Checkpoint, CheckpointMetadata, \ + ChannelVersions +from langgraph.checkpoint.serde.base import SerializerProtocol + +from langgraph_checkpoint_dynamodb.write import Write + + +class DynamoDBSaver(BaseCheckpointSaver): + def __init__( + self, + *, + client_config: Optional[Dict[str, Any]] = None, + serde: Optional[SerializerProtocol] = None, + checkpoints_table_name: str, + writes_table_name: str, + ) -> None: + super().__init__(serde=serde) + self.client = boto3.client("dynamodb", **(client_config or {})) + self.dynamodb = boto3.resource("dynamodb", **(client_config or {})) + self.checkpoints_table_name = checkpoints_table_name + self.writes_table_name = writes_table_name + self.checkpoints_table = self.dynamodb.Table(self.checkpoints_table_name) + self.writes_table = self.dynamodb.Table(self.writes_table_name) + + def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: + configurable = self.validate_configurable(config.get("configurable")) + item = self._get_item(configurable) + if not item: + return None + + checkpoint = self.serde.loads_typed([item["type"], item["checkpoint"].value]) + metadata = self.serde.loads_typed([item["type"], item["metadata"].value]) + + # Fetch pending writes + partition_key = self.get_write_partition_key(item) + response = self.writes_table.query( + KeyConditionExpression=Key( + "thread_id_checkpoint_id_checkpoint_ns" + ).eq(partition_key) + ) + pending_writes = [] + items = response.get("Items", []) + for write_item in items: + write = Write.from_dynamodb_item(write_item) + value = self.serde.loads_typed([write.type, write.value.value]) + pending_writes.append((write.task_id, write.channel, value)) + + config = { + "configurable": { + "thread_id": item["thread_id"], + "checkpoint_ns": item.get("checkpoint_ns", ""), + "checkpoint_id": item["checkpoint_id"], + } + } + parent_config = None + if item.get("parent_checkpoint_id"): + parent_config = { + "configurable": { + "thread_id": item["thread_id"], + "checkpoint_ns": item.get("checkpoint_ns", ""), + "checkpoint_id": item["parent_checkpoint_id"], + } + } + + checkpoint_tuple = CheckpointTuple( + config=config, + checkpoint=checkpoint, + metadata=metadata, + parent_config=parent_config, + pending_writes=pending_writes, + ) + return checkpoint_tuple + + def _get_item(self, configurable): + if configurable["checkpoint_id"] is not None: + # Use get_item + response = self.checkpoints_table.get_item( + Key={ + "thread_id": configurable["thread_id"], + "checkpoint_id": configurable["checkpoint_id"], + } + ) + return response.get("Item") + else: + # Use query + key_condition_expression = Key("thread_id").eq(configurable["thread_id"]) + args = dict() + if configurable["checkpoint_ns"]: + args['FilterExpression'] = Attr("checkpoint_ns").eq( + configurable["checkpoint_ns"] + ) + + response = self.checkpoints_table.query( + KeyConditionExpression=key_condition_expression, + Limit=1, + ConsistentRead=True, + ScanIndexForward=False, + *args, + ) + items = response.get("Items", []) + return items[0] if items else None + + def list( + self, + config: Optional[RunnableConfig], + *, + filter: Optional[Dict[str, Any]] = None, + before: Optional[RunnableConfig] = None, + limit: Optional[int] = None, + ) -> Iterator[CheckpointTuple]: + configurable = config.get("configurable", {}) + thread_id = configurable.get("thread_id") + + key_condition_expression = Key("thread_id").eq(thread_id) + if before and before.get("configurable") and before["configurable"].get( + "checkpoint_id" + ): + key_condition_expression &= Key("checkpoint_id").lt( + before["configurable"]["checkpoint_id"] + ) + + response = self.checkpoints_table.query( + KeyConditionExpression=key_condition_expression, + Limit=limit, + ScanIndexForward=False, + ) + + for item in response.get("Items", []): + checkpoint = self.serde.loads_typed(item["type"], item["checkpoint"]) + metadata = self.serde.loads_typed(item["type"], item["metadata"]) + config = { + "configurable": { + "thread_id": item["thread_id"], + "checkpoint_ns": item.get("checkpoint_ns", ""), + "checkpoint_id": item["checkpoint_id"], + } + } + parent_config = None + if item.get("parent_checkpoint_id"): + parent_config = { + "configurable": { + "thread_id": item["thread_id"], + "checkpoint_ns": item.get("checkpoint_ns", ""), + "checkpoint_id": item["parent_checkpoint_id"], + } + } + yield CheckpointTuple( + config=config, + checkpoint=checkpoint, + metadata=metadata, + parent_config=parent_config, + ) + + def put( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + configurable = self.validate_configurable(config.get("configurable")) + thread_id = configurable["thread_id"] + type1, serialized_checkpoint = self.serde.dumps_typed(checkpoint) + type2, serialized_metadata = self.serde.dumps_typed(metadata) + if type1 != type2: + raise ValueError( + "Failed to serialize checkpoint and metadata to the same type." + ) + item = { + "thread_id": thread_id, + "checkpoint_ns": config.get("configurable", {}).get("checkpoint_ns", ""), + "checkpoint_id": checkpoint.get("id"), + "parent_checkpoint_id": config.get("configurable", {}).get("checkpoint_id"), + "type": type1, + "checkpoint": serialized_checkpoint, + "metadata": serialized_metadata, + } + self.checkpoints_table.put_item(Item=item) + return { + "configurable": { + "thread_id": item["thread_id"], + "checkpoint_ns": item["checkpoint_ns"], + "checkpoint_id": item["checkpoint_id"], + } + } + + def put_writes( + self, + config: RunnableConfig, + writes: Sequence[Tuple[str, Any]], + task_id: str, + ) -> None: + configurable = self.validate_configurable(config.get("configurable")) + thread_id = configurable["thread_id"] + checkpoint_ns = configurable.get("checkpoint_ns", "") + checkpoint_id = configurable.get("checkpoint_id") + if checkpoint_id is None: + raise ValueError("Missing checkpoint_id") + write_items = [] + for idx, write in enumerate(writes): + channel, value = write + type_, serialized_value = self.serde.dumps_typed(value) + item = Write( + thread_id=thread_id, + checkpoint_ns=checkpoint_ns, + checkpoint_id=checkpoint_id, + task_id=task_id, + idx=idx, + channel=channel, + type=type_, + value=serialized_value, + ) + write_items.append({"PutRequest": {"Item": item.to_dynamodb_item()}}) + # Batch write items in batches of 25 + for i in range(0, len(write_items), 25): + batch = write_items[i : i + 25] + request_items = {self.writes_table_name: batch} + self.client.batch_write_item(RequestItems=request_items) + + def get_write_partition_key(self, item): + return Write.get_partition_key(item) + + def get_write_sort_key(self, item): + return Write.get_sort_key(item) + + def validate_configurable(self, configurable): + if not configurable: + raise ValueError("Missing configurable") + thread_id = configurable.get("thread_id") + checkpoint_ns = configurable.get("checkpoint_ns", "") + checkpoint_id = configurable.get("checkpoint_id") + if not isinstance(thread_id, str): + raise ValueError("Invalid thread_id") + if not (isinstance(checkpoint_ns, str) or checkpoint_ns is None): + raise ValueError("Invalid checkpoint_ns") + if not (isinstance(checkpoint_id, str) or checkpoint_id is None): + raise ValueError("Invalid checkpoint_id") + return { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns or "", + "checkpoint_id": checkpoint_id, + } diff --git a/langgraph_checkpoint_dynamodb/write.py b/langgraph_checkpoint_dynamodb/write.py new file mode 100644 index 0000000..cc1dfee --- /dev/null +++ b/langgraph_checkpoint_dynamodb/write.py @@ -0,0 +1,80 @@ +from typing import Any + +class Write: + separator = ":::" + + def __init__( + self, + *, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: str, + task_id: str, + idx: int, + channel: str, + type: str, + value: Any, + ): + self.thread_id = thread_id + self.checkpoint_ns = checkpoint_ns + self.checkpoint_id = checkpoint_id + self.task_id = task_id + self.idx = idx + self.channel = channel + self.type = type + self.value = value + + def to_dynamodb_item(self): + return { + "thread_id_checkpoint_id_checkpoint_ns": { + 'S': self.get_partition_key( + { + "thread_id": self.thread_id, + "checkpoint_id": self.checkpoint_id, + "checkpoint_ns": self.checkpoint_ns, + } + ) + }, + "task_id_idx": { + 'S': self.get_sort_key( + {"task_id": self.task_id, "idx": self.idx} + ) + }, + "channel": {'S': self.channel}, + "type": {'S': self.type}, + "value": {'B': self.value}, # Assuming self.value is binary data + } + + @classmethod + def from_dynamodb_item(cls, item): + thread_id_checkpoint_id_checkpoint_ns = item[ + "thread_id_checkpoint_id_checkpoint_ns" + ] + task_id_idx = item["task_id_idx"] + channel = item["channel"] + type_ = item["type"] + value = item["value"] + thread_id, checkpoint_id, checkpoint_ns = thread_id_checkpoint_id_checkpoint_ns.split( + cls.separator + ) + task_id, idx = task_id_idx.split(cls.separator) + return cls( + thread_id=thread_id, + checkpoint_ns=checkpoint_ns, + checkpoint_id=checkpoint_id, + task_id=task_id, + idx=int(idx), + channel=channel, + type=type_, + value=value, + ) + + @staticmethod + def get_partition_key(item): + return Write.separator.join( + [item["thread_id"], item["checkpoint_id"], item.get("checkpoint_ns", "")] + ) + + @staticmethod + def get_sort_key(item): + return Write.separator.join([item["task_id"], str(item["idx"])]) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..dda80b7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,47 @@ +aiohappyeyeballs==2.4.3 +aiohttp==3.11.8 +aiosignal==1.3.1 +annotated-types==0.7.0 +anyio==4.6.2.post1 +attrs==24.2.0 +boto3==1.35.71 +botocore==1.35.71 +certifi==2024.8.30 +charset-normalizer==3.4.0 +frozenlist==1.5.0 +h11==0.14.0 +httpcore==1.0.7 +httpx==0.28.0 +httpx-sse==0.4.0 +idna==3.10 +jmespath==1.0.1 +jsonpatch==1.33 +jsonpointer==3.0.0 +langchain==0.3.9 +langchain-core==0.3.21 +langchain-text-splitters==0.3.2 +langgraph==0.2.53 +langgraph-checkpoint==2.0.7 +langgraph-sdk==0.1.40 +langsmith==0.1.147 +msgpack==1.1.0 +multidict==6.1.0 +numpy==2.1.3 +orjson==3.10.12 +packaging==24.2 +propcache==0.2.0 +pydantic==2.10.2 +pydantic_core==2.27.1 +python-dateutil==2.9.0.post0 +PyYAML==6.0.2 +requests==2.32.3 +requests-toolbelt==1.0.0 +s3transfer==0.10.4 +setuptools==75.6.0 +six==1.16.0 +sniffio==1.3.1 +SQLAlchemy==2.0.36 +tenacity==9.0.0 +typing_extensions==4.12.2 +urllib3==2.2.3 +yarl==1.18.0 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..1e68400 --- /dev/null +++ b/setup.py @@ -0,0 +1,23 @@ +from setuptools import setup, find_packages + +setup( + name='langgraph_checkpoint_dynamodb', + version='0.1.0', + author='Justin Ramsey', + author_email='justin@rflow.io', + description='DynamoDB Saver for LangGraph Checkpoints', + long_description=open('README.md').read(), + long_description_content_type='text/markdown', + url='https://github.com/justinram11/langgraph-checkpoint-dynamodb', + packages=find_packages(), + install_requires=[ + 'boto3>=1.17.0', + 'langchain>=0.3.9', + 'langgraph>=0.2.53', + ], + classifiers=[ + 'Programming Language :: Python :: 3', + 'Operating System :: OS Independent', + ], + python_requires='>=3.6', +) \ No newline at end of file