Skip to content

Commit

Permalink
Support for pgdiskann client (#388)
Browse files Browse the repository at this point in the history
* Add pgdiskann client

* Add CLI support in pgdiskann

* add pgdiskann load config in frontend.

---------

Co-authored-by: Sheharyar Ahmad <[email protected]>
  • Loading branch information
wahajali and Sheharyar570 authored Oct 29, 2024
1 parent c66dfb5 commit cfaa1c6
Show file tree
Hide file tree
Showing 9 changed files with 677 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ All the database client supported
| pgvector | `pip install vectordb-bench[pgvector]` |
| pgvecto.rs | `pip install vectordb-bench[pgvecto_rs]` |
| pgvectorscale | `pip install vectordb-bench[pgvectorscale]` |
| pgdiskann | `pip install vectordb-bench[pgdiskann]` |
| redis | `pip install vectordb-bench[redis]` |
| memorydb | `pip install vectordb-bench[memorydb]` |
| chromadb | `pip install vectordb-bench[chromadb]` |
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ weaviate = [ "weaviate-client" ]
elastic = [ "elasticsearch" ]
pgvector = [ "psycopg", "psycopg-binary", "pgvector" ]
pgvectorscale = [ "psycopg", "psycopg-binary", "pgvector" ]
pgdiskann = [ "psycopg", "psycopg-binary", "pgvector" ]
pgvecto_rs = [ "pgvecto_rs[psycopg3]>=0.2.2" ]
redis = [ "redis" ]
memorydb = [ "memorydb" ]
Expand Down
13 changes: 13 additions & 0 deletions vectordb_bench/backend/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class DB(Enum):
PgVector = "PgVector"
PgVectoRS = "PgVectoRS"
PgVectorScale = "PgVectorScale"
PgDiskANN = "PgDiskANN"
Redis = "Redis"
MemoryDB = "MemoryDB"
Chroma = "Chroma"
Expand Down Expand Up @@ -77,6 +78,10 @@ def init_cls(self) -> Type[VectorDB]:
from .pgvectorscale.pgvectorscale import PgVectorScale
return PgVectorScale

if self == DB.PgDiskANN:
from .pgdiskann.pgdiskann import PgDiskANN
return PgDiskANN

if self == DB.Redis:
from .redis.redis import Redis
return Redis
Expand Down Expand Up @@ -132,6 +137,10 @@ def config_cls(self) -> Type[DBConfig]:
from .pgvectorscale.config import PgVectorScaleConfig
return PgVectorScaleConfig

if self == DB.PgDiskANN:
from .pgdiskann.config import PgDiskANNConfig
return PgDiskANNConfig

if self == DB.Redis:
from .redis.config import RedisConfig
return RedisConfig
Expand Down Expand Up @@ -185,6 +194,10 @@ def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseCon
from .pgvectorscale.config import _pgvectorscale_case_config
return _pgvectorscale_case_config.get(index_type)

if self == DB.PgDiskANN:
from .pgdiskann.config import _pgdiskann_case_config
return _pgdiskann_case_config.get(index_type)

# DB.Pinecone, DB.Chroma, DB.Redis
return EmptyDBCaseConfig

Expand Down
99 changes: 99 additions & 0 deletions vectordb_bench/backend/clients/pgdiskann/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import click
import os
from pydantic import SecretStr

from ....cli.cli import (
CommonTypedDict,
cli,
click_parameter_decorators_from_typed_dict,
run,
)
from typing import Annotated, Optional, Unpack
from vectordb_bench.backend.clients import DB


class PgDiskAnnTypedDict(CommonTypedDict):
user_name: Annotated[
str, click.option("--user-name", type=str, help="Db username", required=True)
]
password: Annotated[
str,
click.option("--password",
type=str,
help="Postgres database password",
default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
show_default="$POSTGRES_PASSWORD",
),
]

host: Annotated[
str, click.option("--host", type=str, help="Db host", required=True)
]
db_name: Annotated[
str, click.option("--db-name", type=str, help="Db name", required=True)
]
max_neighbors: Annotated[
int,
click.option(
"--max-neighbors", type=int, help="PgDiskAnn max neighbors",
),
]
l_value_ib: Annotated[
int,
click.option(
"--l-value-ib", type=int, help="PgDiskAnn l_value_ib",
),
]
l_value_is: Annotated[
float,
click.option(
"--l-value-is", type=float, help="PgDiskAnn l_value_is",
),
]
maintenance_work_mem: Annotated[
Optional[str],
click.option(
"--maintenance-work-mem",
type=str,
help="Sets the maximum memory to be used for maintenance operations (index creation). "
"Can be entered as string with unit like '64GB' or as an integer number of KB."
"This will set the parameters: max_parallel_maintenance_workers,"
" max_parallel_workers & table(parallel_workers)",
required=False,
),
]
max_parallel_workers: Annotated[
Optional[int],
click.option(
"--max-parallel-workers",
type=int,
help="Sets the maximum number of parallel processes per maintenance operation (index creation)",
required=False,
),
]

@cli.command()
@click_parameter_decorators_from_typed_dict(PgDiskAnnTypedDict)
def PgDiskAnn(
**parameters: Unpack[PgDiskAnnTypedDict],
):
from .config import PgDiskANNConfig, PgDiskANNImplConfig

run(
db=DB.PgDiskANN,
db_config=PgDiskANNConfig(
db_label=parameters["db_label"],
user_name=SecretStr(parameters["user_name"]),
password=SecretStr(parameters["password"]),
host=parameters["host"],
db_name=parameters["db_name"],
),
db_case_config=PgDiskANNImplConfig(
max_neighbors=parameters["max_neighbors"],
l_value_ib=parameters["l_value_ib"],
l_value_is=parameters["l_value_is"],
max_parallel_workers=parameters["max_parallel_workers"],
maintenance_work_mem=parameters["maintenance_work_mem"],
),
**parameters,
)
145 changes: 145 additions & 0 deletions vectordb_bench/backend/clients/pgdiskann/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from abc import abstractmethod
from typing import Any, Mapping, Optional, Sequence, TypedDict
from pydantic import BaseModel, SecretStr
from typing_extensions import LiteralString
from ..api import DBCaseConfig, DBConfig, IndexType, MetricType

POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"


class PgDiskANNConfigDict(TypedDict):
"""These keys will be directly used as kwargs in psycopg connection string,
so the names must match exactly psycopg API"""

user: str
password: str
host: str
port: int
dbname: str


class PgDiskANNConfig(DBConfig):
user_name: SecretStr = SecretStr("postgres")
password: SecretStr
host: str = "localhost"
port: int = 5432
db_name: str

def to_dict(self) -> PgDiskANNConfigDict:
user_str = self.user_name.get_secret_value()
pwd_str = self.password.get_secret_value()
return {
"host": self.host,
"port": self.port,
"dbname": self.db_name,
"user": user_str,
"password": pwd_str,
}


class PgDiskANNIndexConfig(BaseModel, DBCaseConfig):
metric_type: MetricType | None = None
create_index_before_load: bool = False
create_index_after_load: bool = True
maintenance_work_mem: Optional[str]
max_parallel_workers: Optional[int]

def parse_metric(self) -> str:
if self.metric_type == MetricType.L2:
return "vector_l2_ops"
elif self.metric_type == MetricType.IP:
return "vector_ip_ops"
return "vector_cosine_ops"

def parse_metric_fun_op(self) -> LiteralString:
if self.metric_type == MetricType.L2:
return "<->"
elif self.metric_type == MetricType.IP:
return "<#>"
return "<=>"

def parse_metric_fun_str(self) -> str:
if self.metric_type == MetricType.L2:
return "l2_distance"
elif self.metric_type == MetricType.IP:
return "max_inner_product"
return "cosine_distance"

@abstractmethod
def index_param(self) -> dict:
...

@abstractmethod
def search_param(self) -> dict:
...

@abstractmethod
def session_param(self) -> dict:
...

@staticmethod
def _optionally_build_with_options(with_options: Mapping[str, Any]) -> Sequence[dict[str, Any]]:
"""Walk through mappings, creating a List of {key1 = value} pairs. That will be used to build a where clause"""
options = []
for option_name, value in with_options.items():
if value is not None:
options.append(
{
"option_name": option_name,
"val": str(value),
}
)
return options

@staticmethod
def _optionally_build_set_options(
set_mapping: Mapping[str, Any]
) -> Sequence[dict[str, Any]]:
"""Walk through options, creating 'SET 'key1 = "value1";' list"""
session_options = []
for setting_name, value in set_mapping.items():
if value:
session_options.append(
{"parameter": {
"setting_name": setting_name,
"val": str(value),
},
}
)
return session_options


class PgDiskANNImplConfig(PgDiskANNIndexConfig):
index: IndexType = IndexType.DISKANN
max_neighbors: int | None
l_value_ib: int | None
l_value_is: float | None
maintenance_work_mem: Optional[str] = None
max_parallel_workers: Optional[int] = None

def index_param(self) -> dict:
return {
"metric": self.parse_metric(),
"index_type": self.index.value,
"options": {
"max_neighbors": self.max_neighbors,
"l_value_ib": self.l_value_ib,
},
"maintenance_work_mem": self.maintenance_work_mem,
"max_parallel_workers": self.max_parallel_workers,
}

def search_param(self) -> dict:
return {
"metric": self.parse_metric(),
"metric_fun_op": self.parse_metric_fun_op(),
}

def session_param(self) -> dict:
return {
"diskann.l_value_is": self.l_value_is,
}

_pgdiskann_case_config = {
IndexType.DISKANN: PgDiskANNImplConfig,
}
Loading

0 comments on commit cfaa1c6

Please sign in to comment.