diff --git a/qtext/pg_client.py b/qtext/pg_client.py index 6448bd4..1d18216 100644 --- a/qtext/pg_client.py +++ b/qtext/pg_client.py @@ -4,6 +4,7 @@ import numpy as np import psycopg +from psycopg import sql from psycopg.adapt import Dumper, Loader from psycopg.rows import dict_row from psycopg.types import TypeInfo @@ -131,9 +132,14 @@ def add_doc(self, req): attributes.remove(primary_id) placeholders = [getattr(req, key) for key in attributes] self.conn.execute( - ( - f"INSERT INTO {req.namespace} ({','.join(attributes)})" - f"VALUES ({','.join(['%s']*len(placeholders))})" + sql.SQL( + "INSERT INTO {table} ({fields}) VALUES ({placeholders})" + ).format( + table=sql.Identifier(req.namespace), + fields=sql.SQL(",").join(map(sql.Identifier, attributes)), + placeholders=sql.SQL(",").join( + sql.Placeholder() for _ in range(len(placeholders)) + ), ), placeholders, ) diff --git a/qtext/schema.py b/qtext/schema.py index 3389fd8..e5918d1 100644 --- a/qtext/schema.py +++ b/qtext/schema.py @@ -17,6 +17,7 @@ StrType, UnionType, ) +from psycopg import sql from qtext.spec import Record @@ -182,26 +183,26 @@ def create_table(self, name: str, dim: int, sparse_dim: int) -> str: "Sparse vector dimension is required when schema has sparse index" ) - sql = f"CREATE TABLE IF NOT EXISTS {name} (" + create_table_sql = f"CREATE TABLE IF NOT EXISTS {name} (" for i, f in enumerate(self.fields): if f.name == self.primary_key: - sql += f"{f.name} SERIAL PRIMARY KEY, " + create_table_sql += f"{f.name} SERIAL PRIMARY KEY, " continue elif f.name == self.vector_column: - sql += f"{f.name} vector({dim}) " + create_table_sql += f"{f.name} vector({dim}) " elif f.name == self.sparse_column: - sql += f"{f.name} svector({sparse_dim}) " + create_table_sql += f"{f.name} svector({sparse_dim}) " else: - sql += f"{f.name} {Querier.to_pg_type(f.type)} " + create_table_sql += f"{f.name} {Querier.to_pg_type(f.type)} " if f.required: - sql += "NOT NULL " + create_table_sql += "NOT NULL " if f.default not in (NODEFAULT, None): - sql += f"DEFAULT {f.default} " + create_table_sql += f"DEFAULT {f.default} " if i < len(self.fields) - 1: - sql += ", " - return sql + ");" + create_table_sql += ", " + return create_table_sql + ");" def has_vector_index(self) -> bool: return self.vector_column is not None @@ -212,66 +213,88 @@ def has_sparse_index(self) -> bool: def has_text_index(self) -> bool: return len(self.text_columns) > 0 - def vector_index(self, table: str) -> str: + def vector_index(self, table: str) -> sql.SQL: """ This assumes that all the vectors are normalized, so inner product is used since it can be computed efficiently. """ if not self.has_vector_index(): return "" - return ( - f"CREATE INDEX IF NOT EXISTS {table}_vectors ON {table} USING " - f"vectors ({self.vector_column} vector_dot_ops);" + return sql.SQL( + "CREATE INDEX IF NOT EXISTS {vector_index} ON {table} USING " + "vectors ({vector_column} vector_dot_ops);" + ).format( + table=sql.Identifier(table), + vector_index=sql.Identifier(f"{table}_vectors"), + vector_column=sql.Identifier(self.vector_column), ) - def sparse_index(self, table: str) -> str: + def sparse_index(self, table: str) -> sql.SQL: if not self.has_sparse_index(): return "" - return ( - f"CREATE INDEX IF NOT EXISTS {table}_sparse ON {table} USING " - f"vectors ({self.sparse_column} svector_dot_ops);" + return sql.SQL( + "CREATE INDEX IF NOT EXISTS {sparse_index} ON {table} USING " + "vectors ({sparse_column} svector_dot_ops);" + ).format( + table=sql.Identifier(table), + sparse_index=sql.Identifier(f"{table}_sparse"), + sparse_column=sql.Identifier(self.sparse_column), ) - def text_index(self, table: str) -> str: + def text_index(self, table: str) -> sql.SQL: """ refer to https://dba.stackexchange.com/a/164081 """ if not self.has_text_index(): return "" indexed_columns = ( - self.text_columns[0] + sql.Identifier(self.text_columns[0]) if len(self.text_columns) == 1 else f"immutable_concat_ws('. ', {', '.join(self.text_columns)})" ) - return ( + return sql.SQL( "CREATE OR REPLACE FUNCTION immutable_concat_ws(text, VARIADIC text[]) " "RETURNS text LANGUAGE sql IMMUTABLE PARALLEL SAFE " "RETURN array_to_string($2, $1);" - f"ALTER TABLE {table} ADD COLUMN IF NOT EXISTS fts_vector tsvector " - f"GENERATED ALWAYS AS (to_tsvector('english', {indexed_columns})) stored; " - f"CREATE INDEX IF NOT EXISTS ts_idx ON {table} USING GIN (fts_vector);" + "ALTER TABLE {table} ADD COLUMN IF NOT EXISTS fts_vector tsvector " + "GENERATED ALWAYS AS (to_tsvector('english', {indexed_columns})) stored; " + "CREATE INDEX IF NOT EXISTS ts_idx ON {table} USING GIN (fts_vector);" + ).format( + table=sql.Identifier(table), + indexed_columns=indexed_columns, ) - def vector_query(self, table: str) -> str: - columns = ", ".join(f.name for f in self.fields) - return ( - f"SELECT {columns}, {self.vector_column} <#> %s AS rank " - f"FROM {table} ORDER by rank LIMIT %s;" + def vector_query(self, table: str) -> sql.SQL: + columns = sql.SQL(", ").join(sql.Identifier(f.name) for f in self.fields) + return sql.SQL( + "SELECT {columns}, {vector_column} <#> %s AS rank " + "FROM {table} ORDER by rank LIMIT %s;" + ).format( + table=sql.Identifier(table), + columns=columns, + vector_column=sql.Identifier(self.vector_column), ) - def sparse_query(self, table: str) -> str: - columns = ", ".join(f.name for f in self.fields) - return ( - f"SELECT {columns}, {self.sparse_column} <#> %s AS rank " - f"FROM {table} ORDER by rank LIMIT %s;" + def sparse_query(self, table: str) -> sql.SQL: + columns = sql.SQL(", ").join(sql.Identifier(f.name) for f in self.fields) + return sql.SQL( + "SELECT {columns}, {sparse_column} <#> %s AS rank " + "FROM {table} ORDER by rank LIMIT %s;" + ).format( + table=sql.Identifier(table), + columns=columns, + sparse_column=sql.Identifier(self.sparse_column), ) - def text_query(self, table: str) -> str: - columns = ", ".join(f.name for f in self.fields) - return ( - f"SELECT {columns}, ts_rank_cd(fts_vector, query) AS rank " - f"FROM {table}, to_tsquery(%s) query " + def text_query(self, table: str) -> sql.SQL: + columns = sql.SQL(", ").join(sql.Identifier(f.name) for f in self.fields) + return sql.SQL( + "SELECT {columns}, ts_rank_cd(fts_vector, query) AS rank " + "FROM {table}, to_tsquery(%s) query " "WHERE fts_vector @@ query order by rank desc LIMIT %s;" + ).format( + table=sql.Identifier(table), + columns=columns, ) def columns(self) -> list[str]: diff --git a/test.py b/test.py index 9d61fae..40ae6f3 100644 --- a/test.py +++ b/test.py @@ -4,9 +4,13 @@ namespace = "document" dim = 768 +vocab = 30522 client = httpx.Client(base_url="http://127.0.0.1:8000") -resp = client.post("/api/namespace", json={"name": namespace, "vector_dim": dim}) +resp = client.post( + "/api/namespace", + json={"name": namespace, "vector_dim": dim, "sparse_vector_dim": vocab}, +) resp.raise_for_status() for i, text in enumerate( [ diff --git a/test_cohere_wiki.py b/test_cohere_wiki.py index 4885dc2..670053d 100644 --- a/test_cohere_wiki.py +++ b/test_cohere_wiki.py @@ -7,8 +7,12 @@ namespace = "cohere_wiki" dim = 768 +vocab = 30522 client = httpx.Client(base_url="http://127.0.0.1:8000") -resp = client.post("/api/namespace", json={"name": namespace, "vector_dim": dim}) +resp = client.post( + "/api/namespace", + json={"name": namespace, "vector_dim": dim, "sparse_vector_dim": vocab}, +) resp.raise_for_status() docs = load_dataset( diff --git a/test_sparse.py b/test_sparse.py deleted file mode 100644 index 048c27f..0000000 --- a/test_sparse.py +++ /dev/null @@ -1,31 +0,0 @@ -import httpx - -vocab = 30522 -dim = 768 -namespace = "sparse_test" -client = httpx.Client(base_url="http://127.0.0.1:8000") -resp = client.post( - "/api/namespace", - json={"name": namespace, "vector_dim": dim, "sparse_vector_dim": vocab}, -) -resp.raise_for_status() - -for text in [ - "the early bird, not really catches the worm", - "Rust is not always faster than Python", - "Life is short, I use Python", -]: - resp = client.post( - "/api/doc", - json={ - "namespace": namespace, - "text": text, - }, - ) - resp.raise_for_status() - -resp = client.post( - "/api/query", json={"namespace": namespace, "query": "Who creates faster Python?"} -) -resp.raise_for_status() -print([(doc["id"], doc["text"]) for doc in resp.json()])