Skip to content

Commit

Permalink
feat: use secure sql format for identifiers
Browse files Browse the repository at this point in the history
Signed-off-by: Keming <[email protected]>
  • Loading branch information
kemingy committed Apr 5, 2024
1 parent c8940db commit 7473f30
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 74 deletions.
12 changes: 9 additions & 3 deletions qtext/pg_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import psycopg
from psycopg import sql
from psycopg.adapt import Dumper, Loader
from psycopg.rows import dict_row
from psycopg.types import TypeInfo
Expand Down Expand Up @@ -131,9 +132,14 @@ def add_doc(self, req):
attributes.remove(primary_id)
placeholders = [getattr(req, key) for key in attributes]
self.conn.execute(
(
f"INSERT INTO {req.namespace} ({','.join(attributes)})"
f"VALUES ({','.join(['%s']*len(placeholders))})"
sql.SQL(
"INSERT INTO {table} ({fields}) VALUES ({placeholders})"
).format(
table=sql.Identifier(req.namespace),
fields=sql.SQL(",").join(map(sql.Identifier, attributes)),
placeholders=sql.SQL(",").join(
sql.Placeholder() for _ in range(len(placeholders))
),
),
placeholders,
)
Expand Down
99 changes: 61 additions & 38 deletions qtext/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
StrType,
UnionType,
)
from psycopg import sql

from qtext.spec import Record

Expand Down Expand Up @@ -182,26 +183,26 @@ def create_table(self, name: str, dim: int, sparse_dim: int) -> str:
"Sparse vector dimension is required when schema has sparse index"
)

sql = f"CREATE TABLE IF NOT EXISTS {name} ("
create_table_sql = f"CREATE TABLE IF NOT EXISTS {name} ("
for i, f in enumerate(self.fields):
if f.name == self.primary_key:
sql += f"{f.name} SERIAL PRIMARY KEY, "
create_table_sql += f"{f.name} SERIAL PRIMARY KEY, "
continue
elif f.name == self.vector_column:
sql += f"{f.name} vector({dim}) "
create_table_sql += f"{f.name} vector({dim}) "
elif f.name == self.sparse_column:
sql += f"{f.name} svector({sparse_dim}) "
create_table_sql += f"{f.name} svector({sparse_dim}) "
else:
sql += f"{f.name} {Querier.to_pg_type(f.type)} "
create_table_sql += f"{f.name} {Querier.to_pg_type(f.type)} "

if f.required:
sql += "NOT NULL "
create_table_sql += "NOT NULL "
if f.default not in (NODEFAULT, None):
sql += f"DEFAULT {f.default} "
create_table_sql += f"DEFAULT {f.default} "

if i < len(self.fields) - 1:
sql += ", "
return sql + ");"
create_table_sql += ", "
return create_table_sql + ");"

def has_vector_index(self) -> bool:
return self.vector_column is not None
Expand All @@ -212,66 +213,88 @@ def has_sparse_index(self) -> bool:
def has_text_index(self) -> bool:
return len(self.text_columns) > 0

def vector_index(self, table: str) -> str:
def vector_index(self, table: str) -> sql.SQL:
"""
This assumes that all the vectors are normalized, so inner product
is used since it can be computed efficiently.
"""
if not self.has_vector_index():
return ""
return (
f"CREATE INDEX IF NOT EXISTS {table}_vectors ON {table} USING "
f"vectors ({self.vector_column} vector_dot_ops);"
return sql.SQL(
"CREATE INDEX IF NOT EXISTS {vector_index} ON {table} USING "
"vectors ({vector_column} vector_dot_ops);"
).format(
table=sql.Identifier(table),
vector_index=sql.Identifier(f"{table}_vectors"),
vector_column=sql.Identifier(self.vector_column),
)

def sparse_index(self, table: str) -> str:
def sparse_index(self, table: str) -> sql.SQL:
if not self.has_sparse_index():
return ""
return (
f"CREATE INDEX IF NOT EXISTS {table}_sparse ON {table} USING "
f"vectors ({self.sparse_column} svector_dot_ops);"
return sql.SQL(
"CREATE INDEX IF NOT EXISTS {sparse_index} ON {table} USING "
"vectors ({sparse_column} svector_dot_ops);"
).format(
table=sql.Identifier(table),
sparse_index=sql.Identifier(f"{table}_sparse"),
sparse_column=sql.Identifier(self.sparse_column),
)

def text_index(self, table: str) -> str:
def text_index(self, table: str) -> sql.SQL:
"""
refer to https://dba.stackexchange.com/a/164081
"""
if not self.has_text_index():
return ""
indexed_columns = (
self.text_columns[0]
sql.Identifier(self.text_columns[0])
if len(self.text_columns) == 1
else f"immutable_concat_ws('. ', {', '.join(self.text_columns)})"
)
return (
return sql.SQL(
"CREATE OR REPLACE FUNCTION immutable_concat_ws(text, VARIADIC text[]) "
"RETURNS text LANGUAGE sql IMMUTABLE PARALLEL SAFE "
"RETURN array_to_string($2, $1);"
f"ALTER TABLE {table} ADD COLUMN IF NOT EXISTS fts_vector tsvector "
f"GENERATED ALWAYS AS (to_tsvector('english', {indexed_columns})) stored; "
f"CREATE INDEX IF NOT EXISTS ts_idx ON {table} USING GIN (fts_vector);"
"ALTER TABLE {table} ADD COLUMN IF NOT EXISTS fts_vector tsvector "
"GENERATED ALWAYS AS (to_tsvector('english', {indexed_columns})) stored; "
"CREATE INDEX IF NOT EXISTS ts_idx ON {table} USING GIN (fts_vector);"
).format(
table=sql.Identifier(table),
indexed_columns=indexed_columns,
)

def vector_query(self, table: str) -> str:
columns = ", ".join(f.name for f in self.fields)
return (
f"SELECT {columns}, {self.vector_column} <#> %s AS rank "
f"FROM {table} ORDER by rank LIMIT %s;"
def vector_query(self, table: str) -> sql.SQL:
columns = sql.SQL(", ").join(sql.Identifier(f.name) for f in self.fields)
return sql.SQL(
"SELECT {columns}, {vector_column} <#> %s AS rank "
"FROM {table} ORDER by rank LIMIT %s;"
).format(
table=sql.Identifier(table),
columns=columns,
vector_column=sql.Identifier(self.vector_column),
)

def sparse_query(self, table: str) -> str:
columns = ", ".join(f.name for f in self.fields)
return (
f"SELECT {columns}, {self.sparse_column} <#> %s AS rank "
f"FROM {table} ORDER by rank LIMIT %s;"
def sparse_query(self, table: str) -> sql.SQL:
columns = sql.SQL(", ").join(sql.Identifier(f.name) for f in self.fields)
return sql.SQL(
"SELECT {columns}, {sparse_column} <#> %s AS rank "
"FROM {table} ORDER by rank LIMIT %s;"
).format(
table=sql.Identifier(table),
columns=columns,
sparse_column=sql.Identifier(self.sparse_column),
)

def text_query(self, table: str) -> str:
columns = ", ".join(f.name for f in self.fields)
return (
f"SELECT {columns}, ts_rank_cd(fts_vector, query) AS rank "
f"FROM {table}, to_tsquery(%s) query "
def text_query(self, table: str) -> sql.SQL:
columns = sql.SQL(", ").join(sql.Identifier(f.name) for f in self.fields)
return sql.SQL(
"SELECT {columns}, ts_rank_cd(fts_vector, query) AS rank "
"FROM {table}, to_tsquery(%s) query "
"WHERE fts_vector @@ query order by rank desc LIMIT %s;"
).format(
table=sql.Identifier(table),
columns=columns,
)

def columns(self) -> list[str]:
Expand Down
6 changes: 5 additions & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@

namespace = "document"
dim = 768
vocab = 30522

client = httpx.Client(base_url="http://127.0.0.1:8000")
resp = client.post("/api/namespace", json={"name": namespace, "vector_dim": dim})
resp = client.post(
"/api/namespace",
json={"name": namespace, "vector_dim": dim, "sparse_vector_dim": vocab},
)
resp.raise_for_status()
for i, text in enumerate(
[
Expand Down
6 changes: 5 additions & 1 deletion test_cohere_wiki.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@

namespace = "cohere_wiki"
dim = 768
vocab = 30522
client = httpx.Client(base_url="http://127.0.0.1:8000")
resp = client.post("/api/namespace", json={"name": namespace, "vector_dim": dim})
resp = client.post(
"/api/namespace",
json={"name": namespace, "vector_dim": dim, "sparse_vector_dim": vocab},
)
resp.raise_for_status()

docs = load_dataset(
Expand Down
31 changes: 0 additions & 31 deletions test_sparse.py

This file was deleted.

0 comments on commit 7473f30

Please sign in to comment.