From 88d27fa9fad579c8da12129b3a6f90de100aa214 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Mon, 3 Jun 2024 19:28:35 -0700 Subject: [PATCH] Added support for JSON serialization - closes #14 --- CHANGELOG.md | 1 + halfvec.go | 11 +++++++++++ halfvec_test.go | 31 +++++++++++++++++++++++++++++++ vector.go | 11 +++++++++++ vector_test.go | 31 +++++++++++++++++++++++++++++++ 5 files changed, 85 insertions(+) create mode 100644 halfvec_test.go create mode 100644 vector_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 4c5989f..235a9e6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ ## 0.2.0 (unreleased) - Added support for `halfvec` and `sparsevec` types +- Added support for JSON serialization - Improved performance of serialization - Dropped support for Go < 1.21 diff --git a/halfvec.go b/halfvec.go index e3abaa5..739afec 100644 --- a/halfvec.go +++ b/halfvec.go @@ -3,6 +3,7 @@ package pgvector import ( "database/sql" "database/sql/driver" + "encoding/json" "fmt" "strconv" "strings" @@ -75,3 +76,13 @@ var _ driver.Valuer = (*HalfVector)(nil) func (v HalfVector) Value() (driver.Value, error) { return v.String(), nil } + +// MarshalJSON implements the json.Marshaler interface. +func (v HalfVector) MarshalJSON() ([]byte, error) { + return json.Marshal(v.vec) +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (v *HalfVector) UnmarshalJSON(data []byte) error { + return json.Unmarshal(data, &v.vec) +} diff --git a/halfvec_test.go b/halfvec_test.go new file mode 100644 index 0000000..b466f78 --- /dev/null +++ b/halfvec_test.go @@ -0,0 +1,31 @@ +package pgvector_test + +import ( + "encoding/json" + "reflect" + "testing" + + "github.com/pgvector/pgvector-go" +) + +func TestHalfVectorMarshal(t *testing.T) { + vec := pgvector.NewHalfVector([]float32{1, 2, 3}) + data, err := json.Marshal(vec) + if err != nil { + panic(err) + } + if string(data) != `[1,2,3]` { + t.Errorf("Bad marshal") + } +} + +func TestHalfVectorUnmarshal(t *testing.T) { + var vec pgvector.HalfVector + err := json.Unmarshal([]byte(`[1,2,3]`), &vec) + if err != nil { + panic(err) + } + if !reflect.DeepEqual(vec.Slice(), []float32{1, 2, 3}) { + t.Errorf("Bad unmarshal") + } +} diff --git a/vector.go b/vector.go index 619a1ef..f29d087 100644 --- a/vector.go +++ b/vector.go @@ -3,6 +3,7 @@ package pgvector import ( "database/sql" "database/sql/driver" + "encoding/json" "fmt" "strconv" "strings" @@ -75,3 +76,13 @@ var _ driver.Valuer = (*Vector)(nil) func (v Vector) Value() (driver.Value, error) { return v.String(), nil } + +// MarshalJSON implements the json.Marshaler interface. +func (v Vector) MarshalJSON() ([]byte, error) { + return json.Marshal(v.vec) +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (v *Vector) UnmarshalJSON(data []byte) error { + return json.Unmarshal(data, &v.vec) +} diff --git a/vector_test.go b/vector_test.go new file mode 100644 index 0000000..118d589 --- /dev/null +++ b/vector_test.go @@ -0,0 +1,31 @@ +package pgvector_test + +import ( + "encoding/json" + "reflect" + "testing" + + "github.com/pgvector/pgvector-go" +) + +func TestVectorMarshal(t *testing.T) { + vec := pgvector.NewVector([]float32{1, 2, 3}) + data, err := json.Marshal(vec) + if err != nil { + panic(err) + } + if string(data) != `[1,2,3]` { + t.Errorf("Bad marshal") + } +} + +func TestVectorUnmarshal(t *testing.T) { + var vec pgvector.Vector + err := json.Unmarshal([]byte(`[1,2,3]`), &vec) + if err != nil { + panic(err) + } + if !reflect.DeepEqual(vec.Slice(), []float32{1, 2, 3}) { + t.Errorf("Bad unmarshal") + } +}