diff --git a/mc2mc/internal/client/client.go b/mc2mc/internal/client/client.go index f62f703..ff21ce4 100644 --- a/mc2mc/internal/client/client.go +++ b/mc2mc/internal/client/client.go @@ -33,6 +33,7 @@ type Client struct { // TODO: remove this temporary capability after 15 nov enablePartitionValue bool + enableAutoPartition bool } func NewClient(ctx context.Context, setupFns ...SetupFn) (*Client, error) { @@ -64,7 +65,7 @@ func (c *Client) Execute(ctx context.Context, tableID, queryFilePath string) err if err != nil { return errors.WithStack(err) } - if c.enablePartitionValue { + if c.enablePartitionValue && !c.enableAutoPartition { queryRaw = addPartitionValueColumn(queryRaw) } @@ -76,7 +77,8 @@ func (c *Client) Execute(ctx context.Context, tableID, queryFilePath string) err // prepare query queryToExec := c.Loader.GetQuery(tableID, string(queryRaw)) - if len(partitionNames) > 0 { + if len(partitionNames) > 0 && !c.enableAutoPartition { + // when table is partitioned and auto partition is disabled, then we need to specify partition columns explicitly c.logger.Info(fmt.Sprintf("table %s is partitioned by %s", tableID, strings.Join(partitionNames, ", "))) queryToExec = c.Loader.GetPartitionedQuery(tableID, string(queryRaw), partitionNames) } diff --git a/mc2mc/internal/client/client_test.go b/mc2mc/internal/client/client_test.go index 33ff72a..7d64252 100644 --- a/mc2mc/internal/client/client_test.go +++ b/mc2mc/internal/client/client_test.go @@ -60,7 +60,7 @@ func TestExecute(t *testing.T) { }) t.Run("should return nil when everything is successful", func(t *testing.T) { // arrange - client, err := client.NewClient(context.TODO(), client.SetupLogger("error"), client.SetupLoader("APPEND")) + client, err := client.NewClient(context.TODO(), client.SetupLogger("error"), client.SetupLoader("REPLACE")) require.NoError(t, err) client.OdpsClient = &mockOdpsClient{ partitionResult: func() ([]string, error) { @@ -70,6 +70,42 @@ func TestExecute(t *testing.T) { return nil }, } + client.Loader = &mockLoader{ + getQueryFunc: func(tableID, query string) string { + return "INSERT OVERWRITE TABLE project_test.table_test SELECT * FROM table;" + }, + getPartitionedQueryFunc: func(tableID, query string, partitionNames []string) string { + assert.True(t, true, "should be called") + return "INSERT OVERWRITE TABLE project_test.table_test PARTITION(event_date) SELECT * FROM table;" + }, + } + require.NoError(t, os.WriteFile("/tmp/query.sql", []byte("SELECT * FROM table;"), 0644)) + // act + err = client.Execute(context.TODO(), "project_test.table_test", "/tmp/query.sql") + // assert + assert.NoError(t, err) + }) + t.Run("should return nil when everything is successful with enable auto partition", func(t *testing.T) { + // arrange + client, err := client.NewClient(context.TODO(), client.SetupLogger("error"), client.SetupLoader("REPLACE"), client.EnableAutoPartition(true)) + require.NoError(t, err) + client.OdpsClient = &mockOdpsClient{ + partitionResult: func() ([]string, error) { + return []string{"_partition_value"}, nil + }, + execSQLResult: func() error { + return nil + }, + } + client.Loader = &mockLoader{ + getQueryFunc: func(tableID, query string) string { + return "INSERT OVERWRITE TABLE project_test.table_test SELECT * FROM table;" + }, + getPartitionedQueryFunc: func(tableID, query string, partitionNames []string) string { + assert.False(t, true, "should not be called") + return "INSERT OVERWRITE TABLE project_test.table_test PARTITION(event_date) SELECT * FROM table;" + }, + } require.NoError(t, os.WriteFile("/tmp/query.sql", []byte("SELECT * FROM table;"), 0644)) // act err = client.Execute(context.TODO(), "project_test.table_test", "/tmp/query.sql") @@ -90,3 +126,16 @@ func (m *mockOdpsClient) GetPartitionNames(ctx context.Context, tableID string) func (m *mockOdpsClient) ExecSQL(ctx context.Context, query string) error { return m.execSQLResult() } + +type mockLoader struct { + getQueryFunc func(tableID, query string) string + getPartitionedQueryFunc func(tableID, query string, partitionNames []string) string +} + +func (m *mockLoader) GetQuery(tableID, query string) string { + return m.getQueryFunc(tableID, query) +} + +func (m *mockLoader) GetPartitionedQuery(tableID, query string, partitionNames []string) string { + return m.getPartitionedQueryFunc(tableID, query, partitionNames) +} diff --git a/mc2mc/internal/client/setup.go b/mc2mc/internal/client/setup.go index 411a1b0..5e685d9 100644 --- a/mc2mc/internal/client/setup.go +++ b/mc2mc/internal/client/setup.go @@ -59,3 +59,10 @@ func EnablePartitionValue(enabled bool) SetupFn { return nil } } + +func EnableAutoPartition(enabled bool) SetupFn { + return func(c *Client) error { + c.enableAutoPartition = enabled + return nil + } +} diff --git a/mc2mc/internal/config/config.go b/mc2mc/internal/config/config.go index 5f901af..e24ee29 100644 --- a/mc2mc/internal/config/config.go +++ b/mc2mc/internal/config/config.go @@ -18,6 +18,7 @@ type Config struct { ScheduledTime string // TODO: remove this temporary support after 15 nov 2024 DevEnablePartitionValue bool + DevEnableAutoPartition bool } type maxComputeCredentials struct { @@ -41,6 +42,7 @@ func NewConfig() (*Config, error) { ScheduledTime: getEnv("SCHEDULED_TIME", ""), // TODO: delete this after 15 nov DevEnablePartitionValue: getEnv("DEV__ENABLE_PARTITION_VALUE", "false") == "true", + DevEnableAutoPartition: getEnv("DEV__ENABLE_AUTO_PARTITION", "false") == "true", } // ali-odps-go-sdk related config scvAcc := getEnv("MC_SERVICE_ACCOUNT", "") diff --git a/mc2mc/mc2mc.go b/mc2mc/mc2mc.go index d85d3d7..d890b92 100644 --- a/mc2mc/mc2mc.go +++ b/mc2mc/mc2mc.go @@ -31,6 +31,7 @@ func mc2mc() error { client.SetupODPSClient(cfg.GenOdps()), client.SetupLoader(cfg.LoadMethod), client.EnablePartitionValue(cfg.DevEnablePartitionValue), + client.EnableAutoPartition(cfg.DevEnableAutoPartition), ) if err != nil { return errors.WithStack(err)