diff --git a/src/databricks/labs/lsql/dashboards.py b/src/databricks/labs/lsql/dashboards.py index cd9ec1b2..7d1796a4 100644 --- a/src/databricks/labs/lsql/dashboards.py +++ b/src/databricks/labs/lsql/dashboards.py @@ -28,6 +28,9 @@ Page, Position, Query, + RenderFieldEncoding, + TableEncodingMap, + TableV2Spec, Widget, WidgetSpec, ) @@ -50,6 +53,25 @@ def from_dict(cls, raw: dict[str, str]) -> "DashboardMetadata": def as_dict(self) -> dict[str, str]: return dataclasses.asdict(self) + @classmethod + def from_path(cls, path: Path) -> "DashboardMetadata": + """Export dashboard metadata from a YAML file.""" + fallback_metadata = cls(display_name=path.parent.name) + + if not path.exists(): + return fallback_metadata + + try: + raw = yaml.safe_load(path.read_text()) + except yaml.YAMLError as e: + logger.warning(f"Parsing {path}: {e}") + return fallback_metadata + try: + return cls.from_dict(raw) + except KeyError as e: + logger.warning(f"Parsing {path}: {e}") + return fallback_metadata + class WidgetMetadata: def __init__( @@ -64,34 +86,7 @@ def __init__( self.order = order self.width = width self.height = height - self.id = _id - - size = self._size - self.width = self.width or size[0] - self.height = self.height or size[1] - self.id = self.id or path.stem - - def is_markdown(self) -> bool: - return self.path.suffix == ".md" - - @property - def spec_type(self) -> type[WidgetSpec]: - # TODO: When supporting more specs, infer spec from query - return CounterSpec - - @property - def _size(self) -> tuple[int, int]: - """Get the width and height for a widget. - - The tiling logic works if: - - width < _MAXIMUM_DASHBOARD_WIDTH : heights for widgets on the same row should be equal - - width == _MAXIMUM_DASHBOARD_WIDTH : any height - """ - if self.is_markdown(): - return _MAXIMUM_DASHBOARD_WIDTH, 2 - if self.spec_type == CounterSpec: - return 1, 3 - return 0, 0 + self.id = _id or path.stem def as_dict(self) -> dict[str, str]: body = {"path": self.path.as_posix()} @@ -103,6 +98,15 @@ def as_dict(self) -> dict[str, str]: body[attribute] = str(value) return body + def size(self) -> tuple[int, int]: + return self.width, self.height + + def is_markdown(self) -> bool: + return self.path.suffix == ".md" + + def is_query(self) -> bool: + return self.path.suffix == ".sql" + @staticmethod def _get_arguments_parser() -> ArgumentParser: parser = ArgumentParser("WidgetMetadata", add_help=False, exit_on_error=False) @@ -144,6 +148,137 @@ def from_path(cls, path: Path) -> "WidgetMetadata": return fallback_metadata.replace_from_arguments(shlex.split(first_comment)) +class Tile: + """A dashboard tile.""" + + def __init__(self, widget_metadata: WidgetMetadata) -> None: + self._widget_metadata = widget_metadata + + default_width, default_height = self._default_size() + width = self._widget_metadata.width or default_width + height = self._widget_metadata.height or default_height + self.position = Position(0, 0, width, height) + + def _default_size(self) -> tuple[int, int]: + return 0, 0 + + def place_after(self, position: Position) -> "Tile": + """Place the tile after another tile: + + The tiling logic works if: + - `position.width < _MAXIMUM_DASHBOARD_WIDTH` : tiles in a single row should have the same size + - `position.width == _MAXIMUM_DASHBOARD_WIDTH` : any height + """ + x = position.x + position.width + if x + self.position.width > _MAXIMUM_DASHBOARD_WIDTH: + x = 0 + y = position.y + position.height + else: + y = position.y + new_position = dataclasses.replace(self.position, x=x, y=y) + + replica = copy.deepcopy(self) + replica.position = new_position + return replica + + @property + def widget(self) -> Widget: + widget = Widget(name=self._widget_metadata.id, textbox_spec=self._widget_metadata.path.read_text()) + return widget + + @classmethod + def from_widget_metadata(cls, widget_metadata: WidgetMetadata) -> "Tile": + """Create a tile given the widget metadata.""" + if widget_metadata.is_markdown(): + return MarkdownTile(widget_metadata) + query_tile = QueryTile(widget_metadata) + spec_type = query_tile.infer_spec_type() + if spec_type is None: + return MarkdownTile(widget_metadata) + if spec_type == CounterSpec: + return CounterTile(widget_metadata) + return TableTile(widget_metadata) + + +class MarkdownTile(Tile): + def _default_size(self) -> tuple[int, int]: + return _MAXIMUM_DASHBOARD_WIDTH, 2 + + +class QueryTile(Tile): + def _get_abstract_syntax_tree(self) -> sqlglot.Expression | None: + query = self._widget_metadata.path.read_text() + try: + return sqlglot.parse_one(query, dialect=sqlglot.dialects.Databricks) + except sqlglot.ParseError as e: + logger.warning(f"Parsing {query}: {e}") + return None + + def _find_fields(self) -> list[Field]: + """Find the fields in a query. + + The fields are the projections in the query's top level SELECT. + """ + abstract_syntax_tree = self._get_abstract_syntax_tree() + if abstract_syntax_tree is None: + return [] + + fields = [] + for projection in abstract_syntax_tree.find_all(sqlglot.exp.Select): + if projection.depth > 0: + continue + for named_select in projection.named_selects: + field = Field(name=named_select, expression=f"`{named_select}`") + fields.append(field) + return fields + + @property + def widget(self) -> Widget: + fields = self._find_fields() + named_query = self._get_named_query(fields) + spec = self._get_spec(fields) + widget = Widget(name=self._widget_metadata.id, queries=[named_query], spec=spec) + return widget + + def _get_named_query(self, fields: list[Field]) -> NamedQuery: + query = Query(dataset_name=self._widget_metadata.id, fields=fields, disaggregated=True) + # As far as testing went, a NamedQuery should always have "main_query" as name + named_query = NamedQuery(name="main_query", query=query) + return named_query + + @staticmethod + def _get_spec(fields: list[Field]) -> WidgetSpec: + field_encodings = [RenderFieldEncoding(field_name=field.name) for field in fields] + table_encodings = TableEncodingMap(field_encodings) + spec = TableV2Spec(encodings=table_encodings) + return spec + + def infer_spec_type(self) -> type[WidgetSpec] | None: + """Infer the spec type from the query.""" + fields = self._find_fields() + if len(fields) == 0: + return None + if len(fields) == 1: + return CounterSpec + return TableV2Spec + + +class TableTile(QueryTile): + def _default_size(self) -> tuple[int, int]: + return 6, 6 + + +class CounterTile(QueryTile): + def _default_size(self) -> tuple[int, int]: + return 1, 3 + + @staticmethod + def _get_spec(fields: list[Field]) -> CounterSpec: + counter_encodings = CounterFieldEncoding(field_name=fields[0].name, display_name=fields[0].name) + spec = CounterSpec(CounterEncodingMap(value=counter_encodings)) + return spec + + class Dashboards: def __init__(self, ws: WorkspaceClient): self._ws = ws @@ -191,11 +326,10 @@ def _format_query(query: str) -> str: def create_dashboard(self, dashboard_folder: Path) -> Dashboard: """Create a dashboard from code, i.e. configuration and queries.""" - dashboard_metadata = self._parse_dashboard_metadata(dashboard_folder) - widgets_metadata = self._get_widgets_metadata(dashboard_folder) + dashboard_metadata = DashboardMetadata.from_path(dashboard_folder / "dashboard.yml") + widgets_metadata = self._parse_widgets_metadata(dashboard_folder) datasets = self._get_datasets(dashboard_folder) - widgets = self._get_widgets(widgets_metadata) - layouts = self._get_layouts(widgets, widgets_metadata) + layouts = self._get_layouts(widgets_metadata) page = Page( name=dashboard_metadata.display_name, display_name=dashboard_metadata.display_name, @@ -204,122 +338,48 @@ def create_dashboard(self, dashboard_folder: Path) -> Dashboard: lakeview_dashboard = Dashboard(datasets=datasets, pages=[page]) return lakeview_dashboard + @staticmethod + def _parse_widgets_metadata(dashboard_folder: Path) -> list[WidgetMetadata]: + """Parse the widget metadata from each (optional) header.""" + widgets_metadata = [] + for path in dashboard_folder.iterdir(): + if path.suffix in {".sql", ".md"}: + widget_metadata = WidgetMetadata.from_path(path) + widgets_metadata.append(widget_metadata) + return widgets_metadata + @staticmethod def _get_datasets(dashboard_folder: Path) -> list[Dataset]: datasets = [] for query_path in sorted(dashboard_folder.glob("*.sql")): - with query_path.open("r") as query_file: - raw_query = query_file.read() - dataset = Dataset(name=query_path.stem, display_name=query_path.stem, query=raw_query) + dataset = Dataset(name=query_path.stem, display_name=query_path.stem, query=query_path.read_text()) datasets.append(dataset) return datasets @staticmethod - def _get_widgets_metadata(dashboard_folder: Path) -> list[WidgetMetadata]: - """Read and parse the widget metadata from each (optional) header. - - The order is by default the alphanumerically sorted files, however, the order may be overwritten in the file - header with the `order` key. Hence, the multiple loops to get: - i) the optional order from the file header; - ii) set the order when not specified; - iii) sort the widgets using the order field. + def _get_layouts(widgets_metadata: list[WidgetMetadata]) -> list[Layout]: + """Create layouts from the widgets metadata. + + The order of the tiles is by default the alphanumerically sorted tile ids, however, the order may be overwritten + with the `order` key. Hence, the multiple loops to get: + i) set the order when not specified; + ii) sort the widgets using the order field. """ - widgets_metadata = [] - for path in sorted(dashboard_folder.iterdir()): - if path.suffix not in {".sql", ".md"}: - continue - widget_metadata = WidgetMetadata.from_path(path) - widgets_metadata.append(widget_metadata) widgets_metadata_with_order = [] for order, widget_metadata in enumerate(sorted(widgets_metadata, key=lambda wm: wm.id)): replica = copy.deepcopy(widget_metadata) replica.order = widget_metadata.order or order widgets_metadata_with_order.append(replica) - widgets_metadata_sorted = list(sorted(widgets_metadata_with_order, key=lambda wm: (wm.order, wm.id))) - return widgets_metadata_sorted - - def _get_widgets(self, widgets_metadata: list[WidgetMetadata]) -> list[Widget]: - widgets = [] - for widget_metadata in widgets_metadata: - try: - widget = self._get_widget(widget_metadata) - except sqlglot.ParseError as e: - logger.warning(f"Parsing {widget_metadata.path}: {e}") - continue - widgets.append(widget) - return widgets - - def _get_layouts(self, widgets: list[Widget], widgets_metadata: list[WidgetMetadata]) -> list[Layout]: - layouts, position = [], Position(0, 0, 0, 0) # First widget position - for widget, widget_metadata in zip(widgets, widgets_metadata): - position = self._get_position(position, widget_metadata) - layout = Layout(widget=widget, position=position) + + layouts, position = [], Position(0, 0, 0, 0) # Position of first tile + for widget_metadata in sorted(widgets_metadata_with_order, key=lambda wm: (wm.order, wm.id)): + tile = Tile.from_widget_metadata(widget_metadata) + placed_tile = tile.place_after(position) + layout = Layout(widget=placed_tile.widget, position=placed_tile.position) layouts.append(layout) + position = placed_tile.position return layouts - @staticmethod - def _parse_dashboard_metadata(dashboard_folder: Path) -> DashboardMetadata: - fallback_metadata = DashboardMetadata(display_name=dashboard_folder.name) - - dashboard_metadata_path = dashboard_folder / "dashboard.yml" - if not dashboard_metadata_path.exists(): - return fallback_metadata - - try: - raw = yaml.safe_load(dashboard_metadata_path.read_text()) - except yaml.YAMLError as e: - logger.warning(f"Parsing {dashboard_metadata_path}: {e}") - return fallback_metadata - try: - return DashboardMetadata.from_dict(raw) - except KeyError as e: - logger.warning(f"Parsing {dashboard_metadata_path}: {e}") - return fallback_metadata - - def _get_widget(self, widget_metadata: WidgetMetadata) -> Widget: - if widget_metadata.is_markdown(): - return self._get_text_widget(widget_metadata) - return self._get_counter_widget(widget_metadata) - - @staticmethod - def _get_text_widget(widget_metadata: WidgetMetadata) -> Widget: - widget = Widget(name=widget_metadata.id, textbox_spec=widget_metadata.path.read_text()) - return widget - - def _get_counter_widget(self, widget_metadata: WidgetMetadata) -> Widget: - fields = self._get_fields(widget_metadata.path.read_text()) - query = Query(dataset_name=widget_metadata.id, fields=fields, disaggregated=True) - # As far as testing went, a NamedQuery should always have "main_query" as name - named_query = NamedQuery(name="main_query", query=query) - # Counters are expected to have one field - counter_field_encoding = CounterFieldEncoding(field_name=fields[0].name, display_name=fields[0].name) - counter_spec = CounterSpec(CounterEncodingMap(value=counter_field_encoding)) - widget = Widget(name=widget_metadata.id, queries=[named_query], spec=counter_spec) - return widget - - @staticmethod - def _get_fields(query: str) -> list[Field]: - parsed_query = sqlglot.parse_one(query, dialect=sqlglot.dialects.Databricks) - fields = [] - for projection in parsed_query.find_all(sqlglot.exp.Select): - if projection.depth > 0: - continue - for named_select in projection.named_selects: - field = Field(name=named_select, expression=f"`{named_select}`") - fields.append(field) - return fields - - @staticmethod - def _get_position(previous_position: Position, widget_metadata: WidgetMetadata) -> Position: - x = previous_position.x + previous_position.width - if x + widget_metadata.width > _MAXIMUM_DASHBOARD_WIDTH: - x = 0 - y = previous_position.y + previous_position.height - else: - y = previous_position.y - position = Position(x=x, y=y, width=widget_metadata.width, height=widget_metadata.height) - return position - def deploy_dashboard(self, lakeview_dashboard: Dashboard, *, dashboard_id: str | None = None) -> SDKDashboard: """Deploy a lakeview dashboard.""" if dashboard_id is not None: diff --git a/tests/integration/dashboards/one_table/databricks_office_locations.sql b/tests/integration/dashboards/one_table/databricks_office_locations.sql new file mode 100644 index 00000000..fbe8a2e4 --- /dev/null +++ b/tests/integration/dashboards/one_table/databricks_office_locations.sql @@ -0,0 +1,13 @@ +SELECT + Address, + City, + State, + `Zip Code`, + Country +FROM +VALUES + ('160 Spear St 15th Floor', 'San Francisco', 'CA', '94105', 'USA'), + ('756 W Peachtree St NW, Suite 03W114', 'Atlanta', 'GA', '30308', 'USA'), + ('500 108th Ave NE, Suite 1820', 'Bellevue', 'WA', '98004', 'USA'), + ('125 High St, Suite 220', 'Boston', 'MA', '02110', 'USA'), + ('2120 University Ave, Suite 722', 'Berkeley', 'CA', '94704', 'USA') AS tab(Address, City, State, `Zip Code`, Country) diff --git a/tests/integration/test_dashboards.py b/tests/integration/test_dashboards.py index 21452fff..864df51d 100644 --- a/tests/integration/test_dashboards.py +++ b/tests/integration/test_dashboards.py @@ -155,3 +155,32 @@ def test_dashboards_deploys_dashboard_with_order_overwrite(ws, make_dashboard, t sdk_dashboard = dashboards.deploy_dashboard(lakeview_dashboard, dashboard_id=sdk_dashboard.dashboard_id) assert ws.lakeview.get(sdk_dashboard.dashboard_id) + + +def test_dashboard_deploys_dashboard_with_table(ws, make_dashboard): + sdk_dashboard = make_dashboard() + + dashboard_folder = Path(__file__).parent / "dashboards" / "one_table" + dashboards = Dashboards(ws) + lakeview_dashboard = dashboards.create_dashboard(dashboard_folder) + + sdk_dashboard = dashboards.deploy_dashboard(lakeview_dashboard, dashboard_id=sdk_dashboard.dashboard_id) + + assert ws.lakeview.get(sdk_dashboard.dashboard_id) + + +def test_dashboards_deploys_dashboard_with_invalid_query(ws, make_dashboard, tmp_path): + sdk_dashboard = make_dashboard() + + for query_name in range(6): + with (tmp_path / f"{query_name}.sql").open("w") as f: + f.write(f"SELECT {query_name} AS count") + with (tmp_path / "4.sql").open("w") as f: + f.write("SELECT COUNT(* AS invalid_column") + + dashboards = Dashboards(ws) + lakeview_dashboard = dashboards.create_dashboard(tmp_path) + + sdk_dashboard = dashboards.deploy_dashboard(lakeview_dashboard, dashboard_id=sdk_dashboard.dashboard_id) + + assert ws.lakeview.get(sdk_dashboard.dashboard_id) diff --git a/tests/unit/test_dashboards.py b/tests/unit/test_dashboards.py index 0a2c06c2..c085e249 100644 --- a/tests/unit/test_dashboards.py +++ b/tests/unit/test_dashboards.py @@ -3,11 +3,14 @@ from unittest.mock import create_autospec import pytest +import yaml from databricks.sdk import WorkspaceClient from databricks.labs.lsql.dashboards import ( DashboardMetadata, Dashboards, + QueryTile, + Tile, WidgetMetadata, ) from databricks.labs.lsql.lakeview import ( @@ -20,26 +23,68 @@ Page, Position, Query, + TableV2Spec, Widget, ) -def test_dashboard_configuration_raises_key_error_if_display_name_is_missing(): +def test_dashboard_metadata_raises_key_error_if_display_name_is_missing(): with pytest.raises(KeyError): DashboardMetadata.from_dict({}) -def test_dashboard_configuration_sets_display_name_from_dict(): +def test_dashboard_metadata_sets_display_name_from_dict(): dashboard_metadata = DashboardMetadata.from_dict({"display_name": "test"}) assert dashboard_metadata.display_name == "test" -def test_dashboard_configuration_from_and_as_dict_is_a_unit_function(): +def test_dashboard_metadata_from_and_as_dict_is_a_unit_function(): raw = {"display_name": "test"} dashboard_metadata = DashboardMetadata.from_dict(raw) assert dashboard_metadata.as_dict() == raw +def test_dashboard_metadata_from_raw(tmp_path): + raw = {"display_name": "test"} + + path = tmp_path / "dashboard.yml" + with path.open("w") as f: + yaml.safe_dump(raw, f) + + from_dict = DashboardMetadata.from_dict(raw) + from_path = DashboardMetadata.from_path(path) + + for dashboard_metadata in from_dict, from_path: + assert dashboard_metadata.display_name == "test" + + +@pytest.mark.parametrize("dashboard_content", ["missing_display_name: true", "invalid:\nyml", ""]) +def test_dashboard_metadata_handles_invalid_yml(tmp_path, dashboard_content): + path = tmp_path / "dashboard.yml" + if len(dashboard_content) > 0: + path.write_text(dashboard_content) + + dashboard_metadata = DashboardMetadata.from_path(path) + assert dashboard_metadata.display_name == tmp_path.name + + +def test_widget_metadata_sets_size(): + widget_metadata = WidgetMetadata(Path("test.sql"), 1, 10, 10) + assert widget_metadata.size() == (10, 10) + + +def test_widget_metadata_is_markdown(): + widget_metadata = WidgetMetadata(Path("test.md")) + assert widget_metadata.is_markdown() + assert not widget_metadata.is_query() + + +def test_widget_metadata_is_query(): + widget_metadata = WidgetMetadata(Path("test.sql")) + assert not widget_metadata.is_markdown() + assert widget_metadata.is_query() + + def test_widget_metadata_replaces_width_and_height(): widget_metadata = WidgetMetadata(Path("test.sql"), 1, 1, 1) updated_metadata = widget_metadata.replace_from_arguments(["--width", "10", "--height", "10"]) @@ -60,6 +105,28 @@ def test_widget_metadata_as_dict(): assert widget_metadata.as_dict() == raw +def test_tile_places_tile_to_the_right(): + widget_metadata = WidgetMetadata(Path("test.sql"), 1, 1, 1) + tile = Tile(widget_metadata) + + position = Position(0, 4, 3, 4) + placed_tile = tile.place_after(position) + + assert placed_tile.position.x == position.x + position.width + assert placed_tile.position.y == 4 + + +def test_tile_places_tile_below(): + widget_metadata = WidgetMetadata(Path("test.sql"), 1, 1, 1) + tile = Tile(widget_metadata) + + position = Position(5, 4, 3, 4) + placed_tile = tile.place_after(position) + + assert placed_tile.position.x == 0 + assert placed_tile.position.y == 8 + + def test_dashboards_saves_sql_files_to_folder(tmp_path): ws = create_autospec(WorkspaceClient) queries = Path(__file__).parent / "queries" @@ -139,7 +206,7 @@ def test_dashboards_creates_one_counter_widget_per_query(): assert len(counter_widgets) == len([query for query in queries.glob("*.sql")]) -def test_dashboards_skips_invalid_query(tmp_path, caplog): +def test_dashboards_creates_text_widget_for_invalid_query(tmp_path, caplog): ws = create_autospec(WorkspaceClient) # Test for the invalid query not to be the first or last query @@ -154,7 +221,8 @@ def test_dashboards_skips_invalid_query(tmp_path, caplog): with caplog.at_level(logging.WARNING, logger="databricks.labs.lsql.dashboards"): lakeview_dashboard = Dashboards(ws).create_dashboard(tmp_path) - assert len(lakeview_dashboard.pages[0].layout) == 2 + markdown_widget = lakeview_dashboard.pages[0].layout[1].widget + assert markdown_widget.textbox_spec == invalid_query assert invalid_query in caplog.text @@ -234,16 +302,16 @@ def test_dashboards_does_not_create_widget_for_yml_file(tmp_path, caplog): ("SELECT from_unixtime(timestamp) AS timestamp FROM table", ["timestamp"]), ], ) -def test_dashboards_gets_fields_with_expected_names(tmp_path, query, names): - with (tmp_path / "query.sql").open("w") as f: - f.write(query) +def test_query_tile_finds_fields(tmp_path, query, names): + query_file = tmp_path / "query.sql" + query_file.write_text(query) - ws = create_autospec(WorkspaceClient) - lakeview_dashboard = Dashboards(ws).create_dashboard(tmp_path) + widget_metadata = WidgetMetadata(query_file, 1, 1, 1) + tile = QueryTile(widget_metadata) + + fields = tile._find_fields() # pylint: disable=protected-access - fields = lakeview_dashboard.pages[0].layout[0].widget.queries[0].query.fields assert [field.name for field in fields] == names - ws.assert_not_called() def test_dashboards_creates_dashboard_with_expected_counter_field_encoding_names(tmp_path): @@ -260,6 +328,20 @@ def test_dashboards_creates_dashboard_with_expected_counter_field_encoding_names ws.assert_not_called() +def test_dashboards_creates_dashboard_with_expected_table_field_encodings(tmp_path): + with (tmp_path / "query.sql").open("w") as f: + f.write("SELECT 1 AS first, 2 AS second") + + ws = create_autospec(WorkspaceClient) + lakeview_dashboard = Dashboards(ws).create_dashboard(tmp_path) + + table_spec = lakeview_dashboard.pages[0].layout[0].widget.spec + assert isinstance(table_spec, TableV2Spec) + assert table_spec.encodings.columns[0].field_name == "first" + assert table_spec.encodings.columns[1].field_name == "second" + ws.assert_not_called() + + def test_dashboards_creates_dashboards_with_second_widget_to_the_right_of_the_first_widget(tmp_path): ws = create_autospec(WorkspaceClient) @@ -356,7 +438,13 @@ def test_dashboards_creates_dashboards_with_widget_ordered_using_id(tmp_path): ws.assert_not_called() -@pytest.mark.parametrize("query, width, height", [("SELECT 1 AS count", 1, 3)]) +@pytest.mark.parametrize( + "query, width, height", + [ + ("SELECT 1 AS count", 1, 3), + ("SELECT 1 AS first, 2 AS second", 6, 6), + ], +) def test_dashboards_creates_dashboards_where_widget_has_expected_width_and_height(tmp_path, query, width, height): ws = create_autospec(WorkspaceClient)