diff --git a/dialect_mysql.go b/dialect_mysql.go index cf1dbb6f2..8460093dd 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -13,7 +13,7 @@ import ( "strings" "time" - "xorm.io/core" + "github.com/go-xorm/core" ) var ( @@ -220,7 +220,7 @@ func (db *mysql) SqlType(c *core.Column) string { case core.TimeStampz: res = core.Char c.Length = 64 - case core.Enum: // mysql enum + case core.Enum: //mysql enum res = core.Enum res += "(" opts := "" @@ -229,7 +229,7 @@ func (db *mysql) SqlType(c *core.Column) string { } res += strings.TrimLeft(opts, ",") res += ")" - case core.Set: // mysql set + case core.Set: //mysql set res = core.Set res += "(" opts := "" @@ -278,6 +278,10 @@ func (db *mysql) Quote(name string) string { return "`" + name + "`" } +func (db *mysql) QuoteStr() string { + return "`" +} + func (db *mysql) SupportEngine() bool { return true } @@ -345,9 +349,9 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column if colDefault != nil { col.Default = *colDefault - col.DefaultIsEmpty = false - } else { - col.DefaultIsEmpty = true + if col.Default == "" { + col.DefaultIsEmpty = true + } } cts := strings.Split(colType, "(") @@ -356,7 +360,7 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column var len1, len2 int if len(cts) == 2 { idx := strings.Index(cts[1], ")") - if colType == core.Enum && cts[1][0] == '\'' { // enum + if colType == core.Enum && cts[1][0] == '\'' { //enum options := strings.Split(cts[1][0:idx], ",") col.EnumOptions = make(map[string]int) for k, v := range options { @@ -389,9 +393,6 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column if colType == "FLOAT UNSIGNED" { colType = "FLOAT" } - if colType == "DOUBLE UNSIGNED" { - colType = "DOUBLE" - } col.Length = len1 col.Length2 = len2 if _, ok := core.SqlTypes[colType]; ok { @@ -404,18 +405,20 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column col.IsPrimaryKey = true } if colKey == "UNI" { - // col.is + //col.is } if extra == "auto_increment" { col.IsAutoIncrement = true } - if !col.DefaultIsEmpty { - if col.SQLType.IsText() { - col.Default = "'" + col.Default + "'" - } else if col.SQLType.IsTime() && col.Default != "CURRENT_TIMESTAMP" { + if col.SQLType.IsText() || col.SQLType.IsTime() { + if col.Default != "" { col.Default = "'" + col.Default + "'" + } else { + if col.DefaultIsEmpty { + col.Default = "''" + } } } cols[col.Name] = col @@ -507,11 +510,13 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) { func (db *mysql) CreateTableSql(table *core.Table, tableName, storeEngine, charset string) string { var sql string - sql = "CREATE TABLE IF NOT EXISTS " + //sql = "CREATE TABLE IF NOT EXISTS " + sql = "DROP TABLE IF EXISTS " if tableName == "" { tableName = table.Name } - + sql += db.Quote(tableName) + ";\n" + sql += "CREATE TABLE IF NOT EXISTS " sql += db.Quote(tableName) sql += " (" @@ -559,6 +564,36 @@ func (db *mysql) CreateTableSql(table *core.Table, tableName, storeEngine, chars return sql } +func RemoveRepeatedElement(arr []string) (newArr []string) { + newArr = make([]string, 0) + for i := 0; i < len(arr); i++ { + repeat := false + for j := i + 1; j < len(arr); j++ { + if arr[i] == arr[j] { + repeat = true + break + } + } + if !repeat { + newArr = append(newArr, arr[i]) + } + } + return +} +func (db *mysql) CreateIndexSql(tableName string, index *core.Index) string { + quote := db.Quote + var unique string + var idxName string + if index.Type == core.UniqueType { + unique = " UNIQUE" + } + //idxName = index.XName(tableName) + idxName = index.Name + return fmt.Sprintf("CREATE%s INDEX %v ON %v (%v)", unique, + quote(idxName), quote(tableName), + quote(strings.Join(RemoveRepeatedElement(index.Cols), quote(",")))) +} + func (db *mysql) Filters() []core.Filter { return []core.Filter{&core.IdFilter{}} } @@ -625,7 +660,7 @@ func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*core.Uri, error `\/(?P.*?)` + // /dbname `(?:\?(?P[^\?]*))?$`) // [?param1=value1¶mN=valueN] matches := dsnPattern.FindStringSubmatch(dataSourceName) - // tlsConfigRegister := make(map[string]*tls.Config) + //tlsConfigRegister := make(map[string]*tls.Config) names := dsnPattern.SubexpNames() uri := &core.Uri{DbType: core.MYSQL} diff --git a/engine.go b/engine.go index 4ed0f77a9..fc883d59c 100644 --- a/engine.go +++ b/engine.go @@ -7,7 +7,6 @@ package xorm import ( "bufio" "bytes" - "context" "database/sql" "encoding/gob" "errors" @@ -20,8 +19,8 @@ import ( "sync" "time" - "xorm.io/builder" - "xorm.io/core" + "github.com/go-xorm/builder" + "github.com/go-xorm/core" ) // Engine is the major struct of xorm, it means a database manager. @@ -53,8 +52,6 @@ type Engine struct { cachers map[string]core.Cacher cacherLock sync.RWMutex - - defaultContext context.Context } func (engine *Engine) setCacher(tableName string, cacher core.Cacher) { @@ -125,7 +122,6 @@ func (engine *Engine) Logger() core.ILogger { // SetLogger set the new logger func (engine *Engine) SetLogger(logger core.ILogger) { engine.logger = logger - engine.showSQL = logger.IsShowSQL() engine.dialect.SetLogger(logger) } @@ -175,6 +171,12 @@ func (engine *Engine) SupportInsertMany() bool { return engine.dialect.SupportInsertMany() } +// QuoteStr Engine's database use which character as quote. +// mysql, sqlite use ` and postgres use " +func (engine *Engine) QuoteStr() string { + return engine.dialect.QuoteStr() +} + func (engine *Engine) quoteColumns(columnStr string) string { columns := strings.Split(columnStr, ",") for i := 0; i < len(columns); i++ { @@ -190,14 +192,17 @@ func (engine *Engine) Quote(value string) string { return value } - buf := strings.Builder{} - engine.QuoteTo(&buf, value) + if string(value[0]) == engine.dialect.QuoteStr() || value[0] == '`' { + return value + } + + value = strings.Replace(value, ".", engine.dialect.QuoteStr()+"."+engine.dialect.QuoteStr(), -1) - return buf.String() + return engine.dialect.QuoteStr() + value + engine.dialect.QuoteStr() } // QuoteTo quotes string and writes into the buffer -func (engine *Engine) QuoteTo(buf *strings.Builder, value string) { +func (engine *Engine) QuoteTo(buf *builder.StringBuilder, value string) { if buf == nil { return } @@ -207,30 +212,20 @@ func (engine *Engine) QuoteTo(buf *strings.Builder, value string) { return } - quotePair := engine.dialect.Quote("") - - if value[0] == '`' || len(quotePair) < 2 || value[0] == quotePair[0] { // no quote - _, _ = buf.WriteString(value) + if string(value[0]) == engine.dialect.QuoteStr() || value[0] == '`' { + buf.WriteString(value) return - } else { - prefix, suffix := quotePair[0], quotePair[1] - - _ = buf.WriteByte(prefix) - for i := 0; i < len(value); i++ { - if value[i] == '.' { - _ = buf.WriteByte(suffix) - _ = buf.WriteByte('.') - _ = buf.WriteByte(prefix) - } else { - _ = buf.WriteByte(value[i]) - } - } - _ = buf.WriteByte(suffix) } + + value = strings.Replace(value, ".", engine.dialect.QuoteStr()+"."+engine.dialect.QuoteStr(), -1) + + buf.WriteString(engine.dialect.QuoteStr()) + buf.WriteString(value) + buf.WriteString(engine.dialect.QuoteStr()) } func (engine *Engine) quote(sql string) string { - return engine.dialect.Quote(sql) + return engine.dialect.QuoteStr() + sql + engine.dialect.QuoteStr() } // SqlType will be deprecated, please use SQLType instead @@ -377,32 +372,6 @@ func (engine *Engine) NoAutoCondition(no ...bool) *Session { return session.NoAutoCondition(no...) } -func (engine *Engine) loadTableInfo(table *core.Table) error { - colSeq, cols, err := engine.dialect.GetColumns(table.Name) - if err != nil { - return err - } - for _, name := range colSeq { - table.AddColumn(cols[name]) - } - indexes, err := engine.dialect.GetIndexes(table.Name) - if err != nil { - return err - } - table.Indexes = indexes - - for _, index := range indexes { - for _, name := range index.Cols { - if col := table.GetColumn(name); col != nil { - col.Indexes[index.Name] = index.Type - } else { - return fmt.Errorf("Unknown col %s in index %v of table %v, columns %v", name, index.Name, table.Name, table.ColumnsSeq()) - } - } - } - return nil -} - // DBMetas Retrieve all tables, columns, indexes' informations from database. func (engine *Engine) DBMetas() ([]*core.Table, error) { tables, err := engine.dialect.GetTables() @@ -411,9 +380,28 @@ func (engine *Engine) DBMetas() ([]*core.Table, error) { } for _, table := range tables { - if err = engine.loadTableInfo(table); err != nil { + colSeq, cols, err := engine.dialect.GetColumns(table.Name) + if err != nil { + return nil, err + } + for _, name := range colSeq { + table.AddColumn(cols[name]) + } + indexes, err := engine.dialect.GetIndexes(table.Name) + if err != nil { return nil, err } + table.Indexes = indexes + + for _, index := range indexes { + for _, name := range index.Cols { + if col := table.GetColumn(name); col != nil { + col.Indexes[index.Name] = index.Type + } else { + return nil, fmt.Errorf("Unknown col %s in index %v of table %v, columns %v", name, index.Name, table.Name, table.ColumnsSeq()) + } + } + } } return tables, nil } @@ -493,8 +481,7 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D } cols := table.ColumnsSeq() - colNames := engine.dialect.Quote(strings.Join(cols, engine.dialect.Quote(", "))) - destColNames := dialect.Quote(strings.Join(cols, dialect.Quote(", "))) + colNames := dialect.Quote(strings.Join(cols, dialect.Quote(", "))) rows, err := engine.DB().Query("SELECT " + colNames + " FROM " + engine.Quote(table.Name)) if err != nil { @@ -509,7 +496,7 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D return err } - _, err = io.WriteString(w, "INSERT INTO "+dialect.Quote(table.Name)+" ("+destColNames+") VALUES (") + _, err = io.WriteString(w, "INSERT INTO "+dialect.Quote(table.Name)+" ("+colNames+") VALUES (") if err != nil { return err } @@ -539,11 +526,7 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D } else if col.SQLType.IsNumeric() { switch reflect.TypeOf(d).Kind() { case reflect.Slice: - if col.SQLType.Name == core.Bool { - temp += fmt.Sprintf(", %v", strconv.FormatBool(d.([]byte)[0] != byte('0'))) - } else { - temp += fmt.Sprintf(", %s", string(d.([]byte))) - } + temp += fmt.Sprintf(", %s", string(d.([]byte))) case reflect.Int16, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Int: if col.SQLType.Name == core.Bool { temp += fmt.Sprintf(", %v", strconv.FormatBool(reflect.ValueOf(d).Int() > 0)) @@ -580,7 +563,7 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D // FIXME: Hack for postgres if string(dialect.DBType()) == core.POSTGRES && table.AutoIncrColumn() != nil { - _, err = io.WriteString(w, "SELECT setval('"+table.Name+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dialect.Quote(table.Name)+"), 1), false);\n") + _, err = io.WriteString(w, "SELECT setval('table_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") FROM "+dialect.Quote(table.Name)+"), 1), false);\n") if err != nil { return err } @@ -736,7 +719,7 @@ func (engine *Engine) Decr(column string, arg ...interface{}) *Session { } // SetExpr provides a update string like "column = {expression}" -func (engine *Engine) SetExpr(column string, expression interface{}) *Session { +func (engine *Engine) SetExpr(column string, expression string) *Session { session := engine.NewSession() session.isAutoClose = true return session.SetExpr(column, expression) @@ -914,15 +897,8 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { fieldType := fieldValue.Type() if ormTagStr != "" { - col = &core.Column{ - FieldName: t.Field(i).Name, - Nullable: true, - IsPrimaryKey: false, - IsAutoIncrement: false, - MapType: core.TWOSIDES, - Indexes: make(map[string]int), - DefaultIsEmpty: true, - } + col = &core.Column{FieldName: t.Field(i).Name, Nullable: true, IsPrimaryKey: false, + IsAutoIncrement: false, MapType: core.TWOSIDES, Indexes: make(map[string]int)} tags := splitTag(ormTagStr) if len(tags) > 0 { @@ -938,16 +914,7 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { engine: engine, } - if strings.HasPrefix(strings.ToUpper(tags[0]), "EXTENDS") { - pStart := strings.Index(tags[0], "(") - if pStart > -1 && strings.HasSuffix(tags[0], ")") { - var tagPrefix = strings.TrimFunc(tags[0][pStart+1:len(tags[0])-1], func(r rune) bool { - return r == '\'' || r == '"' - }) - - ctx.params = []string{tagPrefix} - } - + if strings.ToUpper(tags[0]) == "EXTENDS" { if err := ExtendsTagHandler(&ctx); err != nil { return nil, err } @@ -1379,31 +1346,31 @@ func (engine *Engine) DropIndexes(bean interface{}) error { } // Exec raw sql -func (engine *Engine) Exec(sqlOrArgs ...interface{}) (sql.Result, error) { +func (engine *Engine) Exec(sqlorArgs ...interface{}) (sql.Result, error) { session := engine.NewSession() defer session.Close() - return session.Exec(sqlOrArgs...) + return session.Exec(sqlorArgs...) } // Query a raw sql and return records as []map[string][]byte -func (engine *Engine) Query(sqlOrArgs ...interface{}) (resultsSlice []map[string][]byte, err error) { +func (engine *Engine) Query(sqlorArgs ...interface{}) (resultsSlice []map[string][]byte, err error) { session := engine.NewSession() defer session.Close() - return session.Query(sqlOrArgs...) + return session.Query(sqlorArgs...) } // QueryString runs a raw sql and return records as []map[string]string -func (engine *Engine) QueryString(sqlOrArgs ...interface{}) ([]map[string]string, error) { +func (engine *Engine) QueryString(sqlorArgs ...interface{}) ([]map[string]string, error) { session := engine.NewSession() defer session.Close() - return session.QueryString(sqlOrArgs...) + return session.QueryString(sqlorArgs...) } // QueryInterface runs a raw sql and return records as []map[string]interface{} -func (engine *Engine) QueryInterface(sqlOrArgs ...interface{}) ([]map[string]interface{}, error) { +func (engine *Engine) QueryInterface(sqlorArgs ...interface{}) ([]map[string]interface{}, error) { session := engine.NewSession() defer session.Close() - return session.QueryInterface(sqlOrArgs...) + return session.QueryInterface(sqlorArgs...) } // Insert one or more records @@ -1540,8 +1507,20 @@ func (engine *Engine) Import(r io.Reader) ([]sql.Result, error) { if atEOF && len(data) == 0 { return 0, nil, nil } + //fmt.Println("--", bytes.IndexByte(data, '\n')) + //if i := bytes.IndexByte(data, '\n'); i >= 0 && data[i-1] == ';' { + // fmt.Println("--", data[i-1] == ';') + // return i + 1, data[0:i], nil + //} if i := bytes.IndexByte(data, ';'); i >= 0 { - return i + 1, data[0:i], nil + //&& len(data) == i + if i+1 >= len(data) { + return i + 1, data[0:i], nil + } else { + if data[i+1] == '\n' { + return i + 1, data[0:i], nil + } + } } // If we're at EOF, we have a final, non-terminated line. Return it. if atEOF { @@ -1596,7 +1575,7 @@ func (engine *Engine) formatColTime(col *core.Column, t time.Time) (v interface{ func (engine *Engine) formatTime(sqlTypeName string, t time.Time) (v interface{}) { switch sqlTypeName { case core.Time: - s := t.Format("2006-01-02 15:04:05") // time.RFC3339 + s := t.Format("2006-01-02 15:04:05") //time.RFC3339 v = s[11:19] case core.Date: v = t.Format("2006-01-02")