Skip to content

Commit

Permalink
feat: add Iceberg writer with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
farbodahm committed Sep 4, 2024
1 parent cc6d6d9 commit 6249743
Show file tree
Hide file tree
Showing 8 changed files with 405 additions and 64 deletions.
4 changes: 3 additions & 1 deletion src/sparkle/application/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from pyspark.storagelevel import StorageLevel
from sparkle.config import Config
from sparkle.writer import Writer
from sparkle.data_reader import DataReader
from sparkle.reader.data_reader import DataReader

PROCESS_TIME_COLUMN = "process_time"


class Sparkle(abc.ABC):
Expand Down
91 changes: 91 additions & 0 deletions src/sparkle/application/spark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import os
from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.sql import SparkSession
from sparkle.utils.logger import logger

try:
from awsglue.context import GlueContext
except ImportError:
logger.warning("Could not import pyspark. This is expected if running locally.")

_SPARK_EXTENSIONS = [
"org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions",
]
_SPARK_PACKAGES = [
"org.apache.iceberg:iceberg-spark-runtime-3.3_2.12:1.3.1",
"org.apache.spark:spark-sql-kafka-0-10_2.12:3.3.0",
"org.apache.spark:spark-avro_2.12:3.3.0",
]


def get_local_session():
"""Create and return a local Spark session configured for use with Iceberg and Kafka.
This function sets up a local Spark session with specific configurations for Iceberg
catalog, session extensions, and other relevant settings needed for local testing
and development. It supports optional custom Ivy settings for managing dependencies.
Returns:
SparkSession: A configured Spark session instance for local use.
"""
ivy_settings_path = os.environ.get("IVY_SETTINGS_PATH", None)
LOCAL_CONFIG = {
"spark.sql.extensions": ",".join(_SPARK_EXTENSIONS),
"spark.jars.packages": ",".join(_SPARK_PACKAGES),
"spark.sql.jsonGenerator.ignoreNullFields": False,
"spark.sql.session.timeZone": "UTC",
"spark.sql.catalog.spark_catalog": "org.apache.iceberg.spark.SparkSessionCatalog",
"spark.sql.catalog.spark_catalog.type": "hive",
"spark.sql.catalog.local": "org.apache.iceberg.spark.SparkCatalog",
"spark.sql.catalog.local.type": "hadoop",
"spark.sql.catalog.local.warehouse": "./tmp/warehouse",
"spark.sql.defaultCatalog": "local",
}

spark_conf = SparkConf()

for key, value in LOCAL_CONFIG.items():
spark_conf.set(key, str(value))

spark_session = (
SparkSession.builder.master("local[*]")
.appName("LocalSparkleApp")
.config(conf=spark_conf)
)

if ivy_settings_path:
spark_session.config("spark.jars.ivySettings", ivy_settings_path)

return spark_session.getOrCreate()


def get_glue_session(warehouse_location: str):
"""Create and return a Glue session configured for use with Iceberg and AWS Glue Catalog.
This function sets up a Spark session integrated with AWS Glue, using configurations
suitable for working with Iceberg tables stored in AWS S3. It configures Spark with
AWS Glue-specific settings, including catalog implementation and S3 file handling.
Args:
warehouse_location (str): The S3 path to the warehouse location for Iceberg tables.
Returns:
SparkSession: A configured Spark session instance integrated with AWS Glue.
"""
GLUE_CONFIG = {
"spark.sql.extensions": ",".join(_SPARK_EXTENSIONS),
"spark.sql.catalog.glue_catalog": "org.apache.iceberg.spark.SparkCatalog",
"spark.sql.catalog.glue_catalog.catalog-impl": "org.apache.iceberg.aws.glue.GlueCatalog",
"spark.sql.catalog.glue_catalog.io-impl": "org.apache.iceberg.aws.s3.S3FileIO",
"spark.sql.catalog.glue_catalog.warehouse": warehouse_location,
"spark.sql.jsonGenerator.ignoreNullFields": False,
}

spark_conf = SparkConf()

for key, value in GLUE_CONFIG.items():
spark_conf.set(key, str(value))

glueContext = GlueContext(SparkContext.getOrCreate(conf=spark_conf))
return glueContext.spark_session
58 changes: 0 additions & 58 deletions src/sparkle/data_reader.py

This file was deleted.

4 changes: 2 additions & 2 deletions src/sparkle/reader/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from pyspark.sql import SparkSession
from pyspark.sql import SparkSession, DataFrame
from sparkle.config import Config


Expand Down Expand Up @@ -29,7 +29,7 @@ def with_config(cls, config: Config, spark: SparkSession):
return cls(config, spark)

@abstractmethod
def read(self):
def read(self) -> DataFrame:
"""Read data from the source.
Returns:
Expand Down
6 changes: 3 additions & 3 deletions src/sparkle/utils/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ def table_exists(
spark: SparkSession | None = SparkSession.getActiveSession(),
catalog_name: str = "glue_catalog",
) -> bool:
"""
Checks if a table exists in the specified catalog and database.
"""Checks if a table exists in the specified catalog and database.
Args:
database_name (str): The name of the database where the table is located.
table_name (str): The name of the table to check for existence.
spark (SparkSession | None, optional): The current active Spark session. Defaults to the active Spark session if not provided.
spark (SparkSession | None, optional): The current active Spark session.
Defaults to the active Spark session if not provided.
catalog_name (str, optional): The name of the catalog to search in. Defaults to "glue_catalog".
Returns:
Expand Down
68 changes: 68 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import pytest
from pyspark.sql import SparkSession
from sparkle.application.spark import get_local_session


@pytest.fixture(scope="session")
def spark_session() -> SparkSession:
"""Fixture for creating a Spark session.
This fixture creates a Spark session to be used in the tests. It attempts to get
the active Spark session if available; otherwise, it creates a new one using
`get_local_session`.
Returns:
SparkSession: An active Spark session for use in tests.
"""
return SparkSession.getActiveSession() or get_local_session()


@pytest.fixture
def user_dataframe(spark_session: SparkSession):
"""Fixture for creating a DataFrame with user data.
This fixture creates a Spark DataFrame containing sample user data with columns
for name, surname, phone, and email.
Args:
spark_session (SparkSession): The Spark session fixture.
Returns:
pyspark.sql.DataFrame: A Spark DataFrame with sample user data.
"""
data = [
{
"name": "John",
"surname": "Doe",
"phone": "12345",
"email": "[email protected]",
},
{
"name": "Jane",
"surname": "Doe",
"phone": "12345",
"email": "[email protected]",
},
]

return spark_session.createDataFrame(data)


@pytest.fixture
def teardown_table(spark_session, catalog, database, table):
"""Fixture to drop a specified table after a test.
This fixture is used to clean up by dropping the specified table after the test
is completed, ensuring the test environment remains clean.
Args:
spark_session (SparkSession): The Spark session fixture.
catalog (str): The catalog where the table is located.
database (str): The database where the table is located.
table (str): The name of the table to drop.
Yields:
None
"""
yield
spark_session.sql(f"DROP TABLE IF EXISTS {catalog}.{database}.{table}")
32 changes: 32 additions & 0 deletions tests/unit/utils/test_spark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest
from sparkle.writer.iceberg_writer import IcebergWriter
from sparkle.utils.spark import table_exists


@pytest.mark.parametrize(
"catalog, database, table",
[("glue_catalog", "test_db", "test_table")],
)
def test_table_exists(spark_session, teardown_table, catalog, database, table):
"""Test the `table_exists` function for checking table existence in a catalog.
Args:
spark_session (SparkSession): The Spark session fixture.
teardown_table (function): Fixture to clean up by dropping the specified table after the test.
catalog (str): The catalog where the table is located, provided via parametrization.
database (str): The database where the table is located, provided via parametrization.
table (str): The name of the table to test for existence, provided via parametrization.
"""
data = [{"id": "001", "value": "some_value"}]
df = spark_session.createDataFrame(data)

writer = IcebergWriter(
database_name=database,
database_path="mock_path",
table_name=table,
spark_session=spark_session,
)
writer.write(df)

assert table_exists(database, table, spark_session) is True
assert table_exists(database, "NON_EXISTENT_TABLE", spark_session) is False
Loading

0 comments on commit 6249743

Please sign in to comment.