Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch of row filtering improvements, backend RPC cleanup. Add less-equal, greater-equal, null/not-null filter types, put supported features in get_state result #2757

Merged
merged 3 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from .access_keys import decode_access_key
from .data_explorer_comm import (
BackendState,
ColumnFrequencyTable,
ColumnHistogram,
ColumnSummaryStats,
Expand All @@ -41,8 +42,8 @@
GetDataValuesRequest,
GetSchemaRequest,
GetStateRequest,
GetSupportedFeaturesRequest,
RowFilter,
RowFilterCondition,
RowFilterType,
SchemaUpdateParams,
SearchFilterType,
Expand All @@ -59,7 +60,6 @@
TableData,
TableSchema,
TableShape,
TableState,
)
from .positron_comm import CommMessage, PositronComm
from .third_party import pd_
Expand Down Expand Up @@ -166,9 +166,6 @@ def get_column_profiles(self, request: GetColumnProfilesRequest):
def get_state(self, request: GetStateRequest):
return self._get_state().dict()

def get_supported_features(self, request: GetSupportedFeaturesRequest):
return self._get_supported_features().dict()

@abc.abstractmethod
def invalidate_computations(self):
pass
Expand Down Expand Up @@ -225,11 +222,7 @@ def _prof_histogram(self, column_index: int) -> ColumnHistogram:
pass

@abc.abstractmethod
def _get_state(self) -> TableState:
pass

@abc.abstractmethod
def _get_supported_features(self) -> SupportedFeatures:
def _get_state(self) -> BackendState:
pass


Expand Down Expand Up @@ -482,7 +475,7 @@ def _update_view_indices(self):
# reflect the filtered_indices that have just been updated
self._sort_data()

def _set_row_filters(self, filters) -> FilterResult:
def _set_row_filters(self, filters: List[RowFilter]) -> FilterResult:
self.filters = filters

if len(filters) == 0:
Expand All @@ -491,20 +484,32 @@ def _set_row_filters(self, filters) -> FilterResult:
self._update_view_indices()
return FilterResult(selected_num_rows=len(self.table))

# Evaluate all the filters and AND them together
# Evaluate all the filters and combine them using the
# indicated conditions
combined_mask = None
for filt in filters:
if filt.is_valid is False:
# If filter is invalid, do not evaluate it
continue

single_mask = self._eval_filter(filt)
if combined_mask is None:
combined_mask = single_mask
else:
elif filt.condition == RowFilterCondition.And:
combined_mask &= single_mask
elif filt.condition == RowFilterCondition.Or:
combined_mask |= single_mask

self.filtered_indices = combined_mask.nonzero()[0]
if combined_mask is None:
self.filtered_indices = None
selected_num_rows = len(self.table)
else:
self.filtered_indices = combined_mask.nonzero()[0]
selected_num_rows = len(self.filtered_indices)

# Update the view indices, re-sorting if needed
self._update_view_indices()
return FilterResult(selected_num_rows=len(self.filtered_indices))
return FilterResult(selected_num_rows=selected_num_rows)

def _eval_filter(self, filt: RowFilter):
col = self.table.iloc[:, filt.column_index]
Expand All @@ -531,8 +536,12 @@ def _eval_filter(self, filt: RowFilter):
op = COMPARE_OPS[params.op]
# pandas comparison filters return False for null values
mask = op(col, _coerce_value_param(params.value, col.dtype))
elif filt.filter_type == RowFilterType.IsEmpty:
mask = col.str.len() == 0
elif filt.filter_type == RowFilterType.IsNull:
mask = col.isnull()
elif filt.filter_type == RowFilterType.NotEmpty:
mask = col.str.len() != 0
elif filt.filter_type == RowFilterType.NotNull:
mask = col.notnull()
elif filt.filter_type == RowFilterType.SetMembership:
Expand Down Expand Up @@ -687,45 +696,45 @@ def _prof_freq_table(self, column_index: int):
def _prof_histogram(self, column_index: int):
raise NotImplementedError

def _get_state(self) -> TableState:
_row_filter_features = SetRowFiltersFeatures(
supported=True,
supports_conditions=False,
supported_types=[
RowFilterType.Between,
RowFilterType.Compare,
RowFilterType.IsNull,
RowFilterType.NotNull,
RowFilterType.NotBetween,
RowFilterType.Search,
RowFilterType.SetMembership,
],
)

_column_profile_features = GetColumnProfilesFeatures(
supported=True,
supported_types=[
ColumnProfileType.NullCount,
ColumnProfileType.SummaryStats,
],
)

FEATURES = SupportedFeatures(
search_schema=SearchSchemaFeatures(supported=True),
set_row_filters=_row_filter_features,
get_column_profiles=_column_profile_features,
)

def _get_state(self) -> BackendState:
if self.view_indices is not None:
num_rows = len(self.view_indices)
else:
num_rows = self.table.shape[0]

return TableState(
return BackendState(
table_shape=TableShape(num_rows=num_rows, num_columns=self.table.shape[1]),
row_filters=self.filters,
sort_keys=self.sort_keys,
)

def _get_supported_features(self) -> SupportedFeatures:
row_filter_features = SetRowFiltersFeatures(
supported=True,
supports_conditions=False,
supported_types=[
RowFilterType.Between,
RowFilterType.Compare,
RowFilterType.IsNull,
RowFilterType.NotNull,
RowFilterType.NotBetween,
RowFilterType.Search,
RowFilterType.SetMembership,
],
)

column_profile_features = GetColumnProfilesFeatures(
supported=True,
supported_types=[
ColumnProfileType.NullCount,
ColumnProfileType.SummaryStats,
],
)

return SupportedFeatures(
search_schema=SearchSchemaFeatures(supported=True),
set_row_filters=row_filter_features,
get_column_profiles=column_profile_features,
supported_features=self.FEATURES,
)


Expand Down Expand Up @@ -1020,4 +1029,13 @@ def handle_msg(self, msg: CommMessage[DataExplorerBackendMessageContent], raw_ms
table = self.table_views[comm_id]

result = getattr(table, request.method.value)(request)

# To help remember to convert pydantic types to dicts
if result is not None:
if isinstance(result, list):
for x in result:
assert isinstance(x, dict)
else:
assert isinstance(result, dict)

comm.send_result(result)
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ class ColumnDisplayType(str, enum.Enum):
Unknown = "unknown"


@enum.unique
class RowFilterCondition(str, enum.Enum):
"""
Possible values for Condition in RowFilter
"""

And = "and"

Or = "or"


@enum.unique
class RowFilterType(str, enum.Enum):
"""
Expand All @@ -52,10 +63,14 @@ class RowFilterType(str, enum.Enum):

Compare = "compare"

IsEmpty = "is_empty"

IsNull = "is_null"

NotBetween = "not_between"

NotEmpty = "not_empty"

NotNull = "not_null"

Search = "search"
Expand Down Expand Up @@ -152,9 +167,9 @@ class FilterResult(BaseModel):
)


class TableState(BaseModel):
class BackendState(BaseModel):
"""
The current backend table state
The current backend state for the data explorer
"""

table_shape: TableShape = Field(
Expand All @@ -169,6 +184,10 @@ class TableState(BaseModel):
description="The set of currently applied sorts",
)

supported_features: SupportedFeatures = Field(
description="The features currently supported by the backend instance",
)


class TableShape(BaseModel):
"""
Expand All @@ -184,24 +203,6 @@ class TableShape(BaseModel):
)


class SupportedFeatures(BaseModel):
"""
For each field, returns flags indicating supported features
"""

search_schema: SearchSchemaFeatures = Field(
description="Support for 'search_schema' RPC and its features",
)

set_row_filters: SetRowFiltersFeatures = Field(
description="Support for 'set_row_filters' RPC and its features",
)

get_column_profiles: GetColumnProfilesFeatures = Field(
description="Support for 'get_column_profiles' RPC and its features",
)


class ColumnSchema(BaseModel):
"""
Schema for a column in a table
Expand Down Expand Up @@ -281,6 +282,15 @@ class RowFilter(BaseModel):
description="Column index to apply filter to",
)

condition: RowFilterCondition = Field(
description="The binary condition to use to combine with preceding row filters",
)

is_valid: Optional[bool] = Field(
default=None,
description="Whether the filter is valid and supported by the backend, if undefined then true",
)

between_params: Optional[BetweenFilterParams] = Field(
default=None,
description="Parameters for the 'between' and 'not_between' filter types",
Expand Down Expand Up @@ -556,6 +566,24 @@ class ColumnSortKey(BaseModel):
)


class SupportedFeatures(BaseModel):
"""
For each field, returns flags indicating supported features
"""

search_schema: SearchSchemaFeatures = Field(
description="Support for 'search_schema' RPC and its features",
)

set_row_filters: SetRowFiltersFeatures = Field(
description="Support for 'set_row_filters' RPC and its features",
)

get_column_profiles: GetColumnProfilesFeatures = Field(
description="Support for 'get_column_profiles' RPC and its features",
)


class SearchSchemaFeatures(BaseModel):
"""
Feature flags for 'search_schema' RPC
Expand Down Expand Up @@ -625,9 +653,6 @@ class DataExplorerBackendRequest(str, enum.Enum):
# Get the state
GetState = "get_state"

# Query the backend to determine supported features
GetSupportedFeatures = "get_supported_features"


class GetSchemaParams(BaseModel):
"""
Expand Down Expand Up @@ -840,22 +865,6 @@ class GetStateRequest(BaseModel):
)


class GetSupportedFeaturesRequest(BaseModel):
"""
Query the backend to determine supported features, to enable feature
toggling
"""

method: Literal[DataExplorerBackendRequest.GetSupportedFeatures] = Field(
description="The JSON-RPC method name (get_supported_features)",
)

jsonrpc: str = Field(
default="2.0",
description="The JSON-RPC version specifier",
)


class DataExplorerBackendMessageContent(BaseModel):
comm_id: str
data: Union[
Expand All @@ -866,7 +875,6 @@ class DataExplorerBackendMessageContent(BaseModel):
SetSortColumnsRequest,
GetColumnProfilesRequest,
GetStateRequest,
GetSupportedFeaturesRequest,
] = Field(..., discriminator="method")


Expand Down Expand Up @@ -899,12 +907,10 @@ class SchemaUpdateParams(BaseModel):

FilterResult.update_forward_refs()

TableState.update_forward_refs()
BackendState.update_forward_refs()

TableShape.update_forward_refs()

SupportedFeatures.update_forward_refs()

ColumnSchema.update_forward_refs()

TableSchema.update_forward_refs()
Expand Down Expand Up @@ -941,6 +947,8 @@ class SchemaUpdateParams(BaseModel):

ColumnSortKey.update_forward_refs()

SupportedFeatures.update_forward_refs()

SearchSchemaFeatures.update_forward_refs()

SetRowFiltersFeatures.update_forward_refs()
Expand Down Expand Up @@ -973,6 +981,4 @@ class SchemaUpdateParams(BaseModel):

GetStateRequest.update_forward_refs()

GetSupportedFeaturesRequest.update_forward_refs()

SchemaUpdateParams.update_forward_refs()
Loading
Loading