Skip to content

Commit

Permalink
test numeric and numeric nullable
Browse files Browse the repository at this point in the history
  • Loading branch information
livinlefevreloca committed Apr 14, 2024
1 parent 8a52b2c commit eefe8ba
Show file tree
Hide file tree
Showing 17 changed files with 199 additions and 94 deletions.
142 changes: 107 additions & 35 deletions core/src/decoders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use arrow_array::{self, ArrayRef};
use arrow_array::{
BooleanArray, Date32Array, DurationMicrosecondArray, Float32Array, Float64Array,
GenericStringArray, Int16Array, Int32Array, Int64Array, Time64MicrosecondArray,
TimestampMicrosecondArray,
TimestampMicrosecondArray, NullArray
};
use std::fmt::Debug;
use std::sync::Arc;
Expand Down Expand Up @@ -87,6 +87,7 @@ pub(crate) trait Decode {
fn finish(&mut self, column_len: usize) -> ArrayRef;
fn column_len(&self) -> usize;
fn name(&self) -> String;
fn is_null(&self) -> bool;
}

macro_rules! impl_decode {
Expand Down Expand Up @@ -121,11 +122,18 @@ macro_rules! impl_decode {
}

fn finish(&mut self, column_len: usize) -> ArrayRef {
let mut data = std::mem::take(&mut self.arr);
data.resize(column_len, None);
self.arr.resize(column_len, None);
if self.is_null() {
return Arc::new(NullArray::new(self.arr.len())) as ArrayRef;
}
let data = std::mem::take(&mut self.arr);
let array = Arc::new($array_kind::from(data));
array as ArrayRef
}

fn is_null(&self) -> bool {
self.arr.iter().all(|v| v.is_none())
}
}
};
}
Expand Down Expand Up @@ -176,11 +184,18 @@ macro_rules! impl_decode_fallible {
}

fn finish(&mut self, column_len: usize) -> ArrayRef {
let mut data = std::mem::take(&mut self.arr);
data.resize(column_len, None);
self.arr.resize(column_len, None);
if self.is_null() {
return Arc::new(NullArray::new(self.arr.len())) as ArrayRef;
}
let data = std::mem::take(&mut self.arr);
let array = Arc::new($array_kind::from(data));
array as ArrayRef
}

fn is_null(&self) -> bool {
self.arr.iter().all(|v| v.is_none())
}
}
};
}
Expand Down Expand Up @@ -235,11 +250,18 @@ macro_rules! impl_decode_variable_size {
}

fn finish(&mut self, column_len: usize) -> ArrayRef {
let mut data = std::mem::take(&mut self.arr);
data.resize(column_len, None);
self.arr.resize(column_len, None);
if self.is_null() {
return Arc::new(NullArray::new(self.arr.len())) as ArrayRef;
}
let data = std::mem::take(&mut self.arr);
let array = Arc::new($array_kind::<$offset_size>::from(data));
array as ArrayRef
}

fn is_null(&self) -> bool {
self.arr.iter().all(|v| v.is_none())
}
}
};
}
Expand Down Expand Up @@ -293,7 +315,6 @@ fn convert_pg_timestamp_to_arrow_timestamp_microseconds(
timestamp_us: i64,
) -> Result<i64, ErrorKind> {
// adjust the timestamp from microseconds since 2000-01-01 to microseconds since 1970-01-01 checking for overflows and underflow
println!("timestamp_us: {}", timestamp_us);
timestamp_us
.checked_add(PG_BASE_TIMESTAMP_OFFSET_US)
.ok_or_else(|| ErrorKind::Decode {
Expand Down Expand Up @@ -352,13 +373,20 @@ impl Decode for TimestampTzMicrosecondDecoder {
}

fn finish(&mut self, column_len: usize) -> ArrayRef {
let mut data = std::mem::take(&mut self.arr);
data.resize(column_len, None);
self.arr.resize(column_len, None);
if self.is_null() {
return Arc::new(NullArray::new(self.arr.len())) as ArrayRef;
}
let data = std::mem::take(&mut self.arr);
let array = Arc::new(
TimestampMicrosecondArray::from(data).with_timezone(self.timezone.to_string()),
);
array as ArrayRef
}

fn is_null(&self) -> bool {
self.arr.iter().all(|v| v.is_none())
}
}

/// Convert Postgres dates (days since 2000-01-01) to Arrow dates (days since 1970-01-01)
Expand Down Expand Up @@ -464,7 +492,7 @@ impl_decode_variable_size!(
},
0,
GenericStringArray,
i64
i32
);

pub struct BinaryDecoder {
Expand Down Expand Up @@ -495,10 +523,13 @@ impl Decode for BinaryDecoder {
}

fn finish(&mut self, column_len: usize) -> ArrayRef {
let mut data = std::mem::take(&mut self.arr);
data.resize(column_len, None);
self.arr.resize(column_len, None);
if self.is_null() {
return Arc::new(NullArray::new(self.arr.len())) as ArrayRef;
}

let mut builder: GenericByteBuilder<GenericBinaryType<i64>> = GenericByteBuilder::new();
let data = std::mem::take(&mut self.arr);
let mut builder: GenericByteBuilder<GenericBinaryType<i32>> = GenericByteBuilder::new();
for v in data {
match v {
Some(v) => builder.append_value(v),
Expand All @@ -507,6 +538,10 @@ impl Decode for BinaryDecoder {
}
Arc::new(builder.finish()) as ArrayRef
}

fn is_null(&self) -> bool {
self.arr.iter().all(|v| v.is_none())
}
}

pub struct JsonbDecoder {
Expand Down Expand Up @@ -544,7 +579,7 @@ fn parse_pg_decimal_to_string(data: Vec<u8>) -> Result<String, ErrorKind> {
let weight = i16::from_be_bytes(data[2..4].try_into().unwrap());
let sign = i16::from_be_bytes(data[4..6].try_into().unwrap());
let scale = i16::from_be_bytes(data[6..8].try_into().unwrap());
let digits: Vec<i16> = data[8..8 + ndigits as usize]
let digits: Vec<i16> = data[8..8 + (ndigits as usize) * (std::mem::size_of::<i16>())]
.chunks(2)
.map(|c| i16::from_be_bytes(c.try_into().unwrap()))
.collect();
Expand Down Expand Up @@ -604,22 +639,20 @@ fn parse_pg_decimal_to_string(data: Vec<u8>) -> Result<String, ErrorKind> {
dig -= d1 * place;
putit |= d1 > 0;
if putit {
decimal.push(d1 as u8 + b'0')
decimal.push(d1 as u8 + b'0');
}
place /= 10;
}
decimal.push(dig as u8 + b'0');
digits_index += 1;
}

// If scale is > 0 we have digits after the decimal point
if scale > 0 {
decimal.push(b'.');
}
}

let mut i = 0;
while i < scale {
// If scale is > 0 we have digits after the decimal point
if scale > 0 {
decimal.push(b'.');
}
while digits_index < ndigits {
let mut dig = if digits_index >= 0 && digits_index < ndigits {
digits[digits_index as usize]
} else {
Expand All @@ -635,24 +668,42 @@ fn parse_pg_decimal_to_string(data: Vec<u8>) -> Result<String, ErrorKind> {
place /= 10;
}
decimal.push(dig as u8 + b'0');
i += 1;
digits_index += 1;
}

// unwrap will not fail here as we know ever value in our decimal vec is ascii;
Ok(String::from_utf8(decimal).unwrap())
// trim trailing zeros and return the string
Ok(trim_trailing_zeros(decimal, scale))
}

pub struct DecimalDecoder {
pub fn trim_trailing_zeros(mut v: Vec<u8>, scale: i16) -> String {
let decimal_point_idx = v.iter().position(|&c| c == b'.');
match decimal_point_idx {

Check failure on line 680 in core/src/decoders.rs

View workflow job for this annotation

GitHub Actions / Clippy

you seem to be trying to use `match` for destructuring a single pattern. Consider using `if let`

Check failure on line 680 in core/src/decoders.rs

View workflow job for this annotation

GitHub Actions / Clippy

you seem to be trying to use `match` for destructuring a single pattern. Consider using `if let`
Some(idx) => {
if idx == v.len() - 1 {
v.push(b'0')
} else {
while v.len() - idx - 1 > scale as usize {
v.pop();
}
}
},
None => {}
};
let result = String::from_utf8(v).unwrap();
result

Check failure on line 693 in core/src/decoders.rs

View workflow job for this annotation

GitHub Actions / Clippy

returning the result of a `let` binding from a block

Check failure on line 693 in core/src/decoders.rs

View workflow job for this annotation

GitHub Actions / Clippy

returning the result of a `let` binding from a block
}

pub struct NumericDecoder {
name: String,
arr: Vec<Option<String>>,
}

impl_decode_variable_size!(
DecimalDecoder,
NumericDecoder,
parse_pg_decimal_to_string,
0,
GenericStringArray,
i64
i32
);

//
Expand All @@ -663,7 +714,7 @@ pub enum Decoder {
Int64(Int64Decoder),
Float32(Float32Decoder),
Float64(Float64Decoder),
Decimal(DecimalDecoder),
Numeric(NumericDecoder),
TimestampMicrosecond(TimestampMicrosecondDecoder),
TimestampTzMicrosecond(TimestampTzMicrosecondDecoder),
Date32(Date32Decoder),
Expand Down Expand Up @@ -703,7 +754,7 @@ impl Decoder {
name: name.to_string(),
arr: vec![],
}),
PostgresType::Decimal => Decoder::Decimal(DecimalDecoder {
PostgresType::Numeric => Decoder::Numeric(NumericDecoder {
name: name.to_string(),
arr: vec![],
}),
Expand Down Expand Up @@ -757,7 +808,7 @@ impl Decoder {
Decoder::Int64(ref mut decoder) => decoder.decode(buf),
Decoder::Float32(ref mut decoder) => decoder.decode(buf),
Decoder::Float64(ref mut decoder) => decoder.decode(buf),
Decoder::Decimal(ref mut decoder) => decoder.decode(buf),
Decoder::Numeric(ref mut decoder) => decoder.decode(buf),
Decoder::TimestampMicrosecond(ref mut decoder) => decoder.decode(buf),
Decoder::TimestampTzMicrosecond(ref mut decoder) => decoder.decode(buf),
Decoder::Date32(ref mut decoder) => decoder.decode(buf),
Expand All @@ -769,6 +820,7 @@ impl Decoder {
}
}

#[allow(dead_code)]
pub(crate) fn name(&self) -> String {
match *self {
Decoder::Boolean(ref decoder) => decoder.name(),
Expand All @@ -777,7 +829,7 @@ impl Decoder {
Decoder::Int64(ref decoder) => decoder.name(),
Decoder::Float32(ref decoder) => decoder.name(),
Decoder::Float64(ref decoder) => decoder.name(),
Decoder::Decimal(ref decoder) => decoder.name(),
Decoder::Numeric(ref decoder) => decoder.name(),
Decoder::TimestampMicrosecond(ref decoder) => decoder.name(),
Decoder::TimestampTzMicrosecond(ref decoder) => decoder.name(),
Decoder::Date32(ref decoder) => decoder.name(),
Expand All @@ -797,7 +849,7 @@ impl Decoder {
Decoder::Int64(ref decoder) => decoder.column_len(),
Decoder::Float32(ref decoder) => decoder.column_len(),
Decoder::Float64(ref decoder) => decoder.column_len(),
Decoder::Decimal(ref decoder) => decoder.column_len(),
Decoder::Numeric(ref decoder) => decoder.column_len(),
Decoder::TimestampMicrosecond(ref decoder) => decoder.column_len(),
Decoder::TimestampTzMicrosecond(ref decoder) => decoder.column_len(),
Decoder::Date32(ref decoder) => decoder.column_len(),
Expand All @@ -817,7 +869,7 @@ impl Decoder {
Decoder::Int64(ref mut decoder) => decoder.finish(column_len),
Decoder::Float32(ref mut decoder) => decoder.finish(column_len),
Decoder::Float64(ref mut decoder) => decoder.finish(column_len),
Decoder::Decimal(ref mut decoder) => decoder.finish(column_len),
Decoder::Numeric(ref mut decoder) => decoder.finish(column_len),
Decoder::TimestampMicrosecond(ref mut decoder) => decoder.finish(column_len),
Decoder::TimestampTzMicrosecond(ref mut decoder) => decoder.finish(column_len),
Decoder::Date32(ref mut decoder) => decoder.finish(column_len),
Expand All @@ -828,4 +880,24 @@ impl Decoder {
Decoder::Jsonb(ref mut decoder) => decoder.finish(column_len),
}
}

pub(crate) fn is_null(&self) -> bool {
match *self {
Decoder::Boolean(ref decoder) => decoder.is_null(),
Decoder::Int16(ref decoder) => decoder.is_null(),
Decoder::Int32(ref decoder) => decoder.is_null(),
Decoder::Int64(ref decoder) => decoder.is_null(),
Decoder::Float32(ref decoder) => decoder.is_null(),
Decoder::Float64(ref decoder) => decoder.is_null(),
Decoder::Numeric(ref decoder) => decoder.is_null(),
Decoder::TimestampMicrosecond(ref decoder) => decoder.is_null(),
Decoder::TimestampTzMicrosecond(ref decoder) => decoder.is_null(),
Decoder::Date32(ref decoder) => decoder.is_null(),
Decoder::Time64Microsecond(ref decoder) => decoder.is_null(),
Decoder::DurationMicrosecond(ref decoder) => decoder.is_null(),
Decoder::String(ref decoder) => decoder.is_null(),
Decoder::Binary(ref decoder) => decoder.is_null(),
Decoder::Jsonb(ref decoder) => decoder.is_null(),
}
}
}
16 changes: 13 additions & 3 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::HashMap;
use std::io::{BufRead, Seek};

use arrow_array::{Array, RecordBatch};
use arrow_array::RecordBatch;
use arrow_schema::Fields;
use arrow_schema::Schema;
use bytes::{Buf, BufMut, BytesMut};
Expand Down Expand Up @@ -305,7 +305,6 @@ impl<R: BufRead + Seek> PostgresBinaryToArrowDecoder<R> {
let tuple_len: u16 = match local_buf.consume_into_u16() {
Ok(len) => len,
Err(e) => {
println!("Error reading tuple length: {:?}", e);
return BatchDecodeResult::Error(e);
}
};
Expand Down Expand Up @@ -378,15 +377,26 @@ impl<R: BufRead + Seek> PostgresBinaryToArrowDecoder<R> {
// we are in a partial consume state. We will truncate the columns to the length
// of the shortest column and pick up the lost data in the next batch.
let column_len = self.decoders.iter().map(|d| d.column_len()).min().unwrap();

// Determine which columns in the batch are fully null so that we can alter the schema
// to reflect this.
let null_columns = self
.decoders
.iter()
.filter(|decoder| decoder.is_null())
.map(|decoder| decoder.name())
.collect::<Vec<String>>();

// For each decoder call its finish method to coerce the data into an Arrow array.
// and append the array to the columns vector.
let columns = self
.decoders
.iter_mut()
.map(|decoder| decoder.finish(column_len))
.collect();

// Create a new RecordBatch from the columns vector and return it.
let record_batch = RecordBatch::try_new(self.schema.clone().into(), columns)?;
let record_batch = RecordBatch::try_new(self.schema.clone().nullify_columns(&null_columns).into(), columns)?;

Ok(record_batch)
}
Expand Down
Loading

0 comments on commit eefe8ba

Please sign in to comment.