Skip to content

Commit

Permalink
Move table shape into data explorer get_state request and test pandas…
Browse files Browse the repository at this point in the history
… state requests (posit-dev/positron-python#393)

* Move table shape into get_state request and test pandas state requests

* Handle parameter change, cleaning

* fix pyright

* Rename shape again

* Dictify state result
  • Loading branch information
wesm committed Mar 28, 2024
1 parent 374bcc3 commit a64b0e1
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

from .access_keys import decode_access_key
from .data_explorer_comm import (
BackendState,
ColumnFilter,
ColumnFilterCompareOp,
ColumnSchema,
Expand All @@ -42,6 +41,8 @@
SetSortColumnsRequest,
TableData,
TableSchema,
TableShape,
TableState,
)
from .positron_comm import CommMessage, PositronComm
from .third_party import pd_
Expand Down Expand Up @@ -128,7 +129,7 @@ def get_column_profile(self, request: GetColumnProfileRequest):
return self._get_column_profile(request.params.profile_type, request.params.column_index)

def get_state(self, request: GetStateRequest):
return self._get_state()
return self._get_state().dict()

def _get_schema(self, column_start: int, num_columns: int) -> TableSchema:
raise NotImplementedError
Expand All @@ -154,7 +155,7 @@ def _get_column_profile(
) -> None:
raise NotImplementedError

def _get_state(self) -> BackendState:
def _get_state(self) -> TableState:
raise NotImplementedError


Expand Down Expand Up @@ -185,6 +186,7 @@ class PandasView(DataExplorerTableView):
"float64": "number",
"mixed-integer": "number",
"mixed-integer-float": "number",
"mixed": "unknown",
"decimal": "number",
"complex": "number",
"categorical": "categorical",
Expand Down Expand Up @@ -291,11 +293,7 @@ def _get_schema(self, column_start: int, num_columns: int) -> TableSchema:
)
column_schemas.append(col_schema)

return TableSchema(
columns=column_schemas,
num_rows=self.table.shape[0],
total_num_columns=self.table.shape[1],
)
return TableSchema(columns=column_schemas)

def _get_data_values(
self, row_start: int, num_rows: int, column_indices: Sequence[int]
Expand Down Expand Up @@ -420,8 +418,12 @@ def _get_column_profile(
) -> None:
pass

def _get_state(self) -> BackendState:
return BackendState(filters=self.filters, sort_keys=self.sort_keys)
def _get_state(self) -> TableState:
return TableState(
table_shape=TableShape(num_rows=self.table.shape[0], num_columns=self.table.shape[1]),
filters=self.filters,
sort_keys=self.sort_keys,
)


COMPARE_OPS = {
Expand Down Expand Up @@ -503,7 +505,13 @@ def shutdown(self) -> None:
for comm_id in list(self.comms.keys()):
self._close_explorer(comm_id)

def register_table(self, table, title, variable_path=None, comm_id=None):
def register_table(
self,
table,
title,
variable_path: Optional[List[str]] = None,
comm_id=None,
):
"""
Set up a new comm and data explorer table query wrapper to
handle requests and manage state.
Expand Down Expand Up @@ -552,6 +560,9 @@ def close_callback(msg):
base_comm.on_close(close_callback)

if variable_path is not None:
if not isinstance(variable_path, list):
raise ValueError(variable_path)

key = tuple(variable_path)
self.comm_id_to_path[comm_id] = key

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,6 @@ class TableSchema(BaseModel):
description="Schema for each column in the table",
)

num_rows: int = Field(
description="Numbers of rows in the unfiltered dataset",
)

total_num_columns: int = Field(
description="Total number of columns in the unfiltered dataset",
)


class TableData(BaseModel):
"""
Expand Down Expand Up @@ -211,11 +203,15 @@ class FreqtableCounts(BaseModel):
)


class BackendState(BaseModel):
class TableState(BaseModel):
"""
The current backend state
The current backend table state
"""

table_shape: TableShape = Field(
description="Provides number of rows and columns in table",
)

filters: List[ColumnFilter] = Field(
description="The set of currently applied filters",
)
Expand All @@ -225,6 +221,20 @@ class BackendState(BaseModel):
)


class TableShape(BaseModel):
"""
Provides number of rows and columns in table
"""

num_rows: int = Field(
description="Numbers of rows in the unfiltered dataset",
)

num_columns: int = Field(
description="Number of columns in the unfiltered dataset",
)


class ColumnSchema(BaseModel):
"""
Schema for a column in a table
Expand Down Expand Up @@ -548,7 +558,7 @@ class GetColumnProfileRequest(BaseModel):

class GetStateRequest(BaseModel):
"""
Request the current backend state (applied filters and sort columns)
Request the current table state (applied filters and sort columns)
"""

method: Literal[DataExplorerBackendRequest.GetState] = Field(
Expand Down Expand Up @@ -606,7 +616,9 @@ class SchemaUpdateParams(BaseModel):

FreqtableCounts.update_forward_refs()

BackendState.update_forward_refs()
TableState.update_forward_refs()

TableShape.update_forward_refs()

ColumnSchema.update_forward_refs()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
from ..access_keys import encode_access_key
from .._vendor.pydantic import BaseModel
from ..data_explorer import COMPARE_OPS, DataExplorerService
from ..data_explorer_comm import ColumnSchema, ColumnSortKey, FilterResult
from ..data_explorer_comm import (
ColumnFilter,
ColumnSchema,
ColumnSortKey,
FilterResult,
)

from .conftest import DummyComm, PositronShell
from .utils import json_rpc_notification, json_rpc_request, json_rpc_response
Expand Down Expand Up @@ -58,6 +63,17 @@ def get_last_message(de_service: DataExplorerService, comm_id: str):
# Test basic service functionality


class MyData:
def __init__(self, value):
self.value = value

def __str__(self):
return str(self.value)

def __repr__(self):
return repr(self.value)


SIMPLE_PANDAS_DF = pd.DataFrame(
{
"a": [1, 2, 3, 4, 5],
Expand All @@ -73,6 +89,7 @@ def get_last_message(de_service: DataExplorerService, comm_id: str):
"2024-01-05 00:00:00",
]
),
"f": [None, MyData(5), MyData(-1), None, None],
}
)

Expand Down Expand Up @@ -216,6 +233,7 @@ def _check_update_variable(name, update_type="schema", discard_state=True):

# Do a simple update and make sure that sort keys are preserved
x_comm_id = list(de_service.path_to_comm_ids[path_x])[0]
x_sort_keys = [{"column_index": 0, "ascending": True}]
msg = json_rpc_request(
"set_sort_columns",
params={"sort_keys": [{"column_index": 0, "ascending": True}]},
Expand All @@ -227,9 +245,15 @@ def _check_update_variable(name, update_type="schema", discard_state=True):
_check_update_variable("x", update_type="data")

tv = de_service.table_views[x_comm_id]
assert tv.sort_keys == [ColumnSortKey(column_index=0, ascending=True)]
assert tv.sort_keys == [ColumnSortKey(**k) for k in x_sort_keys]
assert tv._need_recompute

pf = PandasFixture(de_service)
new_state = pf.get_state("x")
assert new_state["table_shape"]["num_rows"] == 5
assert new_state["table_shape"]["num_columns"] == 1
assert new_state["sort_keys"] == [ColumnSortKey(**k) for k in x_sort_keys]

# Execute code that triggers an update event for big_x because it's large
shell.run_cell("print('hello world')")
_check_update_variable("big_x", update_type="data")
Expand Down Expand Up @@ -281,17 +305,30 @@ def test_shutdown(de_service: DataExplorerService):
class PandasFixture:
def __init__(self, de_service: DataExplorerService):
self.de_service = de_service
self._table_ids = {}

self.register_table("simple", SIMPLE_PANDAS_DF)

def register_table(self, table_name: str, table):
comm_id = guid()
self.de_service.register_table(table, table_name, comm_id=comm_id)
self._table_ids[table_name] = comm_id

paths = self.de_service.get_paths_for_variable(table_name)
for path in paths:
for old_comm_id in list(self.de_service.path_to_comm_ids[path]):
self.de_service._close_explorer(old_comm_id)

self.de_service.register_table(
table,
table_name,
comm_id=comm_id,
variable_path=[encode_access_key(table_name)],
)

def do_json_rpc(self, table_name, method, **params):
comm_id = self._table_ids[table_name]
paths = self.de_service.get_paths_for_variable(table_name)
assert len(paths) == 1

comm_id = list(self.de_service.path_to_comm_ids[paths[0]])[0]

request = json_rpc_request(
method,
params=params,
Expand All @@ -313,6 +350,9 @@ def get_schema(self, table_name, start_index, num_columns):
num_columns=num_columns,
)

def get_state(self, table_name):
return self.do_json_rpc(table_name, "get_state")

def get_data_values(self, table_name, **params):
return self.do_json_rpc(table_name, "get_data_values", **params)

Expand Down Expand Up @@ -372,10 +412,26 @@ def _wrap_json(model: Type[BaseModel], data: JsonRecords):
return [model(**d).dict() for d in data]


def test_pandas_get_state(pandas_fixture: PandasFixture):
result = pandas_fixture.get_state("simple")
assert result["table_shape"]["num_rows"] == 5
assert result["table_shape"]["num_columns"] == 6

sort_keys = [
{"column_index": 0, "ascending": True},
{"column_index": 1, "ascending": False},
]
filters = [_compare_filter(0, ">", 0), _compare_filter(0, "<", 5)]
pandas_fixture.set_sort_columns("simple", sort_keys=sort_keys)
pandas_fixture.set_column_filters("simple", filters=filters)

result = pandas_fixture.get_state("simple")
assert result["sort_keys"] == sort_keys
assert result["filters"] == [ColumnFilter(**f) for f in filters]


def test_pandas_get_schema(pandas_fixture: PandasFixture):
result = pandas_fixture.get_schema("simple", 0, 100)
assert result["num_rows"] == 5
assert result["total_num_columns"] == 5

full_schema = [
{
Expand Down Expand Up @@ -403,20 +459,15 @@ def test_pandas_get_schema(pandas_fixture: PandasFixture):
"type_name": "datetime64[ns]",
"type_display": "datetime",
},
{"column_name": "f", "type_name": "mixed", "type_display": "unknown"},
]

assert result["columns"] == _wrap_json(ColumnSchema, full_schema)

result = pandas_fixture.get_schema("simple", 2, 100)
assert result["num_rows"] == 5
assert result["total_num_columns"] == 5

assert result["columns"] == _wrap_json(ColumnSchema, full_schema[2:])

result = pandas_fixture.get_schema("simple", 5, 100)
assert result["num_rows"] == 5
assert result["total_num_columns"] == 5

result = pandas_fixture.get_schema("simple", 6, 100)
assert result["columns"] == []

# Make a really big schema
Expand All @@ -426,13 +477,9 @@ def test_pandas_get_schema(pandas_fixture: PandasFixture):
pandas_fixture.register_table(bigger_name, bigger_df)

result = pandas_fixture.get_schema(bigger_name, 0, 100)
assert result["num_rows"] == 5
assert result["total_num_columns"] == 500
assert result["columns"] == _wrap_json(ColumnSchema, bigger_schema[:100])

result = pandas_fixture.get_schema(bigger_name, 10, 10)
assert result["num_rows"] == 5
assert result["total_num_columns"] == 500
assert result["columns"] == _wrap_json(ColumnSchema, bigger_schema[10:20])


Expand Down Expand Up @@ -466,7 +513,7 @@ def test_pandas_get_data_values(pandas_fixture: PandasFixture):
"simple",
row_start_index=0,
num_rows=20,
column_indices=list(range(5)),
column_indices=list(range(6)),
)

# TODO: pandas pads all values to fixed width, do we want to do
Expand All @@ -483,6 +530,7 @@ def test_pandas_get_data_values(pandas_fixture: PandasFixture):
"2024-01-04 00:00:00",
"2024-01-05 00:00:00",
],
["None", "5", "-1", "None", "None"],
]

assert _trim_whitespace(result["columns"]) == expected_columns
Expand Down

0 comments on commit a64b0e1

Please sign in to comment.