diff --git a/pkg/proto/rule/rule_test.go b/pkg/proto/rule/rule_test.go index e41bc621..8ebe20e8 100644 --- a/pkg/proto/rule/rule_test.go +++ b/pkg/proto/rule/rule_test.go @@ -19,6 +19,7 @@ package rule import ( "fmt" + "reflect" "testing" ) @@ -149,3 +150,36 @@ func TestRule(t *testing.T) { assert.False(t, ru.Has("student")) assert.False(t, (*Rule)(nil).Has("student")) } + +func TestGetShardColumn_Found(t *testing.T) { + sm := ShardMetadata{ + ShardColumns: []*ShardColumn{ + {Name: "column1"}, + {Name: "column2"}, + {Name: "column3"}, + }, + } + + name := "column2" + expected := &ShardColumn{Name: "column2"} + result := sm.GetShardColumn(name) + + assert.True(t, reflect.DeepEqual(result, expected)) +} + +func TestGetShardColumn_NotFound(t *testing.T) { + sm := ShardMetadata{ + ShardColumns: []*ShardColumn{ + {Name: "column1"}, + {Name: "column2"}, + {Name: "column3"}, + }, + } + + name := "column4" + expected := (*ShardColumn)(nil) + result := sm.GetShardColumn(name) + + assert.Nil(t, result) + assert.True(t, reflect.DeepEqual(result, expected)) +}