diff --git a/docs/source/driver/snowflake.rst b/docs/source/driver/snowflake.rst index bf44534967..f6f6abefa6 100644 --- a/docs/source/driver/snowflake.rst +++ b/docs/source/driver/snowflake.rst @@ -259,37 +259,61 @@ Bulk Ingestion Bulk ingestion is supported. The mapping from Arrow types to Snowflake types is provided below. -Partitioned Result Sets ------------------------ +Bulk ingestion is implemented by writing Arrow data to Parquet file(s) and uploading (via PUT) to a temporary internal stage. +One or more COPY queries are executed in order to load the data into the target table. -Partitioned result sets are not currently supported. +In order for the driver to leverage this temporary stage, the user must have +the `CREATE STAGE ` privilege on the schema. In addition, +the current database and schema for the session must be set. If these are not set, the ``CREATE TEMPORARY STAGE`` command +executed by the driver can fail with the following error: -Performance ------------ +.. code-block:: sql -Formal benchmarking is forthcoming. Snowflake does provide an Arrow native -format for requesting results, but bulk ingestion is still currently executed -using the REST API. As described in the `Snowflake Documentation -` -the driver will potentially attempt to improve performance by streaming the data -(without creating files on the local machine) to a temporary stage for ingestion -if the number of values exceeds some threshold. + CREATE TEMPORARY STAGE ADBC$BIND FILE_FORMAT = (TYPE = PARQUET USE_LOGICAL_TYPE = TRUE BINARY_AS_TEXT = FALSE) + CANNOT perform CREATE STAGE. This session does not have a current schema. Call 'USE SCHEMA' or use a qualified name. -In order for the driver to leverage this temporary stage, the user must have -the ``CREATE STAGE`` privilege on the schema. If the user does not have this -privilege, the driver will fall back to sending the data with the query -to the snowflake database. +The following informal benchmark demonstrates expected performance using default ingestion settings:: -In addition, the current database and schema for the session must be set. If -these are not set, the ``CREATE TEMPORARY STAGE`` command executed by the driver -can fail with the following error: + Running on GCP e2-standard-4 (4 vCPU, 16GB RAM) + Snowflake warehouse size M, same GCP region as Snowflake account + Default ingestion settings -.. code-block:: sql + TPC-H Lineitem (16 Columns): + Scale Factor 1 (6M Rows): 9.5s + Scale Factor 10 (60M Rows): 45s - CREATE TEMPORARY STAGE SYSTEM$BIND file_format=(type=csv field_optionally_enclosed_by='"') - CANNOT perform CREATE STAGE. This session does not have a current schema. Call 'USE SCHEMA' or use a qualified name. +The default settings for ingestion should be well balanced for many real-world configurations. If required, performance +and resource usage may be tuned with the following options on the :cpp:class:`AdbcStatement` object: + +``adbc.snowflake.rpc.ingest_writer_concurrency`` + Number of Parquet files to write in parallel. Default attempts to maximize workers based on logical cores detected, + but may need to be adjusted if running in a constrained environment. If set to 0, default value is used. Cannot be negative. + +``adbc.snowflake.rpc.ingest_upload_concurrency`` + Number of Parquet files to upload in parallel. Greater concurrency can smooth out TCP congestion and help make + use of available network bandwith, but will increase memory utilization. Default is 8. If set to 0, default value is used. + Cannot be negative. + +``adbc.snowflake.rpc.ingest_copy_concurrency`` + Maximum number of COPY operations to run concurrently. Bulk ingestion performance is optimized by executing COPY + queries as files are still being uploaded. Snowflake COPY speed scales with warehouse size, so smaller warehouses + may benefit from setting this value higher to ensure long-running COPY queries do not block newly uploaded files + from being loaded. Default is 4. If set to 0, only a single COPY query will be executed as part of ingestion, + once all files have finished uploading. Cannot be negative. + +``adbc.snowflake.rpc.ingest_target_file_size`` + Approximate size of Parquet files written during ingestion. Actual size will be slightly larger, depending on + size of footer/metadata. Default is 10 MB. If set to 0, file size has no limit. Cannot be negative. + +Partitioned Result Sets +----------------------- + +Partitioned result sets are not currently supported. + +Performance +----------- -In addition, results are potentially fetched in parallel from multiple endpoints. +When querying Snowflake data, results are potentially fetched in parallel from multiple endpoints. A limited number of batches are queued per endpoint, though data is always returned to the client in the order of the endpoints. @@ -490,16 +514,19 @@ indicated are done to ensure consistency of the stream of record batches. - Notes * - integral types - - int64 - - All integral types in Snowflake are stored as 64-bit integers. + - number(38, 0) + - All integral types in Snowflake are stored as numbers for which neither + precision nor scale can be specified. * - float/double - float64 - Snowflake does not distinguish between float or double. Both are 64-bit values. * - decimal/numeric - - int64/float64 - - If scale == 0, then int64 is used, else float64. + - numeric + - Snowflake will respect the precision/scale of the Arrow type. See the + ``adbc.snowflake.sql.client_option.use_high_precision`` for exceptions to this + behavior. * - time - time64[ns] @@ -513,8 +540,9 @@ indicated are done to ensure consistency of the stream of record batches. | timestamp_ntz | timestamp_tz - timestamp[ns] - - Local time zone will be used. No timezone will be specified in - the Arrow type. Values will be converted to UTC. + - Local time zone will be used, except for timestamp_ntz which is not an instant. + In this case no timezone will be present in the type. Physical values will be + UTC-normalized. * - | variant | object @@ -523,7 +551,9 @@ indicated are done to ensure consistency of the stream of record batches. - Snowflake does not provide information about nested types. Values will be strings in a format similar to JSON that can be parsed. The Arrow type will contain a metadata key - ``logicalType`` with the Snowflake field type. + ``logicalType`` with the Snowflake field type. Arrow Struct and + Map types will be stored as objects when ingested. List types will + be stored as arrays. * - | geography | geometry diff --git a/go/adbc/driver/snowflake/bulk_ingestion.go b/go/adbc/driver/snowflake/bulk_ingestion.go new file mode 100644 index 0000000000..18b9a7b7b4 --- /dev/null +++ b/go/adbc/driver/snowflake/bulk_ingestion.go @@ -0,0 +1,586 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package snowflake + +import ( + "bufio" + "bytes" + "compress/flate" + "context" + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "io" + "math" + "runtime" + "strings" + "sync" + + "github.com/apache/arrow-adbc/go/adbc" + "github.com/apache/arrow/go/v15/arrow" + "github.com/apache/arrow/go/v15/arrow/array" + "github.com/apache/arrow/go/v15/arrow/memory" + "github.com/apache/arrow/go/v15/parquet" + "github.com/apache/arrow/go/v15/parquet/compress" + "github.com/apache/arrow/go/v15/parquet/pqarrow" + "github.com/snowflakedb/gosnowflake" + "golang.org/x/sync/errgroup" +) + +const ( + bindStageName = "ADBC$BIND" + createTemporaryStageStmt = "CREATE OR REPLACE TEMPORARY STAGE " + bindStageName + " FILE_FORMAT = (TYPE = PARQUET USE_LOGICAL_TYPE = TRUE BINARY_AS_TEXT = FALSE)" + putQueryTmpl = "PUT 'file:///tmp/placeholder/%s' @" + bindStageName + " OVERWRITE = TRUE" + copyQuery = "COPY INTO IDENTIFIER(?) FROM @" + bindStageName + " MATCH_BY_COLUMN_NAME = CASE_INSENSITIVE" + countQuery = "SELECT COUNT(*) FROM IDENTIFIER(?)" + megabyte = 1024 * 1024 +) + +var ( + defaultTargetFileSize uint = 10 * megabyte + defaultWriterConcurrency uint = uint(runtime.NumCPU()) + defaultUploadConcurrency uint = 8 + defaultCopyConcurrency uint = 4 + + defaultCompressionCodec compress.Compression = compress.Codecs.Snappy + defaultCompressionLevel int = flate.DefaultCompression +) + +// Options for configuring bulk ingestion. +// +// Values should be updated with appropriate calls to stmt.SetOption(). +type ingestOptions struct { + // Approximate size of Parquet files written during ingestion. + // + // Actual size will be slightly larger, depending on size of footer/metadata. + // Default is 10 MB. If set to 0, file size has no limit. Cannot be negative. + targetFileSize uint + // Number of Parquet files to write in parallel. + // + // Default attempts to maximize workers based on logical cores detected, but + // may need to be adjusted if running in a constrained environment. + // If set to 0, default value is used. Cannot be negative. + writerConcurrency uint + // Number of Parquet files to upload in parallel. + // + // Greater concurrency can smooth out TCP congestion and help make use of + // available network bandwith, but will increase memory utilization. + // Default is 8. If set to 0, default value is used. Cannot be negative. + uploadConcurrency uint + // Maximum number of COPY operations to run concurrently. + // + // Bulk ingestion performance is optimized by executing COPY queries as files are + // still being uploaded. Snowflake COPY speed scales with warehouse size, so smaller + // warehouses may benefit from setting this value higher to ensure long-running + // COPY queries do not block newly uploaded files from being loaded. + // Default is 4. If set to 0, only a single COPY query will be executed as part of ingestion, + // once all files have finished uploading. Cannot be negative. + copyConcurrency uint + // Compression codec to use for Parquet files. + // + // When network speeds are high, it is generally faster to use a faster codec with + // a lower compression ratio. The opposite is true if the network is slow by CPU is + // available. + // Default is Snappy. + compressionCodec compress.Compression + // Compression level for Parquet files. + // + // The compression level is codec-specific. Some codecs do not support setting it, + // notably Snappy. + // Default is the default level for the specified compressionCodec. + compressionLevel int +} + +func DefaultIngestOptions() *ingestOptions { + return &ingestOptions{ + targetFileSize: defaultTargetFileSize, + writerConcurrency: defaultWriterConcurrency, + uploadConcurrency: defaultUploadConcurrency, + copyConcurrency: defaultCopyConcurrency, + compressionCodec: defaultCompressionCodec, + compressionLevel: defaultCompressionLevel, + } +} + +// ingestRecord performs bulk ingestion of a single Record and returns the +// number of rows affected. +// +// The Record must already be bound by calling stmt.Bind(), and will be released +// and reset upon completion. +func (st *statement) ingestRecord(ctx context.Context) (nrows int64, err error) { + defer func() { + // Record already released by writeParquet() + st.bound = nil + }() + + parquetProps, arrowProps := newWriterProps(st.alloc, st.ingestOptions) + g := errgroup.Group{} + + // writeParquet takes a channel of Records, but we only have one Record to write + recordCh := make(chan arrow.Record, 1) + recordCh <- st.bound + close(recordCh) + + // Read the Record from the channel and write it into the provided writer + schema := st.bound.Schema() + r, w := io.Pipe() + bw := bufio.NewWriter(w) + g.Go(func() error { + defer r.Close() + defer bw.Flush() + + err = writeParquet(schema, bw, recordCh, 0, parquetProps, arrowProps) + if err != io.EOF { + return err + } + return nil + }) + + // Create a temporary stage, we can't start uploading until it has been created + _, err = st.cnxn.cn.ExecContext(ctx, createTemporaryStageStmt, nil) + if err != nil { + return + } + + // Start uploading the file to Snowflake + fileName := "0.parquet" // Only writing 1 file, so use same name as first file written by ingestStream() for consistency + err = uploadStream(ctx, st.cnxn.cn, r, fileName) + if err != nil { + return + } + + // Parquet writing is already done if the upload finished, so we're just checking for any errors + err = g.Wait() + if err != nil { + return + } + + // Load the uploaded file into the target table + _, err = st.cnxn.cn.ExecContext(ctx, copyQuery, []driver.NamedValue{{Value: st.targetTable}}) + if err != nil { + return + } + + // Check final row count of target table to get definitive rows affected + nrows, err = countRowsInTable(ctx, st.cnxn.sqldb, st.targetTable) + return +} + +// ingestStream performs bulk ingestion of a RecordReader and returns the +// number of rows affected. +// +// The RecordReader must already be bound by calling stmt.BindStream(), and will +// be released and reset upon completion. +func (st *statement) ingestStream(ctx context.Context) (nrows int64, err error) { + defer func() { + st.streamBind.Release() + st.streamBind = nil + }() + defer func() { + // Always check the resulting row count, even in the case of an error. We may have ingested part of the data. + ctx := context.Background() // TODO(joellubi): switch to context.WithoutCancel(ctx) once we're on Go 1.21 + n, countErr := countRowsInTable(ctx, st.cnxn.sqldb, st.targetTable) + nrows = n + + // Ingestion, row-count check, or both could have failed + // Wrap any failures as ADBC errors + + // TODO(joellubi): simplify / improve with errors.Join(err, countErr) once we're on Go 1.20 + if err == nil { + err = errToAdbcErr(adbc.StatusInternal, countErr) + return + } + + // Failure in the pipeline itself + if errors.Is(err, context.Canceled) { + err = errToAdbcErr(adbc.StatusCancelled, err) + } else { + err = errToAdbcErr(adbc.StatusInternal, err) + } + }() + + parquetProps, arrowProps := newWriterProps(st.alloc, st.ingestOptions) + g, gCtx := errgroup.WithContext(ctx) + + // Read records into channel + records := make(chan arrow.Record, st.ingestOptions.writerConcurrency) + g.Go(func() error { + return readRecords(gCtx, st.streamBind, records) + }) + + // Read records from channel and write Parquet files in parallel to buffer pool + schema := st.streamBind.Schema() + pool := newBufferPool(int(st.ingestOptions.targetFileSize)) + buffers := make(chan *bytes.Buffer, st.ingestOptions.writerConcurrency) + g.Go(func() error { + return runParallelParquetWriters( + gCtx, + schema, + int(st.ingestOptions.targetFileSize), + int(st.ingestOptions.writerConcurrency), + parquetProps, + arrowProps, + pool.GetBuffer, + records, + buffers, + ) + }) + + // Create a temporary stage, we can't start uploading until it has been created + _, err = st.cnxn.cn.ExecContext(ctx, createTemporaryStageStmt, nil) + if err != nil { + return + } + + // Kickoff background tasks to COPY Parquet files into Snowflake table as they are uploaded + fileReady, finishCopy, cancelCopy := runCopyTasks(ctx, st.cnxn.cn, st.targetTable, int(st.ingestOptions.copyConcurrency)) + + // Read Parquet files from buffer pool and upload to Snowflake stage in parallel + g.Go(func() error { + return uploadAllStreams(gCtx, st.cnxn.cn, buffers, int(st.ingestOptions.uploadConcurrency), pool.PutBuffer, fileReady) + }) + + // Wait until either all files have been uploaded to Snowflake or the pipeline has failed / been canceled + if err = g.Wait(); err != nil { + // If the pipeline failed, we can cancel in-progress COPYs and avoid starting the final one + cancelCopy() + return + } + + // Start final COPY and wait for all COPY tasks to complete + err = finishCopy() + return +} + +func newWriterProps(mem memory.Allocator, opts *ingestOptions) (*parquet.WriterProperties, pqarrow.ArrowWriterProperties) { + parquetProps := parquet.NewWriterProperties( + parquet.WithAllocator(mem), + parquet.WithCompression(opts.compressionCodec), + parquet.WithCompressionLevel(opts.compressionLevel), + // Overhead for dict building often reduces throughput more than filesize reduction benefits; may expose as config in future + parquet.WithDictionaryDefault(false), + // Stats won't be used since the file is dropped after ingestion completes + parquet.WithStats(false), + parquet.WithMaxRowGroupLength(math.MaxInt64), + ) + arrowProps := pqarrow.NewArrowWriterProperties(pqarrow.WithAllocator(mem)) + + return parquetProps, arrowProps +} + +func readRecords(ctx context.Context, rdr array.RecordReader, out chan<- arrow.Record) error { + defer close(out) + + for rdr.Next() { + rec := rdr.Record() + rec.Retain() + + select { + case out <- rec: + case <-ctx.Done(): + return ctx.Err() + } + } + + return nil +} + +func writeParquet( + schema *arrow.Schema, + w io.Writer, + in <-chan arrow.Record, + targetSize int, + parquetProps *parquet.WriterProperties, + arrowProps pqarrow.ArrowWriterProperties, +) error { + limitWr := &limitWriter{w: w, limit: targetSize} + pqWriter, err := pqarrow.NewFileWriter(schema, limitWr, parquetProps, arrowProps) + if err != nil { + return err + } + defer pqWriter.Close() + + for rec := range in { + err = pqWriter.Write(rec) + rec.Release() + if err != nil { + return err + } + if limitWr.LimitExceeded() { + return nil + } + } + + // Input channel closed, signal that all parquet writing is done + return io.EOF +} + +func runParallelParquetWriters( + ctx context.Context, + schema *arrow.Schema, + targetSize int, + concurrency int, + parquetProps *parquet.WriterProperties, + arrowProps pqarrow.ArrowWriterProperties, + newBuffer func() *bytes.Buffer, + in <-chan arrow.Record, + out chan<- *bytes.Buffer, +) error { + var once sync.Once + defer close(out) + + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(concurrency) + + done := make(chan interface{}) + finished := func() { + once.Do(func() { close(done) }) + } + + tryWriteBuffer := func(buf *bytes.Buffer) { + select { + case out <- buf: + case <-ctx.Done(): + // If the context is canceled, the buffer may be full but we don't want to block indefinitely + } + } + + for { + select { + case <-done: + return g.Wait() + default: + } + + g.Go(func() error { + select { + case <-done: + // Channel may have already closed while goroutine was waiting to get scheduled + return nil + case <-ctx.Done(): + finished() + return ctx.Err() + default: + // Proceed to write a new file + } + + buf := newBuffer() + err := writeParquet(schema, buf, in, targetSize, parquetProps, arrowProps) + if err == io.EOF { + tryWriteBuffer(buf) + finished() + return nil + } + if err == nil { + tryWriteBuffer(buf) + } + return err + }) + } +} + +func uploadStream(ctx context.Context, cn snowflakeConn, r io.Reader, name string) error { + putQuery := fmt.Sprintf(putQueryTmpl, name) + putQuery = strings.ReplaceAll(putQuery, "\\", "\\\\") // Windows compatibility + + _, err := cn.ExecContext(gosnowflake.WithFileStream(ctx, r), putQuery, nil) + if err != nil { + return err + } + + return nil +} + +func uploadAllStreams( + ctx context.Context, + cn snowflakeConn, + streams <-chan *bytes.Buffer, + concurrency int, + freeBuffer func(*bytes.Buffer), + uploadCallback func(), +) error { + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(concurrency) + + var i int + // Loop through buffers as they become ready and assign to available upload workers + for buf := range streams { + select { + case <-ctx.Done(): + // The context is canceled on error, so we wait for graceful shutdown of in-progress uploads. + // The gosnowflake.snowflakeFileTransferAgent does not currently propogate context, so we + // have to wait for uploads to finish for proper shutdown. (https://github.com/snowflakedb/gosnowflake/issues/1028) + return g.Wait() + default: + } + + buf := buf // mutable loop variable + fileName := fmt.Sprintf("%d.parquet", i) + g.Go(func() error { + defer freeBuffer(buf) + defer uploadCallback() + + return uploadStream(ctx, cn, buf, fileName) + }) + i++ + } + return g.Wait() +} + +func runCopyTasks(ctx context.Context, cn snowflakeConn, tableName string, concurrency int) (func(), func() error, func()) { + ctx, cancel := context.WithCancel(ctx) + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(concurrency) + + readyCh := make(chan struct{}, 1) + stopCh := make(chan interface{}) + + // Handler to signal that new data has been uploaded. + // Executing COPY will be a no-op if this has not been called since the last COPY was dispatched, so we wait for the signal. + readyFn := func() { + + // readyFn is a no-op if the shutdown signal has already been recieved + select { + case _, ok := <-stopCh: + if !ok { + return + } + default: + // Continue + } + + // readyFn is a no-op if it already knows that at least 1 file is ready to be loaded + select { + case readyCh <- struct{}{}: + default: + return + } + } + + // Handler to signal that ingestion pipeline has completed successfully. + // Executes COPY to finalize ingestion (may no-op if all files already loaded by bg workers) and waits for all COPYs to complete. + stopFn := func() error { + defer cancel() + close(stopCh) + close(readyCh) + + _, err := cn.ExecContext(ctx, copyQuery, []driver.NamedValue{{Value: tableName}}) + if err != nil { + return err + } + + return g.Wait() + } + + // Handler to signal that ingestion pipeline failed and COPY operations should not proceed. + // Stops the dispatch of new bg workers and cancels all in-progress COPY operations. + cancelFn := func() { + defer cancel() + close(stopCh) + close(readyCh) + } + + go func() { + for { + + // Block until there is at least 1 new file available for copy, or it's time to shutdown + select { + case <-stopCh: + return + case <-ctx.Done(): + return + case _, ok := <-readyCh: + if !ok { + return + } + // Proceed to start a new COPY job + } + + g.Go(func() error { + _, err := cn.ExecContext(ctx, copyQuery, []driver.NamedValue{{Value: tableName}}) + return err + }) + } + }() + + return readyFn, stopFn, cancelFn +} + +func countRowsInTable(ctx context.Context, db *sql.DB, tableName string) (int64, error) { + var nrows int64 + + row := db.QueryRowContext(ctx, countQuery, tableName) + if err := row.Scan(&nrows); err != nil { + return 0, err + } + + return nrows, nil +} + +// Initializes a sync.Pool of *bytes.Buffer. +// Extra space is preallocated so that the Parquet footer can be written after reaching target file size without growing the buffer +func newBufferPool(size int) *bufferPool { + buffers := sync.Pool{ + New: func() interface{} { + extraSpace := 1 * megabyte // TODO(joellubi): Generally works, but can this be smarter? + buf := make([]byte, 0, size+extraSpace) + return bytes.NewBuffer(buf) + }, + } + + return &bufferPool{&buffers} +} + +type bufferPool struct { + *sync.Pool +} + +func (bp *bufferPool) GetBuffer() *bytes.Buffer { + return bp.Pool.Get().(*bytes.Buffer) +} + +func (bp *bufferPool) PutBuffer(buf *bytes.Buffer) { + buf.Reset() + bp.Pool.Put(buf) +} + +// Wraps an io.Writer and specifies a limit. +// Keeps track of how many bytes have been written and can report whether the limit has been exceeded. +// TODO(ARROW-39789): We prefer to use RowGroupTotalBytesWritten on the ParquetWriter, but there seems to be a discrepency with the count. +type limitWriter struct { + w io.Writer + limit int + + bytesWritten int +} + +func (lw *limitWriter) Write(p []byte) (int, error) { + n, err := lw.w.Write(p) + lw.bytesWritten += n + + return n, err +} + +func (lw *limitWriter) LimitExceeded() bool { + if lw.limit > 0 { + return lw.bytesWritten > lw.limit + } + // Limit disabled + return false +} diff --git a/go/adbc/driver/snowflake/connection.go b/go/adbc/driver/snowflake/connection.go index 681d90384b..06284b0ca7 100644 --- a/go/adbc/driver/snowflake/connection.go +++ b/go/adbc/driver/snowflake/connection.go @@ -1052,12 +1052,14 @@ func (c *cnxn) Rollback(_ context.Context) error { // NewStatement initializes a new statement object tied to this connection func (c *cnxn) NewStatement() (adbc.Statement, error) { + defaultIngestOptions := DefaultIngestOptions() return &statement{ alloc: c.db.Alloc, cnxn: c, queueSize: defaultStatementQueueSize, prefetchConcurrency: defaultPrefetchConcurrency, useHighPrecision: c.useHighPrecision, + ingestOptions: defaultIngestOptions, }, nil } diff --git a/go/adbc/driver/snowflake/driver_test.go b/go/adbc/driver/snowflake/driver_test.go index 850fcd5be6..7dd4043809 100644 --- a/go/adbc/driver/snowflake/driver_test.go +++ b/go/adbc/driver/snowflake/driver_test.go @@ -18,6 +18,7 @@ package snowflake_test import ( + "bytes" "context" "crypto/rand" "crypto/rsa" @@ -343,6 +344,11 @@ func (suite *SnowflakeTests) TearDownSuite() { func (suite *SnowflakeTests) TestSqlIngestTimestamp() { suite.Require().NoError(suite.Quirks.DropTable(suite.cnxn, "bulk_ingest")) + sessionTimezone := "UTC" + suite.Require().NoError(suite.stmt.SetSqlQuery(fmt.Sprintf(`ALTER SESSION SET TIMEZONE = "%s"`, sessionTimezone))) + _, err := suite.stmt.ExecuteUpdate(suite.ctx) + suite.Require().NoError(err) + sc := arrow.NewSchema([]arrow.Field{{ Name: "col", Type: arrow.FixedWidthTypes.Timestamp_us, Nullable: true, @@ -388,6 +394,959 @@ func (suite *SnowflakeTests) TestSqlIngestTimestamp() { suite.Require().NoError(rdr.Err()) } +func (suite *SnowflakeTests) TestSqlIngestRecordAndStreamAreEquivalent() { + suite.Require().NoError(suite.Quirks.DropTable(suite.cnxn, "bulk_ingest_bind")) + suite.Require().NoError(suite.Quirks.DropTable(suite.cnxn, "bulk_ingest_bind_stream")) + + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(suite.T(), 0) + + sc := arrow.NewSchema([]arrow.Field{ + { + Name: "col_int64", Type: arrow.PrimitiveTypes.Int64, + Nullable: true, + }, + { + Name: "col_float64", Type: arrow.PrimitiveTypes.Float64, + Nullable: true, + }, + { + Name: "col_string", Type: arrow.BinaryTypes.String, + Nullable: true, + }, + { + Name: "col_binary", Type: arrow.BinaryTypes.Binary, + Nullable: true, + }, + { + Name: "col_boolean", Type: arrow.FixedWidthTypes.Boolean, + Nullable: true, + }, + { + Name: "col_date32", Type: arrow.FixedWidthTypes.Date32, + Nullable: true, + }, + { + Name: "col_time64ns", Type: arrow.FixedWidthTypes.Time64ns, + Nullable: true, + }, + { + Name: "col_time64us", Type: arrow.FixedWidthTypes.Time64us, + Nullable: true, + }, + { + Name: "col_time32ms", Type: arrow.FixedWidthTypes.Time32ms, + Nullable: true, + }, + { + Name: "col_time32s", Type: arrow.FixedWidthTypes.Time32s, + Nullable: true, + }, + { + Name: "col_timestamp_ns", Type: arrow.FixedWidthTypes.Timestamp_ns, + Nullable: true, + }, + { + Name: "col_timestamp_us", Type: arrow.FixedWidthTypes.Timestamp_us, + Nullable: true, + }, + { + Name: "col_timestamp_s", Type: arrow.FixedWidthTypes.Timestamp_s, + Nullable: true, + }, + }, nil) + + bldr := array.NewRecordBuilder(mem, sc) + defer bldr.Release() + + bldr.Field(0).(*array.Int64Builder).AppendValues([]int64{-1, 0, 25}, nil) + bldr.Field(1).(*array.Float64Builder).AppendValues([]float64{-1.1, 0, 25.95}, nil) + bldr.Field(2).(*array.StringBuilder).AppendValues([]string{"first", "second", "third"}, nil) + bldr.Field(3).(*array.BinaryBuilder).AppendValues([][]byte{[]byte("first"), []byte("second"), []byte("third")}, nil) + bldr.Field(4).(*array.BooleanBuilder).AppendValues([]bool{true, false, true}, nil) + bldr.Field(5).(*array.Date32Builder).AppendValues([]arrow.Date32{1, 2, 3}, nil) + bldr.Field(6).(*array.Time64Builder).AppendValues([]arrow.Time64{1, 2, 3}, nil) + bldr.Field(7).(*array.Time64Builder).AppendValues([]arrow.Time64{1, 2, 3}, nil) + bldr.Field(8).(*array.Time32Builder).AppendValues([]arrow.Time32{1, 2, 3}, nil) + bldr.Field(9).(*array.Time32Builder).AppendValues([]arrow.Time32{1, 2, 3}, nil) + bldr.Field(10).(*array.TimestampBuilder).AppendValues([]arrow.Timestamp{1, 2, 3}, nil) + bldr.Field(11).(*array.TimestampBuilder).AppendValues([]arrow.Timestamp{1, 2, 3}, nil) + bldr.Field(12).(*array.TimestampBuilder).AppendValues([]arrow.Timestamp{1, 2, 3}, nil) + + rec := bldr.NewRecord() + defer rec.Release() + + stream, err := array.NewRecordReader(sc, []arrow.Record{rec}) + suite.Require().NoError(err) + defer stream.Release() + + suite.Require().NoError(suite.stmt.Bind(suite.ctx, rec)) + suite.Require().NoError(suite.stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bulk_ingest_bind")) + n, err := suite.stmt.ExecuteUpdate(suite.ctx) + suite.Require().NoError(err) + suite.EqualValues(3, n) + + suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT * FROM bulk_ingest_bind ORDER BY \"col_int64\" ASC")) + rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx) + suite.Require().NoError(err) + defer rdr.Release() + + suite.EqualValues(3, n) + suite.True(rdr.Next()) + resultBind := rdr.Record() + + // New session to clean up TEMPORARY resources in Snowflake associated with the previous one + suite.NoError(suite.stmt.Close()) + suite.NoError(suite.cnxn.Close()) + suite.cnxn, err = suite.db.Open(suite.ctx) + suite.NoError(err) + suite.stmt, err = suite.cnxn.NewStatement() + suite.NoError(err) + + suite.Require().NoError(suite.stmt.BindStream(suite.ctx, stream)) + suite.Require().NoError(suite.stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bulk_ingest_bind_stream")) + n, err = suite.stmt.ExecuteUpdate(suite.ctx) + suite.Require().NoError(err) + suite.EqualValues(3, n) + + suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT * FROM bulk_ingest_bind_stream ORDER BY \"col_int64\" ASC")) + rdr, n, err = suite.stmt.ExecuteQuery(suite.ctx) + suite.Require().NoError(err) + defer rdr.Release() + + suite.EqualValues(3, n) + suite.True(rdr.Next()) + resultBindStream := rdr.Record() + + suite.Truef(array.RecordEqual(resultBind, resultBindStream), "expected: %s\ngot: %s", resultBind, resultBindStream) + suite.False(rdr.Next()) + + suite.Require().NoError(rdr.Err()) +} + +func (suite *SnowflakeTests) TestSqlIngestRoundtripTypes() { + suite.Require().NoError(suite.Quirks.DropTable(suite.cnxn, "bulk_ingest_roundtrip")) + + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(suite.T(), 0) + + sc := arrow.NewSchema([]arrow.Field{ + { + Name: "col_int64", Type: arrow.PrimitiveTypes.Int64, + Nullable: true, + }, + { + Name: "col_float64", Type: arrow.PrimitiveTypes.Float64, + Nullable: true, + }, + { + Name: "col_string", Type: arrow.BinaryTypes.String, + Nullable: true, + }, + { + Name: "col_binary", Type: arrow.BinaryTypes.Binary, + Nullable: true, + }, + { + Name: "col_boolean", Type: arrow.FixedWidthTypes.Boolean, + Nullable: true, + }, + { + Name: "col_date32", Type: arrow.FixedWidthTypes.Date32, + Nullable: true, + }, + { + Name: "col_time64ns", Type: arrow.FixedWidthTypes.Time64ns, + Nullable: true, + }, + { + Name: "col_time64us", Type: arrow.FixedWidthTypes.Time64us, + Nullable: true, + }, + { + Name: "col_time32ms", Type: arrow.FixedWidthTypes.Time32ms, + Nullable: true, + }, + { + Name: "col_time32s", Type: arrow.FixedWidthTypes.Time32s, + Nullable: true, + }, + }, nil) + + bldr := array.NewRecordBuilder(mem, sc) + defer bldr.Release() + + bldr.Field(0).(*array.Int64Builder).AppendValues([]int64{-1, 0, 25}, nil) + bldr.Field(1).(*array.Float64Builder).AppendValues([]float64{-1.1, 0, 25.95}, nil) + bldr.Field(2).(*array.StringBuilder).AppendValues([]string{"first", "second", "third"}, nil) + bldr.Field(3).(*array.BinaryBuilder).AppendValues([][]byte{[]byte("first"), []byte("second"), []byte("third")}, nil) + bldr.Field(4).(*array.BooleanBuilder).AppendValues([]bool{true, false, true}, nil) + bldr.Field(5).(*array.Date32Builder).AppendValues([]arrow.Date32{1, 2, 3}, nil) + bldr.Field(6).(*array.Time64Builder).AppendValues([]arrow.Time64{1, 2, 3}, nil) + bldr.Field(7).(*array.Time64Builder).AppendValues([]arrow.Time64{1, 2, 3}, nil) + bldr.Field(8).(*array.Time32Builder).AppendValues([]arrow.Time32{1, 2, 3}, nil) + bldr.Field(9).(*array.Time32Builder).AppendValues([]arrow.Time32{1, 2, 3}, nil) + + rec := bldr.NewRecord() + defer rec.Release() + + suite.Require().NoError(suite.stmt.Bind(suite.ctx, rec)) + suite.Require().NoError(suite.stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bulk_ingest_roundtrip")) + n, err := suite.stmt.ExecuteUpdate(suite.ctx) + suite.Require().NoError(err) + suite.EqualValues(3, n) + + suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT * FROM bulk_ingest_roundtrip ORDER BY \"col_int64\" ASC")) + rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx) + suite.Require().NoError(err) + defer rdr.Release() + + suite.EqualValues(3, n) + suite.True(rdr.Next()) + result := rdr.Record() + suite.Truef(array.RecordEqual(rec, result), "expected: %s\ngot: %s", rec, result) + suite.False(rdr.Next()) + + suite.Require().NoError(rdr.Err()) +} + +func (suite *SnowflakeTests) TestSqlIngestTimestampTypes() { + suite.Require().NoError(suite.Quirks.DropTable(suite.cnxn, "bulk_ingest_timestamps")) + + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(suite.T(), 0) + + sessionTimezone := "America/Phoenix" + suite.Require().NoError(suite.stmt.SetSqlQuery(fmt.Sprintf(`ALTER SESSION SET TIMEZONE = "%s"`, sessionTimezone))) + _, err := suite.stmt.ExecuteUpdate(suite.ctx) + suite.Require().NoError(err) + + sc := arrow.NewSchema([]arrow.Field{ + { + Name: "col_int64", Type: arrow.PrimitiveTypes.Int64, + Nullable: true, + }, + { + Name: "col_timestamp_ns", Type: arrow.FixedWidthTypes.Timestamp_ns, + Nullable: true, + }, + { + Name: "col_timestamp_us", Type: arrow.FixedWidthTypes.Timestamp_us, + Nullable: true, + }, + { + Name: "col_timestamp_ms", Type: arrow.FixedWidthTypes.Timestamp_ms, + Nullable: true, + }, + { + Name: "col_timestamp_s", Type: arrow.FixedWidthTypes.Timestamp_s, + Nullable: true, + }, + { + Name: "col_timestamp_s_tz", Type: &arrow.TimestampType{Unit: arrow.Second, TimeZone: "EST"}, + Nullable: true, + }, + { + Name: "col_timestamp_s_ntz", Type: &arrow.TimestampType{Unit: arrow.Second}, + Nullable: true, + }, + }, nil) + + bldr := array.NewRecordBuilder(mem, sc) + defer bldr.Release() + + bldr.Field(0).(*array.Int64Builder).AppendValues([]int64{1, 2, 3}, nil) + bldr.Field(1).(*array.TimestampBuilder).AppendValues([]arrow.Timestamp{1, 2, 3}, nil) + bldr.Field(2).(*array.TimestampBuilder).AppendValues([]arrow.Timestamp{1, 2, 3}, nil) + bldr.Field(3).(*array.TimestampBuilder).AppendValues([]arrow.Timestamp{1, 2, 3}, nil) + bldr.Field(4).(*array.TimestampBuilder).AppendValues([]arrow.Timestamp{1, 2, 3}, nil) + bldr.Field(5).(*array.TimestampBuilder).AppendValues([]arrow.Timestamp{1, 2, 3}, nil) + bldr.Field(6).(*array.TimestampBuilder).AppendValues([]arrow.Timestamp{1, 2, 3}, nil) + + rec := bldr.NewRecord() + defer rec.Release() + + suite.Require().NoError(suite.stmt.Bind(suite.ctx, rec)) + suite.Require().NoError(suite.stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bulk_ingest_timestamps")) + n, err := suite.stmt.ExecuteUpdate(suite.ctx) + suite.Require().NoError(err) + suite.EqualValues(3, n) + + suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT * FROM bulk_ingest_timestamps ORDER BY \"col_int64\" ASC")) + rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx) + suite.Require().NoError(err) + defer rdr.Release() + + suite.EqualValues(3, n) + suite.True(rdr.Next()) + result := rdr.Record() + + expectedSchema := arrow.NewSchema([]arrow.Field{ + { + Name: "col_int64", Type: arrow.PrimitiveTypes.Int64, + Nullable: true, + }, + { + Name: "col_timestamp_ns", Type: &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: sessionTimezone}, + Nullable: true, + }, + { + Name: "col_timestamp_us", Type: &arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: sessionTimezone}, + Nullable: true, + }, + { + Name: "col_timestamp_ms", Type: &arrow.TimestampType{Unit: arrow.Millisecond, TimeZone: sessionTimezone}, + Nullable: true, + }, + { + Name: "col_timestamp_s", Type: &arrow.TimestampType{Unit: arrow.Second, TimeZone: sessionTimezone}, + Nullable: true, + }, + { + Name: "col_timestamp_s_tz", Type: &arrow.TimestampType{Unit: arrow.Second, TimeZone: sessionTimezone}, + Nullable: true, + }, + { + Name: "col_timestamp_s_ntz", Type: &arrow.TimestampType{Unit: arrow.Second}, + Nullable: true, + }, + }, nil) + + expectedRecord, _, err := array.RecordFromJSON(mem, expectedSchema, bytes.NewReader([]byte(` + [ + { + "col_int64": 1, + "col_timestamp_ns": 1, + "col_timestamp_us": 1, + "col_timestamp_ms": 1, + "col_timestamp_s": 1, + "col_timestamp_s_tz": 1, + "col_timestamp_s_ntz": 1 + }, + { + "col_int64": 2, + "col_timestamp_ns": 2, + "col_timestamp_us": 2, + "col_timestamp_ms": 2, + "col_timestamp_s": 2, + "col_timestamp_s_tz": 2, + "col_timestamp_s_ntz": 2 + }, + { + "col_int64": 3, + "col_timestamp_ns": 3, + "col_timestamp_us": 3, + "col_timestamp_ms": 3, + "col_timestamp_s": 3, + "col_timestamp_s_tz": 3, + "col_timestamp_s_ntz": 3 + } + ] + `))) + suite.Require().NoError(err) + defer expectedRecord.Release() + + suite.Truef(array.RecordEqual(expectedRecord, result), "expected: %s\ngot: %s", expectedRecord, result) + + suite.False(rdr.Next()) + suite.Require().NoError(rdr.Err()) +} + +func (suite *SnowflakeTests) TestSqlIngestDate64Type() { + suite.Require().NoError(suite.Quirks.DropTable(suite.cnxn, "bulk_ingest_date64")) + + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(suite.T(), 0) + + sc := arrow.NewSchema([]arrow.Field{ + { + Name: "col_int64", Type: arrow.PrimitiveTypes.Int64, + Nullable: true, + }, + { + Name: "col_date64", Type: arrow.FixedWidthTypes.Date64, + Nullable: true, + }, + }, nil) + + bldr := array.NewRecordBuilder(mem, sc) + defer bldr.Release() + + bldr.Field(0).(*array.Int64Builder).AppendValues([]int64{1, 2, 3}, nil) + bldr.Field(1).(*array.Date64Builder).AppendValues([]arrow.Date64{86400000, 172800000, 259200000}, nil) // 1,2,3 days of milliseconds + + rec := bldr.NewRecord() + defer rec.Release() + + suite.Require().NoError(suite.stmt.Bind(suite.ctx, rec)) + suite.Require().NoError(suite.stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bulk_ingest_date64")) + n, err := suite.stmt.ExecuteUpdate(suite.ctx) + suite.Require().NoError(err) + suite.EqualValues(3, n) + + suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT * FROM bulk_ingest_date64 ORDER BY \"col_int64\" ASC")) + rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx) + suite.Require().NoError(err) + defer rdr.Release() + + suite.EqualValues(3, n) + suite.True(rdr.Next()) + result := rdr.Record() + + expectedSchema := arrow.NewSchema([]arrow.Field{ + { + Name: "col_int64", Type: arrow.PrimitiveTypes.Int64, + Nullable: true, + }, + { + Name: "col_date64", Type: arrow.FixedWidthTypes.Date32, + Nullable: true, + }, + }, nil) + + expectedRecord, _, err := array.RecordFromJSON(mem, expectedSchema, bytes.NewReader([]byte(` + [ + { + "col_int64": 1, + "col_date64": 1 + }, + { + "col_int64": 2, + "col_date64": 2 + }, + { + "col_int64": 3, + "col_date64": 3 + } + ] + `))) + suite.Require().NoError(err) + defer expectedRecord.Release() + + suite.Truef(array.RecordEqual(expectedRecord, result), "expected: %s\ngot: %s", expectedRecord, result) + + suite.False(rdr.Next()) + suite.Require().NoError(rdr.Err()) +} + +func (suite *SnowflakeTests) TestSqlIngestHighPrecision() { + suite.Require().NoError(suite.Quirks.DropTable(suite.cnxn, "bulk_ingest_high_precision")) + + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(suite.T(), 0) + + sc := arrow.NewSchema([]arrow.Field{ + { + Name: "col_int64", Type: arrow.PrimitiveTypes.Int64, + Nullable: true, + }, + { + Name: "col_float64", Type: arrow.PrimitiveTypes.Float64, + Nullable: true, + }, + { + Name: "col_decimal128_whole", Type: &arrow.Decimal128Type{Precision: 38, Scale: 0}, + Nullable: true, + }, + { + Name: "col_decimal128_fractional", Type: &arrow.Decimal128Type{Precision: 38, Scale: 2}, + Nullable: true, + }, + }, nil) + + bldr := array.NewRecordBuilder(mem, sc) + defer bldr.Release() + + bldr.Field(0).(*array.Int64Builder).AppendValues([]int64{1, 2, 3}, nil) + bldr.Field(1).(*array.Float64Builder).AppendValues([]float64{1.2, 2.34, 3.456}, nil) + bldr.Field(2).(*array.Decimal128Builder).AppendValues([]decimal128.Num{decimal128.FromI64(123), decimal128.FromI64(456), decimal128.FromI64(789)}, nil) + num1, err := decimal128.FromString("123", 38, 2) + suite.Require().NoError(err) + num2, err := decimal128.FromString("456.7", 38, 2) + suite.Require().NoError(err) + num3, err := decimal128.FromString("891.01", 38, 2) + suite.Require().NoError(err) + bldr.Field(3).(*array.Decimal128Builder).AppendValues([]decimal128.Num{num1, num2, num3}, nil) + + rec := bldr.NewRecord() + defer rec.Release() + + suite.Require().NoError(suite.stmt.Bind(suite.ctx, rec)) + suite.Require().NoError(suite.stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bulk_ingest_high_precision")) + n, err := suite.stmt.ExecuteUpdate(suite.ctx) + suite.Require().NoError(err) + suite.EqualValues(3, n) + + suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT * FROM bulk_ingest_high_precision ORDER BY \"col_int64\" ASC")) + suite.Require().NoError(suite.stmt.SetOption(driver.OptionUseHighPrecision, adbc.OptionValueEnabled)) + defer func() { + suite.Require().NoError(suite.stmt.SetOption(driver.OptionUseHighPrecision, adbc.OptionValueDisabled)) + }() + rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx) + suite.Require().NoError(err) + defer rdr.Release() + + suite.EqualValues(3, n) + suite.True(rdr.Next()) + result := rdr.Record() + + expectedSchema := arrow.NewSchema([]arrow.Field{ + { // INT64 -> DECIMAL(38, 0) on roundtrip + Name: "col_int64", Type: &arrow.Decimal128Type{Precision: 38, Scale: 0}, + Nullable: true, + }, + { // Preserved on roundtrip + Name: "col_float64", Type: arrow.PrimitiveTypes.Float64, + Nullable: true, + }, + { // Preserved on roundtrip + Name: "col_decimal128_whole", Type: &arrow.Decimal128Type{Precision: 38, Scale: 0}, + Nullable: true, + }, + { // Preserved on roundtrip + Name: "col_decimal128_fractional", Type: &arrow.Decimal128Type{Precision: 38, Scale: 2}, + Nullable: true, + }, + }, nil) + + expectedRecord, _, err := array.RecordFromJSON(mem, expectedSchema, bytes.NewReader([]byte(` + [ + { + "col_int64": 1, + "col_float64": 1.2, + "col_decimal128_whole": 123, + "col_decimal128_fractional": 123.00 + }, + { + "col_int64": 2, + "col_float64": 2.34, + "col_decimal128_whole": 456, + "col_decimal128_fractional": 456.70 + }, + { + "col_int64": 3, + "col_float64": 3.456, + "col_decimal128_whole": 789, + "col_decimal128_fractional": 891.01 + } + ] + `))) + suite.Require().NoError(err) + defer expectedRecord.Release() + + suite.Truef(array.RecordEqual(expectedRecord, result), "expected: %s\ngot: %s", expectedRecord, result) + + suite.False(rdr.Next()) + suite.Require().NoError(rdr.Err()) +} + +func (suite *SnowflakeTests) TestSqlIngestLowPrecision() { + suite.Require().NoError(suite.Quirks.DropTable(suite.cnxn, "bulk_ingest_high_precision")) + + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(suite.T(), 0) + + sc := arrow.NewSchema([]arrow.Field{ + { + Name: "col_int64", Type: arrow.PrimitiveTypes.Int64, + Nullable: true, + }, + { + Name: "col_float64", Type: arrow.PrimitiveTypes.Float64, + Nullable: true, + }, + { + Name: "col_decimal128_whole", Type: &arrow.Decimal128Type{Precision: 38, Scale: 0}, + Nullable: true, + }, + { + Name: "col_decimal128_fractional", Type: &arrow.Decimal128Type{Precision: 38, Scale: 2}, + Nullable: true, + }, + }, nil) + + bldr := array.NewRecordBuilder(mem, sc) + defer bldr.Release() + + bldr.Field(0).(*array.Int64Builder).AppendValues([]int64{1, 2, 3}, nil) + bldr.Field(1).(*array.Float64Builder).AppendValues([]float64{1.2, 2.34, 3.456}, nil) + bldr.Field(2).(*array.Decimal128Builder).AppendValues([]decimal128.Num{decimal128.FromI64(123), decimal128.FromI64(456), decimal128.FromI64(789)}, nil) + num1, err := decimal128.FromString("123", 38, 2) + suite.Require().NoError(err) + num2, err := decimal128.FromString("456.7", 38, 2) + suite.Require().NoError(err) + num3, err := decimal128.FromString("891.01", 38, 2) + suite.Require().NoError(err) + bldr.Field(3).(*array.Decimal128Builder).AppendValues([]decimal128.Num{num1, num2, num3}, nil) + + rec := bldr.NewRecord() + defer rec.Release() + + suite.Require().NoError(suite.stmt.Bind(suite.ctx, rec)) + suite.Require().NoError(suite.stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bulk_ingest_high_precision")) + n, err := suite.stmt.ExecuteUpdate(suite.ctx) + suite.Require().NoError(err) + suite.EqualValues(3, n) + + suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT * FROM bulk_ingest_high_precision ORDER BY \"col_int64\" ASC")) + // OptionUseHighPrecision already disabled + rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx) + suite.Require().NoError(err) + defer rdr.Release() + + suite.EqualValues(3, n) + suite.True(rdr.Next()) + result := rdr.Record() + + expectedSchema := arrow.NewSchema([]arrow.Field{ + { // Preserved on roundtrip + Name: "col_int64", Type: arrow.PrimitiveTypes.Int64, + Nullable: true, + }, + { // Preserved on roundtrip + Name: "col_float64", Type: arrow.PrimitiveTypes.Float64, + Nullable: true, + }, + { // DECIMAL(38, 0) -> INT64 on roundtrip + Name: "col_decimal128_whole", Type: arrow.PrimitiveTypes.Int64, + Nullable: true, + }, + { // DECIMAL(38, 2) -> FLOAT64 on roundtrip + Name: "col_decimal128_fractional", Type: arrow.PrimitiveTypes.Float64, + Nullable: true, + }, + }, nil) + + expectedRecord, _, err := array.RecordFromJSON(mem, expectedSchema, bytes.NewReader([]byte(` + [ + { + "col_int64": 1, + "col_float64": 1.2, + "col_decimal128_whole": 123, + "col_decimal128_fractional": 123.00 + }, + { + "col_int64": 2, + "col_float64": 2.34, + "col_decimal128_whole": 456, + "col_decimal128_fractional": 456.70 + }, + { + "col_int64": 3, + "col_float64": 3.456, + "col_decimal128_whole": 789, + "col_decimal128_fractional": 891.01 + } + ] + `))) + suite.Require().NoError(err) + defer expectedRecord.Release() + + suite.Truef(array.RecordEqual(expectedRecord, result), "expected: %s\ngot: %s", expectedRecord, result) + + suite.False(rdr.Next()) + suite.Require().NoError(rdr.Err()) +} + +func (suite *SnowflakeTests) TestSqlIngestStructType() { + suite.Require().NoError(suite.Quirks.DropTable(suite.cnxn, "bulk_ingest_struct")) + + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(suite.T(), 0) + + sc := arrow.NewSchema([]arrow.Field{ + { + Name: "col_int64", Type: arrow.PrimitiveTypes.Int64, + Nullable: true, + }, + { + Name: "col_struct", Type: arrow.StructOf([]arrow.Field{ + {Name: "name", Type: arrow.BinaryTypes.String}, + {Name: "age", Type: arrow.PrimitiveTypes.Int64}, + }...), + Nullable: true, + }, + { + Name: "col_struct_of_struct", Type: arrow.StructOf([]arrow.Field{ + {Name: "id", Type: arrow.PrimitiveTypes.Int64}, + {Name: "nested", Type: arrow.StructOf([]arrow.Field{ + {Name: "nested_id", Type: arrow.PrimitiveTypes.Int64}, + {Name: "ready", Type: arrow.FixedWidthTypes.Boolean}, + }...)}, + }...), + Nullable: true, + }, + }, nil) + + bldr := array.NewRecordBuilder(mem, sc) + defer bldr.Release() + + bldr.Field(0).(*array.Int64Builder).AppendValues([]int64{1, 2, 3}, nil) + + struct1bldr := bldr.Field(1).(*array.StructBuilder) + struct1bldr.AppendValues([]bool{true, true, true}) + struct1bldr.FieldBuilder(0).(*array.StringBuilder).AppendValues([]string{"one", "two", "three"}, nil) + struct1bldr.FieldBuilder(1).(*array.Int64Builder).AppendValues([]int64{10, 20, 30}, nil) + + struct2bldr := bldr.Field(2).(*array.StructBuilder) + struct2bldr.AppendValues([]bool{true, false, true}) + struct2bldr.FieldBuilder(0).(*array.Int64Builder).AppendValues([]int64{1, 0, 3}, nil) + + struct3bldr := struct2bldr.FieldBuilder(1).(*array.StructBuilder) + struct3bldr.AppendValues([]bool{true, false, true}) + struct3bldr.FieldBuilder(0).(*array.Int64Builder).AppendValues([]int64{1, 0, 3}, nil) + struct3bldr.FieldBuilder(1).(*array.BooleanBuilder).AppendValues([]bool{true, false, false}, nil) + + rec := bldr.NewRecord() + defer rec.Release() + + suite.Require().NoError(suite.stmt.Bind(suite.ctx, rec)) + suite.Require().NoError(suite.stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bulk_ingest_struct")) + n, err := suite.stmt.ExecuteUpdate(suite.ctx) + suite.Require().NoError(err) + suite.EqualValues(3, n) + + suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT * FROM bulk_ingest_struct ORDER BY \"col_int64\" ASC")) + rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx) + suite.Require().NoError(err) + defer rdr.Release() + + suite.EqualValues(3, n) + suite.True(rdr.Next()) + result := rdr.Record() + + expectedSchema := arrow.NewSchema([]arrow.Field{ + { + Name: "col_int64", Type: arrow.PrimitiveTypes.Int64, + Nullable: true, + }, + { + Name: "col_struct", Type: arrow.BinaryTypes.String, + Nullable: true, + }, + { + Name: "col_struct_of_struct", Type: arrow.BinaryTypes.String, + Nullable: true, + }, + }, nil) + + expectedRecord, _, err := array.RecordFromJSON(mem, expectedSchema, bytes.NewReader([]byte(` + [ + { + "col_int64": 1, + "col_struct": "{\n \"age\": 10,\n \"name\": \"one\"\n}", + "col_struct_of_struct": "{\n \"id\": 1,\n \"nested\": {\n \"nested_id\": 1,\n \"ready\": true\n }\n}" + }, + { + "col_int64": 2, + "col_struct": "{\n \"age\": 20,\n \"name\": \"two\"\n}" + }, + { + "col_int64": 3, + "col_struct": "{\n \"age\": 30,\n \"name\": \"three\"\n}", + "col_struct_of_struct": "{\n \"id\": 3,\n \"nested\": {\n \"nested_id\": 3,\n \"ready\": false\n }\n}" + } + ] + `))) + suite.Require().NoError(err) + defer expectedRecord.Release() + + suite.Truef(array.RecordEqual(expectedRecord, result), "expected: %s\ngot: %s", expectedRecord, result) + logicalTypeStruct, ok := result.Schema().Field(1).Metadata.GetValue("logicalType") + suite.True(ok) + suite.Equal("OBJECT", logicalTypeStruct) + logicalTypeStructStruct, ok := result.Schema().Field(2).Metadata.GetValue("logicalType") + suite.True(ok) + suite.Equal("OBJECT", logicalTypeStructStruct) + + suite.False(rdr.Next()) + suite.Require().NoError(rdr.Err()) +} + +func (suite *SnowflakeTests) TestSqlIngestMapType() { + suite.Require().NoError(suite.Quirks.DropTable(suite.cnxn, "bulk_ingest_map")) + + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(suite.T(), 0) + + sc := arrow.NewSchema([]arrow.Field{ + { + Name: "col_int64", Type: arrow.PrimitiveTypes.Int64, + Nullable: true, + }, + { + Name: "col_map", Type: arrow.MapOf(arrow.BinaryTypes.String, arrow.PrimitiveTypes.Int64), + Nullable: true, + }, + }, nil) + + bldr := array.NewRecordBuilder(mem, sc) + defer bldr.Release() + + bldr.Field(0).(*array.Int64Builder).AppendValues([]int64{1, 2, 3}, nil) + + mapbldr := bldr.Field(1).(*array.MapBuilder) + keybldr := mapbldr.KeyBuilder().(*array.StringBuilder) + itembldr := mapbldr.ItemBuilder().(*array.Int64Builder) + + mapbldr.Append(true) + keybldr.Append("key1") + itembldr.Append(1) + // keybldr.Append("key1a") TODO(joellubi): Snowflake returns 'SQL execution internal error', seemingly for repetition levels > 0 + // itembldr.Append(11) + mapbldr.Append(true) + keybldr.Append("key2") + itembldr.Append(2) + mapbldr.Append(true) + keybldr.Append("key3") + itembldr.Append(3) + + rec := bldr.NewRecord() + defer rec.Release() + + suite.Require().NoError(suite.stmt.Bind(suite.ctx, rec)) + suite.Require().NoError(suite.stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bulk_ingest_map")) + n, err := suite.stmt.ExecuteUpdate(suite.ctx) + suite.Require().NoError(err) + suite.EqualValues(3, n) + + suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT * FROM bulk_ingest_map ORDER BY \"col_int64\" ASC")) + rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx) + suite.Require().NoError(err) + defer rdr.Release() + + suite.EqualValues(3, n) + suite.True(rdr.Next()) + result := rdr.Record() + + expectedSchema := arrow.NewSchema([]arrow.Field{ + { + Name: "col_int64", Type: arrow.PrimitiveTypes.Int64, + Nullable: true, + }, + { + Name: "col_map", Type: arrow.BinaryTypes.String, + Nullable: true, + }, + }, nil) + + expectedRecord, _, err := array.RecordFromJSON(mem, expectedSchema, bytes.NewReader([]byte(` + [ + { + "col_int64": 1, + "col_map": "{\n \"key_value\": [\n {\n \"key\": \"key1\",\n \"value\": 1\n }\n ]\n}" + }, + { + "col_int64": 2, + "col_map": "{\n \"key_value\": [\n {\n \"key\": \"key2\",\n \"value\": 2\n }\n ]\n}" + }, + { + "col_int64": 3, + "col_map": "{\n \"key_value\": [\n {\n \"key\": \"key3\",\n \"value\": 3\n }\n ]\n}" + } + ] + `))) + suite.Require().NoError(err) + defer expectedRecord.Release() + + suite.Truef(array.RecordEqual(expectedRecord, result), "expected: %s\ngot: %s", expectedRecord, result) + logicalTypeMap, ok := result.Schema().Field(1).Metadata.GetValue("logicalType") + suite.True(ok) + suite.Equal("OBJECT", logicalTypeMap) + + suite.False(rdr.Next()) + suite.Require().NoError(rdr.Err()) +} + +func (suite *SnowflakeTests) TestSqlIngestListType() { + suite.Require().NoError(suite.Quirks.DropTable(suite.cnxn, "bulk_ingest_list")) + + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(suite.T(), 0) + + sc := arrow.NewSchema([]arrow.Field{ + { + Name: "col_int64", Type: arrow.PrimitiveTypes.Int64, + Nullable: true, + }, + { + Name: "col_list", Type: arrow.ListOf(arrow.BinaryTypes.String), + Nullable: true, + }, + }, nil) + + bldr := array.NewRecordBuilder(mem, sc) + defer bldr.Release() + + bldr.Field(0).(*array.Int64Builder).AppendValues([]int64{1, 2, 3}, nil) + + listbldr := bldr.Field(1).(*array.ListBuilder) + listvalbldr := listbldr.ValueBuilder().(*array.StringBuilder) + listbldr.Append(true) + listvalbldr.Append("one") + // listvalbldr.Append("one2") TODO(joellubi): Snowflake returns 'SQL execution internal error', seemingly for repetition levels > 0 + listbldr.Append(true) + listvalbldr.Append("two") + listbldr.Append(true) + listvalbldr.Append("three") + + rec := bldr.NewRecord() + defer rec.Release() + + suite.Require().NoError(suite.stmt.Bind(suite.ctx, rec)) + suite.Require().NoError(suite.stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bulk_ingest_list")) + n, err := suite.stmt.ExecuteUpdate(suite.ctx) + suite.Require().NoError(err) + suite.EqualValues(3, n) + + suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT * FROM bulk_ingest_list ORDER BY \"col_int64\" ASC")) + rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx) + suite.Require().NoError(err) + defer rdr.Release() + + suite.EqualValues(3, n) + suite.True(rdr.Next()) + result := rdr.Record() + + expectedSchema := arrow.NewSchema([]arrow.Field{ + { + Name: "col_int64", Type: arrow.PrimitiveTypes.Int64, + Nullable: true, + }, + { + Name: "col_list", Type: arrow.BinaryTypes.String, + Nullable: true, + }, + }, nil) + + expectedRecord, _, err := array.RecordFromJSON(mem, expectedSchema, bytes.NewReader([]byte(` + [ + { + "col_int64": 1, + "col_list": "[\n \"one\"\n]" + }, + { + "col_int64": 2, + "col_list": "[\n \"two\"\n]" + }, + { + "col_int64": 3, + "col_list": "[\n \"three\"\n]" + } + ] + `))) + suite.Require().NoError(err) + defer expectedRecord.Release() + + suite.Truef(array.RecordEqual(expectedRecord, result), "expected: %s\ngot: %s", expectedRecord, result) + logicalTypeList, ok := result.Schema().Field(1).Metadata.GetValue("logicalType") + suite.True(ok) + suite.Equal("ARRAY", logicalTypeList) + + suite.False(rdr.Next()) + suite.Require().NoError(rdr.Err()) +} + func (suite *SnowflakeTests) TestStatementEmptyResultSet() { // Regression test for https://github.com/apache/arrow-adbc/issues/863 suite.NoError(suite.stmt.SetSqlQuery("SHOW WAREHOUSES")) diff --git a/go/adbc/driver/snowflake/record_reader.go b/go/adbc/driver/snowflake/record_reader.go index 44057d4ee0..2f5335a109 100644 --- a/go/adbc/driver/snowflake/record_reader.go +++ b/go/adbc/driver/snowflake/record_reader.go @@ -20,6 +20,7 @@ package snowflake import ( "context" "encoding/hex" + "fmt" "math" "strconv" "strings" @@ -36,6 +37,8 @@ import ( "golang.org/x/sync/errgroup" ) +const MetadataKeySnowflakeType = "SNOWFLAKE_TYPE" + func identCol(_ context.Context, a arrow.Array) (arrow.Array, error) { a.Retain() return a, nil @@ -212,13 +215,7 @@ func getTransformer(sc *arrow.Schema, ld gosnowflake.ArrowStreamLoader, useHighP continue } - q := int64(t) / int64(math.Pow10(int(srcMeta.Scale))) - r := int64(t) % int64(math.Pow10(int(srcMeta.Scale))) - v, err := arrow.TimestampFromTime(time.Unix(q, r), dt.Unit) - if err != nil { - return nil, err - } - tb.Append(v) + tb.Append(arrow.Timestamp(t)) } } return tb.NewArray(), nil @@ -313,7 +310,7 @@ func rowTypesToArrowSchema(ctx context.Context, ld gosnowflake.ArrowStreamLoader Name: srcMeta.Name, Nullable: srcMeta.Nullable, Metadata: arrow.MetadataFrom(map[string]string{ - "SNOWFLAKE_TYPE": srcMeta.Type, + MetadataKeySnowflakeType: srcMeta.Type, }), } switch srcMeta.Type { @@ -384,69 +381,58 @@ func jsonDataToArrow(ctx context.Context, bldr *array.RecordBuilder, ld gosnowfl fb.Append(arrow.Time64(sec*1e9 + nsec)) case *array.TimestampBuilder: - tz, err := fb.Type().(*arrow.TimestampType).GetZone() - if err != nil { - return nil, err + snowflakeType, ok := bldr.Schema().Field(i).Metadata.GetValue(MetadataKeySnowflakeType) + if !ok { + return nil, errToAdbcErr( + adbc.StatusInvalidData, + fmt.Errorf("key %s not found in metadata for field %s", MetadataKeySnowflakeType, bldr.Schema().Field(i).Name), + ) } - if tz != time.UTC { - sec, nsec, err := extractTimestamp(col) - if err != nil { - return nil, err + if snowflakeType == "timestamp_tz" { + // "timestamp_tz" should be value + offset separated by space + tm := strings.Split(*col, " ") + if len(tm) != 2 { + return nil, adbc.Error{ + Msg: "invalid TIMESTAMP_TZ data. value doesn't consist of two numeric values separated by a space: " + *col, + SqlState: [5]byte{'2', '2', '0', '0', '7'}, + VendorCode: 268000, + Code: adbc.StatusInvalidData, + } } - val := time.Unix(sec, nsec).In(tz) - ts, err := arrow.TimestampFromTime(val, arrow.Nanosecond) + sec, nsec, err := extractTimestamp(&tm[0]) if err != nil { return nil, err } - fb.Append(ts) - break - } + offset, err := strconv.ParseInt(tm[1], 10, 64) + if err != nil { + return nil, adbc.Error{ + Msg: "invalid TIMESTAMP_TZ data. offset value is not an integer: " + tm[1], + SqlState: [5]byte{'2', '2', '0', '0', '7'}, + VendorCode: 268000, + Code: adbc.StatusInvalidData, + } + } - snowflakeType, _ := bldr.Schema().Field(i).Metadata.GetValue("SNOWFLAKE_TYPE") - if snowflakeType == "timestamp_ntz" { - sec, nsec, err := extractTimestamp(col) + loc := gosnowflake.Location(int(offset) - 1440) + tt := time.Unix(sec, nsec).In(loc) + ts, err := arrow.TimestampFromTime(tt, arrow.Nanosecond) if err != nil { return nil, err } - - fb.Append(arrow.Timestamp(sec*1e9 + nsec)) + fb.Append(ts) break } - // "timestamp_tz" should be value + offset separated by space - tm := strings.Split(*col, " ") - if len(tm) != 2 { - return nil, adbc.Error{ - Msg: "invalid TIMESTAMP_TZ data. value doesn't consist of two numeric values separated by a space: " + *col, - SqlState: [5]byte{'2', '2', '0', '0', '7'}, - VendorCode: 268000, - Code: adbc.StatusInvalidData, - } - } - - sec, nsec, err := extractTimestamp(&tm[0]) + // otherwise timestamp_ntz or timestamp_ltz, which have the same physical representation + sec, nsec, err := extractTimestamp(col) if err != nil { return nil, err } - offset, err := strconv.ParseInt(tm[1], 10, 64) - if err != nil { - return nil, adbc.Error{ - Msg: "invalid TIMESTAMP_TZ data. offset value is not an integer: " + tm[1], - SqlState: [5]byte{'2', '2', '0', '0', '7'}, - VendorCode: 268000, - Code: adbc.StatusInvalidData, - } - } - loc := gosnowflake.Location(int(offset) - 1440) - tt := time.Unix(sec, nsec).In(loc) - ts, err := arrow.TimestampFromTime(tt, arrow.Nanosecond) - if err != nil { - return nil, err - } - fb.Append(ts) + fb.Append(arrow.Timestamp(sec*1e9 + nsec)) + case *array.BinaryBuilder: b, err := hex.DecodeString(*col) if err != nil { diff --git a/go/adbc/driver/snowflake/statement.go b/go/adbc/driver/snowflake/statement.go index edbaa826a4..61cb0b62e2 100644 --- a/go/adbc/driver/snowflake/statement.go +++ b/go/adbc/driver/snowflake/statement.go @@ -19,7 +19,6 @@ package snowflake import ( "context" - "database/sql/driver" "fmt" "strconv" "strings" @@ -29,12 +28,17 @@ import ( "github.com/apache/arrow/go/v15/arrow/array" "github.com/apache/arrow/go/v15/arrow/memory" "github.com/snowflakedb/gosnowflake" - "golang.org/x/exp/constraints" ) const ( - OptionStatementQueueSize = "adbc.rpc.result_queue_size" - OptionStatementPrefetchConcurrency = "adbc.snowflake.rpc.prefetch_concurrency" + OptionStatementQueueSize = "adbc.rpc.result_queue_size" + OptionStatementPrefetchConcurrency = "adbc.snowflake.rpc.prefetch_concurrency" + OptionStatementIngestWriterConcurrency = "adbc.snowflake.statement.ingest_writer_concurrency" + OptionStatementIngestUploadConcurrency = "adbc.snowflake.statement.ingest_upload_concurrency" + OptionStatementIngestCopyConcurrency = "adbc.snowflake.statement.ingest_copy_concurrency" + OptionStatementIngestTargetFileSize = "adbc.snowflake.statement.ingest_target_file_size" + OptionStatementIngestCompressionCodec = "adbc.snowflake.statement.ingest_compression_codec" // TODO(GH-1473): Implement option + OptionStatementIngestCompressionLevel = "adbc.snowflake.statement.ingest_compression_level" // TODO(GH-1473): Implement option ) type statement struct { @@ -44,9 +48,10 @@ type statement struct { prefetchConcurrency int useHighPrecision bool - query string - targetTable string - ingestMode string + query string + targetTable string + ingestMode string + ingestOptions *ingestOptions bound arrow.Record streamBind array.RecordReader @@ -143,6 +148,42 @@ func (st *statement) SetOption(key string, val string) error { } } return st.SetOptionInt(key, int64(concurrency)) + case OptionStatementIngestWriterConcurrency: + concurrency, err := strconv.Atoi(val) + if err != nil { + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] could not parse '%s' as int for option '%s'", val, key), + Code: adbc.StatusInvalidArgument, + } + } + return st.SetOptionInt(key, int64(concurrency)) + case OptionStatementIngestUploadConcurrency: + concurrency, err := strconv.Atoi(val) + if err != nil { + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] could not parse '%s' as int for option '%s'", val, key), + Code: adbc.StatusInvalidArgument, + } + } + return st.SetOptionInt(key, int64(concurrency)) + case OptionStatementIngestCopyConcurrency: + concurrency, err := strconv.Atoi(val) + if err != nil { + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] could not parse '%s' as int for option '%s'", val, key), + Code: adbc.StatusInvalidArgument, + } + } + return st.SetOptionInt(key, int64(concurrency)) + case OptionStatementIngestTargetFileSize: + size, err := strconv.Atoi(val) + if err != nil { + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] could not parse '%s' as int for option '%s'", val, key), + Code: adbc.StatusInvalidArgument, + } + } + return st.SetOptionInt(key, int64(size)) case OptionUseHighPrecision: switch val { case adbc.OptionValueEnabled: @@ -191,6 +232,52 @@ func (st *statement) SetOptionInt(key string, value int64) error { } st.prefetchConcurrency = int(value) return nil + case OptionStatementIngestWriterConcurrency: + if value < 0 { + return adbc.Error{ + Msg: fmt.Sprintf("invalid value ('%d') for option '%s', must be >= 0", value, key), + Code: adbc.StatusInvalidArgument, + } + } + if value == 0 { + st.ingestOptions.writerConcurrency = defaultWriterConcurrency + return nil + } + + st.ingestOptions.writerConcurrency = uint(value) + return nil + case OptionStatementIngestUploadConcurrency: + if value < 0 { + return adbc.Error{ + Msg: fmt.Sprintf("invalid value ('%d') for option '%s', must be >= 0", value, key), + Code: adbc.StatusInvalidArgument, + } + } + if value == 0 { + st.ingestOptions.uploadConcurrency = defaultUploadConcurrency + return nil + } + + st.ingestOptions.uploadConcurrency = uint(value) + return nil + case OptionStatementIngestCopyConcurrency: + if value < 0 { + return adbc.Error{ + Msg: fmt.Sprintf("invalid value ('%d') for option '%s', must be >= 0", value, key), + Code: adbc.StatusInvalidArgument, + } + } + st.ingestOptions.copyConcurrency = uint(value) + return nil + case OptionStatementIngestTargetFileSize: + if value < 0 { + return adbc.Error{ + Msg: fmt.Sprintf("invalid value ('%d') for option '%s', must be >= 0", value, key), + Code: adbc.StatusInvalidArgument, + } + } + st.ingestOptions.targetFileSize = uint(value) + return nil } return adbc.Error{ Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), @@ -253,7 +340,7 @@ func toSnowflakeType(dt arrow.DataType) string { if ts.TimeZone == "" { return fmt.Sprintf("timestamp_ntz(%d)", prec) } - return fmt.Sprintf("timestamp_tz(%d)", prec) + return fmt.Sprintf("timestamp_ltz(%d)", prec) case arrow.DENSE_UNION, arrow.SPARSE_UNION: return "variant" case arrow.LIST, arrow.LARGE_LIST, arrow.FIXED_SIZE_LIST: @@ -265,9 +352,9 @@ func toSnowflakeType(dt arrow.DataType) string { return "" } -func (st *statement) initIngest(ctx context.Context) (string, error) { +func (st *statement) initIngest(ctx context.Context) error { var ( - createBldr, insertBldr strings.Builder + createBldr strings.Builder ) createBldr.WriteString("CREATE TABLE ") @@ -277,10 +364,6 @@ func (st *statement) initIngest(ctx context.Context) (string, error) { createBldr.WriteString(st.targetTable) createBldr.WriteString(" (") - insertBldr.WriteString("INSERT INTO ") - insertBldr.WriteString(st.targetTable) - insertBldr.WriteString(" VALUES (") - var schema *arrow.Schema if st.bound != nil { schema = st.bound.Schema() @@ -290,7 +373,6 @@ func (st *statement) initIngest(ctx context.Context) (string, error) { for i, f := range schema.Fields() { if i != 0 { - insertBldr.WriteString(", ") createBldr.WriteString(", ") } @@ -298,7 +380,7 @@ func (st *statement) initIngest(ctx context.Context) (string, error) { createBldr.WriteString(" ") ty := toSnowflakeType(f.Type) if ty == "" { - return "", adbc.Error{ + return adbc.Error{ Msg: fmt.Sprintf("unimplemented type conversion for field %s, arrow type: %s", f.Name, f.Type), Code: adbc.StatusNotImplemented, } @@ -308,12 +390,9 @@ func (st *statement) initIngest(ctx context.Context) (string, error) { if !f.Nullable { createBldr.WriteString(" NOT NULL") } - - insertBldr.WriteString("?") } createBldr.WriteString(")") - insertBldr.WriteString(")") switch st.ingestMode { case adbc.OptionValueIngestModeAppend: @@ -322,7 +401,7 @@ func (st *statement) initIngest(ctx context.Context) (string, error) { replaceQuery := "DROP TABLE IF EXISTS " + st.targetTable _, err := st.cnxn.cn.ExecContext(ctx, replaceQuery, nil) if err != nil { - return "", errToAdbcErr(adbc.StatusInternal, err) + return errToAdbcErr(adbc.StatusInternal, err) } fallthrough @@ -335,124 +414,11 @@ func (st *statement) initIngest(ctx context.Context) (string, error) { createQuery := createBldr.String() _, err := st.cnxn.cn.ExecContext(ctx, createQuery, nil) if err != nil { - return "", errToAdbcErr(adbc.StatusInternal, err) - } - } - - return insertBldr.String(), nil -} - -type nativeArrowArr[T string | []byte] interface { - arrow.Array - Value(int) T -} - -func convToArr[T string | []byte](arr nativeArrowArr[T]) interface{} { - if arr.Len() == 1 { - if arr.IsNull(0) { - return nil - } - - return arr.Value(0) - } - - v := make([]interface{}, arr.Len()) - for i := 0; i < arr.Len(); i++ { - if arr.IsNull(i) { - continue - } - v[i] = arr.Value(i) - } - return gosnowflake.Array(&v) -} - -func convMarshal(arr arrow.Array) interface{} { - if arr.Len() == 0 { - if arr.IsNull(0) { - return nil - } - return arr.ValueStr(0) - } - - v := make([]interface{}, arr.Len()) - for i := 0; i < arr.Len(); i++ { - if arr.IsNull(i) { - continue - } - v[i] = arr.ValueStr(i) - } - return gosnowflake.Array(&v) -} - -// snowflake driver bindings only support specific types -// int/int32/int64/float64/float32/bool/string/byte/time -// so we have to cast anything else appropriately -func convToSlice[T, O constraints.Integer | constraints.Float](arr arrow.Array, vals []T) interface{} { - if arr.Len() == 1 { - if arr.IsNull(0) { - return nil - } - - return vals[0] - } - - out := make([]interface{}, arr.Len()) - for i, v := range vals { - if arr.IsNull(i) { - continue + return errToAdbcErr(adbc.StatusInternal, err) } - out[i] = O(v) } - return gosnowflake.Array(&out) -} -func getQueryArg(arr arrow.Array) interface{} { - switch arr := arr.(type) { - case *array.Int8: - v := arr.Int8Values() - return convToSlice[int8, int32](arr, v) - case *array.Uint8: - v := arr.Uint8Values() - return convToSlice[uint8, int32](arr, v) - case *array.Int16: - v := arr.Int16Values() - return convToSlice[int16, int32](arr, v) - case *array.Uint16: - v := arr.Uint16Values() - return convToSlice[uint16, int32](arr, v) - case *array.Int32: - v := arr.Int32Values() - return convToSlice[int32, int32](arr, v) - case *array.Uint32: - v := arr.Uint32Values() - return convToSlice[uint32, int64](arr, v) - case *array.Int64: - v := arr.Int64Values() - return convToSlice[int64, int64](arr, v) - case *array.Uint64: - v := arr.Uint64Values() - return convToSlice[uint64, int64](arr, v) - case *array.Float32: - v := arr.Float32Values() - return convToSlice[float32, float64](arr, v) - case *array.Float64: - v := arr.Float64Values() - return convToSlice[float64, float64](arr, v) - case *array.LargeBinary: - return convToArr[[]byte](arr) - case *array.Binary: - return convToArr[[]byte](arr) - case *array.LargeString: - return convToArr[string](arr) - case *array.String: - return convToArr[string](arr) - default: - // default convert to array of strings and pass to snowflake driver - // not the most efficient, but snowflake doesn't really give a better - // route currently short of writing everything out to a Parquet file - // and then uploading that (which might be preferable) - return convMarshal(arr) - } + return nil } func (st *statement) executeIngest(ctx context.Context) (int64, error) { @@ -463,60 +429,16 @@ func (st *statement) executeIngest(ctx context.Context) (int64, error) { } } - insertQuery, err := st.initIngest(ctx) + err := st.initIngest(ctx) if err != nil { return -1, err } - // if the ingestion is large enough it might make more sense to - // write this out to a temporary file / stage / etc. and use - // the snowflake bulk loader that way. - // - // on the other hand, according to the documentation, - // https://pkg.go.dev/github.com/snowflakedb/gosnowflake#hdr-Batch_Inserts_and_Binding_Parameters - // snowflake's internal driver work should already be doing this. - - var n int64 - exec := func(rec arrow.Record, args []driver.NamedValue) error { - for i, c := range rec.Columns() { - args[i].Ordinal = i - args[i].Value = getQueryArg(c) - } - - r, err := st.cnxn.cn.ExecContext(ctx, insertQuery, args) - if err != nil { - return errToAdbcErr(adbc.StatusInternal, err) - } - - rows, err := r.RowsAffected() - if err == nil { - n += rows - } - return nil - } - if st.bound != nil { - defer func() { - st.bound.Release() - st.bound = nil - }() - args := make([]driver.NamedValue, len(st.bound.Schema().Fields())) - return n, exec(st.bound, args) + return st.ingestRecord(ctx) } - defer func() { - st.streamBind.Release() - st.streamBind = nil - }() - args := make([]driver.NamedValue, len(st.streamBind.Schema().Fields())) - for st.streamBind.Next() { - rec := st.streamBind.Record() - if err := exec(rec, args); err != nil { - return n, err - } - } - - return n, nil + return st.ingestStream(ctx) } // ExecuteQuery executes the current query or prepared statement diff --git a/go/adbc/go.mod b/go/adbc/go.mod index 8db9d3cbbc..9d83b175b4 100644 --- a/go/adbc/go.mod +++ b/go/adbc/go.mod @@ -41,7 +41,9 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.1 // indirect github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.2.1 // indirect github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c // indirect + github.com/andybalholm/brotli v1.0.5 // indirect github.com/apache/arrow/go/v14 v14.0.2 // indirect + github.com/apache/thrift v0.17.0 // indirect github.com/aws/aws-sdk-go-v2 v1.24.1 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.5.4 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.16.16 // indirect @@ -63,13 +65,17 @@ require ( github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/goccy/go-json v0.10.2 // indirect github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2 // indirect + github.com/golang/snappy v0.0.4 // indirect github.com/google/flatbuffers v23.5.26+incompatible // indirect github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect + github.com/klauspost/asmfmt v1.3.2 // indirect github.com/klauspost/compress v1.17.4 // indirect github.com/klauspost/cpuid/v2 v2.2.6 // indirect github.com/mattn/go-isatty v0.0.19 // indirect + github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 // indirect + github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 // indirect github.com/mtibben/percent v0.2.1 // indirect github.com/pierrec/lz4/v4 v4.1.21 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect diff --git a/go/adbc/go.sum b/go/adbc/go.sum index 679e46fe23..627fb052ac 100644 --- a/go/adbc/go.sum +++ b/go/adbc/go.sum @@ -13,10 +13,14 @@ github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.2.1/go.mod h1:uwfk06ZBcv github.com/AzureAD/microsoft-authentication-library-for-go v1.1.1 h1:WpB/QDNLpMw72xHJc34BNNykqSOeEJDAWkhf0u12/Jk= github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c h1:RGWPOewvKIROun94nF7v2cua9qP+thov/7M50KEoeSU= github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c/go.mod h1:X0CRv0ky0k6m906ixxpzmDRLvX58TFUKS2eePweuyxk= +github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= +github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/apache/arrow/go/v14 v14.0.2 h1:N8OkaJEOfI3mEZt07BIkvo4sC6XDbL+48MBPWO5IONw= github.com/apache/arrow/go/v14 v14.0.2/go.mod h1:u3fgh3EdgN/YQ8cVQRguVW3R+seMybFg8QBQ5LU+eBY= github.com/apache/arrow/go/v15 v15.0.0-20240119162530-143a7da1038c h1:tGuAIZ7IXzDnfeuOcfFUufeVvXlRTWBGyyv2A8E5mnI= github.com/apache/arrow/go/v15 v15.0.0-20240119162530-143a7da1038c/go.mod h1:DGXsR3ajT524njufqf95822i+KTh+yea1jass9YXgjA= +github.com/apache/thrift v0.17.0 h1:cMd2aj52n+8VoAtvSvLn4kDC3aZ6IAkBuqWQ2IDu7wo= +github.com/apache/thrift v0.17.0/go.mod h1:OLxhMRJxomX+1I/KUw03qoV3mMz16BwaKI+d4fPBx7Q= github.com/aws/aws-sdk-go-v2 v1.24.1 h1:xAojnj+ktS95YZlDf0zxWBkbFtymPeDP+rvUQIH3uAU= github.com/aws/aws-sdk-go-v2 v1.24.1/go.mod h1:LNh45Br1YAkEKaAqvmE1m8FUx6a5b/V0oAKV7of29b4= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.5.4 h1:OCs21ST2LrepDfD3lwlQiOqIGp6JiEUqG84GzTDoyJs= @@ -73,6 +77,8 @@ github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/flatbuffers v23.5.26+incompatible h1:M9dgRyhJemaM4Sw8+66GHBu8ioaQmyPLg1b8VwK5WJg= github.com/google/flatbuffers v23.5.26+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= @@ -88,6 +94,8 @@ github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGw github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= +github.com/klauspost/asmfmt v1.3.2 h1:4Ri7ox3EwapiOjCki+hw14RyKk201CN4rzyCJRFLpK4= +github.com/klauspost/asmfmt v1.3.2/go.mod h1:AG8TuvYojzulgDAMCnYn50l/5QV3Bs/tp6j0HLHbNSE= github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc= @@ -100,6 +108,10 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0 github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= +github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs= +github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8/go.mod h1:mC1jAcsrzbxHt8iiaC+zU4b1ylILSosueou12R++wfY= +github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 h1:+n/aFZefKZp7spd8DFdX7uMikMLXX4oubIzJF4kv/wI= +github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3/go.mod h1:RagcQ7I8IeTMnF8JTXieKnO4Z6JCsikNEzj0DwauVzE= github.com/mtibben/percent v0.2.1 h1:5gssi8Nqo8QU/r2pynCm+hBQHpkB/uNK7BJCFogWdzs= github.com/mtibben/percent v0.2.1/go.mod h1:KG9uO+SZkUp+VkRHsCdYQV3XSZrrSpR3O9ibNBTZrns= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=