diff --git a/README.md b/README.md index 4e8b6ae..536b51c 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ End-to-end service to query the text. - [x] full text search (Postgres GIN + text search) - [x] vector similarity search ([pgvecto.rs](https://github.com/tensorchord/pgvecto.rs) HNSW) - [x] generate vector if not provided -- [ ] sparse search +- [x] sparse search ([pgvecto.rs](https://github.com/tensorchord/pgvecto.rs) HNSW) - [ ] filtering - [x] reranking with [reranker](https://github.com/kemingy/reranker) - [x] semantic highlight @@ -20,7 +20,8 @@ docker compose -f docker/compose.yaml up -d server Some of the dependent services can be opt-out: - `emb`: used to generate embedding for query and documents -- `colbert`: used to provide the semantic highlight feature +- `sparse`: used to generate sparse embedding for query and documents +- `highlight`: used to provide the semantic highlight feature - `encoder`: rerank with cross-encoder model, you can choose other methods or other online services For the client example, check: @@ -29,7 +30,7 @@ For the client example, check: ## API -- `/api/namespace` POST: create a new namespace and configure the text + vector index +- `/api/namespace` POST: create a new namespace and configure the index - `/api/doc` POST: add a new doc - `/api/query` POST: query the docs - `/api/highlight` POST: semantic highlight @@ -50,7 +51,7 @@ This project has most of the components you need for the RAG except for the last > If you already have the table in Postgres, you will be responsible for the text-indexing and vector-indexing part. 1. Define a `dataclass` that includes the **necessary** columns as class attributes - - annotate the `primary_key`, `text_index`, `vector_index` with metadata + - annotate the `primary_key`, `text_index`, `vector_index`, `sparse_index` with metadata (not all the columns are required, only the necessary ones) - attributes without default value or default factory is treated as required when you add new docs 2. Implement the `to_record` and `from_record` methods to be used in the reranking stage 3. Change the `config.vector_store.schema` to the class you have defined diff --git a/docker/compose.yaml b/docker/compose.yaml index 15e70e1..f536f0c 100644 --- a/docker/compose.yaml +++ b/docker/compose.yaml @@ -7,7 +7,7 @@ services: - "8080:8000" pg: - image: "tensorchord/pgvecto-rs:pg16-v0.2.1" + image: "tensorchord/pgvecto-rs:pg16-v0.3.0-alpha.1" environment: - POSTGRES_PASSWORD=password ports: @@ -15,7 +15,7 @@ services: volumes: - "/tmp/qtext_pgdata:/var/lib/postgresql/data" - colbert: + highlight: image: "kemingy/colbert-highlight" ports: - "8081:8000" @@ -27,6 +27,13 @@ services: ports: - "8082:8000" + sparse: + image: "kemingy/spladepp" + environment: + - MOSEC_TIMEOUT=10000 + ports: + - "8083:8000" + server: build: context: ../ @@ -36,5 +43,6 @@ services: depends_on: - pg - emb - - colbert + - highlight - encoder + - sparse diff --git a/docker/config.json b/docker/config.json index 9541b45..cf2deb1 100644 --- a/docker/config.json +++ b/docker/config.json @@ -6,7 +6,10 @@ "api_endpoint": "http://emb:8000" }, "highlight": { - "addr": "http://colbert:8000" + "addr": "http://highlight:8000" + }, + "sparse": { + "addr": "http://sparse:8000" }, "server": { "log_level": 10 diff --git a/pyproject.toml b/pyproject.toml index f424d0b..1c16e6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "reranker~=0.2", "openai~=1.12.0", "defspec~=0.1.1", + "httpx~=0.27", ] [project.optional-dependencies] dev = [ diff --git a/qtext/config.py b/qtext/config.py index b5b7398..9257d71 100644 --- a/qtext/config.py +++ b/qtext/config.py @@ -31,6 +31,12 @@ class EmbeddingConfig(msgspec.Struct, kw_only=True, frozen=True): timeout: int = 300 +class SparseEmbeddingConfig(msgspec.Struct, kw_only=True, frozen=True): + addr: str = "http://127.0.0.1:8083" + timeout: int = 10 + dim: int = 30522 + + class RankConfig(msgspec.Struct, kw_only=True, frozen=True): ranker: Type[Ranker] = CrossEncoderClient params: dict[str, str] = msgspec.field( @@ -49,6 +55,7 @@ class Config(msgspec.Struct, kw_only=True, frozen=True): server: ServerConfig = ServerConfig() vector_store: VectorStoreConfig = VectorStoreConfig() embedding: EmbeddingConfig = EmbeddingConfig() + sparse: SparseEmbeddingConfig = SparseEmbeddingConfig() ranker: RankConfig = RankConfig() highlight: HighlightConfig = HighlightConfig() diff --git a/qtext/emb_client.py b/qtext/emb_client.py index 29cdfc6..dd4e500 100644 --- a/qtext/emb_client.py +++ b/qtext/emb_client.py @@ -1,7 +1,11 @@ from __future__ import annotations +import httpx +import msgspec import openai +from qtext.log import logger +from qtext.spec import SparseEmbedding from qtext.utils import time_it @@ -23,3 +27,28 @@ def embedding(self, text: str | list[str]) -> list[float]: if len(response.data) > 1: return [data.embedding for data in response.data] return response.data[0].embedding + + +class SparseEmbeddingClient: + def __init__(self, endpoint: str, dim: int, timeout: int) -> None: + self.dim = dim + self.client = httpx.Client(base_url=endpoint, timeout=timeout) + self.decoder = msgspec.json.Decoder(type=list[SparseEmbedding]) + + @time_it + def sparse_embedding( + self, text: str | list[str] + ) -> list[SparseEmbedding] | SparseEmbedding: + resp = self.client.post("/inference", json=text) + if resp.is_error: + logger.info( + "failed to call sparse embedding [%d]: %s", + resp.status_code, + resp.content, + ) + resp.raise_for_status() + sparse = self.decoder.decode(resp.content) + + if len(sparse) == 1: + return sparse[0] + return sparse diff --git a/qtext/engine.py b/qtext/engine.py index 9d4c467..aa3f940 100644 --- a/qtext/engine.py +++ b/qtext/engine.py @@ -1,7 +1,7 @@ from __future__ import annotations from qtext.config import Config -from qtext.emb_client import EmbeddingClient +from qtext.emb_client import EmbeddingClient, SparseEmbeddingClient from qtext.highlight_client import ENGLISH_STOPWORDS, HighlightClient from qtext.pg_client import PgVectorsClient from qtext.schema import DefaultTable, Querier @@ -27,6 +27,11 @@ def __init__(self, config: Config) -> None: endpoint=config.embedding.api_endpoint, timeout=config.embedding.timeout, ) + self.sparse_client = SparseEmbeddingClient( + endpoint=config.sparse.addr, + dim=config.sparse.dim, + timeout=config.sparse.timeout, + ) self.ranker = config.ranker.ranker(**config.ranker.params) @time_it @@ -36,10 +41,17 @@ def add_namespace(self, req: AddNamespaceRequest) -> None: @time_it def add_doc(self, req) -> None: if self.querier.has_vector_index(): - text = self.querier.retrieve_text(req) vector = self.querier.retrieve_vector(req) - if not vector: + if not vector and self.querier.has_text_index(): + text = self.querier.retrieve_text(req) self.querier.fill_vector(req, self.emb_client.embedding(text=text)) + if self.querier.has_sparse_index(): + sparse = self.querier.retrieve_sparse_vector(req) + if not sparse and self.querier.has_text_index(): + text = self.querier.retrieve_text(req) + self.querier.fill_sparse_vector( + req, self.sparse_client.sparse_embedding(text=text) + ) self.pg_client.add_doc(req) @time_it @@ -48,18 +60,24 @@ def rank( req: QueryDocRequest, text_res: list[DefaultTable], vector_res: list[DefaultTable], + sparse_res: list[DefaultTable], ) -> list[DefaultTable]: - docs = self.querier.combine_vector_text(vec_res=vector_res, text_res=text_res) + docs = self.querier.combine_vector_text( + vec_res=vector_res, sparse_res=sparse_res, text_res=text_res + ) ranked = self.ranker.rank(req.to_record(), docs) return [DefaultTable.from_record(record) for record in ranked] @time_it def query(self, req: QueryDocRequest) -> list[DefaultTable]: kw_results = self.pg_client.query_text(req) - if self.querier.has_vector_index() and not self.querier.retrieve_vector(req): + if self.querier.has_vector_index() and not req.vector: req.vector = self.emb_client.embedding(req.query) + if self.querier.has_sparse_index() and not req.sparse_vector: + req.sparse_vector = self.sparse_client.sparse_embedding(req.query) vec_results = self.pg_client.query_vector(req) - return self.rank(req, kw_results, vec_results) + sparse_results = self.pg_client.query_sparse_vector(req) + return self.rank(req, kw_results, vec_results, sparse_results) @time_it def highlight(self, req: HighlightRequest) -> HighlightResponse: diff --git a/qtext/highlight_client.py b/qtext/highlight_client.py index c48ab6f..821e869 100644 --- a/qtext/highlight_client.py +++ b/qtext/highlight_client.py @@ -17,7 +17,7 @@ def highlight_score( resp = self.client.post("/inference", json=[query, *docs]) if resp.is_error: logger.info( - "failed to call the highlight service [%d], %s", + "failed to call the highlight service [%d]: %s", resp.status_code, resp.content, ) diff --git a/qtext/pg_client.py b/qtext/pg_client.py index 60c5969..d621a74 100644 --- a/qtext/pg_client.py +++ b/qtext/pg_client.py @@ -8,7 +8,7 @@ from qtext.log import logger from qtext.schema import DefaultTable, Querier -from qtext.spec import AddNamespaceRequest, QueryDocRequest +from qtext.spec import AddNamespaceRequest, QueryDocRequest, SparseEmbedding from qtext.utils import time_it @@ -50,6 +50,34 @@ class VectorTextDumper(VectorDumper): adapters.register_loader(info.oid, VectorLoader) +class SparseVectorDumper(Dumper): + def dump(self, obj): + if isinstance(obj, np.ndarray): + return f"[{','.join(map(str, obj))}]".encode() + if isinstance(obj, SparseEmbedding): + return obj.to_str().encode() + raise ValueError(f"unsupported type {type(obj)}") + + +def register_sparse_vector(conn: psycopg.Connection): + info = TypeInfo.fetch(conn=conn, name="svector") + register_svector_type(conn, info) + + +def register_svector_type(conn: psycopg.Connection, info: TypeInfo): + if info is None: + raise ValueError("vector type not found") + info.register(conn) + + class SparseVectorTextDumper(SparseVectorDumper): + oid = info.oid + + adapters = conn.adapters + adapters.register_dumper(SparseEmbedding, SparseVectorTextDumper) + adapters.register_dumper(np.ndarray, SparseVectorTextDumper) + adapters.register_loader(info.oid, VectorLoader) + + class PgVectorsClient: def __init__(self, path: str, querier: Querier): self.path = path @@ -61,6 +89,7 @@ def connect(self): conn = psycopg.connect(self.path, row_factory=dict_row) conn.execute("CREATE EXTENSION IF NOT EXISTS vectors;") register_vector(conn) + register_sparse_vector(conn) conn.commit() return conn @@ -70,11 +99,15 @@ def close(self): @time_it def add_namespace(self, req: AddNamespaceRequest): try: - create_table_sql = self.querier.create_table(req.name, req.vector_dim) + create_table_sql = self.querier.create_table( + req.name, req.vector_dim, req.sparse_vector_dim + ) vector_index_sql = self.querier.vector_index(req.name) + sparse_index_sql = self.querier.sparse_index(req.name) text_index_sql = self.querier.text_index(req.name) self.conn.execute(create_table_sql) self.conn.execute(vector_index_sql) + self.conn.execute(sparse_index_sql) self.conn.execute(text_index_sql) self.conn.commit() except psycopg.errors.Error as err: @@ -112,7 +145,7 @@ def query_text(self, req: QueryDocRequest) -> list[DefaultTable]: (" | ".join(req.query.strip().split(" ")), req.limit), ) except psycopg.errors.Error as err: - logger.info("pg client query error", exc_info=err) + logger.info("pg client query text error", exc_info=err) self.conn.rollback() raise RuntimeError("query text error") from err return [self.resp_cls(**res) for res in cursor.fetchall()] @@ -128,7 +161,22 @@ def query_vector(self, req: QueryDocRequest) -> list[DefaultTable]: (req.vector, req.limit), ) except psycopg.errors.Error as err: - logger.info("pg client query error", exc_info=err) + logger.info("pg client query vector error", exc_info=err) self.conn.rollback() raise RuntimeError("query vector error") from err return [self.resp_cls(**res) for res in cursor.fetchall()] + + @time_it + def query_sparse_vector(self, req: QueryDocRequest) -> list[DefaultTable]: + if not self.querier.has_sparse_index(): + return [] + try: + cursor = self.conn.execute( + self.querier.sparse_query(req.namespace), + (req.sparse_vector, req.limit), + ) + except psycopg.errors.Error as err: + logger.info("pg client query sparse vector error", exc_info=err) + self.conn.rollback() + raise RuntimeError("query sparse vector error") from err + return [self.resp_cls(**res) for res in cursor.fetchall()] diff --git a/qtext/schema.py b/qtext/schema.py index beda64b..3e24593 100644 --- a/qtext/schema.py +++ b/qtext/schema.py @@ -25,6 +25,9 @@ class DefaultTable: id: int | None = field(default=None, metadata={"primary_key": True}) text: str = field(metadata={"text_index": True}) vector: list[float] = field(default_factory=list, metadata={"vector_index": True}) + sparse_vector: list[float] = field( + default_factory=list, metadata={"sparse_index": True} + ) title: str | None = field(default=None, metadata={"text_index": True}) summary: str | None = None author: str | None = None @@ -69,6 +72,7 @@ def __init__(self, table: Type[DefaultTable]) -> None: self.fields: list[Field] = msgspec.inspect.type_info(table).fields self.primary_key: str | None = None self.vector_column: str | None = None + self.sparse_column: str | None = None self.text_columns: list[str] = [] for f in fields(self.table_type): @@ -76,10 +80,14 @@ def __init__(self, table: Type[DefaultTable]) -> None: self.primary_key = f.name if f.metadata.get("vector_index"): self.vector_column = f.name + if f.metadata.get("sparse_index"): + self.sparse_column = f.name if f.metadata.get("text_index"): self.text_columns.append(f.name) def generate_request_class(self) -> DefaultTable: + """Generate the user request class.""" + @dataclass(kw_only=True) class Request(self.table_type): namespace: str @@ -87,6 +95,8 @@ class Request(self.table_type): return Request def generate_response_class(self) -> DefaultTable: + """Generate the class used by the raw dict data from postgres.""" + @dataclass(kw_only=True) class Response(self.table_type): rank: float @@ -99,18 +109,34 @@ def fill_vector(self, obj, vector: list[float]): def retrieve_vector(self, obj): return getattr(obj, self.vector_column) + def fill_sparse_vector(self, obj, vector: list[float]): + setattr(obj, self.sparse_column, vector) + + def retrieve_sparse_vector(self, obj): + return getattr(obj, self.sparse_column) + def retrieve_text(self, obj): return "\n".join(getattr(obj, t, "") or "" for t in self.text_columns) def combine_vector_text( - self, vec_res: list[DefaultTable], text_res: list[DefaultTable] + self, + vec_res: list[DefaultTable], + sparse_res: list[DefaultTable], + text_res: list[DefaultTable], ) -> list[Record]: + """Combine hybrid search results.""" id_to_record = {} for vec in vec_res: record = vec.to_record() record.vector_sim = vec.rank id_to_record[record.id] = record + for sparse in sparse_res: + record = sparse.to_record() + if record.id not in id_to_record: + id_to_record[record.id] = record + id_to_record[record.id].title_sim = sparse.rank + for text in text_res: record = text.to_record() if record.id not in id_to_record: @@ -139,7 +165,17 @@ def to_pg_type(field_type: msgspec.inspect.Type) -> str: DictType: "JSONB", }[field_type.__class__] - def create_table(self, name: str, dim: int) -> str: + def create_table(self, name: str, dim: int, sparse_dim: int) -> str: + # check the vector dimension provided + if self.has_vector_index() and dim == 0: + raise ValueError( + "Vector dimension is required when schema has vector index" + ) + if self.has_sparse_index() and sparse_dim == 0: + raise ValueError( + "Sparse vector dimension is required when schema has sparse index" + ) + sql = f"CREATE TABLE IF NOT EXISTS {name} (" for i, f in enumerate(self.fields): if f.name == self.primary_key: @@ -147,6 +183,8 @@ def create_table(self, name: str, dim: int) -> str: continue elif f.name == self.vector_column: sql += f"{f.name} vector({dim}) " + elif f.name == self.sparse_column: + sql += f"{f.name} svector({sparse_dim}) " else: sql += f"{f.name} {Querier.to_pg_type(f.type)} " @@ -162,6 +200,9 @@ def create_table(self, name: str, dim: int) -> str: def has_vector_index(self) -> bool: return self.vector_column is not None + def has_sparse_index(self) -> bool: + return self.sparse_column is not None + def has_text_index(self) -> bool: return len(self.text_columns) > 0 @@ -177,6 +218,14 @@ def vector_index(self, table: str) -> str: f"vectors ({self.vector_column} vector_dot_ops);" ) + def sparse_index(self, table: str) -> str: + if not self.has_sparse_index(): + return "" + return ( + f"CREATE INDEX IF NOT EXISTS {table}_sparse ON {table} USING " + f"vectors ({self.sparse_column} svector_dot_ops);" + ) + def text_index(self, table: str) -> str: if not self.has_text_index(): return "" @@ -194,6 +243,13 @@ def vector_query(self, table: str) -> str: f"FROM {table} ORDER by rank LIMIT %s;" ) + def sparse_query(self, table: str) -> str: + columns = ", ".join(f.name for f in self.fields) + return ( + f"SELECT {columns}, {self.sparse_column} <#> %s AS rank " + f"FROM {table} ORDER by rank LIMIT %s;" + ) + def text_query(self, table: str) -> str: columns = ", ".join(f.name for f in self.fields) return ( diff --git a/qtext/spec.py b/qtext/spec.py index d6e598a..c55c964 100644 --- a/qtext/spec.py +++ b/qtext/spec.py @@ -1,14 +1,27 @@ from __future__ import annotations import msgspec +import numpy as np from reranker import Record +class SparseEmbedding(msgspec.Struct, kw_only=True, frozen=True): + dim: int + indices: list[int] + values: list[float] + + def to_str(self) -> str: + dense = np.zeros(self.dim) + dense[self.indices] = self.values + return f"[{','.join(map(str, dense))}]" + + class QueryDocRequest(msgspec.Struct, kw_only=True): namespace: str query: str limit: int = 10 vector: list[float] | None = None + sparse_vector: SparseEmbedding | None = None metadata: dict | None = None def to_record(self) -> Record: @@ -20,7 +33,8 @@ def to_record(self) -> Record: class AddNamespaceRequest(msgspec.Struct, frozen=True, kw_only=True): name: str - vector_dim: int + vector_dim: int = 0 + sparse_vector_dim: int = 0 class HighlightRequest(msgspec.Struct, kw_only=True, frozen=True): diff --git a/test.py b/test.py index 8ba8dba..9d61fae 100644 --- a/test.py +++ b/test.py @@ -2,8 +2,11 @@ import httpx +namespace = "document" +dim = 768 + client = httpx.Client(base_url="http://127.0.0.1:8000") -resp = client.post("/api/namespace", json={"name": "document", "vector_dim": 768}) +resp = client.post("/api/namespace", json={"name": namespace, "vector_dim": dim}) resp.raise_for_status() for i, text in enumerate( [ @@ -15,7 +18,7 @@ resp = client.post( "/api/doc", json={ - "namespace": "document", + "namespace": namespace, "text": text, "updated_at": str(datetime.now() - timedelta(days=i)), }, @@ -23,7 +26,7 @@ resp.raise_for_status() resp = client.post( - "/api/query", json={"namespace": "document", "query": "Who creates faster Python?"} + "/api/query", json={"namespace": namespace, "query": "Who creates faster Python?"} ) resp.raise_for_status() print([(doc["id"], doc["text"]) for doc in resp.json()]) diff --git a/test_sparse.py b/test_sparse.py index d6685f6..048c27f 100644 --- a/test_sparse.py +++ b/test_sparse.py @@ -1,102 +1,31 @@ -import numpy as np -import psycopg -from fastembed.sparse.sparse_text_embedding import SparseTextEmbedding -from psycopg.adapt import Dumper, Loader -from psycopg.rows import dict_row -from psycopg.types import TypeInfo - -model = SparseTextEmbedding("prithvida/Splade_PP_en_v1") +import httpx vocab = 30522 -vec = next(model.embed("the quick brown fox jumped over the lazy dog")) -print(vec.values.shape, vec.indices.shape) - - -class VectorDumper(Dumper): - def dump(self, obj): - if isinstance(obj, np.ndarray): - return f"[{','.join(map(str, obj))}]".encode() - return str(obj).replace(" ", "").encode() - - -class VectorLoader(Loader): - def load(self, buf): - if isinstance(buf, memoryview): - buf = bytes(buf) - return np.array(buf.decode()[1:-1].split(","), dtype=np.float32) - - -class SVectorDumper(Dumper): - def dump(self, obj): - if isinstance(obj, np.ndarray): - return f"[{','.join(map(str, obj))}]".encode() - return str(obj).replace(" ", "").encode() - - -class SVectorLoader(Loader): - def load(self, buf): - if isinstance(buf, memoryview): - buf = bytes(buf) - return np.array(buf.decode()[1:-1].split(","), dtype=np.float32) - - -def register_vector(conn: psycopg.Connection): - info = TypeInfo.fetch(conn=conn, name="vector") - register_vector_type(conn, info) - sinfo = TypeInfo.fetch(conn=conn, name="svector") - register_svector_type(conn, sinfo) - - -def register_vector_type(conn: psycopg.Connection, info: TypeInfo): - if info is None: - raise ValueError("vector type not found") - info.register(conn) - - class VectorTextDumper(VectorDumper): - oid = info.oid - - adapters = conn.adapters - adapters.register_dumper(list, VectorTextDumper) - adapters.register_dumper(np.ndarray, VectorTextDumper) - adapters.register_loader(info.oid, VectorLoader) - - -def register_svector_type(conn: psycopg.Connection, info: TypeInfo): - if info is None: - raise ValueError("vector type not found") - info.register(conn) - - class SVectorTextDumper(SVectorDumper): - oid = info.oid - - adapters = conn.adapters - adapters.register_dumper(list, SVectorTextDumper) - adapters.register_dumper(np.ndarray, SVectorTextDumper) - adapters.register_loader(info.oid, SVectorLoader) - - -vocab = 30522 - -conn = psycopg.connect( - "postgresql://postgres:password@127.0.0.1:5432/", - autocommit=True, - row_factory=dict_row, -) -register_vector(conn) -conn.execute("create extension if not exists vectors;") -conn.execute( - f"create table if not exists sparse (id serial primary key, vec svector({vocab}), text text);" +dim = 768 +namespace = "sparse_test" +client = httpx.Client(base_url="http://127.0.0.1:8000") +resp = client.post( + "/api/namespace", + json={"name": namespace, "vector_dim": dim, "sparse_vector_dim": vocab}, ) - -indices = np.array([10, 233]) -values = np.array([0.23, 0.11]) -z = np.zeros(30522) -z[indices] = values - -conn.execute( - "insert into sparse (vec, text) values (%s, %s)", (z, "hello there"), binary=True +resp.raise_for_status() + +for text in [ + "the early bird, not really catches the worm", + "Rust is not always faster than Python", + "Life is short, I use Python", +]: + resp = client.post( + "/api/doc", + json={ + "namespace": namespace, + "text": text, + }, + ) + resp.raise_for_status() + +resp = client.post( + "/api/query", json={"namespace": namespace, "query": "Who creates faster Python?"} ) - -cur = conn.execute("select * from sparse;") -for row in cur.fetchall(): - print(row["id"], row["text"]) +resp.raise_for_status() +print([(doc["id"], doc["text"]) for doc in resp.json()])