Skip to content

Commit

Permalink
NO-SNOW Fix TestPutOverwrite (#1252)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pfus authored Nov 25, 2024
1 parent ce3db31 commit 06593fc
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 52 deletions.
4 changes: 4 additions & 0 deletions assert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ func assertNotEqualF(t *testing.T, actual any, expected any, descriptions ...str
fatalOnNonEmpty(t, validateNotEqual(actual, expected, descriptions...))
}

func assertNotEqualE(t *testing.T, actual any, expected any, descriptions ...string) {
errorOnNonEmpty(t, validateNotEqual(actual, expected, descriptions...))
}

func assertBytesEqualE(t *testing.T, actual []byte, expected []byte, descriptions ...string) {
errorOnNonEmpty(t, validateBytesEqual(actual, expected, descriptions...))
}
Expand Down
106 changes: 54 additions & 52 deletions put_get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ func TestPutError(t *testing.T) {
t.Error(err)
}
defer func() {
assertNilF(t, f.Close())
assertNilF(t, f.Close())
}()
_, err = f.WriteString("test1")
assertNilF(t, err)
assertNilF(t, os.Chmod(file1, 0000))
defer func() {
assertNilF(t, os.Chmod(file1, 0644))
assertNilF(t, os.Chmod(file1, 0644))
}()

data := &execResponseData{
Expand Down Expand Up @@ -217,7 +217,7 @@ func TestPutLocalFile(t *testing.T) {
var s0, s1, s2, s3, s4, s5, s6, s7, s8, s9 string
rows := dbt.mustQuery("copy into gotest_putget_t1")
defer func() {
assertNilF(t, rows.Close())
assertNilF(t, rows.Close())
}()
for rows.Next() {
assertNilF(t, rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7, &s8, &s9))
Expand All @@ -228,7 +228,7 @@ func TestPutLocalFile(t *testing.T) {

rows2 := dbt.mustQuery("select count(*) from gotest_putget_t1")
defer func() {
assertNilF(t, rows2.Close())
assertNilF(t, rows2.Close())
}()
var i int
if rows2.Next() {
Expand All @@ -240,7 +240,7 @@ func TestPutLocalFile(t *testing.T) {

rows3 := dbt.mustQuery(`select STATUS from information_schema .load_history where table_name='gotest_putget_t1'`)
defer func() {
assertNilF(t, rows3.Close())
assertNilF(t, rows3.Close())
}()
if rows3.Next() {
assertNilF(t, rows3.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7, &s8, &s9))
Expand All @@ -263,38 +263,39 @@ func TestPutGetWithAutoCompressFalse(t *testing.T) {
assertNilF(t, err)
assertNilF(t, f.Sync())
defer func() {
assertNilF(t, f.Close())
assertNilF(t, f.Close())
}()

runDBTest(t, func(dbt *DBTest) {
dbt.mustExec("rm @~/test_put_uncompress_file")
stageDir := "test_put_uncompress_file_" + randomString(10)
dbt.mustExec("rm @~/" + stageDir)

// PUT test
sqlText := fmt.Sprintf("put 'file://%v' @~/test_put_uncompress_file auto_compress=FALSE", testData)
sqlText := fmt.Sprintf("put 'file://%v' @~/%v auto_compress=FALSE", testData, stageDir)
sqlText = strings.ReplaceAll(sqlText, "\\", "\\\\")
dbt.mustExec(sqlText)
defer dbt.mustExec("rm @~/test_put_uncompress_file")
rows := dbt.mustQuery("ls @~/test_put_uncompress_file")
defer dbt.mustExec("rm @~/" + stageDir)
rows := dbt.mustQuery("ls @~/" + stageDir)
defer func() {
assertNilF(t, rows.Close())
assertNilF(t, rows.Close())
}()
var file, s1, s2, s3 string
if rows.Next() {
err = rows.Scan(&file, &s1, &s2, &s3)
assertNilE(t, err)
}
assertTrueF(t, strings.Contains(file, "test_put_uncompress_file/data.txt"), fmt.Sprintf("should contain file. got: %v", file))
assertTrueF(t, strings.Contains(file, stageDir+"/data.txt"), fmt.Sprintf("should contain file. got: %v", file))
assertFalseF(t, strings.Contains(file, "data.txt.gz"), fmt.Sprintf("should not contain file. got: %v", file))

// GET test
var streamBuf bytes.Buffer
ctx := WithFileTransferOptions(context.Background(), &SnowflakeFileTransferOptions{GetFileToStream: true})
ctx = WithFileGetStream(ctx, &streamBuf)
sql := fmt.Sprintf("get @~/test_put_uncompress_file/data.txt 'file://%v'", tmpDir)
sql := fmt.Sprintf("get @~/%v/data.txt 'file://%v'", stageDir, tmpDir)
sqlText = strings.ReplaceAll(sql, "\\", "\\\\")
rows2 := dbt.mustQueryContext(ctx, sqlText)
defer func() {
assertNilF(t, rows2.Close())
assertNilF(t, rows2.Close())
}()
for rows2.Next() {
err = rows2.Scan(&file, &s1, &s2, &s3)
Expand Down Expand Up @@ -335,17 +336,19 @@ func TestPutOverwrite(t *testing.T) {
assertNilF(t, err)
assertNilF(t, f.Close())

stageName := "test_put_overwrite_stage_" + randomString(10)

runDBTest(t, func(dbt *DBTest) {
dbt.mustExec("rm @~/test_put_overwrite")
dbt.mustExec("CREATE OR REPLACE STAGE " + stageName)
defer dbt.mustExec("DROP STAGE " + stageName)

f, _ = os.Open(testData)
rows := dbt.mustQueryContext(
WithFileStream(context.Background(), f),
fmt.Sprintf("put 'file://%v' @~/test_put_overwrite",
fmt.Sprintf("put 'file://%v' @"+stageName+"/test_put_overwrite",
strings.ReplaceAll(testData, "\\", "/")))
defer rows.Close()
f.Close()
defer dbt.mustExec("rm @~/test_put_overwrite")
var s0, s1, s2, s3, s4, s5, s6, s7 string
if rows.Next() {
if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil {
Expand All @@ -356,7 +359,7 @@ func TestPutOverwrite(t *testing.T) {
t.Fatalf("expected UPLOADED, got %v", s6)
}

rows = dbt.mustQuery("ls @~/test_put_overwrite")
rows = dbt.mustQuery("ls @" + stageName + "/test_put_overwrite")
defer rows.Close()
assertTrueF(t, rows.Next(), "expected new rows")
if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil {
Expand All @@ -367,7 +370,7 @@ func TestPutOverwrite(t *testing.T) {
f, _ = os.Open(testData)
rows = dbt.mustQueryContext(
WithFileStream(context.Background(), f),
fmt.Sprintf("put 'file://%v' @~/test_put_overwrite",
fmt.Sprintf("put 'file://%v' @"+stageName+"/test_put_overwrite",
strings.ReplaceAll(testData, "\\", "/")))
defer rows.Close()
f.Close()
Expand All @@ -379,7 +382,7 @@ func TestPutOverwrite(t *testing.T) {
t.Fatalf("expected SKIPPED, got %v", s6)
}

rows = dbt.mustQuery("ls @~/test_put_overwrite")
rows = dbt.mustQuery("ls @" + stageName + "/test_put_overwrite")
defer rows.Close()
assertTrueF(t, rows.Next(), "expected new rows")

Expand All @@ -393,7 +396,7 @@ func TestPutOverwrite(t *testing.T) {
f, _ = os.Open(testData)
rows = dbt.mustQueryContext(
WithFileStream(context.Background(), f),
fmt.Sprintf("put 'file://%v' @~/test_put_overwrite overwrite=true",
fmt.Sprintf("put 'file://%v' @"+stageName+"/test_put_overwrite overwrite=true",
strings.ReplaceAll(testData, "\\", "/")))
defer rows.Close()
f.Close()
Expand All @@ -405,18 +408,14 @@ func TestPutOverwrite(t *testing.T) {
t.Fatalf("expected UPLOADED, got %v", s6)
}

rows = dbt.mustQuery("ls @~/test_put_overwrite")
rows = dbt.mustQuery("ls @" + stageName + "/test_put_overwrite")
defer rows.Close()
assertTrueF(t, rows.Next(), "expected new rows")
if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil {
t.Fatal(err)
}
if s0 != fmt.Sprintf("test_put_overwrite/%v.gz", baseName(testData)) {
t.Fatalf("expected test_put_overwrite/%v.gz, got %v", baseName(testData), s0)
}
if s2 == md5Column {
t.Fatalf("file should have been overwritten.")
}
assertEqualE(t, s0, stageName+"/test_put_overwrite/"+baseName(testData)+".gz")
assertNotEqualE(t, s2, md5Column)
})
}

Expand Down Expand Up @@ -452,7 +451,7 @@ func testPutGet(t *testing.T, isStream bool) {
t.Error(err)
}
defer func() {
assertNilF(t, fileStream.Close())
assertNilF(t, fileStream.Close())
}()

var sqlText string
Expand All @@ -469,7 +468,7 @@ func testPutGet(t *testing.T, isStream bool) {
rows = dbt.mustQuery(sqlText)
}
defer func() {
assertNilF(t, rows.Close())
assertNilF(t, rows.Close())
}()

var s0, s1, s2, s3, s4, s5, s6, s7 string
Expand Down Expand Up @@ -499,7 +498,7 @@ func testPutGet(t *testing.T, isStream bool) {
sqlText = strings.ReplaceAll(sql, "\\", "\\\\")
rows2 := dbt.mustQueryContext(ctx, sqlText)
defer func() {
assertNilF(t, rows2.Close())
assertNilF(t, rows2.Close())
}()
for rows2.Next() {
if err = rows2.Scan(&s0, &s1, &s2, &s3); err != nil {
Expand All @@ -524,7 +523,7 @@ func testPutGet(t *testing.T, isStream bool) {
gz, err := gzip.NewReader(&streamBuf)
assertNilE(t, err)
defer func() {
assertNilF(t, gz.Close())
assertNilF(t, gz.Close())
}()
for {
c := make([]byte, defaultChunkBufferSize)
Expand All @@ -547,13 +546,13 @@ func testPutGet(t *testing.T, isStream bool) {
f, err := os.Open(fileName)
assertNilE(t, err)
defer func() {
assertNilF(t, f.Close())
assertNilF(t, f.Close())
}()

gz, err := gzip.NewReader(f)
assertNilE(t, err)
defer func() {
assertNilF(t, gz.Close())
assertNilF(t, gz.Close())
}()

for {
Expand Down Expand Up @@ -582,7 +581,7 @@ func TestPutGetGcsDownscopedCredential(t *testing.T) {
t.Error(err)
}
defer func() {
assertNilF(t, os.RemoveAll(tmpDir))
assertNilF(t, os.RemoveAll(tmpDir))
}()
fname := filepath.Join(tmpDir, "test_put_get.txt.gz")
originalContents := "123,test1\n456,test2\n"
Expand Down Expand Up @@ -619,7 +618,7 @@ func TestPutGetGcsDownscopedCredential(t *testing.T) {
sql, strings.ReplaceAll(fname, "\\", "\\\\"), tableName)
rows = dbt.mustQuery(sqlText)
defer func() {
assertNilF(t, rows.Close())
assertNilF(t, rows.Close())
}()

var s0, s1, s2, s3, s4, s5, s6, s7 string
Expand All @@ -645,7 +644,7 @@ func TestPutGetGcsDownscopedCredential(t *testing.T) {
sqlText = strings.ReplaceAll(sql, "\\", "\\\\")
rows2 := dbt.mustQuery(sqlText)
defer func() {
assertNilF(t, rows2.Close())
assertNilF(t, rows2.Close())
}()
for rows2.Next() {
if err = rows2.Scan(&s0, &s1, &s2, &s3); err != nil {
Expand Down Expand Up @@ -704,16 +703,17 @@ func TestPutGetLargeFile(t *testing.T) {
assertNilF(t, err)

runDBTest(t, func(dbt *DBTest) {
dbt.mustExec("rm @~/test_put_largefile")
stageDir := "test_put_largefile_" + randomString(10)
dbt.mustExec("rm @~/" + stageDir)

// PUT test
putQuery := fmt.Sprintf("put file://%v/test_data/largefile.txt @%v", sourceDir, "~/test_put_largefile")
putQuery := fmt.Sprintf("put file://%v/test_data/largefile.txt @~/%v", sourceDir, stageDir)
sqlText := strings.ReplaceAll(putQuery, "\\", "\\\\")
dbt.mustExec(sqlText)
defer dbt.mustExec("rm @~/test_put_largefile")
rows := dbt.mustQuery("ls @~/test_put_largefile")
defer dbt.mustExec("rm @~/" + stageDir)
rows := dbt.mustQuery("ls @~/" + stageDir)
defer func() {
assertNilF(t, rows.Close())
assertNilF(t, rows.Close())
}()
var file, s1, s2, s3 string
if rows.Next() {
Expand All @@ -729,11 +729,11 @@ func TestPutGetLargeFile(t *testing.T) {
var streamBuf bytes.Buffer
ctx := WithFileTransferOptions(context.Background(), &SnowflakeFileTransferOptions{GetFileToStream: true})
ctx = WithFileGetStream(ctx, &streamBuf)
sql := fmt.Sprintf("get @%v 'file://%v'", "~/test_put_largefile/largefile.txt.gz", t.TempDir())
sql := fmt.Sprintf("get @~/%v/largefile.txt.gz 'file://%v'", stageDir, t.TempDir())
sqlText = strings.ReplaceAll(sql, "\\", "\\\\")
rows2 := dbt.mustQueryContext(ctx, sqlText)
defer func() {
assertNilF(t, rows2.Close())
assertNilF(t, rows2.Close())
}()
for rows2.Next() {
err = rows2.Scan(&file, &s1, &s2, &s3)
Expand All @@ -751,7 +751,7 @@ func TestPutGetLargeFile(t *testing.T) {
gz, err := gzip.NewReader(&streamBuf)
assertNilE(t, err)
defer func() {
assertNilF(t, gz.Close())
assertNilF(t, gz.Close())
}()
for {
c := make([]byte, defaultChunkBufferSize)
Expand Down Expand Up @@ -809,7 +809,7 @@ func TestPutGetMaxLOBSize(t *testing.T) {
fileStream, err := os.Open(fname)
assertNilF(t, err)
defer func() {
assertNilF(t, fileStream.Close())
assertNilF(t, fileStream.Close())
}()

// test PUT command
Expand All @@ -820,7 +820,7 @@ func TestPutGetMaxLOBSize(t *testing.T) {
sql, strings.ReplaceAll(fname, "\\", "\\\\"), tableName)
rows = dbt.mustQuery(sqlText)
defer func() {
assertNilF(t, rows.Close())
assertNilF(t, rows.Close())
}()

var s0, s1, s2, s3, s4, s5, s6, s7 string
Expand All @@ -845,7 +845,7 @@ func TestPutGetMaxLOBSize(t *testing.T) {
sqlText = strings.ReplaceAll(sql, "\\", "\\\\")
rows2 := dbt.mustQuery(sqlText)
defer func() {
assertNilF(t, rows2.Close())
assertNilF(t, rows2.Close())
}()
for rows2.Next() {
err = rows2.Scan(&s0, &s1, &s2, &s3)
Expand All @@ -864,13 +864,13 @@ func TestPutGetMaxLOBSize(t *testing.T) {
assertNilE(t, err)

defer func() {
assertNilF(t, f.Close())
assertNilF(t, f.Close())
}()
gz, err := gzip.NewReader(f)
assertNilE(t, err)

defer func() {
assertNilF(t, gz.Close())
assertNilF(t, gz.Close())
}()
var contents string
for {
Expand All @@ -895,15 +895,17 @@ func TestPutCancel(t *testing.T) {
assertNilF(t, err)
testData := path.Join(sourceDir, "/test_data/largefile.txt")

stageDir := "test_put_cancel_" + randomString(10)

runDBTest(t, func(dbt *DBTest) {
c := make(chan error)
ctx, cancel := context.WithCancel(context.Background())
go func() {
// attempt to upload a large file, but it should be canceled in 3 seconds
_, err := dbt.conn.ExecContext(
ctx,
fmt.Sprintf("put 'file://%v' @~/test_put_cancel overwrite=true",
strings.ReplaceAll(testData, "\\", "/")))
fmt.Sprintf("put 'file://%v' @~/%v overwrite=true",
strings.ReplaceAll(testData, "\\", "/"), stageDir))
if err != nil {
c <- err
return
Expand Down

0 comments on commit 06593fc

Please sign in to comment.