diff --git a/schema/table.go b/schema/table.go index e41a2c732..82132c4f1 100644 --- a/schema/table.go +++ b/schema/table.go @@ -246,6 +246,31 @@ func (t *Table) processFields(typ reflect.Type, canAddr bool) { subfield.SQLName = t.quoteIdent(subfield.Name) } t.addField(subfield) + if v, ok := subfield.Tag.Options["unique"]; ok { + t.addUnique(subfield, embfield.prefix, v) + } + } +} + +func (t *Table) addUnique(field *Field, prefix string, tagOptions []string) { + var names []string + if len(tagOptions) == 1 { + // Split the value by comma, this will allow multiple names to be specified. + // We can use this to create multiple named unique constraints where a single column + // might be included in multiple constraints. + names = strings.Split(tagOptions[0], ",") + } else { + names = tagOptions + } + + for _, uname := range names { + if t.Unique == nil { + t.Unique = make(map[string][]*Field) + } + if uname != "" && prefix != "" { + uname = prefix + uname + } + t.Unique[uname] = append(t.Unique[uname], field) } } @@ -460,22 +485,7 @@ func (t *Table) newField(sf reflect.StructField, tag tagparser.Tag) *Field { } if v, ok := tag.Options["unique"]; ok { - var names []string - if len(v) == 1 { - // Split the value by comma, this will allow multiple names to be specified. - // We can use this to create multiple named unique constraints where a single column - // might be included in multiple constraints. - names = strings.Split(v[0], ",") - } else { - names = v - } - - for _, uniqueName := range names { - if t.Unique == nil { - t.Unique = make(map[string][]*Field) - } - t.Unique[uniqueName] = append(t.Unique[uniqueName], field) - } + t.addUnique(field, "", v) } if s, ok := tag.Option("default"); ok { field.SQLDefault = s @@ -1057,5 +1067,3 @@ func makeIndex(a, b []int) []int { dest = append(dest, b...) return dest } - - diff --git a/schema/table_test.go b/schema/table_test.go index fed1ef2bd..71b29f4a8 100644 --- a/schema/table_test.go +++ b/schema/table_test.go @@ -112,6 +112,41 @@ func TestTable(t *testing.T) { require.Equal(t, []int{1, 0}, barView.Index) }) + t.Run("embedWithUnique", func(t *testing.T) { + type Perms struct { + View bool + Create bool + UniqueID int `bun:",unique"` + UniqueGroupID int `bun:",unique:groupa"` + } + + type Role struct { + Foo Perms `bun:"embed:foo_"` + Perms + } + + table := tables.Get(reflect.TypeOf((*Role)(nil))) + require.Nil(t, table.StructMap["foo"]) + require.Nil(t, table.StructMap["bar"]) + + fooView, ok := table.FieldMap["foo_view"] + require.True(t, ok) + require.Equal(t, []int{0, 0}, fooView.Index) + + barView, ok := table.FieldMap["view"] + require.True(t, ok) + require.Equal(t, []int{1, 0}, barView.Index) + + require.Equal(t, 3, len(table.Unique)) + require.Equal(t, 2, len(table.Unique[""])) + require.Equal(t, "foo_unique_id", table.Unique[""][0].Name) + require.Equal(t, "unique_id", table.Unique[""][1].Name) + require.Equal(t, 1, len(table.Unique["groupa"])) + require.Equal(t, "unique_group_id", table.Unique["groupa"][0].Name) + require.Equal(t, 1, len(table.Unique["foo_groupa"])) + require.Equal(t, "foo_unique_group_id", table.Unique["foo_groupa"][0].Name) + }) + t.Run("embed scanonly", func(t *testing.T) { type Model1 struct { Foo string