From 62497430ab192f370238b54e763fe9d3a4a32086 Mon Sep 17 00:00:00 2001 From: Farbod Ahmadian Date: Wed, 4 Sep 2024 10:46:03 +0200 Subject: [PATCH] feat: add Iceberg writer with tests --- src/sparkle/application/__init__.py | 4 +- src/sparkle/application/spark.py | 91 ++++++++++ src/sparkle/data_reader.py | 58 ------- src/sparkle/reader/__init__.py | 4 +- src/sparkle/utils/spark.py | 6 +- tests/conftest.py | 68 ++++++++ tests/unit/utils/test_spark.py | 32 ++++ tests/unit/writer/test_iceberg_writer.py | 206 +++++++++++++++++++++++ 8 files changed, 405 insertions(+), 64 deletions(-) create mode 100644 src/sparkle/application/spark.py delete mode 100644 src/sparkle/data_reader.py create mode 100644 tests/conftest.py create mode 100644 tests/unit/utils/test_spark.py create mode 100644 tests/unit/writer/test_iceberg_writer.py diff --git a/src/sparkle/application/__init__.py b/src/sparkle/application/__init__.py index 7fceb2a..dc33fdb 100644 --- a/src/sparkle/application/__init__.py +++ b/src/sparkle/application/__init__.py @@ -3,7 +3,9 @@ from pyspark.storagelevel import StorageLevel from sparkle.config import Config from sparkle.writer import Writer -from sparkle.data_reader import DataReader +from sparkle.reader.data_reader import DataReader + +PROCESS_TIME_COLUMN = "process_time" class Sparkle(abc.ABC): diff --git a/src/sparkle/application/spark.py b/src/sparkle/application/spark.py new file mode 100644 index 0000000..506733e --- /dev/null +++ b/src/sparkle/application/spark.py @@ -0,0 +1,91 @@ +import os +from pyspark.conf import SparkConf +from pyspark.context import SparkContext +from pyspark.sql import SparkSession +from sparkle.utils.logger import logger + +try: + from awsglue.context import GlueContext +except ImportError: + logger.warning("Could not import pyspark. This is expected if running locally.") + +_SPARK_EXTENSIONS = [ + "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions", +] +_SPARK_PACKAGES = [ + "org.apache.iceberg:iceberg-spark-runtime-3.3_2.12:1.3.1", + "org.apache.spark:spark-sql-kafka-0-10_2.12:3.3.0", + "org.apache.spark:spark-avro_2.12:3.3.0", +] + + +def get_local_session(): + """Create and return a local Spark session configured for use with Iceberg and Kafka. + + This function sets up a local Spark session with specific configurations for Iceberg + catalog, session extensions, and other relevant settings needed for local testing + and development. It supports optional custom Ivy settings for managing dependencies. + + Returns: + SparkSession: A configured Spark session instance for local use. + """ + ivy_settings_path = os.environ.get("IVY_SETTINGS_PATH", None) + LOCAL_CONFIG = { + "spark.sql.extensions": ",".join(_SPARK_EXTENSIONS), + "spark.jars.packages": ",".join(_SPARK_PACKAGES), + "spark.sql.jsonGenerator.ignoreNullFields": False, + "spark.sql.session.timeZone": "UTC", + "spark.sql.catalog.spark_catalog": "org.apache.iceberg.spark.SparkSessionCatalog", + "spark.sql.catalog.spark_catalog.type": "hive", + "spark.sql.catalog.local": "org.apache.iceberg.spark.SparkCatalog", + "spark.sql.catalog.local.type": "hadoop", + "spark.sql.catalog.local.warehouse": "./tmp/warehouse", + "spark.sql.defaultCatalog": "local", + } + + spark_conf = SparkConf() + + for key, value in LOCAL_CONFIG.items(): + spark_conf.set(key, str(value)) + + spark_session = ( + SparkSession.builder.master("local[*]") + .appName("LocalSparkleApp") + .config(conf=spark_conf) + ) + + if ivy_settings_path: + spark_session.config("spark.jars.ivySettings", ivy_settings_path) + + return spark_session.getOrCreate() + + +def get_glue_session(warehouse_location: str): + """Create and return a Glue session configured for use with Iceberg and AWS Glue Catalog. + + This function sets up a Spark session integrated with AWS Glue, using configurations + suitable for working with Iceberg tables stored in AWS S3. It configures Spark with + AWS Glue-specific settings, including catalog implementation and S3 file handling. + + Args: + warehouse_location (str): The S3 path to the warehouse location for Iceberg tables. + + Returns: + SparkSession: A configured Spark session instance integrated with AWS Glue. + """ + GLUE_CONFIG = { + "spark.sql.extensions": ",".join(_SPARK_EXTENSIONS), + "spark.sql.catalog.glue_catalog": "org.apache.iceberg.spark.SparkCatalog", + "spark.sql.catalog.glue_catalog.catalog-impl": "org.apache.iceberg.aws.glue.GlueCatalog", + "spark.sql.catalog.glue_catalog.io-impl": "org.apache.iceberg.aws.s3.S3FileIO", + "spark.sql.catalog.glue_catalog.warehouse": warehouse_location, + "spark.sql.jsonGenerator.ignoreNullFields": False, + } + + spark_conf = SparkConf() + + for key, value in GLUE_CONFIG.items(): + spark_conf.set(key, str(value)) + + glueContext = GlueContext(SparkContext.getOrCreate(conf=spark_conf)) + return glueContext.spark_session diff --git a/src/sparkle/data_reader.py b/src/sparkle/data_reader.py deleted file mode 100644 index 96cb3d5..0000000 --- a/src/sparkle/data_reader.py +++ /dev/null @@ -1,58 +0,0 @@ -from pyspark.sql import DataFrame, SparkSession -from sparkle.config import Config - - -class KafkaReader: - # TODO - pass - - -class TableReader: - # TODO - pass - - -class DataReader: - def __init__( - self, - entity_name: str, - kafka_reader: None | KafkaReader = None, - table_reader: None | TableReader = None, - spark: SparkSession | None = SparkSession.getActiveSession(), - ): - if not spark: - raise Exception("No Spark session is provided or discoverable.") - - self.entity_name = entity_name - self.kafka_reader = kafka_reader - self.table_reader = table_reader - self.spark = spark - - @classmethod - def with_config( - cls, entity_name: str, config: Config, spark: SparkSession - ) -> "DataReader": - return cls( - entity_name=entity_name, - kafka_reader=KafkaReader.with_config(config=config, spark=spark), - table_reader=TableReader.with_config(config=config, spark=spark), - spark=spark, - ) - - def stream(self, avro: bool = False) -> DataFrame: - if not self.kafka_reader: - raise Exception("No Kafka reader configuration found!") - if avro: - # TODO - if not self.kafka_reader.schema_registry: - raise Exception("Kafka reader doesn't support schema registry") - return self.kafka_reader.read_avro(topic=self.entity_name) - else: - return self.kafka_reader.read_data(topic=self.entity_name) - - def batch( - self, - ) -> DataFrame: - if not self.table_reader: - raise Exception("No table reader configuration found!") - return self.table_reader.read(table_name=self.entity_name) diff --git a/src/sparkle/reader/__init__.py b/src/sparkle/reader/__init__.py index ba7ae00..639cd76 100644 --- a/src/sparkle/reader/__init__.py +++ b/src/sparkle/reader/__init__.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from pyspark.sql import SparkSession +from pyspark.sql import SparkSession, DataFrame from sparkle.config import Config @@ -29,7 +29,7 @@ def with_config(cls, config: Config, spark: SparkSession): return cls(config, spark) @abstractmethod - def read(self): + def read(self) -> DataFrame: """Read data from the source. Returns: diff --git a/src/sparkle/utils/spark.py b/src/sparkle/utils/spark.py index 4f6b902..a8b5412 100644 --- a/src/sparkle/utils/spark.py +++ b/src/sparkle/utils/spark.py @@ -8,13 +8,13 @@ def table_exists( spark: SparkSession | None = SparkSession.getActiveSession(), catalog_name: str = "glue_catalog", ) -> bool: - """ - Checks if a table exists in the specified catalog and database. + """Checks if a table exists in the specified catalog and database. Args: database_name (str): The name of the database where the table is located. table_name (str): The name of the table to check for existence. - spark (SparkSession | None, optional): The current active Spark session. Defaults to the active Spark session if not provided. + spark (SparkSession | None, optional): The current active Spark session. + Defaults to the active Spark session if not provided. catalog_name (str, optional): The name of the catalog to search in. Defaults to "glue_catalog". Returns: diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..754e4bc --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,68 @@ +import pytest +from pyspark.sql import SparkSession +from sparkle.application.spark import get_local_session + + +@pytest.fixture(scope="session") +def spark_session() -> SparkSession: + """Fixture for creating a Spark session. + + This fixture creates a Spark session to be used in the tests. It attempts to get + the active Spark session if available; otherwise, it creates a new one using + `get_local_session`. + + Returns: + SparkSession: An active Spark session for use in tests. + """ + return SparkSession.getActiveSession() or get_local_session() + + +@pytest.fixture +def user_dataframe(spark_session: SparkSession): + """Fixture for creating a DataFrame with user data. + + This fixture creates a Spark DataFrame containing sample user data with columns + for name, surname, phone, and email. + + Args: + spark_session (SparkSession): The Spark session fixture. + + Returns: + pyspark.sql.DataFrame: A Spark DataFrame with sample user data. + """ + data = [ + { + "name": "John", + "surname": "Doe", + "phone": "12345", + "email": "john@test.com", + }, + { + "name": "Jane", + "surname": "Doe", + "phone": "12345", + "email": "jane.doe@test.com", + }, + ] + + return spark_session.createDataFrame(data) + + +@pytest.fixture +def teardown_table(spark_session, catalog, database, table): + """Fixture to drop a specified table after a test. + + This fixture is used to clean up by dropping the specified table after the test + is completed, ensuring the test environment remains clean. + + Args: + spark_session (SparkSession): The Spark session fixture. + catalog (str): The catalog where the table is located. + database (str): The database where the table is located. + table (str): The name of the table to drop. + + Yields: + None + """ + yield + spark_session.sql(f"DROP TABLE IF EXISTS {catalog}.{database}.{table}") diff --git a/tests/unit/utils/test_spark.py b/tests/unit/utils/test_spark.py new file mode 100644 index 0000000..1d99462 --- /dev/null +++ b/tests/unit/utils/test_spark.py @@ -0,0 +1,32 @@ +import pytest +from sparkle.writer.iceberg_writer import IcebergWriter +from sparkle.utils.spark import table_exists + + +@pytest.mark.parametrize( + "catalog, database, table", + [("glue_catalog", "test_db", "test_table")], +) +def test_table_exists(spark_session, teardown_table, catalog, database, table): + """Test the `table_exists` function for checking table existence in a catalog. + + Args: + spark_session (SparkSession): The Spark session fixture. + teardown_table (function): Fixture to clean up by dropping the specified table after the test. + catalog (str): The catalog where the table is located, provided via parametrization. + database (str): The database where the table is located, provided via parametrization. + table (str): The name of the table to test for existence, provided via parametrization. + """ + data = [{"id": "001", "value": "some_value"}] + df = spark_session.createDataFrame(data) + + writer = IcebergWriter( + database_name=database, + database_path="mock_path", + table_name=table, + spark_session=spark_session, + ) + writer.write(df) + + assert table_exists(database, table, spark_session) is True + assert table_exists(database, "NON_EXISTENT_TABLE", spark_session) is False diff --git a/tests/unit/writer/test_iceberg_writer.py b/tests/unit/writer/test_iceberg_writer.py new file mode 100644 index 0000000..1b7f61b --- /dev/null +++ b/tests/unit/writer/test_iceberg_writer.py @@ -0,0 +1,206 @@ +import datetime +import os +import pytest +from pyspark.sql.functions import days + +from sparkle.application import PROCESS_TIME_COLUMN +from sparkle.writer.iceberg_writer import IcebergWriter +from sparkle.utils.spark import table_exists + +TEST_DB = "default" +TEST_TABLE = "test_table" +WAREHOUSE = "./tmp/warehouse" +CATALOG = "glue_catalog" + + +@pytest.fixture +def test_db_path(spark_session): + """Fixture for creating the test database path. + + Sets up the path for the test database and performs cleanup by dropping + the test table after the test is completed. + + Args: + spark_session (SparkSession): The Spark session fixture. + + Yields: + str: The path to the test database. + """ + path = os.path.join(WAREHOUSE, CATALOG, TEST_DB) + + yield path + + # Teardown + spark_session.sql(f"DROP TABLE IF EXISTS {CATALOG}.{TEST_DB}.{TEST_TABLE}") + + +@pytest.fixture +def partition_df(spark_session): + """Fixture for creating a DataFrame with partitioned user data. + + This fixture creates a DataFrame with sample user data, including a timestamp + column used for partitioning. + + Args: + spark_session (SparkSession): The Spark session fixture. + + Returns: + pyspark.sql.DataFrame: A Spark DataFrame with partitioned user data. + """ + data = [ + { + "user_id": 1, + "name": "Bob", + PROCESS_TIME_COLUMN: datetime.datetime.fromisoformat("2023-11-03").replace( + tzinfo=datetime.timezone.utc + ), + }, + { + "user_id": 2, + "name": "Alice", + PROCESS_TIME_COLUMN: datetime.datetime.fromisoformat("2023-11-02").replace( + tzinfo=datetime.timezone.utc + ), + }, + ] + return spark_session.createDataFrame(data) + + +@pytest.fixture +def partition_df_evolved_schema(spark_session): + """Fixture for creating a DataFrame with an evolved schema. + + This fixture creates a DataFrame that includes an additional field not present + in the original schema, simulating a schema evolution scenario. + + Args: + spark_session (SparkSession): The Spark session fixture. + + Returns: + pyspark.sql.DataFrame: A Spark DataFrame with an evolved schema. + """ + data = [ + { + "user_id": 1, + "name": "Bob", + "new_field": "new_field_value", + PROCESS_TIME_COLUMN: datetime.datetime.fromisoformat("2023-11-03").replace( + tzinfo=datetime.timezone.utc + ), + } + ] + return spark_session.createDataFrame(data) + + +def test_writer_should_write_iceberg(user_dataframe, test_db_path, spark_session): + """Test that the IcebergWriter writes data to the Iceberg table. + + This test verifies that the IcebergWriter correctly writes the provided DataFrame + to the specified Iceberg table and checks that the table exists afterward. + + Args: + user_dataframe (pyspark.sql.DataFrame): Fixture providing sample user data. + test_db_path (str): Path to the test database. + spark_session (SparkSession): The Spark session fixture. + """ + writer = IcebergWriter( + database_name=TEST_DB, + database_path=test_db_path, + table_name=TEST_TABLE, + spark_session=spark_session, + ) + + writer.write(user_dataframe) + + assert table_exists( + database_name=TEST_DB, table_name=TEST_TABLE, spark=spark_session + ) + + +def test_write_with_partitions(test_db_path, partition_df, spark_session): + """Test writing data to Iceberg with partitioning. + + This test checks that data is correctly written to the Iceberg table with partitions + based on the `PROCESS_TIME_COLUMN`. It verifies the presence of partitioned data files. + + Args: + test_db_path (str): Path to the test database. + partition_df (pyspark.sql.DataFrame): DataFrame with partitioned user data. + spark_session (SparkSession): The Spark session fixture. + """ + writer = IcebergWriter( + database_name=TEST_DB, + database_path=test_db_path, + table_name=TEST_TABLE, + spark_session=spark_session, + partitions=[days(PROCESS_TIME_COLUMN)], + ) + + writer.write(partition_df) + + assert os.path.exists( + os.path.join( + test_db_path, TEST_TABLE, "data", f"{PROCESS_TIME_COLUMN}_day=2023-11-02" + ) + ) + assert os.path.exists( + os.path.join( + test_db_path, TEST_TABLE, "data", f"{PROCESS_TIME_COLUMN}_day=2023-11-03" + ) + ) + + +def test_write_with_partitions_no_partition_column_provided( + test_db_path, partition_df, spark_session +): + """Test writing data to Iceberg without specifying partitions. + + This test verifies that data is written to the Iceberg table without any partitions + when the partition list is explicitly set to empty. + + Args: + test_db_path (str): Path to the test database. + partition_df (pyspark.sql.DataFrame): DataFrame with partitioned user data. + spark_session (SparkSession): The Spark session fixture. + """ + writer = IcebergWriter( + database_name=TEST_DB, + database_path=test_db_path, + table_name=TEST_TABLE, + spark_session=spark_session, + partitions=[], # Explicitly setting no partitions + ) + + writer.write(partition_df) + + assert os.path.exists(os.path.join(test_db_path, TEST_TABLE, "data")) + + +def test_write_with_schema_evolution( + test_db_path, partition_df, partition_df_evolved_schema, spark_session +): + """Test writing data to Iceberg with schema evolution. + + This test checks that the Iceberg table correctly handles schema evolution by + adding new fields. It writes initial data and then writes data with an evolved + schema, verifying that the new field is present in the table schema. + + Args: + test_db_path (str): Path to the test database. + partition_df (pyspark.sql.DataFrame): DataFrame with initial schema. + partition_df_evolved_schema (pyspark.sql.DataFrame): DataFrame with evolved schema. + spark_session (SparkSession): The Spark session fixture. + """ + writer = IcebergWriter( + database_name=TEST_DB, + database_path=test_db_path, + table_name=TEST_TABLE, + spark_session=spark_session, + partitions=[], + ) + + writer.write(partition_df) + writer.write(partition_df_evolved_schema) + + final_df = spark_session.table(f"{CATALOG}.{TEST_DB}.{TEST_TABLE}") + assert "new_field" in final_df.schema.fieldNames()