diff --git a/extensions/positron-python/pythonFiles/positron/positron_ipykernel/data_explorer.py b/extensions/positron-python/pythonFiles/positron/positron_ipykernel/data_explorer.py index 56bca99fcf2..8b3ddb6576b 100644 --- a/extensions/positron-python/pythonFiles/positron/positron_ipykernel/data_explorer.py +++ b/extensions/positron-python/pythonFiles/positron/positron_ipykernel/data_explorer.py @@ -715,14 +715,14 @@ def _fire_schema_update(discard_state=False): should_discard_state, ) = table_view.ui_should_update_schema(new_table) - # The schema is the same, but the data has changed. We'll just - # set a new table view and preserve the view state and be done - # with it. - self.table_views[comm_id] = _get_table_view( - new_table, - filters=table_view.filters, - sort_keys=table_view.sort_keys, - ) + if should_discard_state: + self.table_views[comm_id] = _get_table_view(new_table) + else: + self.table_views[comm_id] = _get_table_view( + new_table, + filters=table_view.filters, + sort_keys=table_view.sort_keys, + ) if should_update_schema: _fire_schema_update(discard_state=should_discard_state) diff --git a/extensions/positron-python/pythonFiles/positron/positron_ipykernel/tests/test_data_explorer.py b/extensions/positron-python/pythonFiles/positron/positron_ipykernel/tests/test_data_explorer.py index 2e47a077e58..9cfeba367e5 100644 --- a/extensions/positron-python/pythonFiles/positron/positron_ipykernel/tests/test_data_explorer.py +++ b/extensions/positron-python/pythonFiles/positron/positron_ipykernel/tests/test_data_explorer.py @@ -12,7 +12,12 @@ from .._vendor.pydantic import BaseModel from ..access_keys import encode_access_key from ..data_explorer import COMPARE_OPS, DataExplorerService -from ..data_explorer_comm import ColumnFilter, ColumnSchema, ColumnSortKey, FilterResult +from ..data_explorer_comm import ( + ColumnFilter, + ColumnSchema, + ColumnSortKey, + FilterResult, +) from .conftest import DummyComm, PositronShell from .test_variables import BIG_ARRAY_LENGTH from .utils import json_rpc_notification, json_rpc_request, json_rpc_response @@ -185,6 +190,23 @@ def _check_delete_variable(name): _check_delete_variable("y") +def _check_update_variable(de_service, name, update_type="schema", discard_state=True): + paths = de_service.get_paths_for_variable(name) + assert len(paths) > 0 + + comms = [de_service.comms[comm_id] for p in paths for comm_id in de_service.path_to_comm_ids[p]] + + if update_type == "schema": + expected_msg = json_rpc_notification("schema_update", {"discard_state": discard_state}) + else: + expected_msg = json_rpc_notification("data_update", {}) + + # Check that comms were all closed + for comm in comms: + last_message = cast(DummyComm, comm.comm).messages[-1] + assert last_message == expected_msg + + def test_explorer_variable_updates( shell: PositronShell, de_service: DataExplorerService, @@ -201,23 +223,6 @@ def test_explorer_variable_updates( ) # Check updates - def _check_update_variable(name, update_type="schema", discard_state=True): - paths = de_service.get_paths_for_variable(name) - assert len(paths) > 0 - - comms = [ - de_service.comms[comm_id] for p in paths for comm_id in de_service.path_to_comm_ids[p] - ] - - if update_type == "schema": - expected_msg = json_rpc_notification("schema_update", {"discard_state": discard_state}) - else: - expected_msg = json_rpc_notification("data_update", {}) - - # Check that comms were all closed - for comm in comms: - last_message = cast(DummyComm, comm.comm).messages[-1] - assert last_message == expected_msg path_x = _open_viewer(variables_comm, ["x"]) _open_viewer(variables_comm, ["big_x"]) @@ -236,7 +241,7 @@ def _check_update_variable(name, update_type="schema", discard_state=True): de_service.comms[x_comm_id].comm.handle_msg(msg) shell.run_cell("import pandas as pd") shell.run_cell("x = pd.DataFrame({'a': [1, 0, 3, 4, 5]})") - _check_update_variable("x", update_type="data") + _check_update_variable(de_service, "x", update_type="data") tv = de_service.table_views[x_comm_id] assert tv.sort_keys == [ColumnSortKey(**k) for k in x_sort_keys] @@ -250,7 +255,7 @@ def _check_update_variable(name, update_type="schema", discard_state=True): # Execute code that triggers an update event for big_x because it's large shell.run_cell("print('hello world')") - _check_update_variable("big_x", update_type="data") + _check_update_variable(de_service, "big_x", update_type="data") # Update nested values in y and check for schema updates shell.run_cell( @@ -258,14 +263,14 @@ def _check_update_variable(name, update_type="schema", discard_state=True): 'key2': y['key2'].copy()} """ ) - _check_update_variable("y", update_type="update", discard_state=False) + _check_update_variable(de_service, "y", update_type="update", discard_state=False) shell.run_cell( """y = {'key1': y['key1'].iloc[:-1, :-1], 'key2': y['key2'].copy().iloc[:, 1:]} """ ) - _check_update_variable("y", update_type="schema", discard_state=True) + _check_update_variable(de_service, "y", update_type="schema", discard_state=True) def test_register_table(de_service: DataExplorerService): @@ -696,6 +701,39 @@ def test_pandas_set_sort_columns(pandas_fixture: PandasFixture): pandas_fixture.check_sort_case(df, wrapped_keys, expected_filtered, filters=filters) +def test_pandas_change_schema_after_sort( + shell: PositronShell, + de_service: DataExplorerService, + variables_comm: DummyComm, + pandas_fixture: PandasFixture, +): + df = pd.DataFrame( + { + "a": np.arange(10), + "b": np.arange(10), + "c": np.arange(10), + "d": np.arange(10), + "e": np.arange(10), + } + ) + shell.user_ns.update({"df": df}) + _open_viewer(variables_comm, ["df"]) + + # Sort a column that is out of bounds for the table after the + # schema change below + pandas_fixture.set_sort_columns("df", [{"column_index": 4, "ascending": True}]) + + expected_df = df[["a", "b"]] + pandas_fixture.register_table("expected_df", df) + + # Sort last column, and we will then change the schema + shell.run_cell("df = df[['a', 'b']]") + _check_update_variable(de_service, "df", update_type="schema", discard_state=True) + + # Call get_data_values and make sure it works + pandas_fixture.compare_tables("df", "expected_df", expected_df.shape) + + # def test_pandas_get_column_profile(pandas_fixture: PandasFixture): # pass