Skip to content

Commit

Permalink
left joins working
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jun 17, 2020
1 parent aa7cabc commit 97e49be
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 28 deletions.
2 changes: 1 addition & 1 deletion src/fmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ impl Display for AnyType<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let width = 15;
match self {
AnyType::Null => write!(f, "{:width$}", "null", width = width),
AnyType::Null => write!(f, "{:>width$}", "null", width = width),
AnyType::U32(v) => write!(f, "{:width$}", v, width = width),
AnyType::I32(v) => write!(f, "{:width$}", v, width = width),
AnyType::I64(v) => write!(f, "{:width$}", v, width = width),
Expand Down
162 changes: 135 additions & 27 deletions src/frame/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,12 @@ use crate::{datatypes::UInt32Chunked, prelude::*, series::chunked_array::Chunked
use arrow::compute::TakeOptions;
use arrow::datatypes::{ArrowPrimitiveType, Field, Schema};
use fnv::{FnvBuildHasher, FnvHashMap};
use std::collections::HashSet;
use std::collections::{HashMap, HashSet};
use std::hash::Hash;

/// Hash join a and b.
/// b should be the shorter relation.
fn hash_join<T>(
a: impl Iterator<Item = Option<T>>,
fn prepare_hashed_relation<T>(
b: impl Iterator<Item = Option<T>>,
) -> Vec<(usize, usize)>
) -> HashMap<T, Vec<usize>, FnvBuildHasher>
where
T: Hash + Eq + Copy,
{
Expand All @@ -24,6 +21,19 @@ where
hashmap.entry(key).or_insert_with(Vec::new).push(idx)
}
});
hashmap
}

/// Hash join a and b.
/// b should be the shorter relation.
fn hash_join<T>(
a: impl Iterator<Item = Option<T>>,
b: impl Iterator<Item = Option<T>>,
) -> Vec<(usize, usize)>
where
T: Hash + Eq + Copy,
{
let hashmap = prepare_hashed_relation(b);

let mut results = Vec::new();
a.enumerate().for_each(|(idx_a, o)| {
Expand All @@ -37,8 +47,38 @@ where
results
}

fn hash_join_left<T>(
a: impl Iterator<Item = Option<T>>,
b: impl Iterator<Item = Option<T>>,
) -> Vec<(usize, Option<usize>)>
where
T: Hash + Eq + Copy,
{
let hashmap = prepare_hashed_relation(b);
let mut results = Vec::new();

a.enumerate().for_each(|(idx_a, o)| {
match o {
// left value is null, so right is automatically null
None => results.push((idx_a, None)),
Some(key) => {
match hashmap.get(&key) {
// left and right matches
Some(indexes_b) => {
results.extend(indexes_b.iter().map(|&idx_b| (idx_a, Some(idx_b))))
}
// only left values, right = null
None => results.push((idx_a, None)),
}
}
}
});
results
}

pub trait HashJoin<T> {
fn hash_join(&self, other: &ChunkedArray<T>) -> (UInt32Chunked, UInt32Chunked);
fn hash_join_left(&self, other: &ChunkedArray<T>) -> (UInt32Chunked, UInt32Chunked);
}

impl<T> HashJoin<T> for ChunkedArray<T>
Expand Down Expand Up @@ -82,28 +122,36 @@ where
});
(left.finish(), right.finish())
}
}

impl DataFrame {
pub fn join(&self, other: &DataFrame, left_on: &str, right_on: &str) -> Result<DataFrame> {
let s_left = self.select(left_on).ok_or(PolarsError::NotFound)?;
let s_right = other.select(right_on).ok_or(PolarsError::NotFound)?;

macro_rules! hash_join {
($s_right:ident, $ca_left:ident, $type_:ident) => {{
let ca_right = $s_right.$type_()?;
$ca_left.hash_join(ca_right)
}};
}
fn hash_join_left(&self, other: &ChunkedArray<T>) -> (UInt32Chunked, UInt32Chunked) {
let join_tuples = hash_join_left(self.iter(), other.iter());
// Create the UInt32Chunked arrays. These can be used to take values from both the dataframes.
let mut left =
PrimitiveChunkedBuilder::<UInt32Type>::new("left_take_idx", join_tuples.len());
let mut right =
PrimitiveChunkedBuilder::<UInt32Type>::new("right_take_idx", join_tuples.len());
join_tuples
.into_iter()
.for_each(|(idx_left, opt_idx_right)| {
left.append_value(idx_left as u32);

let (take_left, take_right) = match s_left {
Series::UInt32(ca_left) => hash_join!(s_right, ca_left, u32),
Series::Int32(ca_left) => hash_join!(s_right, ca_left, i32),
Series::Int64(ca_left) => hash_join!(s_right, ca_left, i64),
Series::Bool(ca_left) => hash_join!(s_right, ca_left, bool),
_ => unimplemented!(),
};
match opt_idx_right {
Some(idx) => right.append_value(idx as u32),
None => right.append_null(),
}
});
(left.finish(), right.finish())
}
}

impl DataFrame {
fn finish_join(
&self,
other: &DataFrame,
take_left: &UInt32Chunked,
take_right: &UInt32Chunked,
right_on: &str,
) -> Result<DataFrame> {
let mut df_left = self.take(&take_left, Some(TakeOptions::default()))?;
let mut df_right = other.take(&take_right, Some(TakeOptions::default()))?;
df_right.drop(right_on);
Expand All @@ -129,14 +177,61 @@ impl DataFrame {
df_left.hstack(&df_right.columns);
Ok(df_left)
}

pub fn inner_join(
&self,
other: &DataFrame,
left_on: &str,
right_on: &str,
) -> Result<DataFrame> {
let s_left = self.select(left_on).ok_or(PolarsError::NotFound)?;
let s_right = other.select(right_on).ok_or(PolarsError::NotFound)?;

macro_rules! hash_join {
($s_right:ident, $ca_left:ident, $type_:ident) => {{
let ca_right = $s_right.$type_()?;
$ca_left.hash_join(ca_right)
}};
}

let (take_left, take_right) = match s_left {
Series::UInt32(ca_left) => hash_join!(s_right, ca_left, u32),
Series::Int32(ca_left) => hash_join!(s_right, ca_left, i32),
Series::Int64(ca_left) => hash_join!(s_right, ca_left, i64),
Series::Bool(ca_left) => hash_join!(s_right, ca_left, bool),
_ => unimplemented!(),
};
self.finish_join(other, &take_left, &take_right, right_on)
}

pub fn left_join(&self, other: &DataFrame, left_on: &str, right_on: &str) -> Result<DataFrame> {
let s_left = self.select(left_on).ok_or(PolarsError::NotFound)?;
let s_right = other.select(right_on).ok_or(PolarsError::NotFound)?;

macro_rules! hash_join {
($s_right:ident, $ca_left:ident, $type_:ident) => {{
let ca_right = $s_right.$type_()?;
$ca_left.hash_join_left(ca_right)
}};
}

let (take_left, take_right) = match s_left {
Series::UInt32(ca_left) => hash_join!(s_right, ca_left, u32),
Series::Int32(ca_left) => hash_join!(s_right, ca_left, i32),
Series::Int64(ca_left) => hash_join!(s_right, ca_left, i64),
Series::Bool(ca_left) => hash_join!(s_right, ca_left, bool),
_ => unimplemented!(),
};
self.finish_join(other, &take_left, &take_right, right_on)
}
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn test_hash_join() {
fn test_inner_join() {
let s0 = Series::init("days", [0, 1, 2].as_ref());
let s1 = Series::init("temp", [22.1, 19.9, 7.].as_ref());
let s2 = Series::init("rain", [0.2, 0.1, 0.3].as_ref());
Expand All @@ -146,7 +241,20 @@ mod test {
let s1 = Series::init("rain", [0.1, 0.2, 0.3, 0.4].as_ref());
let rain = DataFrame::new_from_columns(vec![s0, s1]).unwrap();

let joined = temp.join(&rain, "days", "days");
let joined = temp.inner_join(&rain, "days", "days");
println!("{}", joined.unwrap())
}

#[test]
fn test_left_join() {
let s0 = Series::init("days", [0, 1, 2, 3, 4].as_ref());
let s1 = Series::init("temp", [22.1, 19.9, 7., 2., 3.].as_ref());
let temp = DataFrame::new_from_columns(vec![s0, s1]).unwrap();

let s0 = Series::init("days", [1, 2].as_ref());
let s1 = Series::init("rain", [0.1, 0.2].as_ref());
let rain = DataFrame::new_from_columns(vec![s0, s1]).unwrap();
let joined = temp.left_join(&rain, "days", "days");
println!("{}", joined.unwrap())
}
}

0 comments on commit 97e49be

Please sign in to comment.