diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 6c198dd468..975f3890b9 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -580,6 +580,13 @@ def __getitem__(self, table: str) -> SupportsReadableRelation: ... def __getattr__(self, table: str) -> SupportsReadableRelation: ... + def row_counts( + self, *, data_tables: bool = True, dlt_tables: bool = False, table_names: List[str] = None + ) -> Dict[str, int]: + """Returns a dictionary of table names and their row counts""" + """If table_names is provided, only the tables in the list are returned regardless of the data_tables and dlt_tables flags""" + return {} + class JobClientBase(ABC): def __init__( diff --git a/dlt/destinations/dataset.py b/dlt/destinations/dataset.py index cffdc0f059..38b5f8219e 100644 --- a/dlt/destinations/dataset.py +++ b/dlt/destinations/dataset.py @@ -1,4 +1,4 @@ -from typing import Any, Generator, Optional, Sequence, Union, List +from typing import Any, Generator, Optional, Sequence, Union, List, Dict from dlt.common.json import json from copy import deepcopy @@ -300,6 +300,33 @@ def __getattr__(self, table_name: str) -> SupportsReadableRelation: """access of table via property notation""" return self.table(table_name) + def row_counts( + self, *, data_tables: bool = True, dlt_tables: bool = False, table_names: List[str] = None + ) -> Dict[str, int]: + """Returns a dictionary of table names and their row counts, returns counts of all data tables by default""" + """If table_names is provided, only the tables in the list are returned regardless of the data_tables and dlt_tables flags""" + + selected_tables = table_names or [] + if not selected_tables: + if data_tables: + selected_tables += self.schema.data_table_names(seen_data_only=True) + if dlt_tables: + selected_tables += self.schema.dlt_table_names() + + # Build UNION ALL query to get row counts for all selected tables + queries = [] + for table in selected_tables: + queries.append( + f"SELECT '{table}' as table_name, COUNT(*) as row_count FROM" + f" {self.sql_client.make_qualified_table_name(table)}" + ) + + query = " UNION ALL ".join(queries) + + # Execute query and build result dict + with self(query).cursor() as cursor: + return {row[0]: row[1] for row in cursor.fetchall()} + def dataset( destination: TDestinationReferenceArg, diff --git a/tests/load/test_read_interfaces.py b/tests/load/test_read_interfaces.py index c6019ecf2d..a00be2bf26 100644 --- a/tests/load/test_read_interfaces.py +++ b/tests/load/test_read_interfaces.py @@ -203,7 +203,6 @@ def test_dataframe_access(populated_pipeline: Pipeline) -> None: if not skip_df_chunk_size_check: assert len(df.index) == chunk_size - # lowercase results for the snowflake case assert set(df.columns.values) == set(EXPECTED_COLUMNS) # iterate all dataframes @@ -292,6 +291,43 @@ def test_loads_table_access(populated_pipeline: Pipeline) -> None: assert len(loads_table.fetchall()) == 1 +@pytest.mark.no_load +@pytest.mark.essential +@pytest.mark.parametrize( + "populated_pipeline", + configs, + indirect=True, + ids=lambda x: x.name, +) +def test_row_counts(populated_pipeline: Pipeline) -> None: + total_records = _total_records(populated_pipeline) + + dataset = populated_pipeline._dataset() + # default is all data tables + assert dataset.row_counts() == { + "items": total_records, + "double_items": total_records, + "items__children": total_records * 2, + } + # get only one data table + assert dataset.row_counts(table_names=["items"]) == {"items": total_records} + # get all dlt tables + assert dataset.row_counts(dlt_tables=True, data_tables=False) == { + "_dlt_version": 1, + "_dlt_loads": 1, + "_dlt_pipeline_state": 1, + } + # get them all + assert dataset.row_counts(dlt_tables=True) == { + "_dlt_version": 1, + "_dlt_loads": 1, + "_dlt_pipeline_state": 1, + "items": total_records, + "double_items": total_records, + "items__children": total_records * 2, + } + + @pytest.mark.no_load @pytest.mark.essential @pytest.mark.parametrize(