diff --git a/src/pai_rag/integrations/vector_stores/postgresql/postgresql.py b/src/pai_rag/integrations/vector_stores/postgresql/postgresql.py index 50f950d0..e1f286ea 100644 --- a/src/pai_rag/integrations/vector_stores/postgresql/postgresql.py +++ b/src/pai_rag/integrations/vector_stores/postgresql/postgresql.py @@ -71,7 +71,7 @@ class HybridAbstractData(base): # type: ignore id = Column(BIGINT, primary_key=True, autoincrement=True) text = Column(VARCHAR, nullable=False) metadata_ = Column(metadata_dtype) - node_id = Column(VARCHAR) + node_id = Column(VARCHAR, unique=True) embedding = Column(Vector(embed_dim)) # type: ignore text_search_tsv = Column( # type: ignore TSVector(), @@ -98,7 +98,7 @@ class AbstractData(base): # type: ignore id = Column(BIGINT, primary_key=True, autoincrement=True) text = Column(VARCHAR, nullable=False) metadata_ = Column(metadata_dtype) - node_id = Column(VARCHAR) + node_id = Column(VARCHAR, unique=True) embedding = Column(Vector(embed_dim)) # type: ignore model = type( @@ -398,7 +398,13 @@ def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]: self._initialize() ids = [] with self._session() as session, session.begin(): + from sqlalchemy import delete + for node in nodes: + stmt = delete(self._table_class).where( + self._table_class.node_id == node.node_id + ) + session.execute(stmt) ids.append(node.node_id) item = self._node_to_table_row(node) session.add(item)