diff --git a/extensions/positron-python/pythonFiles/positron/positron_ipykernel/connections.py b/extensions/positron-python/pythonFiles/positron/positron_ipykernel/connections.py index 8bf793dfd346..aaefdda6e3d5 100644 --- a/extensions/positron-python/pythonFiles/positron/positron_ipykernel/connections.py +++ b/extensions/positron-python/pythonFiles/positron/positron_ipykernel/connections.py @@ -5,7 +5,7 @@ import logging import uuid -from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, TypedDict, Union import comm @@ -49,6 +49,13 @@ class ConnectionObjectFields(TypedDict): dtype: str +class UnsupportedConnectionError(Exception): + pass + + +PathKey = Tuple[str, ...] + + class Connection: """ Base class representing a connection to a data source. @@ -128,7 +135,18 @@ def __init__(self, kernel: PositronIPyKernel, comm_target_name: str): self._kernel = kernel self._comm_target_name = comm_target_name - def register_connection(self, connection: Any) -> str: + # Maps from variable path to set of comm_ids serving requests. + # A variable can point to a single connection object in the pane. + # But a comm_id can be shared by multiple variable paths. + self.path_to_comm_ids: Dict[PathKey, str] = {} + + # Mapping from comm_id to the corresponding variable path. + # Multiple variables paths, might point to the same commm_id. + self.comm_id_to_path: Dict[str, Set[PathKey]] = {} + + def register_connection( + self, connection: Any, variable_path: Optional[List[str]] = None + ) -> str: """ Opens a connection to the given data source. @@ -152,6 +170,7 @@ def register_connection(self, connection: Any) -> str: conn.type, comm_id, ) + self._register_variable_path(variable_path, comm_id) return comm_id comm_id = str(uuid.uuid4()) @@ -161,10 +180,27 @@ def register_connection(self, connection: Any) -> str: data={"name": connection.display_name}, ) + self._register_variable_path(variable_path, comm_id) self.comm_id_to_connection[comm_id] = connection self.on_comm_open(base_comm) return comm_id + def _register_variable_path(self, variable_path: Optional[List[str]], comm_id: str) -> None: + if variable_path is None: + return + + if not isinstance(variable_path, list): + raise ValueError(variable_path) + + key = tuple(variable_path) + + if comm_id in self.comm_id_to_path: + self.comm_id_to_path[comm_id].add(key) + else: + self.comm_id_to_path[comm_id] = {key} + + self.path_to_comm_ids[key] = comm_id + def on_comm_open(self, comm: BaseComm): comm_id = comm.comm_id comm.on_close(lambda msg: self._close_connection(comm_id)) @@ -173,15 +209,53 @@ def on_comm_open(self, comm: BaseComm): self.comms[comm_id] = connections_comm def _wrap_connection(self, obj: Any) -> Connection: - # we don't want to import sqlalchemy for that - type_name = type(obj).__name__ + + if not self.object_is_supported(obj): + type_name = type(obj).__name__ + raise UnsupportedConnectionError(f"Unsupported connection type {type_name}") if safe_isinstance(obj, "sqlite3", "Connection"): return SQLite3Connection(obj) elif safe_isinstance(obj, "sqlalchemy", "Engine"): return SQLAlchemyConnection(obj) - raise ValueError(f"Unsupported connection type {type_name}") + def object_is_supported(self, obj: Any) -> bool: + """ + Checks if an object is supported by the connections pane. + """ + return safe_isinstance(obj, "sqlite3", "Connection") or safe_isinstance( + obj, "sqlalchemy", "Engine" + ) + + def variable_has_active_connections(self, variable_path: List[str]) -> bool: + """ + Checks if the given variable path has an active connection. + """ + return tuple(variable_path) in self.path_to_comm_ids + + def handle_variable_updated(self, variable_name, value) -> None: + """ + Handles a variable being updated in the kernel. + """ + comm_id = self.path_to_comm_ids.get(tuple(variable_name)) + if comm_id is None: + return + + try: + new_comm_id = self.register_connection(value, variable_path=[variable_name]) + except UnsupportedConnectionError: + # if an unsupported connection error, that it means the variable + # is no longer a connection, thus we close the connection. + self._close_connection(comm_id) + return + + # if the connection is the same, we don't need to do anything + if comm_id == new_comm_id: + return + + # if the connections is different, we handle it as if it was a variable deletion + # TODO: we don't really want to close the connection, but 'delete' the variable + self._close_connection(comm_id) def _close_connection(self, comm_id: str): diff --git a/extensions/positron-python/pythonFiles/positron/positron_ipykernel/variables.py b/extensions/positron-python/pythonFiles/positron/positron_ipykernel/variables.py index a9d2b5144365..c04bb284078f 100644 --- a/extensions/positron-python/pythonFiles/positron/positron_ipykernel/variables.py +++ b/extensions/positron-python/pythonFiles/positron/positron_ipykernel/variables.py @@ -122,7 +122,7 @@ def handle_msg( self._send_formatted_var(request.params.path, request.params.format) elif isinstance(request, ViewRequest): - self._open_data_explorer(request.params.path) + self._perform_view_action(request.params.path) else: logger.warning(f"Unhandled request: {request}") @@ -153,6 +153,7 @@ def _send_update(self, assigned: Mapping[str, Any], removed: Set[str]) -> None: # Look for any assigned or removed variables that are active # in the data explorer service exp_service = self.kernel.data_explorer_service + con_service = self.kernel.connections_service for name in removed: if exp_service.variable_has_active_explorers(name): exp_service.handle_variable_deleted(name) @@ -161,6 +162,9 @@ def _send_update(self, assigned: Mapping[str, Any], removed: Set[str]) -> None: if exp_service.variable_has_active_explorers(name): exp_service.handle_variable_updated(name, value) + if con_service.variable_has_active_connection(name): + con_service.handle_variable_updated(name, value) + # Ensure the number of changes does not exceed our maximum items if len(assigned) > MAX_ITEMS or len(removed) > MAX_ITEMS: return self.send_refresh_event() @@ -509,11 +513,11 @@ def _inspect_var(self, path: List[str]) -> None: f"Cannot find variable at '{path}' to inspect", ) - def _open_data_explorer(self, path: List[str]) -> None: - """Opens a DataExplorer comm for the variable at the requested - path in the current user session. - + def _perform_view_action(self, path: List[str]) -> None: """ + Performs the view action depending of the variable type. + """ + if path is None: return @@ -524,6 +528,15 @@ def _open_data_explorer(self, path: List[str]) -> None: f"Cannot find variable at '{path}' to view", ) + if self.kernel.connections_service.object_is_supported(value): + self._open_connections_pane(path, value) + else: + self._open_data_explorer(path, value) + + def _open_data_explorer(self, path: List[str], value: Any) -> None: + """Opens a DataExplorer comm for the variable at the requested + path in the current user session. + """ # Use the leaf segment to get the title access_key = path[-1] @@ -531,6 +544,14 @@ def _open_data_explorer(self, path: List[str]) -> None: self.kernel.data_explorer_service.register_table(value, title, variable_path=path) self._send_result({}) + def _open_connections_pane(self, path: List[str], value: Any) -> None: + """Opens a Connections comm for the variable at the requested + path in the current user session. + """ + # Use the leaf segment to get the title + self.kernel.connections_service.register_connection(value, variable_path=path) + self._send_result({}) + def _send_event(self, name: str, payload: JsonRecord) -> None: """ Send an event payload to the client.