Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for computing summary stats #342

Merged
merged 4 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 88 additions & 1 deletion crates/ark/src/data_explorer/r_data_explorer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use amalthea::comm::data_explorer_comm::ColumnProfileResult;
use amalthea::comm::data_explorer_comm::ColumnProfileType;
use amalthea::comm::data_explorer_comm::ColumnSchema;
use amalthea::comm::data_explorer_comm::ColumnSortKey;
use amalthea::comm::data_explorer_comm::ColumnSummaryStats;
use amalthea::comm::data_explorer_comm::DataExplorerBackendReply;
use amalthea::comm::data_explorer_comm::DataExplorerBackendRequest;
use amalthea::comm::data_explorer_comm::DataExplorerFrontendEvent;
Expand All @@ -29,6 +30,9 @@ use amalthea::comm::data_explorer_comm::SearchSchemaFeatures;
use amalthea::comm::data_explorer_comm::SetRowFiltersFeatures;
use amalthea::comm::data_explorer_comm::SetRowFiltersParams;
use amalthea::comm::data_explorer_comm::SetSortColumnsParams;
use amalthea::comm::data_explorer_comm::SummaryStatsBoolean;
use amalthea::comm::data_explorer_comm::SummaryStatsNumber;
use amalthea::comm::data_explorer_comm::SummaryStatsString;
use amalthea::comm::data_explorer_comm::SupportedFeatures;
use amalthea::comm::data_explorer_comm::TableData;
use amalthea::comm::data_explorer_comm::TableSchema;
Expand Down Expand Up @@ -473,6 +477,26 @@ impl RDataExplorer {
frequency_table: None,
}
},
ColumnProfileType::SummaryStats => {
let summary_stats =
r_task(|| self.r_summary_stats(request.column_index as i32));
ColumnProfileResult {
null_count: None,
summary_stats: match summary_stats {
Err(err) => {
log::error!(
"Error getting summary stats for column {}: {}",
request.column_index,
err
);
None
},
Ok(stats) => Some(stats),
},
histogram: None,
frequency_table: None,
}
},
_ => {
// Other kinds of column profiles are not yet
// implemented in R
Expand Down Expand Up @@ -577,6 +601,66 @@ impl RDataExplorer {
Ok(result.try_into()?)
}

fn r_summary_stats(&self, column_index: i32) -> anyhow::Result<ColumnSummaryStats> {
// Get the column to compute summary stats for
let column = tbl_get_column(self.table.get().sexp, column_index, self.shape.kind)?;
let dtype = display_type(column.sexp);

let call_summary_fn = |fun| {
RFunction::new("", fun)
.param("column", column)
.param("filtered_indices", match &self.filtered_indices {
Some(indices) => RObject::try_from(indices)?,
None => RObject::null(),
})
.call_in(ARK_ENVS.positron_ns)
};

let mut stats = ColumnSummaryStats {
type_display: dtype.clone(),
number_stats: None,
string_stats: None,
boolean_stats: None,
};

match dtype {
ColumnDisplayType::Number => {
let r_stats: HashMap<String, String> =
call_summary_fn("number_summary_stats")?.try_into()?;

stats.number_stats = Some(SummaryStatsNumber {
min_value: r_stats["min_value"].clone(),
max_value: r_stats["max_value"].clone(),
mean: r_stats["mean"].clone(),
median: r_stats["median"].clone(),
stdev: r_stats["stdev"].clone(),
});
},
ColumnDisplayType::String => {
let r_stats: HashMap<String, i32> =
call_summary_fn("string_summary_stats")?.try_into()?;

stats.string_stats = Some(SummaryStatsString {
num_empty: r_stats["num_empty"].clone() as i64,
num_unique: r_stats["num_unique"].clone() as i64,
});
},
ColumnDisplayType::Boolean => {
let r_stats: HashMap<String, i32> =
call_summary_fn("boolean_summary_stats")?.try_into()?;

stats.boolean_stats = Some(SummaryStatsBoolean {
true_count: r_stats["true_count"].clone() as i64,
false_count: r_stats["false_count"].clone() as i64,
});
},
_ => {
bail!("Summary stats not implemented for type: {:?}", dtype);
},
}
Ok(stats)
}

/// Sort the rows of the data object according to the sort keys in
/// self.sort_keys.
///
Expand Down Expand Up @@ -761,7 +845,10 @@ impl RDataExplorer {
supported_features: SupportedFeatures {
get_column_profiles: GetColumnProfilesFeatures {
supported: true,
supported_types: vec![ColumnProfileType::NullCount],
supported_types: vec![
ColumnProfileType::NullCount,
ColumnProfileType::SummaryStats,
],
},
search_schema: SearchSchemaFeatures { supported: false },
set_row_filters: SetRowFiltersFeatures {
Expand Down
29 changes: 29 additions & 0 deletions crates/ark/src/modules/positron/r_data_explorer.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,35 @@
}
}

number_summary_stats <- function(column, filtered_indices) {
col <- col_filter_indices(column, filtered_indices)

format(c(
dfalbel marked this conversation as resolved.
Show resolved Hide resolved
min_value = min(col, na.rm = TRUE),
max_value = max(col, na.rm = TRUE),
mean = mean(col, na.rm = TRUE),
median = stats::median(col, na.rm = TRUE),
stdev = stats::sd(col, na.rm = TRUE)
))
}

string_summary_stats <- function(column, filtered_indices) {
col <- col_filter_indices(column, filtered_indices)
c(num_empty = sum(!nzchar(col)), num_unique = length(unique(col)))
}

boolean_summary_stats <- function(column, filtered_indices) {
col <- col_filter_indices(column, filtered_indices)
c(true_count = sum(col, na.rm = TRUE), false_count = sum(!col, na.rm = TRUE))
}

col_filter_indices <- function(col, idx = NULL) {
if (!is.null(idx)) {
col <- col[idx]
}
col
}

.ps.filter_rows <- function(table, row_filters) {
# Are we working with a matrix here?
is_matrix <- is.matrix(table)
Expand Down
67 changes: 67 additions & 0 deletions crates/ark/tests/data_explorer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ use amalthea::comm::data_explorer_comm::SearchFilterParams;
use amalthea::comm::data_explorer_comm::SearchFilterType;
use amalthea::comm::data_explorer_comm::SetRowFiltersParams;
use amalthea::comm::data_explorer_comm::SetSortColumnsParams;
use amalthea::comm::data_explorer_comm::SummaryStatsBoolean;
use amalthea::comm::data_explorer_comm::SummaryStatsNumber;
use amalthea::comm::data_explorer_comm::SummaryStatsString;
use amalthea::comm::event::CommManagerEvent;
use amalthea::socket;
use ark::data_explorer::r_data_explorer::DataObjectEnvInfo;
Expand Down Expand Up @@ -682,6 +685,70 @@ fn test_data_explorer() {
assert_eq!(num_rows, 3);
});

// --- summary stats ---

// Create a data frame with some numbers, characters and booleans to test
// summary statistics.
r_parse_eval0(
"df <- data.frame(num = c(1, 2, 3, NA), char = c('a', 'a', '', NA), bool = c(TRUE, TRUE, FALSE, NA))",
R_ENVS.global,
)
.unwrap();

// Open the fibo data set in the data explorer.
let socket = open_data_explorer(String::from("df"));

// Ask for summary stats for the columns
let req = DataExplorerBackendRequest::GetColumnProfiles(GetColumnProfilesParams {
profiles: (0..3)
.map(|i| ColumnProfileRequest {
column_index: i,
profile_type: ColumnProfileType::SummaryStats,
})
.collect(),
});

assert_match!(socket_rpc(&socket, req),
DataExplorerBackendReply::GetColumnProfilesReply(data) => {
// We asked for summary stats for all 3 columns
assert!(data.len() == 3);

// The first column is numeric and has 3 non-NA values.
assert!(data[0].summary_stats.is_some());
let number_stats = data[0].summary_stats.clone().unwrap().number_stats;
assert!(number_stats.is_some());
let number_stats = number_stats.unwrap();
assert_eq!(number_stats, SummaryStatsNumber {
min_value: String::from("1"),
max_value: String::from("3"),
mean: String::from("2"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be good to have a test with fractional values (if we do need to format on the backend side).

median: String::from("2"),
stdev: String::from("1"),
});

// The second column is a character column
assert!(data[1].summary_stats.is_some());
let string_stats = data[1].summary_stats.clone().unwrap().string_stats;
assert!(string_stats.is_some());
let string_stats = string_stats.unwrap();
assert_eq!(string_stats, SummaryStatsString {
num_empty: 1,
num_unique: 3, // NA's are counted as unique values
});

// The third column is boolean
assert!(data[2].summary_stats.is_some());
let boolean_stats = data[2].summary_stats.clone().unwrap().boolean_stats;
assert!(boolean_stats.is_some());
let boolean_stats = boolean_stats.unwrap();
assert_eq!(boolean_stats, SummaryStatsBoolean {
true_count: 2,
false_count: 1,
});

}
);

// --- search filters ---

// Create a data frame with a bunch of words to use for regex testing.
Expand Down
64 changes: 62 additions & 2 deletions crates/harp/src/object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,8 @@ impl TryFrom<&Vec<i32>> for RObject {
}
}

// Converts an R named character vector to a HashMap<String, String>
// Note: Duplicated names are silently ignored, and only the first occurence is kept.
impl TryFrom<RObject> for HashMap<String, String> {
type Error = crate::error::Error;
fn try_from(value: RObject) -> Result<Self, Self::Error> {
Expand All @@ -923,7 +925,7 @@ impl TryFrom<RObject> for HashMap<String, String> {
let n = Rf_xlength(names);
let mut map = HashMap::<String, String>::with_capacity(n as usize);

for i in 0..Rf_xlength(names) {
for i in (0..Rf_xlength(names)).rev() {
// Translate the name and value into Rust strings.
let lhs = r_chr_get_owned_utf8(names, i)?;
let rhs = r_chr_get_owned_utf8(value, i)?;
Expand All @@ -936,6 +938,36 @@ impl TryFrom<RObject> for HashMap<String, String> {
}
}

// Converts an R named integer vector to a HashMap<String, i32>
// Note: Duplicated names are silently ignored, and only the first occurence is kept.
impl TryFrom<RObject> for HashMap<String, i32> {
type Error = crate::error::Error;
fn try_from(value: RObject) -> Result<Self, Self::Error> {
unsafe {
r_assert_type(*value, &[INTSXP, VECSXP])?;

let mut protect = RProtect::new();
let names = protect.add(Rf_getAttrib(*value, R_NamesSymbol));
r_assert_type(names, &[STRSXP])?;

let value = protect.add(Rf_coerceVector(*value, INTSXP));

let n = Rf_xlength(names);
let mut map = HashMap::<String, i32>::with_capacity(n as usize);

for i in (0..Rf_xlength(names)).rev() {
// Translate the name and value into Rust strings.
let name = r_chr_get_owned_utf8(names, i)?;
let val = r_int_get(value, i);

map.insert(name, val);
}

Ok(map)
}
}
}

// Converts a named R object into a HashMap<String, RObject> whose names are used as keys.
// Duplicated names are silently ignored, and only the first occurence is kept.
impl TryFrom<RObject> for HashMap<String, RObject> {
Expand All @@ -952,7 +984,7 @@ impl TryFrom<RObject> for HashMap<String, RObject> {
// iterate in the reverse order to keep the first occurence of a name
for i in (0..n).rev() {
let name = r_chr_get_owned_utf8(names, i)?;
let value = RObject::new(VECTOR_ELT(*value, i));
let value: RObject = RObject::new(VECTOR_ELT(*value, i));
map.insert(name, value);
}

Expand Down Expand Up @@ -1381,6 +1413,34 @@ mod tests {
assert_eq!(out.get("pepperoni").unwrap(), "OK");
assert_eq!(out.get("sausage").unwrap(), "OK");
assert_eq!(out.get("pineapple").unwrap(), "NOT OK");


let v = r_parse_eval0("c(x = 'a', y = 'b', z = 'c')", R_ENVS.global).unwrap();
let out: HashMap<String, String> = v.try_into().unwrap();
assert_eq!(out["x"], "a"); // duplicated name is ignored and first is kept
assert_eq!(out["y"], "b");
}
}

#[test]
#[allow(non_snake_case)]
fn test_tryfrom_RObject_hashmap_i32() {
r_test! {
// Create a map of pizza toppings to their acceptability.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wat 😋

let v = r_parse_eval0("list(x = 1L, y = 2L, x = 3L)", R_ENVS.global).unwrap();
assert_eq!(v.length(), 3 as isize);

// Ensure we created an object of the same size as the map.
let out: HashMap<String, i32> = v.try_into().unwrap();

// Ensure we can convert the object back into a map with the same values.
assert_eq!(out["x"], 1); // duplicated name is ignored and first is kept
assert_eq!(out["y"], 2);

let v = r_parse_eval0("c(x = 1L, y = 2L, x = 3L)", R_ENVS.global).unwrap();
let out: HashMap<String, i32> = v.try_into().unwrap();
assert_eq!(out["x"], 1); // duplicated name is ignored and first is kept
assert_eq!(out["y"], 2);
}
}

Expand Down
Loading