From b36ce70c5f141c3cfb3a2a6c4c955055f779ab7e Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Mon, 5 Aug 2024 15:01:10 -0700 Subject: [PATCH] Added support for CopyFrom with string values - closes #16 --- CHANGELOG.md | 4 ++++ pgx/sparsevec.go | 12 +++++++++++- pgx/vector.go | 12 +++++++++++- pgx_test.go | 21 ++++++++++++++++++++- 4 files changed, 46 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b4b84c..4f146fa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.2.2 (unreleased) + +- Added support for `CopyFrom` with `string` values + ## 0.2.1 (2024-07-23) - Added `pgx` package diff --git a/pgx/sparsevec.go b/pgx/sparsevec.go index 6413026..e70b1fe 100644 --- a/pgx/sparsevec.go +++ b/pgx/sparsevec.go @@ -41,14 +41,19 @@ func (encodePlanSparseVectorCodecBinary) Encode(value any, buf []byte) (newBuf [ type scanPlanSparseVectorCodecBinary struct{} +type scanPlanSparseVectorCodecText struct{} + func (SparseVectorCodec) PlanScan(m *pgtype.Map, oid uint32, format int16, target any) pgtype.ScanPlan { _, ok := target.(*pgvector.SparseVector) if !ok { return nil } - if format == pgx.BinaryFormatCode { + switch format { + case pgx.BinaryFormatCode: return scanPlanSparseVectorCodecBinary{} + case pgx.TextFormatCode: + return scanPlanSparseVectorCodecText{} } return nil @@ -59,6 +64,11 @@ func (scanPlanSparseVectorCodecBinary) Scan(src []byte, dst any) error { return v.DecodeBinary(src) } +func (scanPlanSparseVectorCodecText) Scan(src []byte, dst any) error { + v := (dst).(*pgvector.SparseVector) + return v.Scan(src) +} + func (c SparseVectorCodec) DecodeDatabaseSQLValue(m *pgtype.Map, oid uint32, format int16, src []byte) (driver.Value, error) { return c.DecodeValue(m, oid, format, src) } diff --git a/pgx/vector.go b/pgx/vector.go index 70b8f17..a2de21e 100644 --- a/pgx/vector.go +++ b/pgx/vector.go @@ -41,14 +41,19 @@ func (encodePlanVectorCodecBinary) Encode(value any, buf []byte) (newBuf []byte, type scanPlanVectorCodecBinary struct{} +type scanPlanVectorCodecText struct{} + func (VectorCodec) PlanScan(m *pgtype.Map, oid uint32, format int16, target any) pgtype.ScanPlan { _, ok := target.(*pgvector.Vector) if !ok { return nil } - if format == pgx.BinaryFormatCode { + switch format { + case pgx.BinaryFormatCode: return scanPlanVectorCodecBinary{} + case pgx.TextFormatCode: + return scanPlanVectorCodecText{} } return nil @@ -59,6 +64,11 @@ func (scanPlanVectorCodecBinary) Scan(src []byte, dst any) error { return v.DecodeBinary(src) } +func (scanPlanVectorCodecText) Scan(src []byte, dst any) error { + v := (dst).(*pgvector.Vector) + return v.Scan(src) +} + func (c VectorCodec) DecodeDatabaseSQLValue(m *pgtype.Map, oid uint32, format int16, src []byte) (driver.Value, error) { return c.DecodeValue(m, oid, format, src) } diff --git a/pgx_test.go b/pgx_test.go index 56221e1..4e6734e 100644 --- a/pgx_test.go +++ b/pgx_test.go @@ -59,7 +59,7 @@ func TestPgx(t *testing.T) { panic(err) } - _, err = conn.Exec(ctx, "CREATE TABLE pgx_items (id bigserial PRIMARY KEY, embedding vector(3), half_embedding halfvec(3), binary_embedding bit(3), sparse_embedding sparsevec(3))") + _, err = conn.Exec(ctx, "CREATE TABLE pgx_items (id bigserial, embedding vector(3), half_embedding halfvec(3), binary_embedding bit(3), sparse_embedding sparsevec(3), score float8)") if err != nil { panic(err) } @@ -115,4 +115,23 @@ func TestPgx(t *testing.T) { if distances[0] != 0 || distances[1] != 1 || distances[2] != math.Sqrt(3) { t.Error() } + + var item PgxItem + row := conn.QueryRow(ctx, "SELECT embedding, sparse_embedding FROM pgx_items ORDER BY id LIMIT 1", pgx.QueryResultFormats{pgx.TextFormatCode, pgx.TextFormatCode}) + err = row.Scan(&item.Embedding, &item.SparseEmbedding) + if err != nil { + panic(err) + } + + _, err = conn.CopyFrom( + ctx, + pgx.Identifier{"pgx_items"}, + []string{"embedding", "binary_embedding", "sparse_embedding"}, + pgx.CopyFromSlice(1, func(i int) ([]any, error) { + return []interface{}{"[1,2,3]", "101", "{1:1,2:2,3:3}/3"}, nil + }), + ) + if err != nil { + panic(err) + } }