Skip to content

Commit

Permalink
Support postgresql load user dict (#150)
Browse files Browse the repository at this point in the history
* make format

* Allow not install extension pg_jieba

* table name data_default
  • Loading branch information
zt2645802240 authored Aug 9, 2024
1 parent f15385e commit fe4d752
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 3 deletions.
48 changes: 46 additions & 2 deletions src/pai_rag/integrations/vector_stores/postgresql/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ class PGVectorStore(BasePydanticVectorStore):
_async_engine: Any = PrivateAttr()
_async_session: Any = PrivateAttr()
_is_initialized: bool = PrivateAttr(default=False)
_is_extension_load: bool = PrivateAttr(default=False)
_is_async_extension_load: bool = PrivateAttr(default=False)

def __init__(
self,
Expand Down Expand Up @@ -327,8 +329,11 @@ def _create_extension(self) -> None:
with self._session() as session, session.begin():
statement = sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS vector")
session.execute(statement)
statement = sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS pg_jieba")
session.execute(statement)
try:
statement = sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS pg_jieba")
session.execute(statement)
except Exception:
_logger.warning("create extension pg_jieba failed")
session.commit()

def _initialize(self) -> None:
Expand All @@ -340,6 +345,43 @@ def _initialize(self) -> None:
self._create_tables_if_not_exists()
self._is_initialized = True

def _extension_load(self) -> None:
if not self._is_extension_load:
try:
with self._session() as session, session.begin():
res = session.execute(
sqlalchemy.text("show shared_preload_libraries")
)
result = res.all()
if "pg_jieba" in result[0][0]:
session.execute(
sqlalchemy.text("SELECT jieba_load_user_dict(0,0)")
)
_logger.info("session load jieba_load_user_dict success!")
session.commit()
except Exception as e:
_logger.warning(e)
self._is_extension_load = True
_logger.info("load extension done!")

async def _async_extension_load(self) -> None:
if not self._is_async_extension_load:
try:
async with self._async_session() as async_session, async_session.begin():
res = await async_session.execute(
sqlalchemy.text("show shared_preload_libraries")
)
result = res.all()
if "pg_jieba" in result[0][0]:
await async_session.execute(
sqlalchemy.text("SELECT jieba_load_user_dict(0,0)")
)
_logger.info("async_session load jieba_load_user_dict success!")
except Exception as e:
_logger.warning(e)
self._is_async_extension_load = True
_logger.info("async load extension done!")

def _node_to_table_row(self, node: BaseNode) -> Any:
return self._table_class(
node_id=node.node_id,
Expand Down Expand Up @@ -732,6 +774,7 @@ async def aquery(
self, query: VectorStoreQuery, **kwargs: Any
) -> VectorStoreQueryResult:
self._initialize()
await self._async_extension_load()
if query.mode == VectorStoreQueryMode.HYBRID:
results = await self._async_hybrid_query(query, **kwargs)
elif query.mode in [
Expand All @@ -756,6 +799,7 @@ async def aquery(

def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
self._initialize()
self._extension_load()
if query.mode == VectorStoreQueryMode.HYBRID:
results = self._hybrid_query(query, **kwargs)
elif query.mode in [
Expand Down
4 changes: 3 additions & 1 deletion src/pai_rag/modules/index/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ def _get_or_create_postgresql_store(self):
host=pg_config["host"],
port=pg_config["port"],
database=pg_config["database"],
table_name=pg_config["table_name"],
table_name=pg_config["table_name"]
if pg_config["table_name"].strip()
else "default",
user=pg_config["username"],
password=pg_config["password"],
embed_dim=self.embed_dims,
Expand Down

0 comments on commit fe4d752

Please sign in to comment.