Skip to content

Commit

Permalink
Kludge the UI to get the statistics to show up minimally
Browse files Browse the repository at this point in the history
  • Loading branch information
wesm committed Apr 3, 2024
1 parent 73f12b8 commit d9342e1
Show file tree
Hide file tree
Showing 6 changed files with 311 additions and 112 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@
SearchSchemaResult,
SetRowFiltersRequest,
SetSortColumnsRequest,
SummaryStatsBoolean,
SummaryStatsNumber,
SummaryStatsString,
TableData,
TableSchema,
TableShape,
Expand Down Expand Up @@ -104,9 +107,7 @@ def _recompute_if_needed(self) -> bool:
return False

def get_schema(self, request: GetSchemaRequest):
return self._get_schema(
request.params.start_index, request.params.num_columns
).dict()
return self._get_schema(request.params.start_index, request.params.num_columns).dict()

def search_schema(self, request: SearchSchemaRequest):
return self._search_schema(
Expand Down Expand Up @@ -294,9 +295,11 @@ def __init__(
# search term changes, we discard the last search result. We
# might add an LRU cache here or something if it helps
# performance.
self._search_schema_last_result: Optional[
Tuple[str, List[ColumnSchema]]
] = None
self._search_schema_last_result: Optional[Tuple[str, List[ColumnSchema]]] = None

# squelch a warning from pandas 2.2.0 about the use below of
# fillna
pd_.set_option("future.no_silent_downcasting", True)

def invalidate_computations(self):
self.filtered_indices = self.view_indices = None
Expand Down Expand Up @@ -341,12 +344,7 @@ def _get_schema(self, column_start: int, num_columns: int) -> TableSchema:
column_start,
min(column_start + num_columns, len(self.table.columns)),
):
column_raw_name = self.table.columns[column_index]
column_name = str(column_raw_name)

col_schema = self._get_single_column_schema(
column_index, column_name
)
col_schema = self._get_single_column_schema(column_index)
column_schemas.append(col_schema)

return TableSchema(columns=column_schemas)
Expand All @@ -372,9 +370,7 @@ def _search_schema(
total_num_matches=len(matches),
)

def _search_schema_get_matches(
self, search_term: str
) -> List[ColumnSchema]:
def _search_schema_get_matches(self, search_term: str) -> List[ColumnSchema]:
matches = []
for column_index in range(len(self.table.columns)):
column_raw_name = self.table.columns[column_index]
Expand All @@ -384,9 +380,7 @@ def _search_schema_get_matches(
if search_term not in column_name.lower():
continue

col_schema = self._get_single_column_schema(
column_index, column_name
)
col_schema = self._get_single_column_schema(column_index)
matches.append(col_schema)

return matches
Expand All @@ -395,12 +389,13 @@ def _get_inferred_dtype(self, column_index: int):
from pandas.api.types import infer_dtype

if column_index not in self._inferred_dtypes:
self._inferred_dtypes[column_index] = infer_dtype(
self.table.iloc[:, column_index]
)
self._inferred_dtypes[column_index] = infer_dtype(self.table.iloc[:, column_index])
return self._inferred_dtypes[column_index]

def _get_single_column_schema(self, column_index: int, column_name: str):
def _get_single_column_schema(self, column_index: int):
column_raw_name = self.table.columns[column_index]
column_name = str(column_raw_name)

# TODO: pandas MultiIndex columns
# TODO: time zone for datetimetz datetime64[ns] types
dtype = self.dtypes.iloc[column_index]
Expand Down Expand Up @@ -443,16 +438,15 @@ def _get_data_values(

if self.view_indices is not None:
# If the table is either filtered or sorted, use a slice
# the view_indices to select the virtual range of values for the grid
# the view_indices to select the virtual range of values
# for the grid
view_slice = self.view_indices[row_start : row_start + num_rows]
columns = [col.take(view_slice) for col in columns]
indices = self.table.index.take(view_slice)
else:
# No filtering or sorting, just slice directly
indices = self.table.index[row_start : row_start + num_rows]
columns = [
col.iloc[row_start : row_start + num_rows] for col in columns
]
columns = [col.iloc[row_start : row_start + num_rows] for col in columns]

formatted_columns = [_pandas_format_values(col) for col in columns]

Expand Down Expand Up @@ -560,7 +554,7 @@ def _eval_filter(self, filt: RowFilter):

# Nulls are possible in the mask, so we just fill them if any
if mask.dtype != bool:
mask = mask.fillna(False)
mask = mask.fillna(False).infer_objects(copy=False)

return mask.to_numpy()

Expand All @@ -585,9 +579,7 @@ def _sort_data(self) -> None:
self.view_indices = self.filtered_indices.take(sort_indexer)
else:
# Data is not filtered
self.view_indices = nargsort(
column, kind="mergesort", ascending=key.ascending
)
self.view_indices = nargsort(column, kind="mergesort", ascending=key.ascending)
elif len(self.sort_keys) > 1:
# Multiple sorting keys
cols_to_sort = []
Expand Down Expand Up @@ -618,7 +610,63 @@ def _prof_null_count(self, column_index: int):
return self._get_column(column_index).isnull().sum()

def _prof_summary_stats(self, column_index: int):
raise NotImplementedError
col_schema = self._get_single_column_schema(column_index)
col = self._get_column(column_index)

ui_type = col_schema.type_display
handler = self._SUMMARIZERS.get(ui_type)

if handler is None:
# Return nothing for types we don't yet know how to summarize
return ColumnSummaryStats(type_display=ui_type)
else:
return handler(col)

@staticmethod
def _summarize_number(col: "pd.Series"):
min_value = col.min()
max_value = col.max()
mean = col.mean()
median = col.median()
stdev = col.std()

return ColumnSummaryStats(
type_display=ColumnDisplayType.Number,
number_stats=SummaryStatsNumber(
min_value=str(min_value),
max_value=str(max_value),
mean=str(mean),
median=str(median),
stdev=str(stdev),
),
)

@staticmethod
def _summarize_string(col: "pd.Series"):
num_empty = (col.str.len() == 0).sum()
num_unique = col.nunique()

return ColumnSummaryStats(
type_display=ColumnDisplayType.String,
string_stats=SummaryStatsString(num_empty=num_empty, num_unique=num_unique),
)

@staticmethod
def _summarize_boolean(col: "pd.Series"):
null_count = col.isnull().sum()
true_count = col.sum()
false_count = len(col) - true_count - null_count

return ColumnSummaryStats(
type_display=ColumnDisplayType.Boolean,
boolean_stats=SummaryStatsBoolean(true_count=true_count, false_count=false_count),
)

_SUMMARIZERS = {
ColumnDisplayType.Boolean: _summarize_boolean,
ColumnDisplayType.Number: _summarize_number,
ColumnDisplayType.String: _summarize_string,
}

def _prof_freq_table(self, column_index: int):
raise NotImplementedError
Expand All @@ -628,9 +676,7 @@ def _prof_histogram(self, column_index: int):

def _get_state(self) -> TableState:
return TableState(
table_shape=TableShape(
num_rows=self.table.shape[0], num_columns=self.table.shape[1]
),
table_shape=TableShape(num_rows=self.table.shape[0], num_columns=self.table.shape[1]),
row_filters=self.filters,
sort_keys=self.sort_keys,
)
Expand Down Expand Up @@ -817,9 +863,7 @@ def handle_variable_updated(self, variable_name, new_variable):
for comm_id in list(self.path_to_comm_ids[path]):
self._update_explorer_for_comm(comm_id, path, new_variable)

def _update_explorer_for_comm(
self, comm_id: str, path: PathKey, new_variable
):
def _update_explorer_for_comm(self, comm_id: str, path: PathKey, new_variable):
"""
If a variable is updated, we have to handle the different scenarios:
Expand Down Expand Up @@ -853,9 +897,7 @@ def _update_explorer_for_comm(
# data explorer open for a nested value, then we need to use
# the same variables inspection logic to resolve it here.
if len(path) > 1:
is_found, new_table = _resolve_value_from_path(
new_variable, path[1:]
)
is_found, new_table = _resolve_value_from_path(new_variable, path[1:])
if not is_found:
raise KeyError(f"Path {', '.join(path)} not found in value")
else:
Expand All @@ -874,9 +916,7 @@ def _fire_data_update():

def _fire_schema_update(discard_state=False):
msg = SchemaUpdateParams(discard_state=discard_state)
comm.send_event(
DataExplorerFrontendEvent.SchemaUpdate.value, msg.dict()
)
comm.send_event(DataExplorerFrontendEvent.SchemaUpdate.value, msg.dict())

if type(new_table) is not type(table_view.table): # noqa: E721
# Data type has changed. For now, we will signal the UI to
Expand Down Expand Up @@ -921,9 +961,7 @@ def _fire_schema_update(discard_state=False):
else:
_fire_data_update()

def handle_msg(
self, msg: CommMessage[DataExplorerBackendMessageContent], raw_msg
):
def handle_msg(self, msg: CommMessage[DataExplorerBackendMessageContent], raw_msg):
"""
Handle messages received from the client via the
positron.data_explorer comm.
Expand Down
Loading

0 comments on commit d9342e1

Please sign in to comment.