diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index a28ef0ed..eb62ae2d 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -9,12 +9,12 @@ jobs: name: lint-pr-changes runs-on: ubuntu-latest steps: - - uses: actions/setup-go@v3 + - uses: actions/setup-go@v5 with: - go-version: 1.18 - - uses: actions/checkout@v3 + go-version: stable + - uses: actions/checkout@v4 - name: golangci-lint - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v6 with: version: latest only-new-issues: true diff --git a/pkg/sqlcmd/batch.go b/pkg/sqlcmd/batch.go index 4602cd9d..c949a3cf 100644 --- a/pkg/sqlcmd/batch.go +++ b/pkg/sqlcmd/batch.go @@ -35,15 +35,20 @@ type Batch struct { linevarmap map[int]string // cmd is the set of Commands available cmd Commands + // ParseVariables is a function that returns true if Next should parse variables + ParseVariables batchParseVariables } type batchScan func() (string, error) +type batchParseVariables func() bool + // NewBatch creates a Batch which converts runes provided by reader into SQL batches func NewBatch(reader batchScan, cmd Commands) *Batch { b := &Batch{ - read: reader, - cmd: cmd, + read: reader, + cmd: cmd, + ParseVariables: func() bool { return true }, } b.Reset(nil) return b @@ -125,7 +130,7 @@ parse: case b.quote != 0 || b.comment: // Handle variable references - case c == '$' && next == '(': + case (b.ParseVariables == nil || b.ParseVariables()) && c == '$' && next == '(': vi, ok := readVariableReference(b.raw, i+2, b.rawlen) if ok { b.addVariableLocation(i, string(b.raw[i+2:vi])) @@ -231,7 +236,7 @@ func (b *Batch) readString(r []rune, i, end int, quote rune, line uint) (int, bo for ; i < end; i++ { c, next = r[i], grab(r, i+1, end) switch { - case c == '$' && next == '(': + case (b.ParseVariables == nil || b.ParseVariables()) && c == '$' && next == '(': vl, ok := readVariableReference(r, i+2, end) if ok { b.addVariableLocation(i, string(r[i+2:vl])) diff --git a/pkg/sqlcmd/sqlcmd.go b/pkg/sqlcmd/sqlcmd.go index 4ee661dc..bf4f30b0 100644 --- a/pkg/sqlcmd/sqlcmd.go +++ b/pkg/sqlcmd/sqlcmd.go @@ -103,6 +103,7 @@ func New(l Console, workingDirectory string, vars *Variables) *Sqlcmd { colorizer: color.New(false), } s.batch = NewBatch(s.scanNext, s.Cmd) + s.batch.ParseVariables = func() bool { return !s.Connect.DisableVariableSubstitution } mssql.SetContextLogger(s) s.PrintError = func(msg string, severity uint8) bool { return false diff --git a/pkg/sqlcmd/sqlcmd_test.go b/pkg/sqlcmd/sqlcmd_test.go index 3e4729b5..36f0d5dc 100644 --- a/pkg/sqlcmd/sqlcmd_test.go +++ b/pkg/sqlcmd/sqlcmd_test.go @@ -242,22 +242,31 @@ func TestGetRunnableQuery(t *testing.T) { q string } tests := []test{ - {"$(var1)", "v1"}, - {"$ (var2)", "$ (var2)"}, - {"select '$(VAR1) $(VAR2)' as c", "select 'v1 variable2' as c"}, - {" $(VAR1) ' $(VAR2) ' as $(VAR1)", " v1 ' variable2 ' as v1"}, - {"í $(VAR1)", "í v1"}, + // {"$(var1)", "v1"}, + // {"$ (var2)", "$ (var2)"}, + // {"select '$(VAR1) $(VAR2)' as c", "select 'v1 variable2' as c"}, + // {" $(VAR1) ' $(VAR2) ' as $(VAR1)", " v1 ' variable2 ' as v1"}, + // {"í $(VAR1)", "í v1"}, + {"select '$('", ""}, } s := New(nil, "", v) for _, test := range tests { s.batch.Reset([]rune(test.raw)) - _, _, _ = s.batch.Next() s.Connect.DisableVariableSubstitution = false - t.Log(test.raw) - r := s.getRunnableQuery(test.raw) - assert.Equalf(t, test.q, r, `runnableQuery for "%s"`, test.raw) + _, _, err := s.batch.Next() + if test.q == "" { + assert.Error(t, err, "expected variable parsing error") + } else { + assert.NoError(t, err, "Next should have succeeded") + t.Log(test.raw) + r := s.getRunnableQuery(test.raw) + assert.Equalf(t, test.q, r, `runnableQuery for "%s"`, test.raw) + } + s.batch.Reset([]rune(test.raw)) s.Connect.DisableVariableSubstitution = true - r = s.getRunnableQuery(test.raw) + _, _, err = s.batch.Next() + assert.NoError(t, err, "expected no variable parsing error") + r := s.getRunnableQuery(test.raw) assert.Equalf(t, test.raw, r, `runnableQuery without variable subs for "%s"`, test.raw) } }