Skip to content

Commit

Permalink
Batch of row filtering improvements, consolidate get_state/get_suppor…
Browse files Browse the repository at this point in the history
…ted_features backend RPCs (#2757)

Add less-equal, greater-equal, not-equal, null/not-null filter types, put supported features in get_state result. Fix bug in filter modal when switching from one comparison type to another.
  • Loading branch information
wesm authored Apr 12, 2024
1 parent cefacbd commit b60a5f6
Show file tree
Hide file tree
Showing 10 changed files with 679 additions and 232 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from .access_keys import decode_access_key
from .data_explorer_comm import (
BackendState,
ColumnFrequencyTable,
ColumnHistogram,
ColumnSummaryStats,
Expand All @@ -41,8 +42,8 @@
GetDataValuesRequest,
GetSchemaRequest,
GetStateRequest,
GetSupportedFeaturesRequest,
RowFilter,
RowFilterCondition,
RowFilterType,
SchemaUpdateParams,
SearchFilterType,
Expand All @@ -59,7 +60,6 @@
TableData,
TableSchema,
TableShape,
TableState,
)
from .positron_comm import CommMessage, PositronComm
from .third_party import pd_
Expand Down Expand Up @@ -166,9 +166,6 @@ def get_column_profiles(self, request: GetColumnProfilesRequest):
def get_state(self, request: GetStateRequest):
return self._get_state().dict()

def get_supported_features(self, request: GetSupportedFeaturesRequest):
return self._get_supported_features().dict()

@abc.abstractmethod
def invalidate_computations(self):
pass
Expand Down Expand Up @@ -225,11 +222,7 @@ def _prof_histogram(self, column_index: int) -> ColumnHistogram:
pass

@abc.abstractmethod
def _get_state(self) -> TableState:
pass

@abc.abstractmethod
def _get_supported_features(self) -> SupportedFeatures:
def _get_state(self) -> BackendState:
pass


Expand Down Expand Up @@ -482,7 +475,7 @@ def _update_view_indices(self):
# reflect the filtered_indices that have just been updated
self._sort_data()

def _set_row_filters(self, filters) -> FilterResult:
def _set_row_filters(self, filters: List[RowFilter]) -> FilterResult:
self.filters = filters

if len(filters) == 0:
Expand All @@ -491,20 +484,32 @@ def _set_row_filters(self, filters) -> FilterResult:
self._update_view_indices()
return FilterResult(selected_num_rows=len(self.table))

# Evaluate all the filters and AND them together
# Evaluate all the filters and combine them using the
# indicated conditions
combined_mask = None
for filt in filters:
if filt.is_valid is False:
# If filter is invalid, do not evaluate it
continue

single_mask = self._eval_filter(filt)
if combined_mask is None:
combined_mask = single_mask
else:
elif filt.condition == RowFilterCondition.And:
combined_mask &= single_mask
elif filt.condition == RowFilterCondition.Or:
combined_mask |= single_mask

self.filtered_indices = combined_mask.nonzero()[0]
if combined_mask is None:
self.filtered_indices = None
selected_num_rows = len(self.table)
else:
self.filtered_indices = combined_mask.nonzero()[0]
selected_num_rows = len(self.filtered_indices)

# Update the view indices, re-sorting if needed
self._update_view_indices()
return FilterResult(selected_num_rows=len(self.filtered_indices))
return FilterResult(selected_num_rows=selected_num_rows)

def _eval_filter(self, filt: RowFilter):
col = self.table.iloc[:, filt.column_index]
Expand All @@ -531,8 +536,12 @@ def _eval_filter(self, filt: RowFilter):
op = COMPARE_OPS[params.op]
# pandas comparison filters return False for null values
mask = op(col, _coerce_value_param(params.value, col.dtype))
elif filt.filter_type == RowFilterType.IsEmpty:
mask = col.str.len() == 0
elif filt.filter_type == RowFilterType.IsNull:
mask = col.isnull()
elif filt.filter_type == RowFilterType.NotEmpty:
mask = col.str.len() != 0
elif filt.filter_type == RowFilterType.NotNull:
mask = col.notnull()
elif filt.filter_type == RowFilterType.SetMembership:
Expand Down Expand Up @@ -687,45 +696,45 @@ def _prof_freq_table(self, column_index: int):
def _prof_histogram(self, column_index: int):
raise NotImplementedError

def _get_state(self) -> TableState:
_row_filter_features = SetRowFiltersFeatures(
supported=True,
supports_conditions=False,
supported_types=[
RowFilterType.Between,
RowFilterType.Compare,
RowFilterType.IsNull,
RowFilterType.NotNull,
RowFilterType.NotBetween,
RowFilterType.Search,
RowFilterType.SetMembership,
],
)

_column_profile_features = GetColumnProfilesFeatures(
supported=True,
supported_types=[
ColumnProfileType.NullCount,
ColumnProfileType.SummaryStats,
],
)

FEATURES = SupportedFeatures(
search_schema=SearchSchemaFeatures(supported=True),
set_row_filters=_row_filter_features,
get_column_profiles=_column_profile_features,
)

def _get_state(self) -> BackendState:
if self.view_indices is not None:
num_rows = len(self.view_indices)
else:
num_rows = self.table.shape[0]

return TableState(
return BackendState(
table_shape=TableShape(num_rows=num_rows, num_columns=self.table.shape[1]),
row_filters=self.filters,
sort_keys=self.sort_keys,
)

def _get_supported_features(self) -> SupportedFeatures:
row_filter_features = SetRowFiltersFeatures(
supported=True,
supports_conditions=False,
supported_types=[
RowFilterType.Between,
RowFilterType.Compare,
RowFilterType.IsNull,
RowFilterType.NotNull,
RowFilterType.NotBetween,
RowFilterType.Search,
RowFilterType.SetMembership,
],
)

column_profile_features = GetColumnProfilesFeatures(
supported=True,
supported_types=[
ColumnProfileType.NullCount,
ColumnProfileType.SummaryStats,
],
)

return SupportedFeatures(
search_schema=SearchSchemaFeatures(supported=True),
set_row_filters=row_filter_features,
get_column_profiles=column_profile_features,
supported_features=self.FEATURES,
)


Expand Down Expand Up @@ -1020,4 +1029,13 @@ def handle_msg(self, msg: CommMessage[DataExplorerBackendMessageContent], raw_ms
table = self.table_views[comm_id]

result = getattr(table, request.method.value)(request)

# To help remember to convert pydantic types to dicts
if result is not None:
if isinstance(result, list):
for x in result:
assert isinstance(x, dict)
else:
assert isinstance(result, dict)

comm.send_result(result)
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ class ColumnDisplayType(str, enum.Enum):
Unknown = "unknown"


@enum.unique
class RowFilterCondition(str, enum.Enum):
"""
Possible values for Condition in RowFilter
"""

And = "and"

Or = "or"


@enum.unique
class RowFilterType(str, enum.Enum):
"""
Expand All @@ -52,10 +63,14 @@ class RowFilterType(str, enum.Enum):

Compare = "compare"

IsEmpty = "is_empty"

IsNull = "is_null"

NotBetween = "not_between"

NotEmpty = "not_empty"

NotNull = "not_null"

Search = "search"
Expand Down Expand Up @@ -152,9 +167,9 @@ class FilterResult(BaseModel):
)


class TableState(BaseModel):
class BackendState(BaseModel):
"""
The current backend table state
The current backend state for the data explorer
"""

table_shape: TableShape = Field(
Expand All @@ -169,6 +184,10 @@ class TableState(BaseModel):
description="The set of currently applied sorts",
)

supported_features: SupportedFeatures = Field(
description="The features currently supported by the backend instance",
)


class TableShape(BaseModel):
"""
Expand All @@ -184,24 +203,6 @@ class TableShape(BaseModel):
)


class SupportedFeatures(BaseModel):
"""
For each field, returns flags indicating supported features
"""

search_schema: SearchSchemaFeatures = Field(
description="Support for 'search_schema' RPC and its features",
)

set_row_filters: SetRowFiltersFeatures = Field(
description="Support for 'set_row_filters' RPC and its features",
)

get_column_profiles: GetColumnProfilesFeatures = Field(
description="Support for 'get_column_profiles' RPC and its features",
)


class ColumnSchema(BaseModel):
"""
Schema for a column in a table
Expand Down Expand Up @@ -281,6 +282,15 @@ class RowFilter(BaseModel):
description="Column index to apply filter to",
)

condition: RowFilterCondition = Field(
description="The binary condition to use to combine with preceding row filters",
)

is_valid: Optional[bool] = Field(
default=None,
description="Whether the filter is valid and supported by the backend, if undefined then true",
)

between_params: Optional[BetweenFilterParams] = Field(
default=None,
description="Parameters for the 'between' and 'not_between' filter types",
Expand Down Expand Up @@ -556,6 +566,24 @@ class ColumnSortKey(BaseModel):
)


class SupportedFeatures(BaseModel):
"""
For each field, returns flags indicating supported features
"""

search_schema: SearchSchemaFeatures = Field(
description="Support for 'search_schema' RPC and its features",
)

set_row_filters: SetRowFiltersFeatures = Field(
description="Support for 'set_row_filters' RPC and its features",
)

get_column_profiles: GetColumnProfilesFeatures = Field(
description="Support for 'get_column_profiles' RPC and its features",
)


class SearchSchemaFeatures(BaseModel):
"""
Feature flags for 'search_schema' RPC
Expand Down Expand Up @@ -625,9 +653,6 @@ class DataExplorerBackendRequest(str, enum.Enum):
# Get the state
GetState = "get_state"

# Query the backend to determine supported features
GetSupportedFeatures = "get_supported_features"


class GetSchemaParams(BaseModel):
"""
Expand Down Expand Up @@ -840,22 +865,6 @@ class GetStateRequest(BaseModel):
)


class GetSupportedFeaturesRequest(BaseModel):
"""
Query the backend to determine supported features, to enable feature
toggling
"""

method: Literal[DataExplorerBackendRequest.GetSupportedFeatures] = Field(
description="The JSON-RPC method name (get_supported_features)",
)

jsonrpc: str = Field(
default="2.0",
description="The JSON-RPC version specifier",
)


class DataExplorerBackendMessageContent(BaseModel):
comm_id: str
data: Union[
Expand All @@ -866,7 +875,6 @@ class DataExplorerBackendMessageContent(BaseModel):
SetSortColumnsRequest,
GetColumnProfilesRequest,
GetStateRequest,
GetSupportedFeaturesRequest,
] = Field(..., discriminator="method")


Expand Down Expand Up @@ -899,12 +907,10 @@ class SchemaUpdateParams(BaseModel):

FilterResult.update_forward_refs()

TableState.update_forward_refs()
BackendState.update_forward_refs()

TableShape.update_forward_refs()

SupportedFeatures.update_forward_refs()

ColumnSchema.update_forward_refs()

TableSchema.update_forward_refs()
Expand Down Expand Up @@ -941,6 +947,8 @@ class SchemaUpdateParams(BaseModel):

ColumnSortKey.update_forward_refs()

SupportedFeatures.update_forward_refs()

SearchSchemaFeatures.update_forward_refs()

SetRowFiltersFeatures.update_forward_refs()
Expand Down Expand Up @@ -973,6 +981,4 @@ class SchemaUpdateParams(BaseModel):

GetStateRequest.update_forward_refs()

GetSupportedFeaturesRequest.update_forward_refs()

SchemaUpdateParams.update_forward_refs()
Loading

0 comments on commit b60a5f6

Please sign in to comment.