diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 6c198dd468..b3347228be 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -580,6 +580,10 @@ def __getitem__(self, table: str) -> SupportsReadableRelation: ... def __getattr__(self, table: str) -> SupportsReadableRelation: ... + def row_counts( + self, *, data_tables: bool = True, dlt_tables: bool = False, table_names: List[str] = None + ) -> SupportsReadableRelation: ... + class JobClientBase(ABC): def __init__( diff --git a/dlt/destinations/dataset.py b/dlt/destinations/dataset.py index cffdc0f059..e7c0f6e41d 100644 --- a/dlt/destinations/dataset.py +++ b/dlt/destinations/dataset.py @@ -1,4 +1,4 @@ -from typing import Any, Generator, Optional, Sequence, Union, List +from typing import Any, Generator, Optional, Sequence, Union, List, Dict from dlt.common.json import json from copy import deepcopy @@ -300,6 +300,32 @@ def __getattr__(self, table_name: str) -> SupportsReadableRelation: """access of table via property notation""" return self.table(table_name) + def row_counts( + self, *, data_tables: bool = True, dlt_tables: bool = False, table_names: List[str] = None + ) -> SupportsReadableRelation: + """Returns a dictionary of table names and their row counts, returns counts of all data tables by default""" + """If table_names is provided, only the tables in the list are returned regardless of the data_tables and dlt_tables flags""" + + selected_tables = table_names or [] + if not selected_tables: + if data_tables: + selected_tables += self.schema.data_table_names(seen_data_only=True) + if dlt_tables: + selected_tables += self.schema.dlt_table_names() + + # Build UNION ALL query to get row counts for all selected tables + queries = [] + for table in selected_tables: + queries.append( + f"SELECT '{table}' as table_name, COUNT(*) as row_count FROM" + f" {self.sql_client.make_qualified_table_name(table)}" + ) + + query = " UNION ALL ".join(queries) + + # Execute query and build result dict + return self(query) + def dataset( destination: TDestinationReferenceArg, diff --git a/tests/load/test_read_interfaces.py b/tests/load/test_read_interfaces.py index c6019ecf2d..a9803e151e 100644 --- a/tests/load/test_read_interfaces.py +++ b/tests/load/test_read_interfaces.py @@ -203,7 +203,6 @@ def test_dataframe_access(populated_pipeline: Pipeline) -> None: if not skip_df_chunk_size_check: assert len(df.index) == chunk_size - # lowercase results for the snowflake case assert set(df.columns.values) == set(EXPECTED_COLUMNS) # iterate all dataframes @@ -292,6 +291,90 @@ def test_loads_table_access(populated_pipeline: Pipeline) -> None: assert len(loads_table.fetchall()) == 1 +@pytest.mark.no_load +@pytest.mark.essential +@pytest.mark.parametrize( + "populated_pipeline", + configs, + indirect=True, + ids=lambda x: x.name, +) +def test_row_counts(populated_pipeline: Pipeline) -> None: + total_records = _total_records(populated_pipeline) + + dataset = populated_pipeline._dataset() + # default is all data tables + assert set(dataset.row_counts().df().itertuples(index=False, name=None)) == { + ( + "items", + total_records, + ), + ( + "double_items", + total_records, + ), + ( + "items__children", + total_records * 2, + ), + } + # get only one data table + assert set( + dataset.row_counts(table_names=["items"]).df().itertuples(index=False, name=None) + ) == { + ( + "items", + total_records, + ), + } + # get all dlt tables + assert set( + dataset.row_counts(dlt_tables=True, data_tables=False) + .df() + .itertuples(index=False, name=None) + ) == { + ( + "_dlt_version", + 1, + ), + ( + "_dlt_loads", + 1, + ), + ( + "_dlt_pipeline_state", + 1, + ), + } + # get them all + assert set(dataset.row_counts(dlt_tables=True).df().itertuples(index=False, name=None)) == { + ( + "_dlt_version", + 1, + ), + ( + "_dlt_loads", + 1, + ), + ( + "_dlt_pipeline_state", + 1, + ), + ( + "items", + total_records, + ), + ( + "double_items", + total_records, + ), + ( + "items__children", + total_records * 2, + ), + } + + @pytest.mark.no_load @pytest.mark.essential @pytest.mark.parametrize(