Skip to content
This repository has been archived by the owner on Dec 10, 2024. It is now read-only.

Commit

Permalink
Refactor to not use dyn array
Browse files Browse the repository at this point in the history
  • Loading branch information
brad-richardson committed Jul 18, 2024
1 parent 7310917 commit 4f2bd39
Showing 1 changed file with 88 additions and 120 deletions.
208 changes: 88 additions & 120 deletions src/osm_arrow.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
use std::alloc::LayoutErr;
use std::fmt;
use std::sync::Arc;

use arrow::array::builder::{
ArrayBuilder, BooleanBuilder, Int64Builder, ListBuilder, MapBuilder, StringBuilder,
StructBuilder,
};
use arrow::array::{make_builder, Float64Builder, Int32Builder};
use arrow::datatypes::DataType;
use arrow::datatypes::Field;
use arrow::datatypes::Fields;
use arrow::datatypes::Schema;
use arrow::array::{make_builder, Array, ArrayRef, Float64Builder, Int32Builder, TimestampMillisecondBuilder};
use arrow::datatypes::{DataType, TimeUnit, Field, Fields, Schema};
use arrow::error::ArrowError;
use arrow::ipc::Utf8Builder;
use arrow::record_batch::RecordBatch;
use osmpbf::WayNodeLocationsIter;

#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub enum OSMType {
Expand Down Expand Up @@ -87,7 +87,7 @@ pub fn osm_arrow_schema() -> Schema {
true,
),
Field::new("changeset", DataType::Int64, true),
Field::new("timestamp", DataType::Int64, true),
Field::new("timestamp", DataType::Timestamp(TimeUnit::Millisecond, None), true),
Field::new("uid", DataType::Int32, true),
Field::new("user", DataType::Utf8, true),
Field::new("version", DataType::Int32, true),
Expand All @@ -96,8 +96,20 @@ pub fn osm_arrow_schema() -> Schema {
}

pub struct OSMArrowBuilder {
builders: Vec<Box<dyn ArrayBuilder>>,
schema: Arc<Schema>,

id_builder: Box<Int64Builder>,
tags_builder: Box<MapBuilder<StringBuilder, StringBuilder>>,
lat_builder: Box<Float64Builder>,
lon_builder: Box<Float64Builder>,
nodes_builder: Box<ListBuilder<StructBuilder>>,
members_builder: Box<ListBuilder<StructBuilder>>,
changeset_builder: Box<Int64Builder>,
timestamp_builder: Box<TimestampMillisecondBuilder>,
uid_builder: Box<Int32Builder>,
user_builder: Box<StringBuilder>,
version_builder: Box<Int32Builder>,
visible_builder: Box<BooleanBuilder>,
}

impl Default for OSMArrowBuilder {
Expand All @@ -108,39 +120,47 @@ impl Default for OSMArrowBuilder {

impl OSMArrowBuilder {
pub fn new() -> Self {
let schema = osm_arrow_schema();

let mut builders: Vec<Box<dyn ArrayBuilder>> = Vec::new();
for field in schema.fields() {
// Custom builders for `tags`, `nodes`, and `members` as `make_builder` creates a more complex builder structure or doesn't support the type
if field.name() == "tags" {
builders.push(Box::new(MapBuilder::new(
None,
StringBuilder::new(),
StringBuilder::new(),
)));
} else if field.name() == "nds" {
builders.push(Box::new(ListBuilder::new(StructBuilder::from_fields(
vec![Field::new("ref", DataType::Int64, true)],
0,
))));
} else if field.name() == "members" {
builders.push(Box::new(ListBuilder::new(StructBuilder::from_fields(
vec![
Field::new("type", DataType::Utf8, true),
Field::new("ref", DataType::Int64, true),
Field::new("role", DataType::Utf8, true),
],
0,
))));
} else {
builders.push(make_builder(field.data_type(), 0));
}
}
let id_builder = Box::new(Int64Builder::new());
let tags_builder = Box::new(MapBuilder::new(
None,
StringBuilder::new(),
StringBuilder::new(),
));
let lat_builder = Box::new(Float64Builder::new());
let lon_builder = Box::new(Float64Builder::new());
let nodes_builder = Box::new(ListBuilder::new(StructBuilder::from_fields(
vec![Field::new("ref", DataType::Int64, true)],
0,
)));
let members_builder = Box::new(ListBuilder::new(StructBuilder::from_fields(
vec![
Field::new("type", DataType::Utf8, true),
Field::new("ref", DataType::Int64, true),
Field::new("role", DataType::Utf8, true),
],
0,
)));
let changeset_builder = Box::new(Int64Builder::new());
let timestamp_builder = Box::new(TimestampMillisecondBuilder::new());
let uid_builder = Box::new(Int32Builder::new());
let user_builder = Box::new(StringBuilder::new());
let version_builder = Box::new(Int32Builder::new());
let visible_builder = Box::new(BooleanBuilder::new());

OSMArrowBuilder {
builders,
schema: Arc::new(schema),
schema: Arc::new(osm_arrow_schema()),
id_builder,
tags_builder,
lat_builder,
lon_builder,
nodes_builder,
members_builder,
changeset_builder,
timestamp_builder,
uid_builder,
user_builder,
version_builder,
visible_builder,
}
}

Expand Down Expand Up @@ -169,42 +189,20 @@ impl OSMArrowBuilder {
// Track approximate size of inserted data, starting with known constant sizes
let mut est_size_bytes = 64usize;

self.builders[0]
.as_any_mut()
.downcast_mut::<Int64Builder>()
.unwrap()
.append_value(id);
self.id_builder.append_value(id);

let tags_builder = self.builders[1]
.as_any_mut()
.downcast_mut::<MapBuilder<StringBuilder, StringBuilder>>()
.unwrap();
for (key, value) in tags_iter {
est_size_bytes += key.len() + value.len();
tags_builder.keys().append_value(key);
tags_builder.values().append_value(value);
self.tags_builder.keys().append_value(key);
self.tags_builder.values().append_value(value);
}
let _ = tags_builder.append(true);
let _ = self.tags_builder.append(true);

self.builders[2]
.as_any_mut()
.downcast_mut::<Float64Builder>()
.unwrap()
.append_option(lat);
self.builders[3]
.as_any_mut()
.downcast_mut::<Float64Builder>()
.unwrap()
.append_option(lon);
self.lat_builder.append_option(lat);
self.lon_builder.append_option(lon);

// Derived from https://docs.rs/arrow/latest/arrow/array/struct.StructBuilder.html
let nodes_builder = self.builders[4]
.as_any_mut()
.downcast_mut::<ListBuilder<StructBuilder>>()
.unwrap();

let struct_builder = nodes_builder.values();

let struct_builder = self.nodes_builder.values();
for node_id in nodes_iter {
est_size_bytes += 8usize;
struct_builder
Expand All @@ -213,17 +211,9 @@ impl OSMArrowBuilder {
.append_value(node_id);
struct_builder.append(true);
}
self.nodes_builder.append(true);

nodes_builder.append(true);

// Derived from https://docs.rs/arrow/latest/arrow/array/struct.StructBuilder.html
let members_builder = self.builders[5]
.as_any_mut()
.downcast_mut::<ListBuilder<StructBuilder>>()
.unwrap();

let members_struct_builder = members_builder.values();

let members_struct_builder = self.members_builder.values();
for (osm_type, ref_, role) in members_iter {
// Rough size to avoid unwrapping, role should be fairly short.
est_size_bytes += 10usize;
Expand All @@ -245,55 +235,33 @@ impl OSMArrowBuilder {

members_struct_builder.append(true);
}
self.members_builder.append(true);

members_builder.append(true);

self.builders[6]
.as_any_mut()
.downcast_mut::<Int64Builder>()
.unwrap()
.append_option(changeset);
self.builders[7]
.as_any_mut()
.downcast_mut::<Int64Builder>()
.unwrap()
.append_option(timestamp_ms);
self.builders[8]
.as_any_mut()
.downcast_mut::<Int32Builder>()
.unwrap()
.append_option(uid);
self.builders[9]
.as_any_mut()
.downcast_mut::<StringBuilder>()
.unwrap()
.append_option(user);
self.builders[10]
.as_any_mut()
.downcast_mut::<Int32Builder>()
.unwrap()
.append_option(version);
self.builders[11]
.as_any_mut()
.downcast_mut::<BooleanBuilder>()
.unwrap()
.append_option(visible);
self.changeset_builder.append_option(changeset);
self.timestamp_builder.append_option(timestamp_ms);
self.uid_builder.append_option(uid);
self.user_builder.append_option(user);
self.version_builder.append_option(version);
self.visible_builder.append_option(visible);

// // TODO - write this if not writing with partitions
// self.builders[12]
// .as_any_mut()
// .downcast_mut::<StringBuilder>()
// .unwrap()
// .append_value(type_.to_string());
est_size_bytes
}

pub fn finish(&mut self) -> Result<RecordBatch, ArrowError> {
let array_refs = self
.builders
.iter_mut()
.map(|builder| builder.finish())
.collect();
let array_refs: Vec<ArrayRef> = vec![
Arc::new(self.id_builder.finish()),
Arc::new(self.tags_builder.finish()),
Arc::new(self.lat_builder.finish()),
Arc::new(self.lon_builder.finish()),
Arc::new(self.nodes_builder.finish()),
Arc::new(self.members_builder.finish()),
Arc::new(self.changeset_builder.finish()),
Arc::new(self.timestamp_builder.finish()),
Arc::new(self.uid_builder.finish()),
Arc::new(self.user_builder.finish()),
Arc::new(self.version_builder.finish()),
Arc::new(self.visible_builder.finish()),
];

RecordBatch::try_new(self.schema.clone(), array_refs)
}
Expand Down

0 comments on commit 4f2bd39

Please sign in to comment.