Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: map to risingwave types #279

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module gorm.io/driver/postgres
module github.com/risingwave/gorm-driver/postgres-risingwave

go 1.18

Expand Down
62 changes: 26 additions & 36 deletions postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package postgres

import (
"database/sql"
"fmt"
"regexp"
"strconv"
"strings"
Expand Down Expand Up @@ -200,42 +199,24 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
if field.DataType == schema.Uint {
size++
}
if field.AutoIncrement {
switch {
case size <= 16:
return "smallserial"
case size <= 32:
return "serial"
default:
return "bigserial"
}
} else {
switch {
case size <= 16:
return "smallint"
case size <= 32:
return "integer"
default:
return "bigint"
}

switch {
case size <= 16:
return "smallint"
case size <= 32:
return "integer"
default:
return "bigint"
}

case schema.Float:
if field.Precision > 0 {
if field.Scale > 0 {
return fmt.Sprintf("numeric(%d, %d)", field.Precision, field.Scale)
}
return fmt.Sprintf("numeric(%d)", field.Precision)
}
// RisingWave doesn't support precision and scale for float, return decimal instead
return "decimal"
case schema.String:
if field.Size > 0 {
return fmt.Sprintf("varchar(%d)", field.Size)
}
return "text"
// RisingWave ignores size for text, all map to varchar type
return "varchar"
case schema.Time:
if field.Precision > 0 {
return fmt.Sprintf("timestamptz(%d)", field.Precision)
}
// RisingWave doesn't support timestamptz with precision type, map to timestamp instead
return "timestamptz"
case schema.Bytes:
return "bytea"
Expand All @@ -247,18 +228,18 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
func (dialector Dialector) getSchemaCustomType(field *schema.Field) string {
sqlType := string(field.DataType)

if field.AutoIncrement && !strings.Contains(strings.ToLower(sqlType), "serial") {
if field.AutoIncrement && !strings.Contains(strings.ToLower(sqlType), "integer") {
size := field.Size
if field.GORMDataType == schema.Uint {
size++
}
switch {
case size <= 16:
sqlType = "smallserial"
sqlType = "smallint"
case size <= 32:
sqlType = "serial"
sqlType = "integer"
default:
sqlType = "bigserial"
sqlType = "bigint"
}
}

Expand All @@ -275,6 +256,7 @@ func (dialector Dialector) RollbackTo(tx *gorm.DB, name string) error {
return nil
}

// RisingWave doesn't support serial, smallserial, bigserial, map to integer, smallint, bigint instead
func getSerialDatabaseType(s string) (dbType string, ok bool) {
switch s {
case "smallserial":
Expand All @@ -283,6 +265,14 @@ func getSerialDatabaseType(s string) (dbType string, ok bool) {
return "integer", true
case "bigserial":
return "bigint", true

// We mapped the serial type to the integer type before, so we need to handle the int type and serial type
case "smallint":
return "smallint", true
case "integer":
return "integer", true
case "bigint":
return "bigint", true
default:
return "", false
}
Expand Down
Loading