diff --git a/dialect/pgdialect/append.go b/dialect/pgdialect/append.go index c95fa86e7..aa2b6850d 100644 --- a/dialect/pgdialect/append.go +++ b/dialect/pgdialect/append.go @@ -67,9 +67,9 @@ func appendMapStringString(b []byte, m map[string]string) []byte { b = append(b, '\'') for key, value := range m { - b = arrayAppendString(b, key) + b = appendStringElem(b, key) b = append(b, '=', '>') - b = arrayAppendString(b, value) + b = appendStringElem(b, value) b = append(b, ',') } if len(m) > 0 { diff --git a/dialect/pgdialect/array.go b/dialect/pgdialect/array.go index b0d9a2331..4567057f1 100644 --- a/dialect/pgdialect/array.go +++ b/dialect/pgdialect/array.go @@ -3,13 +3,11 @@ package pgdialect import ( "database/sql" "database/sql/driver" - "encoding/hex" "fmt" "math" "reflect" "strconv" "time" - "unicode/utf8" "github.com/uptrace/bun/dialect" "github.com/uptrace/bun/internal" @@ -146,44 +144,21 @@ func (d *Dialect) arrayElemAppender(typ reflect.Type) schema.AppenderFunc { } switch typ.Kind() { case reflect.String: - return arrayAppendStringValue + return appendStringElemValue case reflect.Slice: if typ.Elem().Kind() == reflect.Uint8 { - return arrayAppendBytesValue + return appendBytesElemValue } } return schema.Appender(d, typ) } -func arrayAppend(fmter schema.Formatter, b []byte, v interface{}) []byte { - switch v := v.(type) { - case int64: - return strconv.AppendInt(b, v, 10) - case float64: - return arrayAppendFloat64(b, v) - case bool: - return dialect.AppendBool(b, v) - case []byte: - return arrayAppendBytes(b, v) - case string: - return arrayAppendString(b, v) - case time.Time: - b = append(b, '"') - b = appendTime(b, v) - b = append(b, '"') - return b - default: - err := fmt.Errorf("pgdialect: can't append %T", v) - return dialect.AppendError(b, err) - } -} - -func arrayAppendStringValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { - return arrayAppendString(b, v.String()) +func appendStringElemValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + return appendStringElem(b, v.String()) } -func arrayAppendBytesValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { - return arrayAppendBytes(b, v.Bytes()) +func appendBytesElemValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + return appendBytesElem(b, v.Bytes()) } func arrayAppendDriverValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { @@ -191,7 +166,7 @@ func arrayAppendDriverValue(fmter schema.Formatter, b []byte, v reflect.Value) [ if err != nil { return dialect.AppendError(b, err) } - return arrayAppend(fmter, b, iface) + return appendElem(b, iface) } func appendStringSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { @@ -208,7 +183,7 @@ func appendStringSlice(b []byte, ss []string) []byte { b = append(b, '{') for _, s := range ss { - b = arrayAppendString(b, s) + b = appendStringElem(b, s) b = append(b, ',') } if len(ss) > 0 { @@ -617,50 +592,3 @@ func toBytes(src interface{}) ([]byte, error) { return nil, fmt.Errorf("pgdialect: got %T, wanted []byte or string", src) } } - -//------------------------------------------------------------------------------ - -func arrayAppendBytes(b []byte, bs []byte) []byte { - if bs == nil { - return dialect.AppendNull(b) - } - - b = append(b, `"\\x`...) - - s := len(b) - b = append(b, make([]byte, hex.EncodedLen(len(bs)))...) - hex.Encode(b[s:], bs) - - b = append(b, '"') - - return b -} - -func arrayAppendString(b []byte, s string) []byte { - b = append(b, '"') - for _, r := range s { - switch r { - case 0: - // ignore - case '\'': - b = append(b, "''"...) - case '"': - b = append(b, '\\', '"') - case '\\': - b = append(b, '\\', '\\') - default: - if r < utf8.RuneSelf { - b = append(b, byte(r)) - break - } - l := len(b) - if cap(b)-l < utf8.UTFMax { - b = append(b, make([]byte, utf8.UTFMax)...) - } - n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r) - b = b[:l+n] - } - } - b = append(b, '"') - return b -} diff --git a/dialect/pgdialect/elem.go b/dialect/pgdialect/elem.go new file mode 100644 index 000000000..7fbec3778 --- /dev/null +++ b/dialect/pgdialect/elem.go @@ -0,0 +1,87 @@ +package pgdialect + +import ( + "database/sql/driver" + "encoding/hex" + "fmt" + "strconv" + "time" + "unicode/utf8" + + "github.com/uptrace/bun/dialect" +) + +func appendElem(buf []byte, val interface{}) []byte { + switch val := val.(type) { + case int64: + return strconv.AppendInt(buf, val, 10) + case float64: + return arrayAppendFloat64(buf, val) + case bool: + return dialect.AppendBool(buf, val) + case []byte: + return appendBytesElem(buf, val) + case string: + return appendStringElem(buf, val) + case time.Time: + buf = append(buf, '"') + buf = appendTime(buf, val) + buf = append(buf, '"') + return buf + case driver.Valuer: + val2, err := val.Value() + if err != nil { + err := fmt.Errorf("pgdialect: can't append elem value: %w", err) + return dialect.AppendError(buf, err) + } + return appendElem(buf, val2) + default: + err := fmt.Errorf("pgdialect: can't append elem %T", val) + return dialect.AppendError(buf, err) + } +} + +func appendBytesElem(b []byte, bs []byte) []byte { + if bs == nil { + return dialect.AppendNull(b) + } + + b = append(b, `"\\x`...) + + s := len(b) + b = append(b, make([]byte, hex.EncodedLen(len(bs)))...) + hex.Encode(b[s:], bs) + + b = append(b, '"') + + return b +} + +func appendStringElem(b []byte, s string) []byte { + b = append(b, '"') + for _, r := range s { + switch r { + case 0: + // ignore + case '\'': + b = append(b, "''"...) + case '"': + b = append(b, '\\', '"') + case '\\': + b = append(b, '\\', '\\') + default: + if r < utf8.RuneSelf { + b = append(b, byte(r)) + break + } + l := len(b) + if cap(b)-l < utf8.UTFMax { + b = append(b, make([]byte, utf8.UTFMax)...) + } + n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r) + b = b[:l+n] + } + } + b = append(b, '"') + return b +} diff --git a/dialect/pgdialect/parser.go b/dialect/pgdialect/parser.go new file mode 100644 index 000000000..08f4727db --- /dev/null +++ b/dialect/pgdialect/parser.go @@ -0,0 +1,107 @@ +package pgdialect + +import ( + "bytes" + "encoding/hex" + + "github.com/uptrace/bun/internal/parser" +) + +type pgparser struct { + parser.Parser + buf []byte +} + +func newParser(b []byte) *pgparser { + p := new(pgparser) + p.Reset(b) + return p +} + +func (p *pgparser) ReadLiteral(ch byte) []byte { + p.Unread() + lit, _ := p.ReadSep(',') + return lit +} + +func (p *pgparser) ReadUnescapedSubstring(ch byte) ([]byte, error) { + return p.readSubstring(ch, false) +} + +func (p *pgparser) ReadSubstring(ch byte) ([]byte, error) { + return p.readSubstring(ch, true) +} + +func (p *pgparser) readSubstring(ch byte, escaped bool) ([]byte, error) { + ch, err := p.ReadByte() + if err != nil { + return nil, err + } + + p.buf = p.buf[:0] + for { + if ch == '"' { + break + } + + next, err := p.ReadByte() + if err != nil { + return nil, err + } + + if ch == '\\' { + switch next { + case '\\', '"': + p.buf = append(p.buf, next) + + ch, err = p.ReadByte() + if err != nil { + return nil, err + } + default: + p.buf = append(p.buf, '\\') + ch = next + } + continue + } + + if escaped && ch == '\'' && next == '\'' { + p.buf = append(p.buf, next) + ch, err = p.ReadByte() + if err != nil { + return nil, err + } + continue + } + + p.buf = append(p.buf, ch) + ch = next + } + + if bytes.HasPrefix(p.buf, []byte("\\x")) && len(p.buf)%2 == 0 { + data := p.buf[2:] + buf := make([]byte, hex.DecodedLen(len(data))) + n, err := hex.Decode(buf, data) + if err != nil { + return nil, err + } + return buf[:n], nil + } + + return p.buf, nil +} + +func (p *pgparser) ReadRange(ch byte) ([]byte, error) { + p.buf = p.buf[:0] + p.buf = append(p.buf, ch) + + for p.Valid() { + ch = p.Read() + p.buf = append(p.buf, ch) + if ch == ']' || ch == ')' { + break + } + } + + return p.buf, nil +} diff --git a/dialect/pgdialect/range.go b/dialect/pgdialect/range.go index b942a068e..936ad5521 100644 --- a/dialect/pgdialect/range.go +++ b/dialect/pgdialect/range.go @@ -1,15 +1,12 @@ package pgdialect import ( - "bytes" "database/sql" - "encoding/hex" "fmt" "io" "time" "github.com/uptrace/bun/internal" - "github.com/uptrace/bun/internal/parser" "github.com/uptrace/bun/schema" ) @@ -41,7 +38,10 @@ func NewRange[T any](lower, upper T) Range[T] { var _ sql.Scanner = (*Range[any])(nil) func (r *Range[T]) Scan(anySrc any) (err error) { - src := anySrc.([]byte) + src, ok := anySrc.([]byte) + if !ok { + return fmt.Errorf("pgdialect: Range can't scan %T", anySrc) + } if len(src) == 0 { return io.ErrUnexpectedEOF @@ -90,18 +90,6 @@ func (r *Range[T]) AppendQuery(fmt schema.Formatter, buf []byte) ([]byte, error) return buf, nil } -func appendElem(buf []byte, val any) []byte { - switch val := val.(type) { - case time.Time: - buf = append(buf, '"') - buf = appendTime(buf, val) - buf = append(buf, '"') - return buf - default: - panic(fmt.Errorf("unsupported range type: %T", val)) - } -} - func scanElem(ptr any, src []byte) ([]byte, error) { switch ptr := ptr.(type) { case *time.Time: @@ -117,6 +105,17 @@ func scanElem(ptr any, src []byte) ([]byte, error) { *ptr = tm return src, nil + + case sql.Scanner: + src, str, err := readStringLiteral(src) + if err != nil { + return nil, err + } + if err := ptr.Scan(str); err != nil { + return nil, err + } + return src, nil + default: panic(fmt.Errorf("unsupported range type: %T", ptr)) } @@ -137,104 +136,3 @@ func readStringLiteral(src []byte) ([]byte, []byte, error) { src = p.Remaining() return src, str, nil } - -//------------------------------------------------------------------------------ - -type pgparser struct { - parser.Parser - buf []byte -} - -func newParser(b []byte) *pgparser { - p := new(pgparser) - p.Reset(b) - return p -} - -func (p *pgparser) ReadLiteral(ch byte) []byte { - p.Unread() - lit, _ := p.ReadSep(',') - return lit -} - -func (p *pgparser) ReadUnescapedSubstring(ch byte) ([]byte, error) { - return p.readSubstring(ch, false) -} - -func (p *pgparser) ReadSubstring(ch byte) ([]byte, error) { - return p.readSubstring(ch, true) -} - -func (p *pgparser) readSubstring(ch byte, escaped bool) ([]byte, error) { - ch, err := p.ReadByte() - if err != nil { - return nil, err - } - - p.buf = p.buf[:0] - for { - if ch == '"' { - break - } - - next, err := p.ReadByte() - if err != nil { - return nil, err - } - - if ch == '\\' { - switch next { - case '\\', '"': - p.buf = append(p.buf, next) - - ch, err = p.ReadByte() - if err != nil { - return nil, err - } - default: - p.buf = append(p.buf, '\\') - ch = next - } - continue - } - - if escaped && ch == '\'' && next == '\'' { - p.buf = append(p.buf, next) - ch, err = p.ReadByte() - if err != nil { - return nil, err - } - continue - } - - p.buf = append(p.buf, ch) - ch = next - } - - if bytes.HasPrefix(p.buf, []byte("\\x")) && len(p.buf)%2 == 0 { - data := p.buf[2:] - buf := make([]byte, hex.DecodedLen(len(data))) - n, err := hex.Decode(buf, data) - if err != nil { - return nil, err - } - return buf[:n], nil - } - - return p.buf, nil -} - -func (p *pgparser) ReadRange(ch byte) ([]byte, error) { - p.buf = p.buf[:0] - p.buf = append(p.buf, ch) - - for p.Valid() { - ch = p.Read() - p.buf = append(p.buf, ch) - if ch == ']' || ch == ')' { - break - } - } - - return p.buf, nil -}