Skip to content

Commit

Permalink
return readablerelation from row_counts method
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Nov 19, 2024
1 parent 7eaeecb commit d472ef0
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 23 deletions.
5 changes: 1 addition & 4 deletions dlt/common/destination/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,10 +582,7 @@ def __getattr__(self, table: str) -> SupportsReadableRelation: ...

def row_counts(
self, *, data_tables: bool = True, dlt_tables: bool = False, table_names: List[str] = None
) -> Dict[str, int]:
"""Returns a dictionary of table names and their row counts"""
"""If table_names is provided, only the tables in the list are returned regardless of the data_tables and dlt_tables flags"""
return {}
) -> SupportsReadableRelation: ...


class JobClientBase(ABC):
Expand Down
5 changes: 2 additions & 3 deletions dlt/destinations/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def __getattr__(self, table_name: str) -> SupportsReadableRelation:

def row_counts(
self, *, data_tables: bool = True, dlt_tables: bool = False, table_names: List[str] = None
) -> Dict[str, int]:
) -> SupportsReadableRelation:
"""Returns a dictionary of table names and their row counts, returns counts of all data tables by default"""
"""If table_names is provided, only the tables in the list are returned regardless of the data_tables and dlt_tables flags"""

Expand All @@ -324,8 +324,7 @@ def row_counts(
query = " UNION ALL ".join(queries)

# Execute query and build result dict
with self(query).cursor() as cursor:
return {row[0]: row[1] for row in cursor.fetchall()}
return self(query)


def dataset(
Expand Down
73 changes: 57 additions & 16 deletions tests/load/test_read_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,27 +304,68 @@ def test_row_counts(populated_pipeline: Pipeline) -> None:

dataset = populated_pipeline._dataset()
# default is all data tables
assert dataset.row_counts() == {
"items": total_records,
"double_items": total_records,
"items__children": total_records * 2,
assert set(dataset.row_counts().fetchall()) == {
(
"items",
total_records,
),
(
"double_items",
total_records,
),
(
"items__children",
total_records * 2,
),
}
# get only one data table
assert dataset.row_counts(table_names=["items"]) == {"items": total_records}
assert set(dataset.row_counts(table_names=["items"]).fetchall()) == {
(
"items",
total_records,
),
}
# get all dlt tables
assert dataset.row_counts(dlt_tables=True, data_tables=False) == {
"_dlt_version": 1,
"_dlt_loads": 1,
"_dlt_pipeline_state": 1,
assert set(dataset.row_counts(dlt_tables=True, data_tables=False).fetchall()) == {
(
"_dlt_version",
1,
),
(
"_dlt_loads",
1,
),
(
"_dlt_pipeline_state",
1,
),
}
# get them all
assert dataset.row_counts(dlt_tables=True) == {
"_dlt_version": 1,
"_dlt_loads": 1,
"_dlt_pipeline_state": 1,
"items": total_records,
"double_items": total_records,
"items__children": total_records * 2,
assert set(dataset.row_counts(dlt_tables=True).fetchall()) == {
(
"_dlt_version",
1,
),
(
"_dlt_loads",
1,
),
(
"_dlt_pipeline_state",
1,
),
(
"items",
total_records,
),
(
"double_items",
total_records,
),
(
"items__children",
total_records * 2,
),
}


Expand Down

0 comments on commit d472ef0

Please sign in to comment.