Skip to content

Commit

Permalink
Add unit test for feature flags
Browse files Browse the repository at this point in the history
  • Loading branch information
wesm committed Apr 3, 2024
1 parent 8a29d9b commit f01c38e
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 109 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,7 @@ def _recompute_if_needed(self) -> bool:
return False

def get_schema(self, request: GetSchemaRequest):
return self._get_schema(
request.params.start_index, request.params.num_columns
).dict()
return self._get_schema(request.params.start_index, request.params.num_columns).dict()

def search_schema(self, request: SearchSchemaRequest):
return self._search_schema(
Expand Down Expand Up @@ -309,9 +307,7 @@ def __init__(
# search term changes, we discard the last search result. We
# might add an LRU cache here or something if it helps
# performance.
self._search_schema_last_result: Optional[
Tuple[str, List[ColumnSchema]]
] = None
self._search_schema_last_result: Optional[Tuple[str, List[ColumnSchema]]] = None

# Putting this here rather than in the class body before
# Python < 3.10 has fussier rules about staticmethods
Expand Down Expand Up @@ -390,9 +386,7 @@ def _search_schema(
total_num_matches=len(matches),
)

def _search_schema_get_matches(
self, search_term: str
) -> List[ColumnSchema]:
def _search_schema_get_matches(self, search_term: str) -> List[ColumnSchema]:
matches = []
for column_index in range(len(self.table.columns)):
column_raw_name = self.table.columns[column_index]
Expand All @@ -411,9 +405,7 @@ def _get_inferred_dtype(self, column_index: int):
from pandas.api.types import infer_dtype

if column_index not in self._inferred_dtypes:
self._inferred_dtypes[column_index] = infer_dtype(
self.table.iloc[:, column_index]
)
self._inferred_dtypes[column_index] = infer_dtype(self.table.iloc[:, column_index])
return self._inferred_dtypes[column_index]

def _get_single_column_schema(self, column_index: int):
Expand Down Expand Up @@ -470,9 +462,7 @@ def _get_data_values(
else:
# No filtering or sorting, just slice directly
indices = self.table.index[row_start : row_start + num_rows]
columns = [
col.iloc[row_start : row_start + num_rows] for col in columns
]
columns = [col.iloc[row_start : row_start + num_rows] for col in columns]

formatted_columns = [_pandas_format_values(col) for col in columns]

Expand Down Expand Up @@ -608,9 +598,7 @@ def _sort_data(self) -> None:
self.view_indices = self.filtered_indices.take(sort_indexer)
else:
# Data is not filtered
self.view_indices = nargsort(
column, kind="mergesort", ascending=key.ascending
)
self.view_indices = nargsort(column, kind="mergesort", ascending=key.ascending)
elif len(self.sort_keys) > 1:
# Multiple sorting keys
cols_to_sort = []
Expand Down Expand Up @@ -679,9 +667,7 @@ def _summarize_string(col: "pd.Series"):

return ColumnSummaryStats(
type_display=ColumnDisplayType.String,
string_stats=SummaryStatsString(
num_empty=num_empty, num_unique=num_unique
),
string_stats=SummaryStatsString(num_empty=num_empty, num_unique=num_unique),
)

@staticmethod
Expand All @@ -692,9 +678,7 @@ def _summarize_boolean(col: "pd.Series"):

return ColumnSummaryStats(
type_display=ColumnDisplayType.Boolean,
boolean_stats=SummaryStatsBoolean(
true_count=true_count, false_count=false_count
),
boolean_stats=SummaryStatsBoolean(true_count=true_count, false_count=false_count),
)

def _prof_freq_table(self, column_index: int):
Expand All @@ -705,9 +689,7 @@ def _prof_histogram(self, column_index: int):

def _get_state(self) -> TableState:
return TableState(
table_shape=TableShape(
num_rows=self.table.shape[0], num_columns=self.table.shape[1]
),
table_shape=TableShape(num_rows=self.table.shape[0], num_columns=self.table.shape[1]),
row_filters=self.filters,
sort_keys=self.sort_keys,
)
Expand Down Expand Up @@ -923,9 +905,7 @@ def handle_variable_updated(self, variable_name, new_variable):
for comm_id in list(self.path_to_comm_ids[path]):
self._update_explorer_for_comm(comm_id, path, new_variable)

def _update_explorer_for_comm(
self, comm_id: str, path: PathKey, new_variable
):
def _update_explorer_for_comm(self, comm_id: str, path: PathKey, new_variable):
"""
If a variable is updated, we have to handle the different scenarios:
Expand Down Expand Up @@ -959,9 +939,7 @@ def _update_explorer_for_comm(
# data explorer open for a nested value, then we need to use
# the same variables inspection logic to resolve it here.
if len(path) > 1:
is_found, new_table = _resolve_value_from_path(
new_variable, path[1:]
)
is_found, new_table = _resolve_value_from_path(new_variable, path[1:])
if not is_found:
raise KeyError(f"Path {', '.join(path)} not found in value")
else:
Expand All @@ -980,9 +958,7 @@ def _fire_data_update():

def _fire_schema_update(discard_state=False):
msg = SchemaUpdateParams(discard_state=discard_state)
comm.send_event(
DataExplorerFrontendEvent.SchemaUpdate.value, msg.dict()
)
comm.send_event(DataExplorerFrontendEvent.SchemaUpdate.value, msg.dict())

if type(new_table) is not type(table_view.table): # noqa: E721
# Data type has changed. For now, we will signal the UI to
Expand Down Expand Up @@ -1027,9 +1003,7 @@ def _fire_schema_update(discard_state=False):
else:
_fire_data_update()

def handle_msg(
self, msg: CommMessage[DataExplorerBackendMessageContent], raw_msg
):
def handle_msg(self, msg: CommMessage[DataExplorerBackendMessageContent], raw_msg):
"""
Handle messages received from the client via the
positron.data_explorer comm.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,7 @@ def test_explorer_open_close_delete(
assert len(de_service.table_views) == 0


def _assign_variables(
shell: PositronShell, variables_comm: DummyComm, **variables
):
def _assign_variables(shell: PositronShell, variables_comm: DummyComm, **variables):
# A hack to make sure that change events are fired when we
# manipulate user_ns
shell.kernel.variables_service.snapshot_user_ns()
Expand Down Expand Up @@ -187,9 +185,7 @@ def _check_delete_variable(name):
assert len(paths) > 0

comms = [
de_service.comms[comm_id]
for p in paths
for comm_id in de_service.path_to_comm_ids[p]
de_service.comms[comm_id] for p in paths for comm_id in de_service.path_to_comm_ids[p]
]
variables_comm.handle_msg(msg)

Expand All @@ -205,22 +201,14 @@ def _check_delete_variable(name):
_check_delete_variable("y")


def _check_update_variable(
de_service, name, update_type="schema", discard_state=True
):
def _check_update_variable(de_service, name, update_type="schema", discard_state=True):
paths = de_service.get_paths_for_variable(name)
assert len(paths) > 0

comms = [
de_service.comms[comm_id]
for p in paths
for comm_id in de_service.path_to_comm_ids[p]
]
comms = [de_service.comms[comm_id] for p in paths for comm_id in de_service.path_to_comm_ids[p]]

if update_type == "schema":
expected_msg = json_rpc_notification(
"schema_update", {"discard_state": discard_state}
)
expected_msg = json_rpc_notification("schema_update", {"discard_state": discard_state})
else:
expected_msg = json_rpc_notification("data_update", {})

Expand Down Expand Up @@ -287,18 +275,14 @@ def test_explorer_variable_updates(
'key2': y['key2'].copy()}
"""
)
_check_update_variable(
de_service, "y", update_type="update", discard_state=False
)
_check_update_variable(de_service, "y", update_type="update", discard_state=False)

shell.run_cell(
"""y = {'key1': y['key1'].iloc[:-1, :-1],
'key2': y['key2'].copy().iloc[:, 1:]}
"""
)
_check_update_variable(
de_service, "y", update_type="schema", discard_state=True
)
_check_update_variable(de_service, "y", update_type="schema", discard_state=True)


def test_register_table(de_service: DataExplorerService):
Expand Down Expand Up @@ -399,14 +383,10 @@ def set_row_filters(self, table_name, filters=None):
return self.do_json_rpc(table_name, "set_row_filters", filters=filters)

def set_sort_columns(self, table_name, sort_keys=None):
return self.do_json_rpc(
table_name, "set_sort_columns", sort_keys=sort_keys
)
return self.do_json_rpc(table_name, "set_sort_columns", sort_keys=sort_keys)

def get_column_profiles(self, table_name, profiles):
return self.do_json_rpc(
table_name, "get_column_profiles", profiles=profiles
)
return self.do_json_rpc(table_name, "get_column_profiles", profiles=profiles)

def check_filter_case(self, table, filter_set, expected_table):
table_id = guid()
Expand All @@ -431,9 +411,7 @@ def check_sort_case(self, table, sort_keys, expected_table, filters=None):
assert response is None
self.compare_tables(table_id, ex_id, table.shape)

def compare_tables(
self, table_id: str, expected_id: str, table_shape: tuple
):
def compare_tables(self, table_id: str, expected_id: str, table_shape: tuple):
# Query the data and check it yields the same result as the
# manually constructed data frame without the filter
response = self.get_data_values(
Expand Down Expand Up @@ -479,7 +457,32 @@ def test_pandas_get_state(dxf: DataExplorerFixture):


def test_pandas_get_supported_features(dxf: DataExplorerFixture):
pass
dxf.register_table("example", SIMPLE_PANDAS_DF)
features = dxf.get_supported_features("example")

search_schema = features["search_schema"]
row_filters = features["set_row_filters"]
column_profiles = features["get_column_profiles"]

assert search_schema["supported"]

assert row_filters["supported"]
assert not row_filters["supports_conditions"]
assert set(row_filters["supported_types"]) == {
"is_null",
"not_null",
"between",
"compare",
"not_between",
"search",
"set_membership",
}

assert column_profiles["supported"]
assert set(column_profiles["supported_types"]) == {
"null_count",
"summary_stats",
}


def test_pandas_get_schema(dxf: DataExplorerFixture):
Expand Down Expand Up @@ -586,9 +589,7 @@ def test_pandas_search_schema(dxf: DataExplorerFixture):

# Make a data frame with those column names
arr = np.arange(10)
df = pd.DataFrame(
{name: arr for name in column_names}, columns=pd.Index(column_names)
)
df = pd.DataFrame({name: arr for name in column_names}, columns=pd.Index(column_names))

dxf.register_table("df", df)

Expand Down Expand Up @@ -647,9 +648,7 @@ def test_pandas_get_data_values(dxf: DataExplorerFixture):
assert result["row_labels"] == [["0", "1", "2", "3", "4"]]

# Edge cases: request beyond end of table
response = dxf.get_data_values(
"simple", row_start_index=5, num_rows=10, column_indices=[0]
)
response = dxf.get_data_values("simple", row_start_index=5, num_rows=10, column_indices=[0])
assert response["columns"] == [[]]

# Issue #2149 -- return empty result when requesting non-existent
Expand Down Expand Up @@ -681,9 +680,7 @@ def _filter(filter_type, column_index, **kwargs):


def _compare_filter(column_index, op, value):
return _filter(
"compare", column_index, compare_params={"op": op, "value": value}
)
return _filter("compare", column_index, compare_params={"op": op, "value": value})


def _between_filter(column_index, left_value, right_value, op="between"):
Expand All @@ -695,14 +692,10 @@ def _between_filter(column_index, left_value, right_value, op="between"):


def _not_between_filter(column_index, left_value, right_value):
return _between_filter(
column_index, left_value, right_value, op="not_between"
)
return _between_filter(column_index, left_value, right_value, op="not_between")


def _search_filter(
column_index, term, case_sensitive=False, search_type="contains"
):
def _search_filter(column_index, term, case_sensitive=False, search_type="contains"):
return _filter(
"search",
column_index,
Expand Down Expand Up @@ -745,11 +738,7 @@ def test_pandas_filter_between(dxf: DataExplorerFixture):
)
dxf.check_filter_case(
df,
[
_not_between_filter(
column_index, str(left_value), str(right_value)
)
],
[_not_between_filter(column_index, str(left_value), str(right_value))],
ex_not_between,
)

Expand Down Expand Up @@ -925,14 +914,11 @@ def test_pandas_set_sort_columns(dxf: DataExplorerFixture):
]

# Test sort AND filter
filter_cases = {
"df2": [(lambda x: x[x["a"] > 0], [_compare_filter(0, ">", 0)])]
}
filter_cases = {"df2": [(lambda x: x[x["a"] > 0], [_compare_filter(0, ">", 0)])]}

for df_name, keys, expected_params in cases:
wrapped_keys = [
{"column_index": index, "ascending": ascending}
for index, ascending in keys
{"column_index": index, "ascending": ascending} for index, ascending in keys
]
df = tables[df_name]

Expand All @@ -944,9 +930,7 @@ def test_pandas_set_sort_columns(dxf: DataExplorerFixture):

for filter_f, filters in filter_cases.get(df_name, []):
expected_filtered = filter_f(df).sort_values(**expected_params)
dxf.check_sort_case(
df, wrapped_keys, expected_filtered, filters=filters
)
dxf.check_sort_case(df, wrapped_keys, expected_filtered, filters=filters)


def test_pandas_change_schema_after_sort(
Expand Down Expand Up @@ -976,9 +960,7 @@ def test_pandas_change_schema_after_sort(

# Sort last column, and we will then change the schema
shell.run_cell("df = df[['a', 'b']]")
_check_update_variable(
de_service, "df", update_type="schema", discard_state=True
)
_check_update_variable(de_service, "df", update_type="schema", discard_state=True)

# Call get_data_values and make sure it works
dxf.compare_tables("df", "expected_df", expected_df.shape)
Expand Down Expand Up @@ -1039,17 +1021,13 @@ def test_pandas_profile_null_counts(dxf: DataExplorerFixture):
for table_name, profiles, ex_results in cases:
results = dxf.get_column_profiles(table_name, profiles)

ex_results = [
ColumnProfileResult(null_count=count) for count in ex_results
]
ex_results = [ColumnProfileResult(null_count=count) for count in ex_results]

assert results == ex_results

# Test profiling with filter
# format: (table, filters, filtered_table, profiles)
filter_cases = [
(df1, [_filter("not_null", 0)], df1[df1["a"].notnull()], all_profiles)
]
filter_cases = [(df1, [_filter("not_null", 0)], df1[df1["a"].notnull()], all_profiles)]
for table, filters, filtered_table, profiles in filter_cases:
table_id = guid()
dxf.register_table(table_id, table)
Expand Down

0 comments on commit f01c38e

Please sign in to comment.