diff --git a/Cargo.lock b/Cargo.lock index 01a85e7..a0de25b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -410,14 +410,14 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.31" +version = "0.4.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f2c685bad3eb3d45a01354cedb7d5faa66194d1d58ba6e267a8de788f79db38" +checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" dependencies = [ "android-tzdata", "iana-time-zone", "num-traits", - "windows-targets", + "windows-targets 0.52.6", ] [[package]] @@ -1115,9 +1115,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.6.3" +version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f232d6ef707e1956a43342693d2a31e72989554d58299d7a88738cc95b0d35c" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "memoffset" @@ -1255,7 +1255,7 @@ dependencies = [ "libc", "redox_syscall", "smallvec", - "windows-targets", + "windows-targets 0.48.5", ] [[package]] @@ -1680,9 +1680,9 @@ dependencies = [ [[package]] name = "ryu" -version = "1.0.15" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" [[package]] name = "same-file" @@ -2114,7 +2114,7 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e686886bc078bc1b0b600cac0147aadb815089b6e4da64016cbd754b6342700f" dependencies = [ - "windows-targets", + "windows-targets 0.48.5", ] [[package]] @@ -2123,7 +2123,7 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" dependencies = [ - "windows-targets", + "windows-targets 0.48.5", ] [[package]] @@ -2132,13 +2132,29 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", ] [[package]] @@ -2147,42 +2163,90 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + [[package]] name = "windows_i686_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + [[package]] name = "windows_i686_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + [[package]] name = "zstd" version = "0.12.4" diff --git a/core/Cargo.toml b/core/Cargo.toml index 2e6d606..e5e0c2e 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -17,6 +17,7 @@ arrow-schema = ">=46.0.0" enum_dispatch = "0.3.11" anyhow = "1.0.70" thiserror = "1.0.40" +arrow = "^46.0.0" [dependencies.arrow-array] version = ">=46.0.0" diff --git a/core/src/buffer_view.rs b/core/src/buffer_view.rs new file mode 100644 index 0000000..0cfa583 --- /dev/null +++ b/core/src/buffer_view.rs @@ -0,0 +1,70 @@ +use std::fmt::Debug; +use crate::error::ErrorKind; + +pub(crate) struct BufferView<'a> { + inner: &'a [u8], + consumed: usize, +} + +impl Debug for BufferView<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", &self.inner[self.consumed..]) + } +} + +impl BufferView<'_> { + pub fn new(inner: &'_ [u8]) -> BufferView<'_> { + BufferView { inner, consumed: 0 } + } + + pub fn consume_into_u32(&mut self) -> Result { + if self.consumed + 4 > self.inner.len() { + return Err(ErrorKind::IncompleteData); + } + let res = u32::from_be_bytes( + self.inner[self.consumed..self.consumed + 4] + .try_into() + .unwrap(), + ); + self.consumed += 4; + Ok(res) + } + + pub fn consume_into_u16(&mut self) -> Result { + if self.consumed + 2 > self.inner.len() { + return Err(ErrorKind::IncompleteData); + } + let res = u16::from_be_bytes( + self.inner[self.consumed..self.consumed + 2] + .try_into() + .unwrap(), + ); + self.consumed += 2; + Ok(res) + } + + pub fn consume_into_vec_n(&mut self, n: usize) -> Result, ErrorKind> { + if self.consumed + n > self.inner.len() { + return Err(ErrorKind::IncompleteData); + } + let data = self.inner[self.consumed..self.consumed + n].to_vec(); + self.consumed += n; + if data.len() != n { + return Err(ErrorKind::IncompleteData); + } + Ok(data) + } + + pub fn remaining(&self) -> usize { + self.inner.len() - self.consumed + } + + pub fn consumed(&self) -> usize { + self.consumed + } + + pub fn swallow(&mut self, n: usize) { + self.consumed += n; + } +} + diff --git a/core/src/decoders.rs b/core/src/decoders.rs new file mode 100644 index 0000000..6e1a734 --- /dev/null +++ b/core/src/decoders.rs @@ -0,0 +1,984 @@ +#![allow(clippy::redundant_closure_call)] + +use std::sync::Arc; +use arrow::compute::concat; +use arrow::buffer::{OffsetBuffer, NullBuffer, BooleanBuffer}; +use arrow_schema::{DataType, Field}; +use arrow_array::builder::GenericByteBuilder; +use arrow_array::types::GenericBinaryType; +use arrow_array::{self, ArrayRef, Array}; +use arrow_array::{ + BooleanArray, Date32Array, DurationMicrosecondArray, Float32Array, Float64Array, + GenericStringArray, Int16Array, Int32Array, Int64Array, Time64MicrosecondArray, + TimestampMicrosecondArray, NullArray, GenericListArray +}; + +use crate::encoders::{PG_BASE_DATE_OFFSET, PG_BASE_TIMESTAMP_OFFSET_US}; +use crate::error::ErrorKind; +use crate::buffer_view::BufferView; +use crate::pg_schema::{PostgresSchema, PostgresType, Column}; + + +/// Trait defining the methods needed to decode a Postgres type into an Arrow array +pub(crate) trait Decoder { + fn decode(&mut self, buf: &mut BufferView) -> Result<(), ErrorKind>; + fn finish(&mut self, column_len: usize) -> ArrayRef; + fn column_len(&self) -> usize; + fn name(&self) -> String; + fn is_null(&self) -> bool; +} + + +macro_rules! impl_decode { + ($struct_name:ident, $size:expr, $transform:expr, $array_kind:ident) => { + impl Decoder for $struct_name { + fn decode(&mut self, buf: &mut BufferView<'_>) -> Result<(), ErrorKind> { + let field_size = buf.consume_into_u32()?; + if field_size == u32::MAX { + self.arr.push(None); + return Ok(()); + } + if field_size != $size { + return Err(ErrorKind::IncompleteData); + } + + let data = buf.consume_into_vec_n(field_size as usize)?; + // + // Unwrap is safe here because have checked the field size is the expected size + // above + let value = $transform(data.try_into().unwrap()); + self.arr.push(Some(value)); + + Ok(()) + } + + fn column_len(&self) -> usize { + self.arr.len() + } + + fn name(&self) -> String { + self.name.to_string() + } + + fn finish(&mut self, column_len: usize) -> ArrayRef { + self.arr.resize(column_len, None); + if self.is_null() { + return Arc::new(NullArray::new(self.arr.len())) as ArrayRef; + } + let data = std::mem::take(&mut self.arr); + let array = Arc::new($array_kind::from(data)); + array as ArrayRef + } + + fn is_null(&self) -> bool { + self.arr.iter().all(|v| v.is_none()) + } + } + }; +} + +macro_rules! impl_decode_fallible { + ($struct_name:ident, $size:expr, $transform:expr, $array_kind:ident) => { + impl Decoder for $struct_name { + fn decode(&mut self, buf: &mut BufferView<'_>) -> Result<(), ErrorKind> { + let field_size = buf.consume_into_u32()?; + + if field_size == u32::MAX { + self.arr.push(None); + return Ok(()); + } + + if field_size != $size { + return Err(ErrorKind::IncompleteData); + } + + let data = buf.consume_into_vec_n(field_size as usize)?; + + // Unwrap is safe here because have checked the field size is the expected size + // above + match $transform(data.try_into().unwrap()) { + Ok(v) => self.arr.push(Some(v)), + Err(e) => { + // If the error is a decode error, return a decode error with the name of the field + return match e { + ErrorKind::Decode { reason, .. } => { + return Err(ErrorKind::Decode { + reason, + name: self.name.to_string(), + }) + } + _ => Err(e), + }; + } + }; + + Ok(()) + } + + fn column_len(&self) -> usize { + self.arr.len() + } + + fn name(&self) -> String { + self.name.to_string() + } + + fn finish(&mut self, column_len: usize) -> ArrayRef { + self.arr.resize(column_len, None); + if self.is_null() { + return Arc::new(NullArray::new(self.arr.len())) as ArrayRef; + } + let data = std::mem::take(&mut self.arr); + let array = Arc::new($array_kind::from(data)); + array as ArrayRef + } + + fn is_null(&self) -> bool { + self.arr.iter().all(|v| v.is_none()) + } + } + }; +} + +macro_rules! impl_decode_variable_size { + ($struct_name:ident, $transform:expr, $extra_bytes:expr, $array_kind:ident, $offset_size:ident) => { + impl Decoder for $struct_name { + fn decode(&mut self, buf: &mut BufferView<'_>) -> Result<(), ErrorKind> { + let field_size = buf.consume_into_u32()?; + if field_size == u32::MAX { + self.arr.push(None); + return Ok(()); + } + + if field_size > buf.remaining() as u32 { + return Err(ErrorKind::IncompleteData); + } + + // Consume and any extra data that is not part of the field + // This is needed for example on jsonb fields where the first + // byte is the version number. This is more efficient than + // using remove in the transform function or creating a new + // vec from a slice of the data passed into it. + buf.swallow($extra_bytes); + + let data = buf.consume_into_vec_n(field_size as usize)?; + match $transform(data.try_into().unwrap()) { + Ok(v) => self.arr.push(Some(v)), + Err(e) => { + // If the error is a decode error, return a decode error with the name of the field + return match e { + ErrorKind::Decode { reason, .. } => { + return Err(ErrorKind::Decode { + reason, + name: self.name.to_string(), + }) + } + _ => Err(e), + }; + } + }; + + Ok(()) + } + + fn column_len(&self) -> usize { + self.arr.len() + } + + fn name(&self) -> String { + self.name.to_string() + } + + fn finish(&mut self, column_len: usize) -> ArrayRef { + self.arr.resize(column_len, None); + if self.is_null() { + return Arc::new(NullArray::new(self.arr.len())) as ArrayRef; + } + let data = std::mem::take(&mut self.arr); + let array = Arc::new($array_kind::<$offset_size>::from(data)); + array as ArrayRef + } + + fn is_null(&self) -> bool { + self.arr.iter().all(|v| v.is_none()) + } + } + }; +} + +#[allow(dead_code)] +#[derive(Debug)] +pub struct BooleanDecoder { + name: String, + arr: Vec>, +} + +impl_decode!(BooleanDecoder, 1, |b: [u8; 1]| b[0] == 1, BooleanArray); + +#[derive(Debug)] +pub struct Int16Decoder { + name: String, + arr: Vec>, +} + +impl_decode!(Int16Decoder, 2, i16::from_be_bytes, Int16Array); + +#[derive(Debug)] +pub struct Int32Decoder { + name: String, + arr: Vec>, +} + +impl_decode!(Int32Decoder, 4, i32::from_be_bytes, Int32Array); + +#[derive(Debug)] +pub struct Int64Decoder { + name: String, + arr: Vec>, +} + +impl_decode!(Int64Decoder, 8, i64::from_be_bytes, Int64Array); + +#[derive(Debug)] +pub struct Float32Decoder { + name: String, + arr: Vec>, +} + +impl_decode!(Float32Decoder, 4, f32::from_be_bytes, Float32Array); + +#[derive(Debug)] +pub struct Float64Decoder { + name: String, + arr: Vec>, +} + +impl_decode!(Float64Decoder, 8, f64::from_be_bytes, Float64Array); + +/// Convert Postgres timestamps (microseconds since 2000-01-01) to Arrow timestamps (mircroseconds since 1970-01-01) +#[inline(always)] +fn convert_pg_timestamp_to_arrow_timestamp_microseconds( + timestamp_us: i64, +) -> Result { + // adjust the timestamp from microseconds since 2000-01-01 to microseconds since 1970-01-01 checking for overflows and underflow + timestamp_us + .checked_add(PG_BASE_TIMESTAMP_OFFSET_US) + .ok_or_else(|| ErrorKind::Decode { + reason: "Overflow converting microseconds since 2000-01-01 (Postgres) to microseconds since 1970-01-01 (Arrow)".to_string(), + name: "".to_string(), + }) +} + +#[derive(Debug)] +pub struct TimestampMicrosecondDecoder { + name: String, + arr: Vec>, +} + +impl_decode_fallible!( + TimestampMicrosecondDecoder, + 8, + |b| { + let timestamp_us = i64::from_be_bytes(b); + convert_pg_timestamp_to_arrow_timestamp_microseconds(timestamp_us) + }, + TimestampMicrosecondArray +); + +#[derive(Debug)] +pub struct TimestampTzMicrosecondDecoder { + name: String, + arr: Vec>, + timezone: String, +} + +impl Decoder for TimestampTzMicrosecondDecoder { + fn decode(&mut self, buf: &mut BufferView<'_>) -> Result<(), ErrorKind> { + let field_size = buf.consume_into_u32()?; + if field_size == u32::MAX { + self.arr.push(None); + return Ok(()); + } + + if field_size != 8 { + return Err(ErrorKind::IncompleteData); + } + + let data = buf.consume_into_vec_n(field_size as usize)?; + let timestamp_us = i64::from_be_bytes(data.try_into().unwrap()); + let timestamp_us = convert_pg_timestamp_to_arrow_timestamp_microseconds(timestamp_us)?; + self.arr.push(Some(timestamp_us)); + + Ok(()) + } + + fn column_len(&self) -> usize { + self.arr.len() + } + + fn name(&self) -> String { + self.name.to_string() + } + + fn finish(&mut self, column_len: usize) -> ArrayRef { + self.arr.resize(column_len, None); + if self.is_null() { + return Arc::new(NullArray::new(self.arr.len())) as ArrayRef; + } + let data = std::mem::take(&mut self.arr); + let array = Arc::new( + TimestampMicrosecondArray::from(data).with_timezone(self.timezone.to_string()), + ); + array as ArrayRef + } + + fn is_null(&self) -> bool { + self.arr.iter().all(|v| v.is_none()) + } +} + +/// Convert Postgres dates (days since 2000-01-01) to Arrow dates (days since 1970-01-01) +#[inline(always)] +fn convert_pg_date_to_arrow_date(date: i32) -> Result { + date.checked_add(PG_BASE_DATE_OFFSET).ok_or_else(|| ErrorKind::Decode { + reason: "Overflow converting days since 2000-01-01 (Postgres) to days since 1970-01-01 (Arrow)".to_string(), + name: "".to_string(), + }) +} + +#[derive(Debug)] +pub struct Date32Decoder { + name: String, + arr: Vec>, +} + +impl_decode_fallible!( + Date32Decoder, + 4, + |b| { + let date = i32::from_be_bytes(b); + convert_pg_date_to_arrow_date(date) + }, + Date32Array +); + +#[derive(Debug)] +pub struct Time64MicrosecondDecoder { + name: String, + arr: Vec>, +} + +impl_decode!( + Time64MicrosecondDecoder, + 8, + i64::from_be_bytes, + Time64MicrosecondArray +); + +/// Convert Postgres durations to Arrow durations (microseconds) +fn convert_pg_duration_to_arrow_duration( + duration_us: i64, + duration_days: i32, + duration_months: i32, +) -> Result { + let days = (duration_days as i64) + .checked_mul(24 * 60 * 60 * 1_000_000) + .ok_or_else(|| ErrorKind::Decode { + reason: "Overflow converting days to microseconds".to_string(), + name: "".to_string(), + })?; + let months = (duration_months as i64) + .checked_mul(30 * 24 * 60 * 60 * 1_000_000) + .ok_or_else(|| ErrorKind::Decode { + reason: "Overflow converting months to microseconds".to_string(), + name: "".to_string(), + })?; + + duration_us + .checked_add(days) + .ok_or_else(|| ErrorKind::Decode { + reason: "Overflow adding days in microseconds to duration microseconds".to_string(), + name: "".to_string(), + }) + .map(|v| { + v.checked_add(months).ok_or_else(|| ErrorKind::Decode { + reason: "Overflow adding months in microseconds to duration microseconds" + .to_string(), + name: "".to_string(), + }) + })? +} + +#[derive(Debug)] +pub struct DurationMicrosecondDecoder { + name: String, + arr: Vec>, +} + +impl_decode_fallible!( + DurationMicrosecondDecoder, + 16, + |b: [u8; 16]| { + // Unwrap here since we know the exact size of the array we are passing + let duration_us = i64::from_be_bytes(b[..8].try_into().unwrap()); + let duration_days = i32::from_be_bytes(b[8..12].try_into().unwrap()); + let duration_months = i32::from_be_bytes(b[12..16].try_into().unwrap()); + convert_pg_duration_to_arrow_duration(duration_us, duration_days, duration_months) + }, + DurationMicrosecondArray +); + +#[derive(Debug)] +pub struct StringDecoder { + name: String, + arr: Vec>, +} + +impl_decode_variable_size!( + StringDecoder, + |b: Vec| { + String::from_utf8(b).map_err(|_| ErrorKind::Decode { + reason: "Invalid UTF-8 string".to_string(), + name: "".to_string(), + }) + }, + 0, + GenericStringArray, + i32 +); + +#[derive(Debug)] +pub struct BinaryDecoder { + name: String, + arr: Vec>>, +} + +impl Decoder for BinaryDecoder { + fn decode(&mut self, buf: &mut BufferView<'_>) -> Result<(), ErrorKind> { + let field_size = buf.consume_into_u32()?; + if field_size == u32::MAX { + self.arr.push(None); + return Ok(()); + } + + let data = buf.consume_into_vec_n(field_size as usize)?; + self.arr.push(Some(data)); + + Ok(()) + } + + fn column_len(&self) -> usize { + self.arr.len() + } + + fn name(&self) -> String { + self.name.to_string() + } + + fn finish(&mut self, column_len: usize) -> ArrayRef { + self.arr.resize(column_len, None); + if self.is_null() { + return Arc::new(NullArray::new(self.arr.len())) as ArrayRef; + } + + let data = std::mem::take(&mut self.arr); + let mut builder: GenericByteBuilder> = GenericByteBuilder::new(); + for v in data { + match v { + Some(v) => builder.append_value(v), + None => builder.append_null(), + } + } + Arc::new(builder.finish()) as ArrayRef + } + + fn is_null(&self) -> bool { + self.arr.iter().all(|v| v.is_none()) + } +} + +#[derive(Debug)] +pub struct JsonbDecoder { + name: String, + arr: Vec>, +} + +impl_decode_variable_size!( + JsonbDecoder, + |b: Vec| { + String::from_utf8(b).map_err(|_| ErrorKind::Decode { + reason: "Invalid UTF-8 string".to_string(), + name: "".to_string(), + }) + }, + // Remove the first byte which is the version number + // https://www.postgresql.org/docs/13/datatype-json.html + 1, + GenericStringArray, + i64 +); + +// const used for stringifying postgres decimals +const DEC_DIGITS: i16 = 4; +// const used for determining sign of numeric +const NUMERIC_NEG: i16 = 0x4000; + +/// Parse a Postgres numeric type in binary format into a string +fn parse_pg_decimal_to_string(data: Vec) -> Result { + // Logic ported from src/backend/utils/adt/numeric.c:get_str_from_var + // Decimals will be decoded to strings since rust does not have a ubiquitos + // decimal type and arrow does not implment `From` for any of them. Arrow + // does have a Decimal128 array but its only accepts i128s as input + // TODO: Seems like there could be a fast path here for simpler numbers + let ndigits = i16::from_be_bytes(data[0..2].try_into().unwrap()); + let weight = i16::from_be_bytes(data[2..4].try_into().unwrap()); + let sign = i16::from_be_bytes(data[4..6].try_into().unwrap()); + let scale = i16::from_be_bytes(data[6..8].try_into().unwrap()); + let digits: Vec = data[8..8 + (ndigits as usize) * (std::mem::size_of::())] + .chunks(2) + .map(|c| i16::from_be_bytes(c.try_into().unwrap())) + .collect(); + + // the number of digits before the decimal place + let pre_decimal = (weight + 1) * DEC_DIGITS; + + // scale is the number digits after the decimal place. + // 2 is for a possible sign and decimal point + let str_len: usize = (pre_decimal + DEC_DIGITS + 2 + scale) as usize; + + // -1 because we dont need to account for the null terminator + let mut decimal: Vec = Vec::with_capacity(str_len - 1); + + // put a negative sign if the numeric is negative + if sign == NUMERIC_NEG { + decimal.push(b'-'); + } + + let mut digits_index = 0; + // If weight is less than 0 we have a fractional number. + // Put a 0 before the decimal. + if weight < 0 { + decimal.push(b'0'); + // Otherwise put digits in the decimal string by computing the value for each place in decimal + } else { + while digits_index <= weight { + let mut dig = if digits_index < ndigits { + digits[digits_index as usize] + } else { + 0 + }; + let mut putit = digits_index > 0; + + /* below unwraps too: + d1 = dig / 1000; + dig -= d1 * 1000; + putit |= (d1 > 0); + if (putit) + *cp++ = d1 + '0'; + d1 = dig / 100; + dig -= d1 * 100; + putit |= (d1 > 0); + if (putit) + *cp++ = d1 + '0'; + d1 = dig / 10; + dig -= d1 * 10; + putit |= (d1 > 0); + if (putit) + *cp++ = d1 + '0'; + *cp++ = dig + '0'; + */ + + let mut place = 1000; + while place > 1 { + let d1 = dig / place; + dig -= d1 * place; + putit |= d1 > 0; + if putit { + decimal.push(d1 as u8 + b'0'); + } + place /= 10; + } + decimal.push(dig as u8 + b'0'); + digits_index += 1; + } + + } + // If scale is > 0 we have digits after the decimal point + if scale > 0 { + decimal.push(b'.'); + } + while digits_index < ndigits { + let mut dig = if digits_index >= 0 && digits_index < ndigits { + digits[digits_index as usize] + } else { + 0 + }; + let mut place = 1000; + // Same as the loop above but no putit since all digits prior to the + // scale-TH digit are significant + while place > 1 { + let d1 = dig / place; + dig -= d1 * place; + decimal.push(d1 as u8 + b'0'); + place /= 10; + } + decimal.push(dig as u8 + b'0'); + digits_index += 1; + } + + // trim trailing zeros and return the string + Ok(truncate_and_finalize(decimal, scale)) +} + +/// truncate any trailing zeros on end of the string if they not significant. +/// If the number ends in a decimal point, add a zero +fn truncate_and_finalize(mut v: Vec, scale: i16) -> String { + // if there is a decimal point do some general cleanup + let decimal_point_idx = v.iter().position(|&c| c == b'.'); + match decimal_point_idx { + Some(idx) => { + // If the number ends in a decimal point, add a zero + if idx == v.len() - 1 { + v.push(b'0') + // Strip any trailing zeros after the decimal point + } else { + while v.len() - idx - 1 > scale as usize { + v.pop(); + } + } + }, + None => {} + }; + String::from_utf8(v).unwrap() +} + +#[derive(Debug)] +pub struct NumericDecoder { + name: String, + arr: Vec>, +} + +impl_decode_variable_size!( + NumericDecoder, + parse_pg_decimal_to_string, + 0, + GenericStringArray, + i32 +); + +#[derive(Debug)] +pub struct ArrayDecoder { + name: String, + arr: Vec, + inner: Box, +} + + +impl Decoder for ArrayDecoder { + fn decode(&mut self, buf: &mut BufferView<'_>) -> Result<(), ErrorKind> { + let field_size = buf.consume_into_u32()?; + if field_size == u32::MAX { + self.arr.push(Arc::new(NullArray::new(0))); + return Ok(()); + } + + let ndim = buf.consume_into_u32()?; + let _flags = buf.consume_into_u32()?; + let _elemtype = buf.consume_into_u32()?; + + let mut dims = vec![]; + for _ in 0..ndim { + let dim = buf.consume_into_u32()?; + dims.push(dim); + // consume the lbound which we dont need + buf.consume_into_u32()?; + } + + let nitems = dims.iter().product::() as usize; + + for _ in 0..nitems { + self.inner.decode(buf)?; + } + let array = self.inner.finish(nitems); + self.arr.push(array); + + Ok(()) + } + + fn finish(&mut self, _column_len: usize) -> ArrayRef { + // Check if all the arrays are null and return a null array if so + if self.is_null() { + return Arc::new(NullArray::new(0)) as ArrayRef; + } + + let arrays = std::mem::take(&mut self.arr); + + // Build the offset buffer for the commbined ListArray using + // the lengths of the component arrays + let mut offset_values = vec![0 as i32]; + for i in 1..arrays.len() + 1 { + offset_values.push(arrays[i - 1].len() as i32 + offset_values[i - 1]); + } + let offsets = OffsetBuffer::new(offset_values.into()); + + // Concatenate the data of the component arrays + // to create the child data of the ListArray + let array_refs: Vec<& dyn Array> = arrays.iter().filter( + |a| !matches!(a.data_type(), DataType::Null) + ).map(|a| a.as_ref()).collect(); + let child_data = concat(&array_refs).unwrap(); + + // Calculate the null buffer for the ListArray + // by checking if the component arrays are null + let null_values = arrays.iter().map(|a| { + match a.data_type() { + DataType::Null => false, + _ => true, + } + }).collect::>(); + + // If there are no nulls, return None for the null buffer + let nulls = if null_values.iter().all(|v| *v) { + None + } else { + Some(NullBuffer::from(BooleanBuffer::from(null_values))) + }; + + // Construct the ListArray from parts + Arc::new(GenericListArray::new( + Arc::new(Field::new(self.name().replace("list_", ""), arrays[0].data_type().clone(), true)), + offsets, + child_data, + nulls, + )) as ArrayRef + } + + fn is_null(&self) -> bool { + self.arr.iter().all(|a| matches!(a.data_type(), DataType::Null)) + } + + fn column_len(&self) -> usize { + self.arr.len() + } + + fn name(&self) -> String { + self.name.to_string() + } +} + + +// +#[derive(Debug)] +pub enum PostgresDecoder { + Boolean(BooleanDecoder), + Int16(Int16Decoder), + Int32(Int32Decoder), + Int64(Int64Decoder), + Float32(Float32Decoder), + Float64(Float64Decoder), + Numeric(NumericDecoder), + TimestampMicrosecond(TimestampMicrosecondDecoder), + TimestampTzMicrosecond(TimestampTzMicrosecondDecoder), + Date32(Date32Decoder), + Time64Microsecond(Time64MicrosecondDecoder), + DurationMicrosecond(DurationMicrosecondDecoder), + String(StringDecoder), + Binary(BinaryDecoder), + Jsonb(JsonbDecoder), + List(ArrayDecoder), +} + + + + +pub(crate) fn create_decoders(schema: &PostgresSchema) -> Vec { + schema.iter().map(|(name, column)| PostgresDecoder::new(name, column)).collect() +} + +impl PostgresDecoder { + pub fn new(name: &str, column: &Column) -> Self { + match column.data_type { + PostgresType::Bool => PostgresDecoder::Boolean(BooleanDecoder { + name: name.to_string(), + arr: vec![], + }), + PostgresType::Int2 => PostgresDecoder::Int16(Int16Decoder { + name: name.to_string(), + arr: vec![], + }), + PostgresType::Int4 => PostgresDecoder::Int32(Int32Decoder { + name: name.to_string(), + arr: vec![], + }), + PostgresType::Int8 => PostgresDecoder::Int64(Int64Decoder { + name: name.to_string(), + arr: vec![], + }), + PostgresType::Float4 => PostgresDecoder::Float32(Float32Decoder { + name: name.to_string(), + arr: vec![], + }), + PostgresType::Float8 => PostgresDecoder::Float64(Float64Decoder { + name: name.to_string(), + arr: vec![], + }), + PostgresType::Numeric => PostgresDecoder::Numeric(NumericDecoder { + name: name.to_string(), + arr: vec![], + }), + PostgresType::Timestamp => { + PostgresDecoder::TimestampMicrosecond(TimestampMicrosecondDecoder { + name: name.to_string(), + arr: vec![], + }) + } + PostgresType::TimestampTz(ref timezone) => { + PostgresDecoder::TimestampTzMicrosecond(TimestampTzMicrosecondDecoder { + name: name.to_string(), + arr: vec![], + timezone: timezone.to_string(), + }) + } + PostgresType::Date => PostgresDecoder::Date32(Date32Decoder { + name: name.to_string(), + arr: vec![], + }), + PostgresType::Time => PostgresDecoder::Time64Microsecond(Time64MicrosecondDecoder { + name: name.to_string(), + arr: vec![], + }), + PostgresType::Interval => { + PostgresDecoder::DurationMicrosecond(DurationMicrosecondDecoder { + name: name.to_string(), + arr: vec![], + }) + } + PostgresType::Text | PostgresType::Char | PostgresType::Json => { + PostgresDecoder::String(StringDecoder { + name: name.to_string(), + arr: vec![], + }) + } + PostgresType::Bytea => PostgresDecoder::Binary(BinaryDecoder { + name: name.to_string(), + arr: vec![], + }), + PostgresType::List(ref inner) => { + let (name, column) = inner; + let inner_decoder = Box::new(PostgresDecoder::new(name, column)); + PostgresDecoder::List(ArrayDecoder { + name: name.to_string(), + arr: vec![], + inner: inner_decoder, + }) + } + _ => unimplemented!(), + } + } + + pub(crate) fn decode(&mut self, buf: &mut BufferView) -> Result<(), ErrorKind> { + match *self { + PostgresDecoder::Boolean(ref mut decoder) => decoder.decode(buf), + PostgresDecoder::Int16(ref mut decoder) => decoder.decode(buf), + PostgresDecoder::Int32(ref mut decoder) => decoder.decode(buf), + PostgresDecoder::Int64(ref mut decoder) => decoder.decode(buf), + PostgresDecoder::Float32(ref mut decoder) => decoder.decode(buf), + PostgresDecoder::Float64(ref mut decoder) => decoder.decode(buf), + PostgresDecoder::Numeric(ref mut decoder) => decoder.decode(buf), + PostgresDecoder::TimestampMicrosecond(ref mut decoder) => decoder.decode(buf), + PostgresDecoder::TimestampTzMicrosecond(ref mut decoder) => decoder.decode(buf), + PostgresDecoder::Date32(ref mut decoder) => decoder.decode(buf), + PostgresDecoder::Time64Microsecond(ref mut decoder) => decoder.decode(buf), + PostgresDecoder::DurationMicrosecond(ref mut decoder) => decoder.decode(buf), + PostgresDecoder::String(ref mut decoder) => decoder.decode(buf), + PostgresDecoder::Binary(ref mut decoder) => decoder.decode(buf), + PostgresDecoder::Jsonb(ref mut decoder) => decoder.decode(buf), + PostgresDecoder::List(ref mut decoder) => decoder.decode(buf), + } + } + + #[allow(dead_code)] + pub(crate) fn name(&self) -> String { + match *self { + PostgresDecoder::Boolean(ref decoder) => decoder.name(), + PostgresDecoder::Int16(ref decoder) => decoder.name(), + PostgresDecoder::Int32(ref decoder) => decoder.name(), + PostgresDecoder::Int64(ref decoder) => decoder.name(), + PostgresDecoder::Float32(ref decoder) => decoder.name(), + PostgresDecoder::Float64(ref decoder) => decoder.name(), + PostgresDecoder::Numeric(ref decoder) => decoder.name(), + PostgresDecoder::TimestampMicrosecond(ref decoder) => decoder.name(), + PostgresDecoder::TimestampTzMicrosecond(ref decoder) => decoder.name(), + PostgresDecoder::Date32(ref decoder) => decoder.name(), + PostgresDecoder::Time64Microsecond(ref decoder) => decoder.name(), + PostgresDecoder::DurationMicrosecond(ref decoder) => decoder.name(), + PostgresDecoder::String(ref decoder) => decoder.name(), + PostgresDecoder::Binary(ref decoder) => decoder.name(), + PostgresDecoder::Jsonb(ref decoder) => decoder.name(), + PostgresDecoder::List(ref decoder) => decoder.name(), + } + } + + pub(crate) fn column_len(&self) -> usize { + match *self { + PostgresDecoder::Boolean(ref decoder) => decoder.column_len(), + PostgresDecoder::Int16(ref decoder) => decoder.column_len(), + PostgresDecoder::Int32(ref decoder) => decoder.column_len(), + PostgresDecoder::Int64(ref decoder) => decoder.column_len(), + PostgresDecoder::Float32(ref decoder) => decoder.column_len(), + PostgresDecoder::Float64(ref decoder) => decoder.column_len(), + PostgresDecoder::Numeric(ref decoder) => decoder.column_len(), + PostgresDecoder::TimestampMicrosecond(ref decoder) => decoder.column_len(), + PostgresDecoder::TimestampTzMicrosecond(ref decoder) => decoder.column_len(), + PostgresDecoder::Date32(ref decoder) => decoder.column_len(), + PostgresDecoder::Time64Microsecond(ref decoder) => decoder.column_len(), + PostgresDecoder::DurationMicrosecond(ref decoder) => decoder.column_len(), + PostgresDecoder::String(ref decoder) => decoder.column_len(), + PostgresDecoder::Binary(ref decoder) => decoder.column_len(), + PostgresDecoder::Jsonb(ref decoder) => decoder.column_len(), + PostgresDecoder::List(ref decoder) => decoder.column_len(), + } + } + + pub(crate) fn finish(&mut self, column_len: usize) -> ArrayRef { + match *self { + PostgresDecoder::Boolean(ref mut decoder) => decoder.finish(column_len), + PostgresDecoder::Int16(ref mut decoder) => decoder.finish(column_len), + PostgresDecoder::Int32(ref mut decoder) => decoder.finish(column_len), + PostgresDecoder::Int64(ref mut decoder) => decoder.finish(column_len), + PostgresDecoder::Float32(ref mut decoder) => decoder.finish(column_len), + PostgresDecoder::Float64(ref mut decoder) => decoder.finish(column_len), + PostgresDecoder::Numeric(ref mut decoder) => decoder.finish(column_len), + PostgresDecoder::TimestampMicrosecond(ref mut decoder) => decoder.finish(column_len), + PostgresDecoder::TimestampTzMicrosecond(ref mut decoder) => decoder.finish(column_len), + PostgresDecoder::Date32(ref mut decoder) => decoder.finish(column_len), + PostgresDecoder::Time64Microsecond(ref mut decoder) => decoder.finish(column_len), + PostgresDecoder::DurationMicrosecond(ref mut decoder) => decoder.finish(column_len), + PostgresDecoder::String(ref mut decoder) => decoder.finish(column_len), + PostgresDecoder::Binary(ref mut decoder) => decoder.finish(column_len), + PostgresDecoder::Jsonb(ref mut decoder) => decoder.finish(column_len), + PostgresDecoder::List(ref mut decoder) => decoder.finish(column_len), + } + } + + pub(crate) fn is_null(&self) -> bool { + match *self { + PostgresDecoder::Boolean(ref decoder) => decoder.is_null(), + PostgresDecoder::Int16(ref decoder) => decoder.is_null(), + PostgresDecoder::Int32(ref decoder) => decoder.is_null(), + PostgresDecoder::Int64(ref decoder) => decoder.is_null(), + PostgresDecoder::Float32(ref decoder) => decoder.is_null(), + PostgresDecoder::Float64(ref decoder) => decoder.is_null(), + PostgresDecoder::Numeric(ref decoder) => decoder.is_null(), + PostgresDecoder::TimestampMicrosecond(ref decoder) => decoder.is_null(), + PostgresDecoder::TimestampTzMicrosecond(ref decoder) => decoder.is_null(), + PostgresDecoder::Date32(ref decoder) => decoder.is_null(), + PostgresDecoder::Time64Microsecond(ref decoder) => decoder.is_null(), + PostgresDecoder::DurationMicrosecond(ref decoder) => decoder.is_null(), + PostgresDecoder::String(ref decoder) => decoder.is_null(), + PostgresDecoder::Binary(ref decoder) => decoder.is_null(), + PostgresDecoder::Jsonb(ref decoder) => decoder.is_null(), + PostgresDecoder::List(ref decoder) => decoder.is_null(), + } + } +} diff --git a/core/src/encoders.rs b/core/src/encoders.rs index b5bfa89..f89d715 100644 --- a/core/src/encoders.rs +++ b/core/src/encoders.rs @@ -10,7 +10,7 @@ use crate::error::ErrorKind; use crate::pg_schema::{Column, PostgresType, TypeSize}; #[inline] -fn downcast_checked<'a, T: 'static>(arr: &'a dyn Array, field: &str) -> Result<&'a T, ErrorKind> { +pub(crate) fn downcast_checked<'a, T: 'static>(arr: &'a dyn Array, field: &str) -> Result<&'a T, ErrorKind> { match arr.as_any().downcast_ref::() { Some(v) => Ok(v), None => Err(ErrorKind::mismatched_column_type( @@ -234,7 +234,7 @@ impl_encode!( BufMut::put_f64 ); -const PG_BASE_TIMESTAMP_OFFSET_US: i64 = 946_684_800_000_000; // microseconds between 2000-01-01 at midnight (Postgres's epoch) and 1970-01-01 (Arrow's / UNIX epoch) +pub(crate) const PG_BASE_TIMESTAMP_OFFSET_US: i64 = 946_684_800_000_000; // microseconds between 2000-01-01 at midnight (Postgres's epoch) and 1970-01-01 (Arrow's / UNIX epoch) const PG_BASE_TIMESTAMP_OFFSET_MS: i64 = 946_684_800_000; // milliseconds between 2000-01-01 at midnight (Postgres's epoch) and 1970-01-01 (Arrow's / UNIX epoch) const PG_BASE_TIMESTAMP_OFFSET_S: i64 = 946_684_800; // seconds between 2000-01-01 at midnight (Postgres's epoch) and 1970-01-01 (Arrow's / UNIX epoch) @@ -320,7 +320,7 @@ impl_encode_fallible!( BufMut::put_i64 ); -const PG_BASE_DATE_OFFSET: i32 = 10_957; // Number of days between PostgreSQL's epoch (2000-01-01) and Arrow's / UNIX epoch (1970-01-01) +pub(crate) const PG_BASE_DATE_OFFSET: i32 = 10_957; // Number of days between PostgreSQL's epoch (2000-01-01) and Arrow's / UNIX epoch (1970-01-01) #[inline(always)] fn convert_arrow_date32_to_postgres_date(_field: &str, date: i32) -> Result { @@ -1123,9 +1123,9 @@ macro_rules! impl_list_encoder_builder { } fn schema(&self) -> Column { Column { - data_type: PostgresType::List(Box::new( + data_type: PostgresType::List(("".to_string(), Box::new( self.inner_encoder_builder.schema().clone(), - )), + ))), nullable: self.field.is_nullable(), } } diff --git a/core/src/error.rs b/core/src/error.rs index 077fe31..8450176 100644 --- a/core/src/error.rs +++ b/core/src/error.rs @@ -1,4 +1,4 @@ -use arrow_schema::DataType; +use arrow_schema::{ArrowError, DataType}; use thiserror::Error; use crate::pg_schema::PostgresType; @@ -40,6 +40,41 @@ pub enum ErrorKind { EncoderMissing { field: String }, #[error("No fields match supplied encoder fields: {fields:?}")] UnknownFields { fields: Vec }, + + // Decoding + #[error("Error decoding data: {reason}")] + Decode { reason: String, name: String }, + #[error("Got invalid binary file header {bytes:?}")] + InvalidBinaryHeader { bytes: [u8; 11] }, + #[error("Reached EOF in the middle of a tuple. partial tuple: {remaining_bytes:?}")] + IncompleteDecode { remaining_bytes: Vec }, + #[error("Expected data size was not found")] + IncompleteData, + #[error("Invalid column specification: {spec}")] + InvalidColumnSpec { spec: String }, + #[error("Invalid column type found while parsing schema: {typ}")] + UnsupportedColumnType { typ: String }, + #[error("Got an error in an IO Operation: {io_error:?}")] + IOError { io_error: std::io::Error }, + #[error("Got an error: {name} in Arrow while decoding: {reason}")] + ArrowErrorDecode { reason: String, name: String }, + #[error("ArrowType: {typ:?} not currently supported for decoding")] + UnsupportedArrowType { typ: DataType }, +} + +impl From for ErrorKind { + fn from(io_error: std::io::Error) -> Self { + ErrorKind::IOError { io_error } + } +} + +impl From for ErrorKind { + fn from(arrow_error: ArrowError) -> Self { + ErrorKind::ArrowErrorDecode { + reason: arrow_error.to_string(), + name: "ArrowError".to_string(), + } + } } impl ErrorKind { diff --git a/core/src/lib.rs b/core/src/lib.rs index c30e5fd..dac2897 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -1,15 +1,20 @@ use std::collections::HashMap; +use std::io::BufRead; use arrow_array::RecordBatch; use arrow_schema::Fields; use arrow_schema::Schema; -use bytes::{BufMut, BytesMut}; +use bytes::{Buf, BufMut, BytesMut}; use error::ErrorKind; +pub mod decoders; pub mod encoders; pub mod error; pub mod pg_schema; +mod buffer_view; +use crate::decoders::{PostgresDecoder, create_decoders}; +use crate::buffer_view::BufferView; use crate::encoders::{BuildEncoder, Encode, EncoderBuilder}; use crate::pg_schema::PostgresSchema; @@ -150,6 +155,271 @@ impl ArrowToPostgresBinaryEncoder { } } +enum BatchDecodeResult { + Batch(RecordBatch), + Incomplete(usize), + Error(ErrorKind), + PartialConsume { batch: RecordBatch, consumed: usize }, +} + +pub struct PostgresBinaryToArrowDecoder { + schema: PostgresSchema, + decoders: Vec, + source: R, + state: EncoderState, + capacity: usize, +} + +impl PostgresBinaryToArrowDecoder { + pub fn new(schema: PostgresSchema, source: R, capacity: usize) -> Result { + let decoders = create_decoders(&schema); + Ok(PostgresBinaryToArrowDecoder { + schema, + decoders, + source, + state: EncoderState::Created, + capacity, + }) + } + + /// Try to create a new decoder using an arrow schema. Will fail if the types in the schema + /// are not supported by the decoder. + pub fn try_new_with_arrow_schema( + schema: Schema, + source: R, + capacity: usize, + ) -> Result { + let pg_schema = PostgresSchema::try_from(schema)?; + let decoders = create_decoders(&pg_schema); + Ok(PostgresBinaryToArrowDecoder { + decoders, + schema: pg_schema, + source, + state: EncoderState::Created, + capacity, + }) + } + + /// Reads the header from the source and validates it. + pub fn read_header(&mut self) -> Result<(), ErrorKind> { + assert_eq!(self.state, EncoderState::Created); + + // Header is always 11 bytes long. b"PGCOPY\n\xff\r\n\0" + let mut header = [0; 11]; + self.source.read_exact(&mut header)?; + if header != HEADER_MAGIC_BYTES { + return Err(ErrorKind::InvalidBinaryHeader { bytes: header }); + } + + // read flags and header extension both of which we ignore the values of. + let mut flags = [0; 4]; + self.source.read_exact(&mut flags)?; + + let mut header_extension = [0; 4]; + self.source.read_exact(&mut header_extension)?; + + self.state = EncoderState::Encoding; + Ok(()) + } + + /// read batches of bytes from the source and decode them into RecordBatches. + pub fn decode_batches(&mut self) -> Result, ErrorKind> { + let mut batches = Vec::new(); + let mut buf = BytesMut::with_capacity(self.capacity); + let mut eof = false; + while self.state == EncoderState::Encoding { + // Read loop. Read from source until buf is full. + let mut data = self.source.fill_buf()?; + loop { + // If there is no data left in the source, break the loop and finish the batch. + // set the eof flag to true to indicate that the source has been fully read. + if data.is_empty() { + eof = true; + break; + // if data was read from the source, put it into the buffer and single + // to source that we have consumed the data by calling BufRead::consume. + } else { + buf.put(data); + let read = data.len(); + self.source.consume(read); + } + data = self.source.fill_buf()?; + + // If the remaining capacity of the buffer is less than the length of the data, + // break the loop. + let remaining = buf.capacity() - buf.len(); + if remaining < data.len() { + break; + } + } + + // If the eof flag is not set and there remains capacity in the buffer, read the + // ${remaining_capacity} bytes from the source and put them into the buffer. + let remaining = buf.capacity() - buf.len(); + if !eof && remaining > 0 { + let read = std::cmp::min(data.len(), buf.remaining()); + buf.put(&data[..read]); + self.source.consume(read); + } + + // If we have reached the end of the source decode the batch and return error if + // there is any remaining data in the buffer indicated by and IncompleteDecode + // or PartialConsume BatchDecodeResult. + if eof { + if !buf.is_empty() { + match self.decode_batch(&mut buf) { + BatchDecodeResult::Batch(batch) => batches.push(batch), + BatchDecodeResult::Error(e) => return Err(e), + BatchDecodeResult::Incomplete(consumed) + | BatchDecodeResult::PartialConsume { batch: _, consumed } => { + return Err(ErrorKind::IncompleteDecode { + remaining_bytes: buf[consumed..].to_vec(), + }) + } + } + } + self.state = EncoderState::Finished; + // If we have not reached the end of the source, decode the batch. + } else { + match self.decode_batch(&mut buf) { + BatchDecodeResult::Batch(batch) => { + batches.push(batch); + buf.clear() + } + // If we receive a PartialConsume BatchDecodeResult, store the batches we did + // manage to decode and continue reading from the source with the remaining + // data from the previous read in the buffer. + BatchDecodeResult::PartialConsume { batch, consumed } => { + batches.push(batch); + let old_buf = buf; + buf = BytesMut::with_capacity(self.capacity); + buf.put(&old_buf[consumed..]); + } + // If we receive an Incomplete BatchDecodeResult, increase the capacity of the + // buffer reading more data from the source and try to decode the batch again. + BatchDecodeResult::Incomplete(_) => { + // increase the capacity attribute of the decoder by a factor of 2. + buf.reserve(self.capacity); + self.capacity *= 2; + } + BatchDecodeResult::Error(e) => return Err(e), + } + } + } + Ok(batches) + } + + /// Decode a single batch of bytes from the buffer. This method is called by decode_batches + /// and has several different completion states each represented by a BatchDecodeResult. + fn decode_batch(&mut self, buf: &mut BytesMut) -> BatchDecodeResult { + // ensure that the decoder is in the correct state before proceeding. + assert_eq!(self.state, EncoderState::Encoding); + + // create a new BufferView from the buffer. + let mut local_buf = BufferView::new(buf); + // Keep track of the number of rows in the batch. + let mut rows = 0; + + // Each iteration of the loop reads a tuple from the data. + // If we are not able to read a tuple, return a BatchDecodeResult::Incomplete. + // If we were able to read some tuples, but not all the data in the buffer was consumed, + // return a BatchDecodeResult::PartialConsume. + while local_buf.remaining() > 0 { + // Store the number of bytes consumed before reading the tuple. + let consumed = local_buf.consumed(); + + // Read the number of columns in the tuple. This number is + // stored as a 16-bit integer. This value is the same for all + // tuples in the batch. + let tuple_len: u16 = match local_buf.consume_into_u16() { + Ok(len) => len, + Err(e) => { + return BatchDecodeResult::Error(e); + } + }; + + // If the tuple length is 0xffff we have reached the end of the + // snapshot and we can break the loop and finish the batch. + if tuple_len == 0xffff { + break; + } + + // Each iteration of the loop reads a column from the tuple using the + // decoder specfied via the schema. + for decoder in self.decoders.iter_mut() { + // If local_buf has been fully consumed and we have not read any rows, + // return a BatchDecodeResult::Incomplete. + if local_buf.remaining() == 0 && rows == 0 { + return BatchDecodeResult::Incomplete(local_buf.consumed()); + // If local_buf has been fully consumed and we have read some rows, + // return a BatchDecodeResult::PartialConsume, passing the number of bytes + // consumed before reading the tuple to the caller so it can know how much data + // was consumed. + } else if local_buf.remaining() == 0 { + return match self.finish_batch() { + Ok(batch) => BatchDecodeResult::PartialConsume { batch, consumed }, + Err(e) => BatchDecodeResult::Error(e), + }; + } + // Apply the decoder to the local_buf. Cosume the data from the buffer as needed + match decoder.decode(&mut local_buf) { + // If the decoder was able to decode the data, continue to the next column. + Ok(_) => {} + // If we receive a IncompleteData error, we have reached the end of the data in + // the buffer. If we have decoded some tuples, return a BatchDecodeResult::PartialConsume, + // otherwise return a BatchDecodeResult::Incomplete. + Err(ErrorKind::IncompleteData) => { + // If we have not read any rows, return a BatchDecodeResult::Incomplete. + if rows == 0 { + return BatchDecodeResult::Incomplete(local_buf.consumed()); + } else { + // If we have read some rows, return a BatchDecodeResult::PartialConsume, + return match self.finish_batch() { + Ok(batch) => BatchDecodeResult::PartialConsume { batch, consumed }, + Err(e) => BatchDecodeResult::Error(e), + }; + } + } + Err(e) => return BatchDecodeResult::Error(e), + } + } + // Increment the number of rows in the batch. + rows += 1; + } + + match self.finish_batch() { + Ok(batch) => BatchDecodeResult::Batch(batch), + Err(e) => BatchDecodeResult::Error(e), + } + } + + fn finish_batch(&mut self) -> Result { + // Find the mininum length column in the decoders. These can be different if + // we are in a partial consume state. We will truncate the columns to the length + // of the shortest column and pick up the lost data in the next batch. + let column_len = self.decoders.iter().map(|d| d.column_len()).min().unwrap(); + + // Determine which columns in the batch are fully null so that we can alter the schema + // to reflect this. + let null_columns = self + .decoders + .iter() + .filter(|decoder| decoder.is_null()) + .map(|decoder| decoder.name()) + .collect::>(); + + // For each decoder call its finish method to coerce the data into an Arrow array. + // and append the array to the columns vector. + let columns = self + .decoders + .iter_mut() + .map(|decoder| decoder.finish(column_len)) + .collect(); + + Ok(RecordBatch::try_new(self.schema.clone().nullify_columns(&null_columns).into(), columns)?) + } +} + #[cfg(test)] mod tests { use std::{collections::HashMap, sync::Arc}; diff --git a/core/src/pg_schema.rs b/core/src/pg_schema.rs index 1aaa3af..22b6523 100644 --- a/core/src/pg_schema.rs +++ b/core/src/pg_schema.rs @@ -1,3 +1,7 @@ +use crate::error::ErrorKind; +use arrow_schema::{DataType, Field, Schema, SchemaRef, TimeUnit}; +use std::sync::Arc; + #[derive(Debug, Clone, PartialEq)] pub enum TypeSize { Fixed(usize), @@ -17,11 +21,14 @@ pub enum PostgresType { Jsonb, Float4, Float8, + Numeric, Date, Time, Timestamp, + TimestampTz(String), Interval, - List(Box), + List((String, Box)), + Null, } impl PostgresType { @@ -41,8 +48,11 @@ impl PostgresType { PostgresType::Date => TypeSize::Fixed(4), PostgresType::Time => TypeSize::Fixed(8), PostgresType::Timestamp => TypeSize::Fixed(8), - PostgresType::Interval => TypeSize::Fixed(16), + PostgresType::TimestampTz(_) => TypeSize::Fixed(8), + PostgresType::Interval => TypeSize::Fixed(12), + PostgresType::Numeric => TypeSize::Variable, PostgresType::List(_) => TypeSize::Variable, + PostgresType::Null => TypeSize::Fixed(0), } } pub fn oid(&self) -> Option { @@ -58,11 +68,14 @@ impl PostgresType { PostgresType::Jsonb => Some(3802), PostgresType::Float4 => Some(700), PostgresType::Float8 => Some(701), + PostgresType::Numeric => Some(1700), PostgresType::Date => Some(1082), PostgresType::Time => Some(1083), PostgresType::Timestamp => Some(1114), + PostgresType::TimestampTz(_) => Some(1182), PostgresType::Interval => Some(1186), PostgresType::List(_) => None, + PostgresType::Null => None, } } pub fn name(&self) -> Option { @@ -78,27 +91,266 @@ impl PostgresType { PostgresType::Jsonb => "JSONB".to_string(), PostgresType::Float4 => "FLOAT4".to_string(), PostgresType::Float8 => "FLOAT8".to_string(), + PostgresType::Numeric => "DECIMAL".to_string(), PostgresType::Date => "DATE".to_string(), PostgresType::Time => "TIME".to_string(), PostgresType::Timestamp => "TIMESTAMP".to_string(), + PostgresType::TimestampTz(_) => "TIMESTAMP WITH ZONE".to_string(), PostgresType::Interval => "INTERVAL".to_string(), - PostgresType::List(inner) => { + PostgresType::List((_, column)) => { // arrays of structs and such are not supported - let inner_tp = inner.data_type.name().unwrap(); + let inner_tp = column.data_type.name().unwrap(); format!("{inner_tp}[]") } + PostgresType::Null => "NULL".to_string(), }; Some(v) } } + +impl PostgresType { +} + + +impl From for DataType { + fn from(pg_type: PostgresType) -> Self { + match pg_type { + PostgresType::Bool => DataType::Boolean, + PostgresType::Bytea => DataType::Binary, + PostgresType::Int8 => DataType::Int64, + PostgresType::Int2 => DataType::Int16, + PostgresType::Int4 => DataType::Int32, + PostgresType::Char => DataType::Utf8, + PostgresType::Text => DataType::Utf8, + PostgresType::Json => DataType::Utf8, + PostgresType::Jsonb => DataType::Utf8, + PostgresType::Float4 => DataType::Float32, + PostgresType::Float8 => DataType::Float64, + PostgresType::Numeric => DataType::Utf8, + PostgresType::Date => DataType::Date32, + PostgresType::Time => DataType::Time64(TimeUnit::Microsecond), + PostgresType::Timestamp => DataType::Timestamp(TimeUnit::Microsecond, None), + PostgresType::TimestampTz(timezone) => { + DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.into())) + } + PostgresType::Interval => DataType::Duration(TimeUnit::Microsecond), + PostgresType::List((name, column)) => { + let name = name.replace("list_", ""); + DataType::List(Arc::new(Field::new( + &name, + column.data_type.clone().into(), + column.nullable, + ))) + }, + PostgresType::Null => DataType::Null, + } + } +} + + +impl TryFrom for PostgresType { + type Error = ErrorKind; + fn try_from(data_type: DataType) -> Result { + let pg_type = match data_type { + DataType::Boolean => PostgresType::Bool, + DataType::Binary => PostgresType::Bytea, + DataType::Int64 => PostgresType::Int8, + DataType::Int32 => PostgresType::Int4, + DataType::Int16 => PostgresType::Int2, + DataType::Utf8 => PostgresType::Text, + DataType::Float32 => PostgresType::Float4, + DataType::Float64 => PostgresType::Float8, + DataType::Date32 => PostgresType::Date, + DataType::Time64(_) => PostgresType::Time, + DataType::Timestamp(_, tz) => { + if let Some(timezone) = tz { + PostgresType::TimestampTz(timezone.to_string()) + } else { + PostgresType::Timestamp + } + } + DataType::Duration(_) => PostgresType::Interval, + DataType::Null => PostgresType::Null, + _ => return Err(ErrorKind::UnsupportedArrowType { typ: data_type }), + }; + Ok(pg_type) + } +} + #[derive(Debug, Clone, PartialEq)] pub struct Column { pub data_type: PostgresType, pub nullable: bool, } +impl Column { + pub fn from_parts(name: &str, type_str: &str, nullable: &str, timezone: String) -> Result { + match type_str { + "boolean" => Ok(Column { + data_type: PostgresType::Bool, + nullable: nullable == "t", + }), + "bytea" => Ok(Column { + data_type: PostgresType::Bytea, + nullable: nullable == "t", + }), + "bigint" => Ok(Column { + data_type: PostgresType::Int8, + nullable: nullable == "t", + }), + "smallint" => Ok(Column { + data_type: PostgresType::Int2, + nullable: nullable == "t", + }), + "integer" => Ok(Column { + data_type: PostgresType::Int4, + nullable: nullable == "t", + }), + "character" | "character varying" => Ok(Column { + data_type: PostgresType::Char, + nullable: nullable == "t", + }), + "text" => Ok(Column { + data_type: PostgresType::Text, + nullable: nullable == "t", + }), + "json" => Ok(Column { + data_type: PostgresType::Json, + nullable: nullable == "t", + }), + "jsonb" => Ok(Column { + data_type: PostgresType::Jsonb, + nullable: nullable == "t", + }), + "real" => Ok(Column { + data_type: PostgresType::Float4, + nullable: nullable == "t", + }), + "double precision" => Ok(Column { + data_type: PostgresType::Float8, + nullable: nullable == "t", + }), + "numeric" => Ok(Column { + data_type: PostgresType::Numeric, + nullable: nullable == "t", + }), + "date" => Ok(Column { + data_type: PostgresType::Date, + nullable: nullable == "t", + }), + "time" => Ok(Column { + data_type: PostgresType::Time, + nullable: nullable == "t", + }), + "timestamp without time zone" => Ok(Column { + data_type: PostgresType::Timestamp, + nullable: nullable == "t", + }), + "timestamp with time zone" => Ok(Column { + data_type: PostgresType::TimestampTz(timezone), + nullable: nullable == "t", + }), + "interval" => Ok(Column { + data_type: PostgresType::Interval, + nullable: nullable == "t", + }), + typ if typ.ends_with("[]") => { + Ok(Column { + data_type: PostgresType::List((name.to_string(), Box::new(Column { + data_type: Column::from_parts(name, &typ[..typ.len() - 2], "f", timezone)?.data_type, + nullable: true, + }))), + nullable: nullable == "t", + }) + }, + _ => Err(ErrorKind::UnsupportedColumnType { + typ: type_str.to_string(), + }), + } + } +} + #[derive(Debug, Clone)] pub struct PostgresSchema { pub columns: Vec<(String, Column)>, } + +impl From for SchemaRef { + fn from(pg_schema: PostgresSchema) -> Self { + let fields: Vec = pg_schema + .columns + .iter() + .map(|(name, col)| Field::new(name, col.data_type.clone().into(), col.nullable)) + .collect(); + Arc::new(Schema::new(fields)) + } +} + +impl TryFrom for PostgresSchema { + type Error = ErrorKind; + fn try_from(schema: Schema) -> Result { + let columns: Result, ErrorKind> = schema + .fields() + .iter() + .map(|field| { + let name = field.name().to_string(); + let data_type = field.data_type().clone(); + let nullable = field.is_nullable(); + let data_type = PostgresType::try_from(data_type)?; + let col = Column { + data_type, + nullable, + }; + Ok((name, col)) + }) + .collect(); + + Ok(PostgresSchema { columns: columns? }) + } +} + +impl PostgresSchema { + pub fn from_reader( + mut reader: R, + delim: char, + timezone: String, + ) -> Result { + let mut schema_str = String::new(); + reader.read_to_string(&mut schema_str)?; + + let schema = schema_str + .split('\n') + .filter(|s| !s.is_empty()) + .map(|s| { + let parts: Vec<&str> = s.splitn(3, delim).collect(); + if parts.len() != 3 { + return Err(ErrorKind::InvalidColumnSpec { + spec: s.to_string(), + }); + } + let name = parts[0]; + let typ = parts[1]; + let nullable = parts[2]; + let col = Column::from_parts(name, typ, nullable, timezone.to_string())?; + Ok((name.to_string(), col)) + }) + .collect::, ErrorKind>>() + .map(|columns| PostgresSchema { columns })?; + + Ok(schema) + } + + pub fn nullify_columns(mut self, columns: &[String]) -> Self { + for (name, col) in self.columns.iter_mut() { + if columns.contains(name) { + col.data_type = PostgresType::Null; + } + } + self + } + + pub fn iter(&self) -> impl Iterator { + self.columns.iter() + } +} diff --git a/core/tests/decode_integration_tests.rs b/core/tests/decode_integration_tests.rs new file mode 100644 index 0000000..22865b1 --- /dev/null +++ b/core/tests/decode_integration_tests.rs @@ -0,0 +1,442 @@ +use arrow_array::RecordBatch; +use arrow_ipc::reader::FileReader; +use arrow_schema::{Field, Schema}; +use pgpq::error::ErrorKind; +use std::fs::File; +use std::io::BufReader; +use std::path::PathBuf; +use std::sync::Arc; + +use pgpq::{pg_schema::PostgresSchema, PostgresBinaryToArrowDecoder}; + +const READ_CHUNK_SIZE: usize = 1024 * 1024 * 8; + +fn read_schema_file(path: PathBuf, timezone: String) -> PostgresSchema { + let file = File::open(path).unwrap(); + let reader = BufReader::new(file); + PostgresSchema::from_reader(reader, ',', timezone).unwrap() +} + +fn read_arrow_file(path: PathBuf) -> Vec { + let file = File::open(path).unwrap(); + let reader = FileReader::try_new(file, None).unwrap(); + reader.collect::, _>>().unwrap() +} + +fn run_test_case(case: &str, timezone: String) -> Result<(), ErrorKind> { + let path = + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(format!("tests/snapshots/{case}.bin")); + let schema_path = + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(format!("tests/decoding/{case}.schema")); + let arrow_path = + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(format!("tests/testdata/{case}.arrow")); + + let file = File::open(path).unwrap(); + let reader = BufReader::with_capacity(READ_CHUNK_SIZE, file); + let schema = read_schema_file(schema_path, timezone); + println!("Schema: {:?}", schema); + + let mut decoder = PostgresBinaryToArrowDecoder::new(schema, reader, READ_CHUNK_SIZE).unwrap(); + decoder.read_header()?; + let batches = decoder.decode_batches()?; + + let mut expected_batches = read_arrow_file(arrow_path); + + // all testdata currently has nullable set where it should not. + // This is a workaround to make the test pass. + if !case.contains("nullable") { + println!("Setting nullable to false for case: {}", case); + expected_batches = expected_batches + .into_iter() + .map(|batch| { + let new_fields: Vec> = (*(*batch.schema()).clone().fields) + .to_vec() + .clone() + .into_iter() + .map(|f| Arc::new((*f).clone().with_nullable(false))) + .collect(); + let new_schema = Schema::new(new_fields); + RecordBatch::try_new(Arc::new(new_schema), batch.columns().to_vec()).unwrap() + }) + .collect::>(); + } + + // All list types are not nullable in the expected data + // This is a workaround to make the test pass. + if case.contains("list") && case.contains("nullable") { + expected_batches = expected_batches + .into_iter() + .map(|batch| { + let new_fields: Vec> = (*(*batch.schema()).clone().fields) + .to_vec() + .clone() + .into_iter() + .map(|f| { + let new_field = (*f).clone(); + let new_field = new_field.with_nullable(false); + if new_field.data_type().to_string().contains("List") { + Arc::new(new_field.with_nullable(true)) + } else { + Arc::new(new_field) + } + }) + .collect(); + let new_schema = Schema::new(new_fields); + RecordBatch::try_new(Arc::new(new_schema), batch.columns().to_vec()).unwrap() + }) + .collect::>(); + } + + assert_eq!(batches.iter().map(|b| b.num_rows()).sum::(), expected_batches.iter().map(|b| b.num_rows()).sum::()); + let batch_schemas = batches.iter().map(|b| b.schema()).collect::>(); + let expected_batch_schemas = expected_batches.iter().map(|b| b.schema()).collect::>(); + assert_eq!(batch_schemas, expected_batch_schemas); + assert_eq!(batches, expected_batches); + + Ok(()) +} + +#[test] +fn test_bool() -> Result<(), ErrorKind> { + run_test_case("bool", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_int16() -> Result<(), ErrorKind> { + run_test_case("int16", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_int32() -> Result<(), ErrorKind> { + run_test_case("int32", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_int64() -> Result<(), ErrorKind> { + run_test_case("int64", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_float32() -> Result<(), ErrorKind> { + run_test_case("float32", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_float64() -> Result<(), ErrorKind> { + run_test_case("float64", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_numeric() -> Result<(), ErrorKind> { + run_test_case("numeric", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_timestamp_us_notz() -> Result<(), ErrorKind> { + run_test_case("timestamp_us_notz", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_timestamp_us_tz() -> Result<(), ErrorKind> { + run_test_case("timestamp_us_tz", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_time_us() -> Result<(), ErrorKind> { + run_test_case("time_us", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_date32() -> Result<(), ErrorKind> { + run_test_case("date32", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_duration_us() -> Result<(), ErrorKind> { + run_test_case("duration_us", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_binary() -> Result<(), ErrorKind> { + run_test_case("binary", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_string() -> Result<(), ErrorKind> { + run_test_case("string", "America/New_York".to_string())?; + Ok(()) +} + +// nullable types + +#[test] +fn test_bool_nullable() -> Result<(), ErrorKind> { + run_test_case("bool_nullable", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_int16_nullable() -> Result<(), ErrorKind> { + run_test_case("int16_nullable", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_int32_nullable() -> Result<(), ErrorKind> { + run_test_case("int32_nullable", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_int64_nullable() -> Result<(), ErrorKind> { + run_test_case("int64_nullable", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_float32_nullable() -> Result<(), ErrorKind> { + run_test_case("float32_nullable", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_float64_nullable() -> Result<(), ErrorKind> { + run_test_case("float64_nullable", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_timestamp_us_notz_nullable() -> Result<(), ErrorKind> { + run_test_case("timestamp_us_notz_nullable", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_timestamp_us_tz_nullable() -> Result<(), ErrorKind> { + run_test_case("timestamp_us_tz_nullable", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_time_us_nullable() -> Result<(), ErrorKind> { + run_test_case("time_us_nullable", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_date32_nullable() -> Result<(), ErrorKind> { + run_test_case("date32_nullable", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_duration_us_nullable() -> Result<(), ErrorKind> { + run_test_case("duration_us_nullable", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_binary_nullable() -> Result<(), ErrorKind> { + run_test_case("binary_nullable", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_string_nullable() -> Result<(), ErrorKind> { + run_test_case("string_nullable", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_numeric_nullable() -> Result<(), ErrorKind> { + run_test_case("numeric_nullable", "America/New_York".to_string())?; + Ok(()) +} + +// Nested types non-nullable + + +#[test] +fn test_list_int16() -> Result<(), ErrorKind> { + run_test_case("list_int16", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_list_int32() -> Result<(), ErrorKind> { + run_test_case("list_int32", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_list_int64() -> Result<(), ErrorKind> { + run_test_case("list_int64", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_list_float32() -> Result<(), ErrorKind> { + run_test_case("list_float32", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_list_float64() -> Result<(), ErrorKind> { + run_test_case("list_float64", "America/New_York".to_string())?; + Ok(()) +} + +// Needed to add expected data for list_numeric +// #[test] +// fn test_list_numeric() -> Result<(), ErrorKind> { +// run_test_case("list_numeric", "America/New_York".to_string())?; +// Ok(()) +// } + +#[test] +fn test_list_timestamp_us_notz() -> Result<(), ErrorKind> { + run_test_case("list_timestamp_us_notz", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_list_timestamp_us_tz() -> Result<(), ErrorKind> { + run_test_case("list_timestamp_us_tz", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_list_time_us() -> Result<(), ErrorKind> { + run_test_case("list_time_us", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_list_date32() -> Result<(), ErrorKind> { + run_test_case("list_date32", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_list_duration_us() -> Result<(), ErrorKind> { + run_test_case("list_duration_us", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_list_binary() -> Result<(), ErrorKind> { + run_test_case("list_binary", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_list_string() -> Result<(), ErrorKind> { + run_test_case("list_string", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_list_bool() -> Result<(), ErrorKind> { + run_test_case("list_bool", "America/New_York".to_string())?; + Ok(()) +} + +// Nested types nullable + +#[test] +fn test_list_bool_nullable() -> Result<(), ErrorKind> { + run_test_case("list_bool_nullable", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_list_int16_nullable() -> Result<(), ErrorKind> { + run_test_case("list_int16_nullable", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_list_int32_nullable() -> Result<(), ErrorKind> { + run_test_case("list_int32_nullable", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_list_int64_nullable() -> Result<(), ErrorKind> { + run_test_case("list_int64_nullable", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_list_float32_nullable() -> Result<(), ErrorKind> { + run_test_case("list_float32_nullable", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_list_float64_nullable() -> Result<(), ErrorKind> { + run_test_case("list_float64_nullable", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_list_timestamp_us_notz_nullable() -> Result<(), ErrorKind> { + run_test_case("list_timestamp_us_notz_nullable", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_list_timestamp_us_tz_nullable() -> Result<(), ErrorKind> { + run_test_case("list_timestamp_us_tz_nullable", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_list_time_us_nullable() -> Result<(), ErrorKind> { + run_test_case("list_time_us_nullable", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_list_date32_nullable() -> Result<(), ErrorKind> { + run_test_case("list_date32_nullable", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_list_duration_us_nullable() -> Result<(), ErrorKind> { + run_test_case("list_duration_us_nullable", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_list_binary_nullable() -> Result<(), ErrorKind> { + run_test_case("list_binary_nullable", "America/New_York".to_string())?; + Ok(()) +} + +#[test] +fn test_list_string_nullable() -> Result<(), ErrorKind> { + run_test_case("list_string_nullable", "America/New_York".to_string())?; + Ok(()) +} + +// Needed to add expected data for list_numeric_nullable +// #[test] +// fn test_list_numeric_nullable() -> Result<(), ErrorKind> { +// run_test_case("list_numeric_nullable", "America/New_York".to_string())?; +// Ok(()) +// } diff --git a/core/tests/decoding/binary.schema b/core/tests/decoding/binary.schema new file mode 100644 index 0000000..9abe242 --- /dev/null +++ b/core/tests/decoding/binary.schema @@ -0,0 +1 @@ +binary,bytea,f diff --git a/core/tests/decoding/binary_nullable.schema b/core/tests/decoding/binary_nullable.schema new file mode 100644 index 0000000..e3faaf6 --- /dev/null +++ b/core/tests/decoding/binary_nullable.schema @@ -0,0 +1 @@ +binary_nullable,bytea,t diff --git a/core/tests/decoding/bool.schema b/core/tests/decoding/bool.schema new file mode 100644 index 0000000..a0e6dbf --- /dev/null +++ b/core/tests/decoding/bool.schema @@ -0,0 +1 @@ +bool,boolean,f \ No newline at end of file diff --git a/core/tests/decoding/bool_nullable.schema b/core/tests/decoding/bool_nullable.schema new file mode 100644 index 0000000..5e11eac --- /dev/null +++ b/core/tests/decoding/bool_nullable.schema @@ -0,0 +1 @@ +bool_nullable,boolean,t diff --git a/core/tests/decoding/date32.schema b/core/tests/decoding/date32.schema new file mode 100644 index 0000000..9fb0ee6 --- /dev/null +++ b/core/tests/decoding/date32.schema @@ -0,0 +1 @@ +date32,date,f \ No newline at end of file diff --git a/core/tests/decoding/date32_nullable.schema b/core/tests/decoding/date32_nullable.schema new file mode 100644 index 0000000..f0082d7 --- /dev/null +++ b/core/tests/decoding/date32_nullable.schema @@ -0,0 +1 @@ +date32_nullable,date,t diff --git a/core/tests/decoding/duration_us.schema b/core/tests/decoding/duration_us.schema new file mode 100644 index 0000000..64fd6f8 --- /dev/null +++ b/core/tests/decoding/duration_us.schema @@ -0,0 +1 @@ +duration_us,interval,f \ No newline at end of file diff --git a/core/tests/decoding/duration_us_nullable.schema b/core/tests/decoding/duration_us_nullable.schema new file mode 100644 index 0000000..8a013b7 --- /dev/null +++ b/core/tests/decoding/duration_us_nullable.schema @@ -0,0 +1 @@ +duration_us_nullable,interval,t diff --git a/core/tests/decoding/float32.schema b/core/tests/decoding/float32.schema new file mode 100644 index 0000000..1905873 --- /dev/null +++ b/core/tests/decoding/float32.schema @@ -0,0 +1 @@ +float32,real,f \ No newline at end of file diff --git a/core/tests/decoding/float32_nullable.schema b/core/tests/decoding/float32_nullable.schema new file mode 100644 index 0000000..b3708a9 --- /dev/null +++ b/core/tests/decoding/float32_nullable.schema @@ -0,0 +1 @@ +float32_nullable,real,t diff --git a/core/tests/decoding/float64.schema b/core/tests/decoding/float64.schema new file mode 100644 index 0000000..9c68dd4 --- /dev/null +++ b/core/tests/decoding/float64.schema @@ -0,0 +1 @@ +float64,double precision,f \ No newline at end of file diff --git a/core/tests/decoding/float64_nullable.schema b/core/tests/decoding/float64_nullable.schema new file mode 100644 index 0000000..a8e547c --- /dev/null +++ b/core/tests/decoding/float64_nullable.schema @@ -0,0 +1 @@ +float64_nullable,double precision,t diff --git a/core/tests/decoding/int16.schema b/core/tests/decoding/int16.schema new file mode 100644 index 0000000..fa3f2bd --- /dev/null +++ b/core/tests/decoding/int16.schema @@ -0,0 +1 @@ +int16,smallint,f \ No newline at end of file diff --git a/core/tests/decoding/int16_nullable.schema b/core/tests/decoding/int16_nullable.schema new file mode 100644 index 0000000..9eb14de --- /dev/null +++ b/core/tests/decoding/int16_nullable.schema @@ -0,0 +1 @@ +int16_nullable,smallint,t diff --git a/core/tests/decoding/int32.schema b/core/tests/decoding/int32.schema new file mode 100644 index 0000000..6409207 --- /dev/null +++ b/core/tests/decoding/int32.schema @@ -0,0 +1 @@ +int32,integer,f \ No newline at end of file diff --git a/core/tests/decoding/int32_nullable.schema b/core/tests/decoding/int32_nullable.schema new file mode 100644 index 0000000..e88df62 --- /dev/null +++ b/core/tests/decoding/int32_nullable.schema @@ -0,0 +1 @@ +int32_nullable,integer,t diff --git a/core/tests/decoding/int64.schema b/core/tests/decoding/int64.schema new file mode 100644 index 0000000..fc3fd4c --- /dev/null +++ b/core/tests/decoding/int64.schema @@ -0,0 +1 @@ +int64,bigint,f \ No newline at end of file diff --git a/core/tests/decoding/int64_nullable.schema b/core/tests/decoding/int64_nullable.schema new file mode 100644 index 0000000..7f0fdab --- /dev/null +++ b/core/tests/decoding/int64_nullable.schema @@ -0,0 +1 @@ +int64_nullable,bigint,t diff --git a/core/tests/decoding/list_binary.schema b/core/tests/decoding/list_binary.schema new file mode 100644 index 0000000..05ed96b --- /dev/null +++ b/core/tests/decoding/list_binary.schema @@ -0,0 +1 @@ +list_binary,bytea[],f diff --git a/core/tests/decoding/list_binary_nullable.schema b/core/tests/decoding/list_binary_nullable.schema new file mode 100644 index 0000000..a728383 --- /dev/null +++ b/core/tests/decoding/list_binary_nullable.schema @@ -0,0 +1 @@ +list_binary_nullable,bytea[],t diff --git a/core/tests/decoding/list_bool.schema b/core/tests/decoding/list_bool.schema new file mode 100644 index 0000000..fa56254 --- /dev/null +++ b/core/tests/decoding/list_bool.schema @@ -0,0 +1 @@ +list_bool,boolean[],f diff --git a/core/tests/decoding/list_bool_nullable.schema b/core/tests/decoding/list_bool_nullable.schema new file mode 100644 index 0000000..7bb29af --- /dev/null +++ b/core/tests/decoding/list_bool_nullable.schema @@ -0,0 +1 @@ +list_bool_nullable,boolean[],t diff --git a/core/tests/decoding/list_date32.schema b/core/tests/decoding/list_date32.schema new file mode 100644 index 0000000..10f6597 --- /dev/null +++ b/core/tests/decoding/list_date32.schema @@ -0,0 +1 @@ +list_date32,date[],f diff --git a/core/tests/decoding/list_date32_nullable.schema b/core/tests/decoding/list_date32_nullable.schema new file mode 100644 index 0000000..c61c08b --- /dev/null +++ b/core/tests/decoding/list_date32_nullable.schema @@ -0,0 +1 @@ +list_date32_nullable,date[],t diff --git a/core/tests/decoding/list_duration_us.schema b/core/tests/decoding/list_duration_us.schema new file mode 100644 index 0000000..34a89bc --- /dev/null +++ b/core/tests/decoding/list_duration_us.schema @@ -0,0 +1 @@ +list_duration_us,interval[],f diff --git a/core/tests/decoding/list_duration_us_nullable.schema b/core/tests/decoding/list_duration_us_nullable.schema new file mode 100644 index 0000000..8fb3609 --- /dev/null +++ b/core/tests/decoding/list_duration_us_nullable.schema @@ -0,0 +1 @@ +list_duration_us_nullable,interval[],t diff --git a/core/tests/decoding/list_float32.schema b/core/tests/decoding/list_float32.schema new file mode 100644 index 0000000..127f883 --- /dev/null +++ b/core/tests/decoding/list_float32.schema @@ -0,0 +1 @@ +list_float32,real[],f diff --git a/core/tests/decoding/list_float32_nullable.schema b/core/tests/decoding/list_float32_nullable.schema new file mode 100644 index 0000000..0220ad4 --- /dev/null +++ b/core/tests/decoding/list_float32_nullable.schema @@ -0,0 +1 @@ +list_float32_nullable,real[],t diff --git a/core/tests/decoding/list_float64.schema b/core/tests/decoding/list_float64.schema new file mode 100644 index 0000000..1d51647 --- /dev/null +++ b/core/tests/decoding/list_float64.schema @@ -0,0 +1 @@ +list_float64,double precision[],f diff --git a/core/tests/decoding/list_float64_nullable.schema b/core/tests/decoding/list_float64_nullable.schema new file mode 100644 index 0000000..7afa553 --- /dev/null +++ b/core/tests/decoding/list_float64_nullable.schema @@ -0,0 +1 @@ +list_float64_nullable,double precision[],t diff --git a/core/tests/decoding/list_int16.schema b/core/tests/decoding/list_int16.schema new file mode 100644 index 0000000..b3cd100 --- /dev/null +++ b/core/tests/decoding/list_int16.schema @@ -0,0 +1 @@ +list_int16,smallint[],f diff --git a/core/tests/decoding/list_int16_nullable.schema b/core/tests/decoding/list_int16_nullable.schema new file mode 100644 index 0000000..b393c8f --- /dev/null +++ b/core/tests/decoding/list_int16_nullable.schema @@ -0,0 +1 @@ +list_int16_nullable,smallint[],t diff --git a/core/tests/decoding/list_int32.schema b/core/tests/decoding/list_int32.schema new file mode 100644 index 0000000..a9cdc49 --- /dev/null +++ b/core/tests/decoding/list_int32.schema @@ -0,0 +1 @@ +list_int32,integer[],f diff --git a/core/tests/decoding/list_int32_nullable.schema b/core/tests/decoding/list_int32_nullable.schema new file mode 100644 index 0000000..bd9cc8e --- /dev/null +++ b/core/tests/decoding/list_int32_nullable.schema @@ -0,0 +1 @@ +list_int32_nullable,integer[],t diff --git a/core/tests/decoding/list_int64.schema b/core/tests/decoding/list_int64.schema new file mode 100644 index 0000000..e1fb72e --- /dev/null +++ b/core/tests/decoding/list_int64.schema @@ -0,0 +1 @@ +list_int64,bigint[],f diff --git a/core/tests/decoding/list_int64_nullable.schema b/core/tests/decoding/list_int64_nullable.schema new file mode 100644 index 0000000..1af3130 --- /dev/null +++ b/core/tests/decoding/list_int64_nullable.schema @@ -0,0 +1 @@ +list_int64_nullable,bigint[],t diff --git a/core/tests/decoding/list_numeric.schema b/core/tests/decoding/list_numeric.schema new file mode 100644 index 0000000..52af8c2 --- /dev/null +++ b/core/tests/decoding/list_numeric.schema @@ -0,0 +1 @@ +list_numeric,numeric[],f diff --git a/core/tests/decoding/list_numeric_nullable.schema b/core/tests/decoding/list_numeric_nullable.schema new file mode 100644 index 0000000..294b201 --- /dev/null +++ b/core/tests/decoding/list_numeric_nullable.schema @@ -0,0 +1 @@ +list_numeric_nullable,numeric[],t diff --git a/core/tests/decoding/list_string.schema b/core/tests/decoding/list_string.schema new file mode 100644 index 0000000..b4bd045 --- /dev/null +++ b/core/tests/decoding/list_string.schema @@ -0,0 +1 @@ +list_string,text[],f diff --git a/core/tests/decoding/list_string_nullable.schema b/core/tests/decoding/list_string_nullable.schema new file mode 100644 index 0000000..5fe1da9 --- /dev/null +++ b/core/tests/decoding/list_string_nullable.schema @@ -0,0 +1 @@ +list_string_nullable,text[],t diff --git a/core/tests/decoding/list_time_us.schema b/core/tests/decoding/list_time_us.schema new file mode 100644 index 0000000..d6032a9 --- /dev/null +++ b/core/tests/decoding/list_time_us.schema @@ -0,0 +1 @@ +list_time_us,time[],f diff --git a/core/tests/decoding/list_time_us_nullable.schema b/core/tests/decoding/list_time_us_nullable.schema new file mode 100644 index 0000000..f4c40f4 --- /dev/null +++ b/core/tests/decoding/list_time_us_nullable.schema @@ -0,0 +1 @@ +list_time_us_nullable,time[],t diff --git a/core/tests/decoding/list_timestamp_us_notz.schema b/core/tests/decoding/list_timestamp_us_notz.schema new file mode 100644 index 0000000..40e94ca --- /dev/null +++ b/core/tests/decoding/list_timestamp_us_notz.schema @@ -0,0 +1 @@ +list_timestamp_us_notz,timestamp without time zone[],f diff --git a/core/tests/decoding/list_timestamp_us_notz_nullable.schema b/core/tests/decoding/list_timestamp_us_notz_nullable.schema new file mode 100644 index 0000000..18a924c --- /dev/null +++ b/core/tests/decoding/list_timestamp_us_notz_nullable.schema @@ -0,0 +1 @@ +list_timestamp_us_notz_nullable,timestamp without time zone[],t diff --git a/core/tests/decoding/list_timestamp_us_tz.schema b/core/tests/decoding/list_timestamp_us_tz.schema new file mode 100644 index 0000000..69a52d4 --- /dev/null +++ b/core/tests/decoding/list_timestamp_us_tz.schema @@ -0,0 +1 @@ +list_timestamp_us_tz,timestamp with time zone[],f diff --git a/core/tests/decoding/list_timestamp_us_tz_nullable.schema b/core/tests/decoding/list_timestamp_us_tz_nullable.schema new file mode 100644 index 0000000..abd71f8 --- /dev/null +++ b/core/tests/decoding/list_timestamp_us_tz_nullable.schema @@ -0,0 +1 @@ +list_timestamp_us_tz_nullable,timestamp with time zone[],t diff --git a/core/tests/decoding/numeric.schema b/core/tests/decoding/numeric.schema new file mode 100644 index 0000000..ce2108f --- /dev/null +++ b/core/tests/decoding/numeric.schema @@ -0,0 +1 @@ +numeric,numeric,f \ No newline at end of file diff --git a/core/tests/decoding/numeric_nullable.schema b/core/tests/decoding/numeric_nullable.schema new file mode 100644 index 0000000..32445c4 --- /dev/null +++ b/core/tests/decoding/numeric_nullable.schema @@ -0,0 +1 @@ +numeric_nullable,numeric,t diff --git a/core/tests/decoding/string.schema b/core/tests/decoding/string.schema new file mode 100644 index 0000000..62a7265 --- /dev/null +++ b/core/tests/decoding/string.schema @@ -0,0 +1 @@ +string,text,f diff --git a/core/tests/decoding/string_nullable.schema b/core/tests/decoding/string_nullable.schema new file mode 100644 index 0000000..576a89c --- /dev/null +++ b/core/tests/decoding/string_nullable.schema @@ -0,0 +1 @@ +string_nullable,text,t diff --git a/core/tests/decoding/time_us.schema b/core/tests/decoding/time_us.schema new file mode 100644 index 0000000..e327d79 --- /dev/null +++ b/core/tests/decoding/time_us.schema @@ -0,0 +1 @@ +time_us,time,f \ No newline at end of file diff --git a/core/tests/decoding/time_us_nullable.schema b/core/tests/decoding/time_us_nullable.schema new file mode 100644 index 0000000..8e6d71f --- /dev/null +++ b/core/tests/decoding/time_us_nullable.schema @@ -0,0 +1 @@ +time_us_nullable,time,t diff --git a/core/tests/decoding/timestamp_us_notz.schema b/core/tests/decoding/timestamp_us_notz.schema new file mode 100644 index 0000000..705cf01 --- /dev/null +++ b/core/tests/decoding/timestamp_us_notz.schema @@ -0,0 +1 @@ +timestamp_us_notz,timestamp without time zone,f diff --git a/core/tests/decoding/timestamp_us_notz_nullable.schema b/core/tests/decoding/timestamp_us_notz_nullable.schema new file mode 100644 index 0000000..1a3360a --- /dev/null +++ b/core/tests/decoding/timestamp_us_notz_nullable.schema @@ -0,0 +1 @@ +timestamp_us_notz_nullable,timestamp without time zone,t diff --git a/core/tests/decoding/timestamp_us_tz.schema b/core/tests/decoding/timestamp_us_tz.schema new file mode 100644 index 0000000..64ec24b --- /dev/null +++ b/core/tests/decoding/timestamp_us_tz.schema @@ -0,0 +1 @@ +timestamp_us_tz,timestamp with time zone,f \ No newline at end of file diff --git a/core/tests/decoding/timestamp_us_tz_nullable.schema b/core/tests/decoding/timestamp_us_tz_nullable.schema new file mode 100644 index 0000000..e16aede --- /dev/null +++ b/core/tests/decoding/timestamp_us_tz_nullable.schema @@ -0,0 +1 @@ +timestamp_us_tz_nullable,timestamp with time zone,t diff --git a/core/tests/snapshots/numeric.bin b/core/tests/snapshots/numeric.bin new file mode 100644 index 0000000..8171175 Binary files /dev/null and b/core/tests/snapshots/numeric.bin differ diff --git a/core/tests/snapshots/numeric_nullable.bin b/core/tests/snapshots/numeric_nullable.bin new file mode 100644 index 0000000..c152cff Binary files /dev/null and b/core/tests/snapshots/numeric_nullable.bin differ diff --git a/core/tests/testdata/numeric.arrow b/core/tests/testdata/numeric.arrow new file mode 100644 index 0000000..1180dde Binary files /dev/null and b/core/tests/testdata/numeric.arrow differ diff --git a/core/tests/testdata/numeric_nullable.arrow b/core/tests/testdata/numeric_nullable.arrow new file mode 100644 index 0000000..a9a70b7 Binary files /dev/null and b/core/tests/testdata/numeric_nullable.arrow differ diff --git a/py/src/pg_schema.rs b/py/src/pg_schema.rs index c5f5a41..5588a58 100644 --- a/py/src/pg_schema.rs +++ b/py/src/pg_schema.rs @@ -229,6 +229,7 @@ impl From for PostgresType { pgpq::pg_schema::PostgresType::Time => PostgresType::Time(Time), pgpq::pg_schema::PostgresType::Timestamp => PostgresType::Timestamp(Timestamp), pgpq::pg_schema::PostgresType::Interval => PostgresType::Interval(Interval), + pgpq::pg_schema::PostgresType::Decimal => todo!(), pgpq::pg_schema::PostgresType::List(inner) => { PostgresType::List(List::new((*inner).into())) }