Skip to content

Commit

Permalink
feat: support sparse vector (#1)
Browse files Browse the repository at this point in the history
Signed-off-by: Keming <[email protected]>
  • Loading branch information
kemingy authored Mar 27, 2024
1 parent c03ea7f commit 909e0f2
Show file tree
Hide file tree
Showing 13 changed files with 240 additions and 123 deletions.
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ End-to-end service to query the text.
- [x] full text search (Postgres GIN + text search)
- [x] vector similarity search ([pgvecto.rs](https://github.com/tensorchord/pgvecto.rs) HNSW)
- [x] generate vector if not provided
- [ ] sparse search
- [x] sparse search ([pgvecto.rs](https://github.com/tensorchord/pgvecto.rs) HNSW)
- [ ] filtering
- [x] reranking with [reranker](https://github.com/kemingy/reranker)
- [x] semantic highlight
Expand All @@ -20,7 +20,8 @@ docker compose -f docker/compose.yaml up -d server

Some of the dependent services can be opt-out:
- `emb`: used to generate embedding for query and documents
- `colbert`: used to provide the semantic highlight feature
- `sparse`: used to generate sparse embedding for query and documents
- `highlight`: used to provide the semantic highlight feature
- `encoder`: rerank with cross-encoder model, you can choose other methods or other online services

For the client example, check:
Expand All @@ -29,7 +30,7 @@ For the client example, check:

## API

- `/api/namespace` POST: create a new namespace and configure the text + vector index
- `/api/namespace` POST: create a new namespace and configure the index
- `/api/doc` POST: add a new doc
- `/api/query` POST: query the docs
- `/api/highlight` POST: semantic highlight
Expand All @@ -50,7 +51,7 @@ This project has most of the components you need for the RAG except for the last
> If you already have the table in Postgres, you will be responsible for the text-indexing and vector-indexing part.
1. Define a `dataclass` that includes the **necessary** columns as class attributes
- annotate the `primary_key`, `text_index`, `vector_index` with metadata
- annotate the `primary_key`, `text_index`, `vector_index`, `sparse_index` with metadata (not all the columns are required, only the necessary ones)
- attributes without default value or default factory is treated as required when you add new docs
2. Implement the `to_record` and `from_record` methods to be used in the reranking stage
3. Change the `config.vector_store.schema` to the class you have defined
Expand Down
14 changes: 11 additions & 3 deletions docker/compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@ services:
- "8080:8000"

pg:
image: "tensorchord/pgvecto-rs:pg16-v0.2.1"
image: "tensorchord/pgvecto-rs:pg16-v0.3.0-alpha.1"
environment:
- POSTGRES_PASSWORD=password
ports:
- "5432:5432"
volumes:
- "/tmp/qtext_pgdata:/var/lib/postgresql/data"

colbert:
highlight:
image: "kemingy/colbert-highlight"
ports:
- "8081:8000"
Expand All @@ -27,6 +27,13 @@ services:
ports:
- "8082:8000"

sparse:
image: "kemingy/spladepp"
environment:
- MOSEC_TIMEOUT=10000
ports:
- "8083:8000"

server:
build:
context: ../
Expand All @@ -36,5 +43,6 @@ services:
depends_on:
- pg
- emb
- colbert
- highlight
- encoder
- sparse
5 changes: 4 additions & 1 deletion docker/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
"api_endpoint": "http://emb:8000"
},
"highlight": {
"addr": "http://colbert:8000"
"addr": "http://highlight:8000"
},
"sparse": {
"addr": "http://sparse:8000"
},
"server": {
"log_level": 10
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies = [
"reranker~=0.2",
"openai~=1.12.0",
"defspec~=0.1.1",
"httpx~=0.27",
]
[project.optional-dependencies]
dev = [
Expand Down
7 changes: 7 additions & 0 deletions qtext/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ class EmbeddingConfig(msgspec.Struct, kw_only=True, frozen=True):
timeout: int = 300


class SparseEmbeddingConfig(msgspec.Struct, kw_only=True, frozen=True):
addr: str = "http://127.0.0.1:8083"
timeout: int = 10
dim: int = 30522


class RankConfig(msgspec.Struct, kw_only=True, frozen=True):
ranker: Type[Ranker] = CrossEncoderClient
params: dict[str, str] = msgspec.field(
Expand All @@ -49,6 +55,7 @@ class Config(msgspec.Struct, kw_only=True, frozen=True):
server: ServerConfig = ServerConfig()
vector_store: VectorStoreConfig = VectorStoreConfig()
embedding: EmbeddingConfig = EmbeddingConfig()
sparse: SparseEmbeddingConfig = SparseEmbeddingConfig()
ranker: RankConfig = RankConfig()
highlight: HighlightConfig = HighlightConfig()

Expand Down
29 changes: 29 additions & 0 deletions qtext/emb_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from __future__ import annotations

import httpx
import msgspec
import openai

from qtext.log import logger
from qtext.spec import SparseEmbedding
from qtext.utils import time_it


Expand All @@ -23,3 +27,28 @@ def embedding(self, text: str | list[str]) -> list[float]:
if len(response.data) > 1:
return [data.embedding for data in response.data]
return response.data[0].embedding


class SparseEmbeddingClient:
def __init__(self, endpoint: str, dim: int, timeout: int) -> None:
self.dim = dim
self.client = httpx.Client(base_url=endpoint, timeout=timeout)
self.decoder = msgspec.json.Decoder(type=list[SparseEmbedding])

@time_it
def sparse_embedding(
self, text: str | list[str]
) -> list[SparseEmbedding] | SparseEmbedding:
resp = self.client.post("/inference", json=text)
if resp.is_error:
logger.info(
"failed to call sparse embedding [%d]: %s",
resp.status_code,
resp.content,
)
resp.raise_for_status()
sparse = self.decoder.decode(resp.content)

if len(sparse) == 1:
return sparse[0]
return sparse
30 changes: 24 additions & 6 deletions qtext/engine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from qtext.config import Config
from qtext.emb_client import EmbeddingClient
from qtext.emb_client import EmbeddingClient, SparseEmbeddingClient
from qtext.highlight_client import ENGLISH_STOPWORDS, HighlightClient
from qtext.pg_client import PgVectorsClient
from qtext.schema import DefaultTable, Querier
Expand All @@ -27,6 +27,11 @@ def __init__(self, config: Config) -> None:
endpoint=config.embedding.api_endpoint,
timeout=config.embedding.timeout,
)
self.sparse_client = SparseEmbeddingClient(
endpoint=config.sparse.addr,
dim=config.sparse.dim,
timeout=config.sparse.timeout,
)
self.ranker = config.ranker.ranker(**config.ranker.params)

@time_it
Expand All @@ -36,10 +41,17 @@ def add_namespace(self, req: AddNamespaceRequest) -> None:
@time_it
def add_doc(self, req) -> None:
if self.querier.has_vector_index():
text = self.querier.retrieve_text(req)
vector = self.querier.retrieve_vector(req)
if not vector:
if not vector and self.querier.has_text_index():
text = self.querier.retrieve_text(req)
self.querier.fill_vector(req, self.emb_client.embedding(text=text))
if self.querier.has_sparse_index():
sparse = self.querier.retrieve_sparse_vector(req)
if not sparse and self.querier.has_text_index():
text = self.querier.retrieve_text(req)
self.querier.fill_sparse_vector(
req, self.sparse_client.sparse_embedding(text=text)
)
self.pg_client.add_doc(req)

@time_it
Expand All @@ -48,18 +60,24 @@ def rank(
req: QueryDocRequest,
text_res: list[DefaultTable],
vector_res: list[DefaultTable],
sparse_res: list[DefaultTable],
) -> list[DefaultTable]:
docs = self.querier.combine_vector_text(vec_res=vector_res, text_res=text_res)
docs = self.querier.combine_vector_text(
vec_res=vector_res, sparse_res=sparse_res, text_res=text_res
)
ranked = self.ranker.rank(req.to_record(), docs)
return [DefaultTable.from_record(record) for record in ranked]

@time_it
def query(self, req: QueryDocRequest) -> list[DefaultTable]:
kw_results = self.pg_client.query_text(req)
if self.querier.has_vector_index() and not self.querier.retrieve_vector(req):
if self.querier.has_vector_index() and not req.vector:
req.vector = self.emb_client.embedding(req.query)
if self.querier.has_sparse_index() and not req.sparse_vector:
req.sparse_vector = self.sparse_client.sparse_embedding(req.query)
vec_results = self.pg_client.query_vector(req)
return self.rank(req, kw_results, vec_results)
sparse_results = self.pg_client.query_sparse_vector(req)
return self.rank(req, kw_results, vec_results, sparse_results)

@time_it
def highlight(self, req: HighlightRequest) -> HighlightResponse:
Expand Down
2 changes: 1 addition & 1 deletion qtext/highlight_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def highlight_score(
resp = self.client.post("/inference", json=[query, *docs])
if resp.is_error:
logger.info(
"failed to call the highlight service [%d], %s",
"failed to call the highlight service [%d]: %s",
resp.status_code,
resp.content,
)
Expand Down
56 changes: 52 additions & 4 deletions qtext/pg_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from qtext.log import logger
from qtext.schema import DefaultTable, Querier
from qtext.spec import AddNamespaceRequest, QueryDocRequest
from qtext.spec import AddNamespaceRequest, QueryDocRequest, SparseEmbedding
from qtext.utils import time_it


Expand Down Expand Up @@ -50,6 +50,34 @@ class VectorTextDumper(VectorDumper):
adapters.register_loader(info.oid, VectorLoader)


class SparseVectorDumper(Dumper):
def dump(self, obj):
if isinstance(obj, np.ndarray):
return f"[{','.join(map(str, obj))}]".encode()
if isinstance(obj, SparseEmbedding):
return obj.to_str().encode()
raise ValueError(f"unsupported type {type(obj)}")


def register_sparse_vector(conn: psycopg.Connection):
info = TypeInfo.fetch(conn=conn, name="svector")
register_svector_type(conn, info)


def register_svector_type(conn: psycopg.Connection, info: TypeInfo):
if info is None:
raise ValueError("vector type not found")
info.register(conn)

class SparseVectorTextDumper(SparseVectorDumper):
oid = info.oid

adapters = conn.adapters
adapters.register_dumper(SparseEmbedding, SparseVectorTextDumper)
adapters.register_dumper(np.ndarray, SparseVectorTextDumper)
adapters.register_loader(info.oid, VectorLoader)


class PgVectorsClient:
def __init__(self, path: str, querier: Querier):
self.path = path
Expand All @@ -61,6 +89,7 @@ def connect(self):
conn = psycopg.connect(self.path, row_factory=dict_row)
conn.execute("CREATE EXTENSION IF NOT EXISTS vectors;")
register_vector(conn)
register_sparse_vector(conn)
conn.commit()
return conn

Expand All @@ -70,11 +99,15 @@ def close(self):
@time_it
def add_namespace(self, req: AddNamespaceRequest):
try:
create_table_sql = self.querier.create_table(req.name, req.vector_dim)
create_table_sql = self.querier.create_table(
req.name, req.vector_dim, req.sparse_vector_dim
)
vector_index_sql = self.querier.vector_index(req.name)
sparse_index_sql = self.querier.sparse_index(req.name)
text_index_sql = self.querier.text_index(req.name)
self.conn.execute(create_table_sql)
self.conn.execute(vector_index_sql)
self.conn.execute(sparse_index_sql)
self.conn.execute(text_index_sql)
self.conn.commit()
except psycopg.errors.Error as err:
Expand Down Expand Up @@ -112,7 +145,7 @@ def query_text(self, req: QueryDocRequest) -> list[DefaultTable]:
(" | ".join(req.query.strip().split(" ")), req.limit),
)
except psycopg.errors.Error as err:
logger.info("pg client query error", exc_info=err)
logger.info("pg client query text error", exc_info=err)
self.conn.rollback()
raise RuntimeError("query text error") from err
return [self.resp_cls(**res) for res in cursor.fetchall()]
Expand All @@ -128,7 +161,22 @@ def query_vector(self, req: QueryDocRequest) -> list[DefaultTable]:
(req.vector, req.limit),
)
except psycopg.errors.Error as err:
logger.info("pg client query error", exc_info=err)
logger.info("pg client query vector error", exc_info=err)
self.conn.rollback()
raise RuntimeError("query vector error") from err
return [self.resp_cls(**res) for res in cursor.fetchall()]

@time_it
def query_sparse_vector(self, req: QueryDocRequest) -> list[DefaultTable]:
if not self.querier.has_sparse_index():
return []
try:
cursor = self.conn.execute(
self.querier.sparse_query(req.namespace),
(req.sparse_vector, req.limit),
)
except psycopg.errors.Error as err:
logger.info("pg client query sparse vector error", exc_info=err)
self.conn.rollback()
raise RuntimeError("query sparse vector error") from err
return [self.resp_cls(**res) for res in cursor.fetchall()]
Loading

0 comments on commit 909e0f2

Please sign in to comment.