diff --git a/mc2mc/internal/query/helper.go b/mc2mc/internal/query/helper.go index abafbf9..0c8e4f4 100644 --- a/mc2mc/internal/query/helper.go +++ b/mc2mc/internal/query/helper.go @@ -1,25 +1,38 @@ package query import ( + "regexp" "strings" ) +var ( + headerPattern = regexp.MustCompile(`(?i)^\s*set\s+[^;]+;`) // regex to match header statements +) + func SeparateHeadersAndQuery(query string) (string, string) { - parts := strings.Split(query, ";") + query = strings.TrimSpace(query) + + headers := []string{} + remainingQuery := query - last := "" - idx := len(parts) - 1 - for idx >= 0 { - last = parts[idx] - if strings.TrimSpace(last) != "" { + // keep matching header statements until there are no more + for { + match := headerPattern.FindString(remainingQuery) + if match == "" { break } - idx = idx - 1 + headers = append(headers, strings.TrimSpace(match)) + remainingQuery = strings.TrimSpace(remainingQuery[len(match):]) } - headers := strings.Join(parts[:idx], ";") - if headers != "" { - headers += ";" + headerStr := "" + if len(headers) > 0 { + headerStr = strings.Join(headers, "\n") } - return headers, last + + // remove any leading semicolons from the remaining SQL + queryStr := strings.TrimSuffix(remainingQuery, ";") + + // Trim any remaining whitespace from both parts + return strings.TrimSpace(headerStr), queryStr } diff --git a/mc2mc/internal/query/helper_test.go b/mc2mc/internal/query/helper_test.go index 3c9554c..43c5b8f 100644 --- a/mc2mc/internal/query/helper_test.go +++ b/mc2mc/internal/query/helper_test.go @@ -1,7 +1,6 @@ package query_test import ( - "strings" "testing" "github.com/stretchr/testify/assert" @@ -22,14 +21,14 @@ select * from playground` header, query := query.SeparateHeadersAndQuery(q1) assert.Empty(t, header) - assert.Contains(t, query, q1) + assert.Equal(t, "select * from playground", query) }) t.Run("splits headers and query", func(t *testing.T) { q1 := `set odps.sql.allow.fullscan=true; select * from playground` headers, query := query.SeparateHeadersAndQuery(q1) assert.Equal(t, "set odps.sql.allow.fullscan=true;", headers) - assert.Equal(t, "select * from playground", strings.TrimSpace(query)) + assert.Equal(t, "select * from playground", query) }) t.Run("works with query of multiple headers", func(t *testing.T) { q1 := `set odps.sql.allow.fullscan=true; @@ -53,6 +52,16 @@ set odps.sql.python.version=cp37;` from presentation.main.important_date where CAST(event_timestamp as DATE) = '{{ .DSTART | Date }}' and client_id in ('123')` - assert.Contains(t, query, expectedQuery) + assert.Contains(t, expectedQuery, query) + }) + t.Run("works with query contains semicolon", func(t *testing.T) { + q1 := `set odps.sql.allow.fullscan=true; +select CONCAT_WS('; ', COLLECT_LIST(dates)) AS dates from presentation.main.important_date` + headers, query := query.SeparateHeadersAndQuery(q1) + expectedHeader := `set odps.sql.allow.fullscan=true;` + assert.Equal(t, expectedHeader, headers) + + expectedQuery := `select CONCAT_WS('; ', COLLECT_LIST(dates)) AS dates from presentation.main.important_date` + assert.Equal(t, expectedQuery, query) }) }