Skip to content
This repository has been archived by the owner on Aug 25, 2023. It is now read-only.

Commit

Permalink
Merge branch 'main' into feature/direct-access
Browse files Browse the repository at this point in the history
  • Loading branch information
LauJohansson authored Oct 7, 2022
2 parents 16ab023 + 50f0821 commit bb11e84
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/atc/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from .drop_oldest_duplicate_transformer import ( # noqa: F401
DropOldestDuplicatesTransformer,
)
from .dropColumnsTransformer import DropColumnsTransformer # noqa: F401
from .selectColumnsTransformer import SelectColumnsTransformer # noqa: F401
from .simple_sql_transformer import SimpleSqlServerTransformer # noqa: F401
from .timezone_transformer_nc import TimeZoneTransformerNC # noqa: F401
from .union_transformer import UnionTransformer # noqa: F401
Expand Down
22 changes: 22 additions & 0 deletions src/atc/transformers/dropColumnsTransformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import List

from pyspark.sql import DataFrame

from atc.etl import TransformerNC


class DropColumnsTransformer(TransformerNC):
def __init__(
self,
columnList: List[str],
dataset_input_key: str = None,
dataset_output_key: str = None,
):
super().__init__(
dataset_input_key=dataset_input_key, dataset_output_key=dataset_output_key
)
self.columnList = columnList

def process(self, df: DataFrame) -> DataFrame:
df = df.drop(*self.columnList)
return df
21 changes: 21 additions & 0 deletions src/atc/transformers/selectColumnsTransformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import List

from pyspark.sql import DataFrame

from atc.etl import TransformerNC


class SelectColumnsTransformer(TransformerNC):
def __init__(
self,
columnList: List[str],
dataset_input_key: str = None,
dataset_output_key: str = None,
):
super().__init__(
dataset_input_key=dataset_input_key, dataset_output_key=dataset_output_key
)
self.columnList = columnList

def process(self, df: DataFrame) -> DataFrame:
return df.select(self.columnList)
36 changes: 36 additions & 0 deletions tests/local/transformers/test_drop_columns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import pyspark.sql.types as T
from atc_tools.testing import DataframeTestCase

from atc.spark import Spark
from atc.transformers import DropColumnsTransformer


class TestDropColumnsTransformer(DataframeTestCase):
def test_drop_columns_transformer(self):

inputSchema = T.StructType(
[
T.StructField("id", T.LongType(), True),
T.StructField("text1", T.StringType(), True),
T.StructField("text2", T.StringType(), True),
]
)

inputData = [
(
1,
"text1",
"text2",
),
]

df_input = Spark.get().createDataFrame(data=inputData, schema=inputSchema)

expectedData = [
(1,),
]

df_transformed = DropColumnsTransformer(columnList=["text1", "text2"]).process(
df_input
)
self.assertDataframeMatches(df_transformed, None, expectedData)
30 changes: 30 additions & 0 deletions tests/local/transformers/test_select_columns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pyspark.sql.types as T
from atc_tools.testing import DataframeTestCase

from atc.spark import Spark
from atc.transformers import SelectColumnsTransformer


class TestSelectColumnsTransformer(DataframeTestCase):
def test_select_columns_transformer(self):

inputSchema = T.StructType(
[
T.StructField("Col1", T.StringType(), True),
T.StructField("Col2", T.IntegerType(), True),
T.StructField("Col3", T.DoubleType(), True),
T.StructField("Col4", T.StringType(), True),
T.StructField("Col5", T.StringType(), True),
]
)
inputData = [("Col1Data", 42, 13.37, "Col4Data", "Col5Data")]

input_df = Spark.get().createDataFrame(data=inputData, schema=inputSchema)

expectedData = [(42, 13.37, "Col4Data")]

transformed_df = SelectColumnsTransformer(
columnList=["Col2", "Col3", "Col4"]
).process(input_df)

self.assertDataframeMatches(transformed_df, None, expectedData)

0 comments on commit bb11e84

Please sign in to comment.