Skip to content

Commit

Permalink
fix: resolve vars in error and out commands (#535)
Browse files Browse the repository at this point in the history
  • Loading branch information
shueybubbles authored Jun 4, 2024
1 parent 9cd8538 commit f78b382
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
25 changes: 17 additions & 8 deletions pkg/sqlcmd/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,13 +264,18 @@ func outCommand(s *Sqlcmd, args []string, line uint) error {
if len(args) == 0 || args[0] == "" {
return InvalidCommandError("OUT", line)
}
filePath, err := resolveArgumentVariables(s, []rune(args[0]), true)
if err != nil {
return err
}

switch {
case strings.EqualFold(args[0], "stdout"):
case strings.EqualFold(filePath, "stdout"):
s.SetOutput(os.Stdout)
case strings.EqualFold(args[0], "stderr"):
case strings.EqualFold(filePath, "stderr"):
s.SetOutput(os.Stderr)
default:
o, err := os.OpenFile(args[0], os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644)
o, err := os.OpenFile(filePath, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
return InvalidFileError(err, args[0])
}
Expand All @@ -290,15 +295,19 @@ func outCommand(s *Sqlcmd, args []string, line uint) error {
// errorCommand changes the error writer to use a file
func errorCommand(s *Sqlcmd, args []string, line uint) error {
if len(args) == 0 || args[0] == "" {
return InvalidCommandError("OUT", line)
return InvalidCommandError("ERROR", line)
}
filePath, err := resolveArgumentVariables(s, []rune(args[0]), true)
if err != nil {
return err
}
switch {
case strings.EqualFold(args[0], "stderr"):
case strings.EqualFold(filePath, "stderr"):
s.SetError(os.Stderr)
case strings.EqualFold(args[0], "stdout"):
case strings.EqualFold(filePath, "stdout"):
s.SetError(os.Stdout)
default:
o, err := os.OpenFile(args[0], os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644)
o, err := os.OpenFile(filePath, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
return InvalidFileError(err, args[0])
}
Expand Down Expand Up @@ -549,7 +558,7 @@ func xmlCommand(s *Sqlcmd, args []string, line uint) error {
func resolveArgumentVariables(s *Sqlcmd, arg []rune, failOnUnresolved bool) (string, error) {
var b *strings.Builder
end := len(arg)
for i := 0; i < end; {
for i := 0; i < end && !s.Connect.DisableVariableSubstitution; {
c, next := arg[i], grab(arg, i+1, end)
switch {
case c == '$' && next == '(':
Expand Down
13 changes: 11 additions & 2 deletions pkg/sqlcmd/commands_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,24 +246,28 @@ func TestConnectCommand(t *testing.T) {

func TestErrorCommand(t *testing.T) {
s, buf := setupSqlCmdWithMemoryOutput(t)
defer s.SetError(nil)
defer buf.Close()
file, err := os.CreateTemp("", "sqlcmderr")
assert.NoError(t, err, "os.CreateTemp")
defer os.Remove(file.Name())
fileName := file.Name()
_ = file.Close()
err = errorCommand(s, []string{""}, 1)
assert.EqualError(t, err, InvalidCommandError("OUT", 1).Error(), "errorCommand with empty file name")
assert.EqualError(t, err, InvalidCommandError("ERROR", 1).Error(), "errorCommand with empty file name")
err = errorCommand(s, []string{fileName}, 1)
assert.NoError(t, err, "errorCommand")
// Only some error kinds go to the error output
err = runSqlCmd(t, s, []string{"print N'message'", "RAISERROR(N'Error', 16, 1)", "SELECT 1", ":SETVAR 1", "GO"})
assert.NoError(t, err, "runSqlCmd")
s.SetError(nil)
errText, err := os.ReadFile(file.Name())
if assert.NoError(t, err, "ReadFile") {
assert.Regexp(t, "Msg 50000, Level 16, State 1, Server .*, Line 2"+SqlcmdEol+"Error"+SqlcmdEol, string(errText), "Error file contents: "+string(errText))
}
s.vars.Set("myvar", "stdout")
err = errorCommand(s, []string{"$(myvar)"}, 1)
assert.NoError(t, err, "errorCommand with a variable")
assert.Equal(t, os.Stdout, s.err, "error set to stdout using a variable")
}

func TestOnErrorCommand(t *testing.T) {
Expand Down Expand Up @@ -320,6 +324,11 @@ func TestResolveArgumentVariables(t *testing.T) {
if assert.ErrorContains(t, err, UndefinedVariable("var2").Error(), "fail on unresolved variable") {
assert.Empty(t, actual, "fail on unresolved variable")
}
s.Connect.DisableVariableSubstitution = true
input := "$(var1) notvar"
actual, err = resolveArgumentVariables(s, []rune(input), true)
assert.NoError(t, err)
assert.Equal(t, input, actual, "resolveArgumentVariables when DisableVariableSubstitution is false")
}

func TestExecCommand(t *testing.T) {
Expand Down

0 comments on commit f78b382

Please sign in to comment.