Skip to content

Commit

Permalink
Merge pull request uptrace#567 from oGi4i/issue-291
Browse files Browse the repository at this point in the history
Add hstore support
  • Loading branch information
vmihailenco authored Jun 16, 2022
2 parents 11d8c2f + 66b44f7 commit 242e195
Show file tree
Hide file tree
Showing 12 changed files with 561 additions and 37 deletions.
55 changes: 55 additions & 0 deletions dialect/pgdialect/append.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,58 @@ func arrayAppendString(b []byte, s string) []byte {
b = append(b, '"')
return b
}

//------------------------------------------------------------------------------

var mapStringStringType = reflect.TypeOf(map[string]string(nil))

func (d *Dialect) hstoreAppender(typ reflect.Type) schema.AppenderFunc {
kind := typ.Kind()

switch kind {
case reflect.Ptr:
if fn := d.hstoreAppender(typ.Elem()); fn != nil {
return schema.PtrAppender(fn)
}
case reflect.Map:
// ok:
default:
return nil
}

if typ.Key() == stringType && typ.Elem() == stringType {
return appendMapStringStringValue
}

return func(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
err := fmt.Errorf("bun: Hstore(unsupported %s)", v.Type())
return dialect.AppendError(b, err)
}
}

func appendMapStringString(b []byte, m map[string]string) []byte {
if m == nil {
return dialect.AppendNull(b)
}

b = append(b, '\'')

for key, value := range m {
b = arrayAppendString(b, key)
b = append(b, '=', '>')
b = arrayAppendString(b, value)
b = append(b, ',')
}
if len(m) > 0 {
b = b[:len(b)-1] // Strip trailing comma.
}

b = append(b, '\'')

return b
}

func appendMapStringStringValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
m := v.Convert(mapStringStringType).Interface().(map[string]string)
return appendMapStringString(b, m)
}
36 changes: 2 additions & 34 deletions dialect/pgdialect/array_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,13 @@ import (
)

type arrayParser struct {
b []byte
i int

buf []byte
*streamParser
err error
}

func newArrayParser(b []byte) *arrayParser {
p := &arrayParser{
b: b,
i: 1,
streamParser: newStreamParser(b, 1),
}
if len(b) < 2 || b[0] != '{' || b[len(b)-1] != '}' {
p.err = fmt.Errorf("bun: can't parse array: %q", b)
Expand Down Expand Up @@ -135,31 +131,3 @@ func (p *arrayParser) readSubstring() ([]byte, error) {

return p.buf, nil
}

func (p *arrayParser) valid() bool {
return p.i < len(p.b)
}

func (p *arrayParser) readByte() (byte, error) {
if p.valid() {
c := p.b[p.i]
p.i++
return c, nil
}
return 0, io.EOF
}

func (p *arrayParser) unreadByte() {
p.i--
}

func (p *arrayParser) peek() byte {
if p.valid() {
return p.b[p.i]
}
return 0
}

func (p *arrayParser) skipNext() {
p.i++
}
5 changes: 5 additions & 0 deletions dialect/pgdialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ func (d *Dialect) onField(field *schema.Field) {
field.Append = d.arrayAppender(field.StructField.Type)
field.Scan = arrayScanner(field.StructField.Type)
}

if field.DiscoveredSQLType == sqltype.HSTORE {
field.Append = d.hstoreAppender(field.StructField.Type)
field.Scan = hstoreScanner(field.StructField.Type)
}
}

func (d *Dialect) IdentQuote() byte {
Expand Down
73 changes: 73 additions & 0 deletions dialect/pgdialect/hstore.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package pgdialect

import (
"database/sql"
"fmt"
"reflect"

"github.com/uptrace/bun/schema"
)

type HStoreValue struct {
v reflect.Value

append schema.AppenderFunc
scan schema.ScannerFunc
}

// HStore accepts a map[string]string and returns a wrapper for working with PostgreSQL
// hstore data type.
//
// For struct fields you can use hstore tag:
//
// Attrs map[string]string `bun:",hstore"`
func HStore(vi interface{}) *HStoreValue {
v := reflect.ValueOf(vi)
if !v.IsValid() {
panic(fmt.Errorf("bun: HStore(nil)"))
}

typ := v.Type()
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
if typ.Kind() != reflect.Map {
panic(fmt.Errorf("bun: Hstore(unsupported %s)", typ))
}

return &HStoreValue{
v: v,

append: pgDialect.hstoreAppender(v.Type()),
scan: hstoreScanner(v.Type()),
}
}

var (
_ schema.QueryAppender = (*HStoreValue)(nil)
_ sql.Scanner = (*HStoreValue)(nil)
)

func (h *HStoreValue) AppendQuery(fmter schema.Formatter, b []byte) ([]byte, error) {
if h.append == nil {
panic(fmt.Errorf("bun: HStore(unsupported %s)", h.v.Type()))
}
return h.append(fmter, b, h.v), nil
}

func (h *HStoreValue) Scan(src interface{}) error {
if h.scan == nil {
return fmt.Errorf("bun: HStore(unsupported %s)", h.v.Type())
}
if h.v.Kind() != reflect.Ptr {
return fmt.Errorf("bun: HStore(non-pointer %s)", h.v.Type())
}
return h.scan(h.v.Elem(), src)
}

func (h *HStoreValue) Value() interface{} {
if h.v.IsValid() {
return h.v.Interface()
}
return nil
}
142 changes: 142 additions & 0 deletions dialect/pgdialect/hstore_parser.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package pgdialect

import (
"bytes"
"fmt"
)

type hstoreParser struct {
*streamParser
err error
}

func newHStoreParser(b []byte) *hstoreParser {
p := &hstoreParser{
streamParser: newStreamParser(b, 0),
}
if len(b) < 6 || b[0] != '"' {
p.err = fmt.Errorf("bun: can't parse hstore: %q", b)
}
return p
}

func (p *hstoreParser) NextKey() (string, error) {
if p.err != nil {
return "", p.err
}

err := p.skipByte('"')
if err != nil {
return "", err
}

key, err := p.readSubstring()
if err != nil {
return "", err
}

const separator = "=>"

for i := range separator {
err = p.skipByte(separator[i])
if err != nil {
return "", err
}
}

return string(key), nil
}

func (p *hstoreParser) NextValue() (string, error) {
if p.err != nil {
return "", p.err
}

c, err := p.readByte()
if err != nil {
return "", err
}

switch c {
case '"':
value, err := p.readSubstring()
if err != nil {
return "", err
}

if p.peek() == ',' {
p.skipNext()
}

if p.peek() == ' ' {
p.skipNext()
}

return string(value), nil
default:
value := p.readSimple()
if bytes.Equal(value, []byte("NULL")) {
value = nil
}

if p.peek() == ',' {
p.skipNext()
}

return string(value), nil
}
}

func (p *hstoreParser) readSimple() []byte {
p.unreadByte()

if i := bytes.IndexByte(p.b[p.i:], ','); i >= 0 {
b := p.b[p.i : p.i+i]
p.i += i
return b
}

b := p.b[p.i:len(p.b)]
p.i = len(p.b)
return b
}

func (p *hstoreParser) readSubstring() ([]byte, error) {
c, err := p.readByte()
if err != nil {
return nil, err
}

p.buf = p.buf[:0]
for {
if c == '"' {
break
}

next, err := p.readByte()
if err != nil {
return nil, err
}

if c == '\\' {
switch next {
case '\\', '"':
p.buf = append(p.buf, next)

c, err = p.readByte()
if err != nil {
return nil, err
}
default:
p.buf = append(p.buf, '\\')
c = next
}
continue
}

p.buf = append(p.buf, c)
c = next
}

return p.buf, nil
}
Loading

0 comments on commit 242e195

Please sign in to comment.