From cfaa1c67c3135023a08c5a4912a602fc90f82cda Mon Sep 17 00:00:00 2001 From: Wahaj Ali <894764+wahajali@users.noreply.github.com> Date: Tue, 29 Oct 2024 13:42:04 +0500 Subject: [PATCH] Support for pgdiskann client (#388) * Add pgdiskann client * Add CLI support in pgdiskann * add pgdiskann load config in frontend. --------- Co-authored-by: Sheharyar Ahmad --- README.md | 1 + pyproject.toml | 1 + vectordb_bench/backend/clients/__init__.py | 13 + .../backend/clients/pgdiskann/cli.py | 99 +++++ .../backend/clients/pgdiskann/config.py | 145 ++++++++ .../backend/clients/pgdiskann/pgdiskann.py | 350 ++++++++++++++++++ vectordb_bench/cli/vectordbbench.py | 2 + .../frontend/config/dbCaseConfigs.py | 63 ++++ vectordb_bench/models.py | 3 + 9 files changed, 677 insertions(+) create mode 100644 vectordb_bench/backend/clients/pgdiskann/cli.py create mode 100644 vectordb_bench/backend/clients/pgdiskann/config.py create mode 100644 vectordb_bench/backend/clients/pgdiskann/pgdiskann.py diff --git a/README.md b/README.md index b779af11f..15767a980 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ All the database client supported | pgvector | `pip install vectordb-bench[pgvector]` | | pgvecto.rs | `pip install vectordb-bench[pgvecto_rs]` | | pgvectorscale | `pip install vectordb-bench[pgvectorscale]` | +| pgdiskann | `pip install vectordb-bench[pgdiskann]` | | redis | `pip install vectordb-bench[redis]` | | memorydb | `pip install vectordb-bench[memorydb]` | | chromadb | `pip install vectordb-bench[chromadb]` | diff --git a/pyproject.toml b/pyproject.toml index 015ed8c3f..000800389 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,7 @@ weaviate = [ "weaviate-client" ] elastic = [ "elasticsearch" ] pgvector = [ "psycopg", "psycopg-binary", "pgvector" ] pgvectorscale = [ "psycopg", "psycopg-binary", "pgvector" ] +pgdiskann = [ "psycopg", "psycopg-binary", "pgvector" ] pgvecto_rs = [ "pgvecto_rs[psycopg3]>=0.2.2" ] redis = [ "redis" ] memorydb = [ "memorydb" ] diff --git a/vectordb_bench/backend/clients/__init__.py b/vectordb_bench/backend/clients/__init__.py index 3e87e1fbe..c26aa3d6d 100644 --- a/vectordb_bench/backend/clients/__init__.py +++ b/vectordb_bench/backend/clients/__init__.py @@ -31,6 +31,7 @@ class DB(Enum): PgVector = "PgVector" PgVectoRS = "PgVectoRS" PgVectorScale = "PgVectorScale" + PgDiskANN = "PgDiskANN" Redis = "Redis" MemoryDB = "MemoryDB" Chroma = "Chroma" @@ -77,6 +78,10 @@ def init_cls(self) -> Type[VectorDB]: from .pgvectorscale.pgvectorscale import PgVectorScale return PgVectorScale + if self == DB.PgDiskANN: + from .pgdiskann.pgdiskann import PgDiskANN + return PgDiskANN + if self == DB.Redis: from .redis.redis import Redis return Redis @@ -132,6 +137,10 @@ def config_cls(self) -> Type[DBConfig]: from .pgvectorscale.config import PgVectorScaleConfig return PgVectorScaleConfig + if self == DB.PgDiskANN: + from .pgdiskann.config import PgDiskANNConfig + return PgDiskANNConfig + if self == DB.Redis: from .redis.config import RedisConfig return RedisConfig @@ -185,6 +194,10 @@ def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseCon from .pgvectorscale.config import _pgvectorscale_case_config return _pgvectorscale_case_config.get(index_type) + if self == DB.PgDiskANN: + from .pgdiskann.config import _pgdiskann_case_config + return _pgdiskann_case_config.get(index_type) + # DB.Pinecone, DB.Chroma, DB.Redis return EmptyDBCaseConfig diff --git a/vectordb_bench/backend/clients/pgdiskann/cli.py b/vectordb_bench/backend/clients/pgdiskann/cli.py new file mode 100644 index 000000000..18a9ecbd5 --- /dev/null +++ b/vectordb_bench/backend/clients/pgdiskann/cli.py @@ -0,0 +1,99 @@ +import click +import os +from pydantic import SecretStr + +from ....cli.cli import ( + CommonTypedDict, + cli, + click_parameter_decorators_from_typed_dict, + run, +) +from typing import Annotated, Optional, Unpack +from vectordb_bench.backend.clients import DB + + +class PgDiskAnnTypedDict(CommonTypedDict): + user_name: Annotated[ + str, click.option("--user-name", type=str, help="Db username", required=True) + ] + password: Annotated[ + str, + click.option("--password", + type=str, + help="Postgres database password", + default=lambda: os.environ.get("POSTGRES_PASSWORD", ""), + show_default="$POSTGRES_PASSWORD", + ), + ] + + host: Annotated[ + str, click.option("--host", type=str, help="Db host", required=True) + ] + db_name: Annotated[ + str, click.option("--db-name", type=str, help="Db name", required=True) + ] + max_neighbors: Annotated[ + int, + click.option( + "--max-neighbors", type=int, help="PgDiskAnn max neighbors", + ), + ] + l_value_ib: Annotated[ + int, + click.option( + "--l-value-ib", type=int, help="PgDiskAnn l_value_ib", + ), + ] + l_value_is: Annotated[ + float, + click.option( + "--l-value-is", type=float, help="PgDiskAnn l_value_is", + ), + ] + maintenance_work_mem: Annotated[ + Optional[str], + click.option( + "--maintenance-work-mem", + type=str, + help="Sets the maximum memory to be used for maintenance operations (index creation). " + "Can be entered as string with unit like '64GB' or as an integer number of KB." + "This will set the parameters: max_parallel_maintenance_workers," + " max_parallel_workers & table(parallel_workers)", + required=False, + ), + ] + max_parallel_workers: Annotated[ + Optional[int], + click.option( + "--max-parallel-workers", + type=int, + help="Sets the maximum number of parallel processes per maintenance operation (index creation)", + required=False, + ), + ] + +@cli.command() +@click_parameter_decorators_from_typed_dict(PgDiskAnnTypedDict) +def PgDiskAnn( + **parameters: Unpack[PgDiskAnnTypedDict], +): + from .config import PgDiskANNConfig, PgDiskANNImplConfig + + run( + db=DB.PgDiskANN, + db_config=PgDiskANNConfig( + db_label=parameters["db_label"], + user_name=SecretStr(parameters["user_name"]), + password=SecretStr(parameters["password"]), + host=parameters["host"], + db_name=parameters["db_name"], + ), + db_case_config=PgDiskANNImplConfig( + max_neighbors=parameters["max_neighbors"], + l_value_ib=parameters["l_value_ib"], + l_value_is=parameters["l_value_is"], + max_parallel_workers=parameters["max_parallel_workers"], + maintenance_work_mem=parameters["maintenance_work_mem"], + ), + **parameters, + ) \ No newline at end of file diff --git a/vectordb_bench/backend/clients/pgdiskann/config.py b/vectordb_bench/backend/clients/pgdiskann/config.py new file mode 100644 index 000000000..970720afa --- /dev/null +++ b/vectordb_bench/backend/clients/pgdiskann/config.py @@ -0,0 +1,145 @@ +from abc import abstractmethod +from typing import Any, Mapping, Optional, Sequence, TypedDict +from pydantic import BaseModel, SecretStr +from typing_extensions import LiteralString +from ..api import DBCaseConfig, DBConfig, IndexType, MetricType + +POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s" + + +class PgDiskANNConfigDict(TypedDict): + """These keys will be directly used as kwargs in psycopg connection string, + so the names must match exactly psycopg API""" + + user: str + password: str + host: str + port: int + dbname: str + + +class PgDiskANNConfig(DBConfig): + user_name: SecretStr = SecretStr("postgres") + password: SecretStr + host: str = "localhost" + port: int = 5432 + db_name: str + + def to_dict(self) -> PgDiskANNConfigDict: + user_str = self.user_name.get_secret_value() + pwd_str = self.password.get_secret_value() + return { + "host": self.host, + "port": self.port, + "dbname": self.db_name, + "user": user_str, + "password": pwd_str, + } + + +class PgDiskANNIndexConfig(BaseModel, DBCaseConfig): + metric_type: MetricType | None = None + create_index_before_load: bool = False + create_index_after_load: bool = True + maintenance_work_mem: Optional[str] + max_parallel_workers: Optional[int] + + def parse_metric(self) -> str: + if self.metric_type == MetricType.L2: + return "vector_l2_ops" + elif self.metric_type == MetricType.IP: + return "vector_ip_ops" + return "vector_cosine_ops" + + def parse_metric_fun_op(self) -> LiteralString: + if self.metric_type == MetricType.L2: + return "<->" + elif self.metric_type == MetricType.IP: + return "<#>" + return "<=>" + + def parse_metric_fun_str(self) -> str: + if self.metric_type == MetricType.L2: + return "l2_distance" + elif self.metric_type == MetricType.IP: + return "max_inner_product" + return "cosine_distance" + + @abstractmethod + def index_param(self) -> dict: + ... + + @abstractmethod + def search_param(self) -> dict: + ... + + @abstractmethod + def session_param(self) -> dict: + ... + + @staticmethod + def _optionally_build_with_options(with_options: Mapping[str, Any]) -> Sequence[dict[str, Any]]: + """Walk through mappings, creating a List of {key1 = value} pairs. That will be used to build a where clause""" + options = [] + for option_name, value in with_options.items(): + if value is not None: + options.append( + { + "option_name": option_name, + "val": str(value), + } + ) + return options + + @staticmethod + def _optionally_build_set_options( + set_mapping: Mapping[str, Any] + ) -> Sequence[dict[str, Any]]: + """Walk through options, creating 'SET 'key1 = "value1";' list""" + session_options = [] + for setting_name, value in set_mapping.items(): + if value: + session_options.append( + {"parameter": { + "setting_name": setting_name, + "val": str(value), + }, + } + ) + return session_options + + +class PgDiskANNImplConfig(PgDiskANNIndexConfig): + index: IndexType = IndexType.DISKANN + max_neighbors: int | None + l_value_ib: int | None + l_value_is: float | None + maintenance_work_mem: Optional[str] = None + max_parallel_workers: Optional[int] = None + + def index_param(self) -> dict: + return { + "metric": self.parse_metric(), + "index_type": self.index.value, + "options": { + "max_neighbors": self.max_neighbors, + "l_value_ib": self.l_value_ib, + }, + "maintenance_work_mem": self.maintenance_work_mem, + "max_parallel_workers": self.max_parallel_workers, + } + + def search_param(self) -> dict: + return { + "metric": self.parse_metric(), + "metric_fun_op": self.parse_metric_fun_op(), + } + + def session_param(self) -> dict: + return { + "diskann.l_value_is": self.l_value_is, + } + +_pgdiskann_case_config = { + IndexType.DISKANN: PgDiskANNImplConfig, +} diff --git a/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py b/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py new file mode 100644 index 000000000..c363490f7 --- /dev/null +++ b/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py @@ -0,0 +1,350 @@ +"""Wrapper around the pg_diskann vector database over VectorDB""" + +import logging +import pprint +from contextlib import contextmanager +from typing import Any, Generator, Optional, Tuple + +import numpy as np +import psycopg +from pgvector.psycopg import register_vector +from psycopg import Connection, Cursor, sql + +from ..api import VectorDB +from .config import PgDiskANNConfigDict, PgDiskANNIndexConfig + +log = logging.getLogger(__name__) + + +class PgDiskANN(VectorDB): + """Use psycopg instructions""" + + conn: psycopg.Connection[Any] | None = None + coursor: psycopg.Cursor[Any] | None = None + + _filtered_search: sql.Composed + _unfiltered_search: sql.Composed + + def __init__( + self, + dim: int, + db_config: PgDiskANNConfigDict, + db_case_config: PgDiskANNIndexConfig, + collection_name: str = "pg_diskann_collection", + drop_old: bool = False, + **kwargs, + ): + self.name = "PgDiskANN" + self.db_config = db_config + self.case_config = db_case_config + self.table_name = collection_name + self.dim = dim + + self._index_name = "pgdiskann_index" + self._primary_field = "id" + self._vector_field = "embedding" + + self.conn, self.cursor = self._create_connection(**self.db_config) + + log.info(f"{self.name} config values: {self.db_config}\n{self.case_config}") + if not any( + ( + self.case_config.create_index_before_load, + self.case_config.create_index_after_load, + ) + ): + err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load" + log.error(err) + raise RuntimeError( + f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}" + ) + + if drop_old: + self._drop_index() + self._drop_table() + self._create_table(dim) + if self.case_config.create_index_before_load: + self._create_index() + + self.cursor.close() + self.conn.close() + self.cursor = None + self.conn = None + + @staticmethod + def _create_connection(**kwargs) -> Tuple[Connection, Cursor]: + conn = psycopg.connect(**kwargs) + conn.cursor().execute("CREATE EXTENSION IF NOT EXISTS pg_diskann CASCADE") + conn.commit() + register_vector(conn) + conn.autocommit = False + cursor = conn.cursor() + + assert conn is not None, "Connection is not initialized" + assert cursor is not None, "Cursor is not initialized" + + return conn, cursor + + @contextmanager + def init(self) -> Generator[None, None, None]: + self.conn, self.cursor = self._create_connection(**self.db_config) + + # index configuration may have commands defined that we should set during each client session + session_options: dict[str, Any] = self.case_config.session_param() + + if len(session_options) > 0: + for setting_name, setting_val in session_options.items(): + command = sql.SQL("SET {setting_name} " + "= {setting_val};").format( + setting_name=sql.Identifier(setting_name), + setting_val=sql.Identifier(str(setting_val)), + ) + log.debug(command.as_string(self.cursor)) + self.cursor.execute(command) + self.conn.commit() + + self._filtered_search = sql.Composed( + [ + sql.SQL( + "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding " + ).format(table_name=sql.Identifier(self.table_name)), + sql.SQL(self.case_config.search_param()["metric_fun_op"]), + sql.SQL(" %s::vector LIMIT %s::int"), + ] + ) + + self._unfiltered_search = sql.Composed( + [ + sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format( + sql.Identifier(self.table_name) + ), + sql.SQL(self.case_config.search_param()["metric_fun_op"]), + sql.SQL(" %s::vector LIMIT %s::int"), + ] + ) + + try: + yield + finally: + self.cursor.close() + self.conn.close() + self.cursor = None + self.conn = None + + def _drop_table(self): + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + log.info(f"{self.name} client drop table : {self.table_name}") + + self.cursor.execute( + sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format( + table_name=sql.Identifier(self.table_name) + ) + ) + self.conn.commit() + + def ready_to_load(self): + pass + + def optimize(self): + self._post_insert() + + def _post_insert(self): + log.info(f"{self.name} post insert before optimize") + if self.case_config.create_index_after_load: + self._drop_index() + self._create_index() + + def _drop_index(self): + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + log.info(f"{self.name} client drop index : {self._index_name}") + + drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format( + index_name=sql.Identifier(self._index_name) + ) + log.debug(drop_index_sql.as_string(self.cursor)) + self.cursor.execute(drop_index_sql) + self.conn.commit() + + def _set_parallel_index_build_param(self): + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + index_param = self.case_config.index_param() + + if index_param["maintenance_work_mem"] is not None: + self.cursor.execute( + sql.SQL("SET maintenance_work_mem TO {};").format( + index_param["maintenance_work_mem"] + ) + ) + self.cursor.execute( + sql.SQL("ALTER USER {} SET maintenance_work_mem TO {};").format( + sql.Identifier(self.db_config["user"]), + index_param["maintenance_work_mem"], + ) + ) + self.conn.commit() + + if index_param["max_parallel_workers"] is not None: + self.cursor.execute( + sql.SQL("SET max_parallel_maintenance_workers TO '{}';").format( + index_param["max_parallel_workers"] + ) + ) + self.cursor.execute( + sql.SQL( + "ALTER USER {} SET max_parallel_maintenance_workers TO '{}';" + ).format( + sql.Identifier(self.db_config["user"]), + index_param["max_parallel_workers"], + ) + ) + self.cursor.execute( + sql.SQL("SET max_parallel_workers TO '{}';").format( + index_param["max_parallel_workers"] + ) + ) + self.cursor.execute( + sql.SQL( + "ALTER USER {} SET max_parallel_workers TO '{}';" + ).format( + sql.Identifier(self.db_config["user"]), + index_param["max_parallel_workers"], + ) + ) + self.cursor.execute( + sql.SQL( + "ALTER TABLE {} SET (parallel_workers = {});" + ).format( + sql.Identifier(self.table_name), + index_param["max_parallel_workers"], + ) + ) + self.conn.commit() + + results = self.cursor.execute( + sql.SQL("SHOW max_parallel_maintenance_workers;") + ).fetchall() + results.extend( + self.cursor.execute(sql.SQL("SHOW max_parallel_workers;")).fetchall() + ) + results.extend( + self.cursor.execute(sql.SQL("SHOW maintenance_work_mem;")).fetchall() + ) + log.info(f"{self.name} parallel index creation parameters: {results}") + def _create_index(self): + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + log.info(f"{self.name} client create index : {self._index_name}") + + index_param: dict[str, Any] = self.case_config.index_param() + self._set_parallel_index_build_param() + + options = [] + for option_name, option_val in index_param["options"].items(): + if option_val is not None: + options.append( + sql.SQL("{option_name} = {val}").format( + option_name=sql.Identifier(option_name), + val=sql.Identifier(str(option_val)), + ) + ) + + if any(options): + with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) + else: + with_clause = sql.Composed(()) + + index_create_sql = sql.SQL( + """ + CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name} + USING {index_type} (embedding {embedding_metric}) + """ + ).format( + index_name=sql.Identifier(self._index_name), + table_name=sql.Identifier(self.table_name), + index_type=sql.Identifier(index_param["index_type"].lower()), + embedding_metric=sql.Identifier(index_param["metric"]), + ) + index_create_sql_with_with_clause = ( + index_create_sql + with_clause + ).join(" ") + log.debug(index_create_sql_with_with_clause.as_string(self.cursor)) + self.cursor.execute(index_create_sql_with_with_clause) + self.conn.commit() + + def _create_table(self, dim: int): + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + try: + log.info(f"{self.name} client create table : {self.table_name}") + + self.cursor.execute( + sql.SQL( + "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));" + ).format(table_name=sql.Identifier(self.table_name), dim=dim) + ) + self.conn.commit() + except Exception as e: + log.warning( + f"Failed to create pgdiskann table: {self.table_name} error: {e}" + ) + raise e from None + + def insert_embeddings( + self, + embeddings: list[list[float]], + metadata: list[int], + **kwargs: Any, + ) -> Tuple[int, Optional[Exception]]: + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + try: + metadata_arr = np.array(metadata) + embeddings_arr = np.array(embeddings) + + with self.cursor.copy( + sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format( + table_name=sql.Identifier(self.table_name) + ) + ) as copy: + copy.set_types(["bigint", "vector"]) + for i, row in enumerate(metadata_arr): + copy.write_row((row, embeddings_arr[i])) + self.conn.commit() + + if kwargs.get("last_batch"): + self._post_insert() + + return len(metadata), None + except Exception as e: + log.warning( + f"Failed to insert data into table ({self.table_name}), error: {e}" + ) + return 0, e + + def search_embedding( + self, + query: list[float], + k: int = 100, + filters: dict | None = None, + timeout: int | None = None, + ) -> list[int]: + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + q = np.asarray(query) + if filters: + gt = filters.get("id") + result = self.cursor.execute( + self._filtered_search, (gt, q, k), prepare=True, binary=True + ) + else: + result = self.cursor.execute( + self._unfiltered_search, (q, k), prepare=True, binary=True + ) + + return [int(i[0]) for i in result.fetchall()] diff --git a/vectordb_bench/cli/vectordbbench.py b/vectordb_bench/cli/vectordbbench.py index e62c25a3d..4d23ed952 100644 --- a/vectordb_bench/cli/vectordbbench.py +++ b/vectordb_bench/cli/vectordbbench.py @@ -1,6 +1,7 @@ from ..backend.clients.pgvector.cli import PgVectorHNSW from ..backend.clients.pgvecto_rs.cli import PgVectoRSHNSW, PgVectoRSIVFFlat from ..backend.clients.pgvectorscale.cli import PgVectorScaleDiskAnn +from ..backend.clients.pgdiskann.cli import PgDiskAnn from ..backend.clients.redis.cli import Redis from ..backend.clients.memorydb.cli import MemoryDB from ..backend.clients.test.cli import Test @@ -22,6 +23,7 @@ cli.add_command(MilvusAutoIndex) cli.add_command(AWSOpenSearch) cli.add_command(PgVectorScaleDiskAnn) +cli.add_command(PgDiskAnn) if __name__ == "__main__": diff --git a/vectordb_bench/frontend/config/dbCaseConfigs.py b/vectordb_bench/frontend/config/dbCaseConfigs.py index 68bf83f19..7736a4e03 100644 --- a/vectordb_bench/frontend/config/dbCaseConfigs.py +++ b/vectordb_bench/frontend/config/dbCaseConfigs.py @@ -180,6 +180,16 @@ class CaseConfigInput(BaseModel): }, ) +CaseConfigParamInput_IndexType_PgDiskANN = CaseConfigInput( + label=CaseConfigParamType.IndexType, + inputHelp="Select Index Type", + inputType=InputType.Option, + inputConfig={ + "options": [ + IndexType.DISKANN.value, + ], + }, +) CaseConfigParamInput_IndexType_PgVectorScale = CaseConfigInput( label=CaseConfigParamType.IndexType, @@ -205,6 +215,42 @@ class CaseConfigInput(BaseModel): }, ) +CaseConfigParamInput_max_neighbors = CaseConfigInput( + label=CaseConfigParamType.max_neighbors, + inputType=InputType.Number, + inputConfig={ + "min": 10, + "max": 300, + "value": 32, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + == IndexType.DISKANN.value, +) + +CaseConfigParamInput_l_value_ib = CaseConfigInput( + label=CaseConfigParamType.l_value_ib, + inputType=InputType.Number, + inputConfig={ + "min": 10, + "max": 300, + "value": 50, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + == IndexType.DISKANN.value, +) + +CaseConfigParamInput_l_value_is = CaseConfigInput( + label=CaseConfigParamType.l_value_is, + inputType=InputType.Number, + inputConfig={ + "min": 10, + "max": 300, + "value": 40, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + == IndexType.DISKANN.value, +) + CaseConfigParamInput_num_neighbors = CaseConfigInput( label=CaseConfigParamType.num_neighbors, inputType=InputType.Number, @@ -942,6 +988,19 @@ class CaseConfigInput(BaseModel): CaseConfigParamInput_query_search_list_size, ] +PgDiskANNLoadConfig = [ + CaseConfigParamInput_IndexType_PgDiskANN, + CaseConfigParamInput_max_neighbors, + CaseConfigParamInput_l_value_ib, +] + +PgDiskANNPerformanceConfig = [ + CaseConfigParamInput_IndexType_PgDiskANN, + CaseConfigParamInput_max_neighbors, + CaseConfigParamInput_l_value_ib, + CaseConfigParamInput_l_value_is, +] + CASE_CONFIG_MAP = { DB.Milvus: { CaseLabel.Load: MilvusLoadConfig, @@ -974,4 +1033,8 @@ class CaseConfigInput(BaseModel): CaseLabel.Load: PgVectorScaleLoadingConfig, CaseLabel.Performance: PgVectorScalePerformanceConfig, }, + DB.PgDiskANN: { + CaseLabel.Load: PgDiskANNLoadConfig, + CaseLabel.Performance: PgDiskANNPerformanceConfig, + }, } diff --git a/vectordb_bench/models.py b/vectordb_bench/models.py index 7968e3e26..d431479ed 100644 --- a/vectordb_bench/models.py +++ b/vectordb_bench/models.py @@ -64,6 +64,9 @@ class CaseConfigParamType(Enum): max_parallel_workers = "max_parallel_workers" storage_layout = "storage_layout" num_neighbors = "num_neighbors" + max_neighbors = "max_neighbors" + l_value_ib = "l_value_ib" + l_value_is = "l_value_is" search_list_size = "search_list_size" max_alpha = "max_alpha" num_dimensions = "num_dimensions"