Skip to content

Commit

Permalink
feat: add table reader
Browse files Browse the repository at this point in the history
  • Loading branch information
farbodahm committed Sep 5, 2024
1 parent 801fbad commit 8919420
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 14 deletions.
5 changes: 2 additions & 3 deletions src/sparkle/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ class Config:
version: str
database_bucket: str
kafka: KafkaConfig | None
input_database: TableConfig | None
output_database: TableConfig | None
iceberg_config: IcebergConfig | None
hive_table_input: TableConfig | None
iceberg_output_config: IcebergConfig | None
spark_trigger: str = '{"once": True}'
5 changes: 4 additions & 1 deletion src/sparkle/config/database_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,8 @@
class TableConfig:
"""Hive Table Configuration."""

bucket: str
database: str
table: str
bucket: str
catalog_name: str = "glue_catalog"
catalog_id: str | None = None
96 changes: 96 additions & 0 deletions src/sparkle/reader/table_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from pyspark.sql import SparkSession, DataFrame
from sparkle.config import Config
from sparkle.utils.logger import logger


class TableReader:
"""A class for reading tables from a specified catalog using Spark.
The `TableReader` class provides methods to read data from a table in a specified
catalog and database using Apache Spark. It supports reading tables with specified
configurations and provides utility methods to access the fully qualified table name.
Attributes:
spark (SparkSession): The Spark session used for reading data.
database_name (str): The name of the database containing the table.
table_name (str): The name of the table to read.
catalog_name (str): The name of the catalog containing the table. Defaults to "glue_catalog".
catalog_id (Optional[str]): The catalog ID, if applicable. Defaults to None.
"""

def __init__(
self,
spark: SparkSession,
database_name: str,
table_name: str,
catalog_name: str = "glue_catalog",
catalog_id: str | None = None,
):
"""Initializes a TableReader instance.
Args:
spark (SparkSession): The Spark session used for reading data.
database_name (str): The name of the database containing the table.
table_name (str): The name of the table to read.
catalog_name (str, optional): The name of the catalog containing the table.
Defaults to "glue_catalog".
catalog_id (Optional[str], optional): The catalog ID, if applicable. Defaults to None.
"""
self.spark = spark
self.database_name = database_name
self.table_name = table_name
self.catalog_name = catalog_name
self.catalog_id = catalog_id

@classmethod
def with_config(
cls, spark: SparkSession, config: Config, **kwargs
) -> "TableReader":
"""Creates a TableReader instance using a configuration object.
Args:
spark (SparkSession): The Spark session used for reading data.
config (Config): The configuration object containing table input configuration.
**kwargs: Additional keyword arguments passed to the TableReader initializer.
Returns:
TableReader: An instance of TableReader configured with the provided settings.
Raises:
ValueError: If the input configuration is missing in the provided config.
"""
if not config.hive_table_input:
raise ValueError("Hive input configuration is missing.")
return cls(
spark=spark,
database_name=config.hive_table_input.database,
table_name=config.hive_table_input.table,
catalog_name=config.hive_table_input.catalog_name,
catalog_id=config.hive_table_input.catalog_id,
**kwargs,
)

@property
def qualified_table_name(self) -> str:
"""Gets the fully qualified table name.
Returns:
str: The fully qualified table name in the format "catalog_name.database_name.table_name".
"""
return f"{self.catalog_name}.{self.database_name}.{self.table_name}"

def read(self) -> DataFrame:
"""Reads the table as a DataFrame.
This method reads data from the specified table in the configured catalog and database,
returning it as a Spark DataFrame.
Returns:
DataFrame: A Spark DataFrame containing the data read from the table.
"""
table_fqdn = self.qualified_table_name
logger.info(f"Reading dataframe from {table_fqdn}")
df = self.spark.read.table(table_fqdn)

logger.info(f"Finished reading dataframe from {table_fqdn}")
return df
20 changes: 10 additions & 10 deletions src/sparkle/writer/iceberg_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,18 +87,18 @@ def with_config(
Raises:
ValueError: If the Iceberg configuration is not provided in the config object.
"""
if not config.iceberg_config:
raise ValueError("Iceberg configuration is not provided")
if not config.iceberg_output_config:
raise ValueError("Iceberg output configuration is not provided")

return cls(
database_name=config.iceberg_config.database_name,
database_path=config.iceberg_config.database_path,
table_name=config.iceberg_config.table_name,
delete_before_write=config.iceberg_config.delete_before_write,
catalog_name=config.iceberg_config.catalog_name,
catalog_id=config.iceberg_config.catalog_id,
partitions=config.iceberg_config.partitions,
number_of_partitions=config.iceberg_config.number_of_partitions,
database_name=config.iceberg_output_config.database_name,
database_path=config.iceberg_output_config.database_path,
table_name=config.iceberg_output_config.table_name,
delete_before_write=config.iceberg_output_config.delete_before_write,
catalog_name=config.iceberg_output_config.catalog_name,
catalog_id=config.iceberg_output_config.catalog_id,
partitions=config.iceberg_output_config.partitions,
number_of_partitions=config.iceberg_output_config.number_of_partitions,
spark_session=spark,
**kwargs,
)
Expand Down
126 changes: 126 additions & 0 deletions tests/unit/reader/test_table_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import os
import pytest
from pyspark.sql import DataFrame
from sparkle.reader.table_reader import TableReader
from sparkle.config import TableConfig, Config

TEST_DB = "test_db"
TEST_TABLE = "test_table"
CATALOG = "glue_catalog"
WAREHOUSE = "./tmp/warehouse"


@pytest.fixture
def test_db_path(spark_session):
"""Fixture for setting up the test database and performing cleanup.
This fixture sets up the test database path and drops the test table after the tests.
Args:
spark_session (SparkSession): The Spark session fixture.
Yields:
str: The path to the test database.
"""
path = os.path.join(WAREHOUSE, CATALOG, TEST_DB)
spark_session.sql(f"CREATE DATABASE IF NOT EXISTS {TEST_DB}")
yield path
spark_session.sql(f"DROP TABLE IF EXISTS {CATALOG}.{TEST_DB}.{TEST_TABLE}")


@pytest.fixture
def config():
"""Fixture for creating a configuration object.
This fixture returns a Config object with necessary attributes set
for testing the TableReader class.
Returns:
Config: A configuration object with test database and table names.
"""
table_config = TableConfig(
database=TEST_DB,
table=TEST_TABLE,
bucket="test_bucket",
catalog_name=CATALOG,
catalog_id=None,
)
return Config(
app_name="test_app",
app_id="test_id",
version="1.0",
database_bucket="test_bucket",
kafka=None,
hive_table_input=table_config,
iceberg_output_config=None,
)


def test_table_reader_with_config(spark_session, config):
"""Test the TableReader initialization using a Config object.
This test verifies that the TableReader initializes correctly using
the class method with_config and a Config object.
Args:
spark_session (SparkSession): The Spark session fixture.
config (Config): The configuration object.
"""
reader = TableReader.with_config(spark=spark_session, config=config)

assert reader.database_name == TEST_DB
assert reader.table_name == TEST_TABLE
assert reader.catalog_name == CATALOG
assert reader.catalog_id is None


def test_qualified_table_name(spark_session):
"""Test the qualified_table_name property of TableReader.
This test checks that the qualified_table_name property returns the
correctly formatted string.
Args:
spark_session (SparkSession): The Spark session fixture.
"""
reader = TableReader(
spark=spark_session,
database_name=TEST_DB,
table_name=TEST_TABLE,
catalog_name=CATALOG,
)

assert reader.qualified_table_name == f"{CATALOG}.{TEST_DB}.{TEST_TABLE}"


def test_read_table(spark_session, test_db_path):
"""Test the read method of TableReader.
This test verifies that the read method correctly reads data from the specified
table and returns a DataFrame.
Args:
spark_session (SparkSession): The Spark session fixture.
test_db_path (str): Path to the test database.
"""
# Create a sample table for testing
spark_session.sql(
f"CREATE TABLE {CATALOG}.{TEST_DB}.{TEST_TABLE} (id INT, name STRING)"
)
spark_session.sql(
f"INSERT INTO {CATALOG}.{TEST_DB}.{TEST_TABLE} VALUES (1, 'Alice'), (2, 'Bob')"
)

reader = TableReader(
spark=spark_session,
database_name=TEST_DB,
table_name=TEST_TABLE,
catalog_name=CATALOG,
)

df = reader.read()

assert isinstance(df, DataFrame)
assert df.count() == 2
assert df.filter(df.name == "Alice").count() == 1
assert df.filter(df.name == "Bob").count() == 1

0 comments on commit 8919420

Please sign in to comment.