diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 975f3890b9..b3347228be 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -582,10 +582,7 @@ def __getattr__(self, table: str) -> SupportsReadableRelation: ... def row_counts( self, *, data_tables: bool = True, dlt_tables: bool = False, table_names: List[str] = None - ) -> Dict[str, int]: - """Returns a dictionary of table names and their row counts""" - """If table_names is provided, only the tables in the list are returned regardless of the data_tables and dlt_tables flags""" - return {} + ) -> SupportsReadableRelation: ... class JobClientBase(ABC): diff --git a/dlt/destinations/dataset.py b/dlt/destinations/dataset.py index 38b5f8219e..e7c0f6e41d 100644 --- a/dlt/destinations/dataset.py +++ b/dlt/destinations/dataset.py @@ -302,7 +302,7 @@ def __getattr__(self, table_name: str) -> SupportsReadableRelation: def row_counts( self, *, data_tables: bool = True, dlt_tables: bool = False, table_names: List[str] = None - ) -> Dict[str, int]: + ) -> SupportsReadableRelation: """Returns a dictionary of table names and their row counts, returns counts of all data tables by default""" """If table_names is provided, only the tables in the list are returned regardless of the data_tables and dlt_tables flags""" @@ -324,8 +324,7 @@ def row_counts( query = " UNION ALL ".join(queries) # Execute query and build result dict - with self(query).cursor() as cursor: - return {row[0]: row[1] for row in cursor.fetchall()} + return self(query) def dataset( diff --git a/tests/load/test_read_interfaces.py b/tests/load/test_read_interfaces.py index a00be2bf26..36d25ad90b 100644 --- a/tests/load/test_read_interfaces.py +++ b/tests/load/test_read_interfaces.py @@ -304,27 +304,68 @@ def test_row_counts(populated_pipeline: Pipeline) -> None: dataset = populated_pipeline._dataset() # default is all data tables - assert dataset.row_counts() == { - "items": total_records, - "double_items": total_records, - "items__children": total_records * 2, + assert set(dataset.row_counts().fetchall()) == { + ( + "items", + total_records, + ), + ( + "double_items", + total_records, + ), + ( + "items__children", + total_records * 2, + ), } # get only one data table - assert dataset.row_counts(table_names=["items"]) == {"items": total_records} + assert set(dataset.row_counts(table_names=["items"]).fetchall()) == { + ( + "items", + total_records, + ), + } # get all dlt tables - assert dataset.row_counts(dlt_tables=True, data_tables=False) == { - "_dlt_version": 1, - "_dlt_loads": 1, - "_dlt_pipeline_state": 1, + assert set(dataset.row_counts(dlt_tables=True, data_tables=False).fetchall()) == { + ( + "_dlt_version", + 1, + ), + ( + "_dlt_loads", + 1, + ), + ( + "_dlt_pipeline_state", + 1, + ), } # get them all - assert dataset.row_counts(dlt_tables=True) == { - "_dlt_version": 1, - "_dlt_loads": 1, - "_dlt_pipeline_state": 1, - "items": total_records, - "double_items": total_records, - "items__children": total_records * 2, + assert set(dataset.row_counts(dlt_tables=True).fetchall()) == { + ( + "_dlt_version", + 1, + ), + ( + "_dlt_loads", + 1, + ), + ( + "_dlt_pipeline_state", + 1, + ), + ( + "items", + total_records, + ), + ( + "double_items", + total_records, + ), + ( + "items__children", + total_records * 2, + ), }