diff --git a/python-sdk/src/astro/files/base.py b/python-sdk/src/astro/files/base.py index 1b9ba8f6a..350286005 100644 --- a/python-sdk/src/astro/files/base.py +++ b/python-sdk/src/astro/files/base.py @@ -114,18 +114,19 @@ def is_pattern(self) -> bool: """ return not pathlib.PosixPath(self.path).suffix - def create_from_dataframe(self, df: pd.DataFrame, store_as_dataframe: bool = True) -> None: + def create_from_dataframe(self, df: pd.DataFrame, store_as_dataframe: bool = True, export_options: dict | None = None) -> None: """Create a file in the desired location using the values of a dataframe. :param store_as_dataframe: Whether the data should later be deserialized as a dataframe or as a file containing delimited data (e.g. csv, parquet, etc.). :param df: pandas dataframe + :param export_options: additional arguments to pass to the underlying write functionality """ self.is_dataframe = store_as_dataframe with self.location.get_stream() as stream: - self.type.create_from_dataframe(stream=stream, df=df) + self.type.create_from_dataframe(stream=stream, df=df, **export_options) @property def openlineage_dataset_namespace(self) -> str: diff --git a/python-sdk/src/astro/files/types/base.py b/python-sdk/src/astro/files/types/base.py index 48dcdda5e..d71a042f3 100644 --- a/python-sdk/src/astro/files/types/base.py +++ b/python-sdk/src/astro/files/types/base.py @@ -27,11 +27,12 @@ def export_to_dataframe(self, stream, **kwargs) -> pd.DataFrame: raise NotImplementedError @abstractmethod - def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None: + def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs) -> None: """Write file to one of the supported locations :param df: pandas dataframe :param stream: file stream object + :param kwargs: additional arguments to pass to the underlying write functionality """ raise NotImplementedError diff --git a/python-sdk/src/astro/files/types/csv.py b/python-sdk/src/astro/files/types/csv.py index f5d6e6229..2aae0bf7e 100644 --- a/python-sdk/src/astro/files/types/csv.py +++ b/python-sdk/src/astro/files/types/csv.py @@ -38,13 +38,15 @@ def export_to_dataframe( return PandasDataframe.from_pandas_df(df) # We need skipcq because it's a method overloading so we don't want to make it a static method - def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None: # skipcq PYL-R0201 + def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs) -> None: # skipcq PYL-R0201 """Write csv file to one of the supported locations :param df: pandas dataframe :param stream: file stream object + :param kwargs: additional arguments to pass to the pandas `to_csv` function """ - df.to_csv(stream, index=False) + + df.to_csv(stream, **dict(index=False, **kwargs)) @property def name(self): diff --git a/python-sdk/src/astro/files/types/excel.py b/python-sdk/src/astro/files/types/excel.py index 1073deaaf..e4e6b434a 100644 --- a/python-sdk/src/astro/files/types/excel.py +++ b/python-sdk/src/astro/files/types/excel.py @@ -37,10 +37,11 @@ def export_to_dataframe( return PandasDataframe.from_pandas_df(df) # We need skipcq because it's a method overloading so we don't want to make it a static method - def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None: # skipcq PYL-R0201 + def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs) -> None: # skipcq PYL-R0201 """Write Excel file to one of the supported locations :param df: pandas dataframe :param stream: file stream object + :param kwargs: additional arguments to pass to the pandas `to_excel` function """ - df.to_excel(stream, index=False) + df.to_excel(stream, **dict(index=False, **kwargs)) diff --git a/python-sdk/src/astro/files/types/json.py b/python-sdk/src/astro/files/types/json.py index 91cf878f7..18a153e70 100644 --- a/python-sdk/src/astro/files/types/json.py +++ b/python-sdk/src/astro/files/types/json.py @@ -42,13 +42,14 @@ def export_to_dataframe( return PandasDataframe.from_pandas_df(df) # We need skipcq because it's a method overloading so we don't want to make it a static method - def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None: # skipcq PYL-R0201 + def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs) -> None: # skipcq PYL-R0201 """Write json file to one of the supported locations :param df: pandas dataframe :param stream: file stream object + :param kwargs: additional arguments to pass to the pandas `to_json` function """ - df.to_json(stream, orient="records") + df.to_json(stream, **dict(orient="records", **kwargs)) @property def name(self): diff --git a/python-sdk/src/astro/files/types/ndjson.py b/python-sdk/src/astro/files/types/ndjson.py index 5bd92b33f..94935167d 100644 --- a/python-sdk/src/astro/files/types/ndjson.py +++ b/python-sdk/src/astro/files/types/ndjson.py @@ -39,13 +39,14 @@ def export_to_dataframe( return PandasDataframe.from_pandas_df(df) # We need skipcq because it's a method overloading so we don't want to make it a static method - def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None: # skipcq PYL-R0201 + def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs) -> None: # skipcq PYL-R0201 """Write ndjson file to one of the supported locations :param df: pandas dataframe :param stream: file stream object + :param kwargs: additional arguments to pass to the pandas `to_json` function """ - df.to_json(stream, orient="records", lines=True) + df.to_json(stream, **dict(orient="records", lines=True, **kwargs)) @property def name(self): diff --git a/python-sdk/src/astro/files/types/parquet.py b/python-sdk/src/astro/files/types/parquet.py index a6213dda2..1a61e7446 100644 --- a/python-sdk/src/astro/files/types/parquet.py +++ b/python-sdk/src/astro/files/types/parquet.py @@ -57,13 +57,14 @@ def _convert_remote_file_to_byte_stream(stream) -> io.IOBase: return remote_obj_buffer # We need skipcq because it's a method overloading so we don't want to make it a static method - def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None: # skipcq PYL-R0201 + def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs) -> None: # skipcq PYL-R0201 """Write parquet file to one of the supported locations :param df: pandas dataframe :param stream: file stream object + :param kwargs: additional arguments to pass to the pandas `to_parquet` method """ - df.to_parquet(stream) + df.to_parquet(stream, **kwargs) @property def name(self): diff --git a/python-sdk/src/astro/sql/operators/export_to_file.py b/python-sdk/src/astro/sql/operators/export_to_file.py index e4be82201..27a00a409 100644 --- a/python-sdk/src/astro/sql/operators/export_to_file.py +++ b/python-sdk/src/astro/sql/operators/export_to_file.py @@ -21,20 +21,23 @@ class ExportToFileOperator(AstroSQLBaseOperator): :param input_data: Table to convert to file :param output_file: File object containing the path to the file and connection id. :param if_exists: Overwrite file if exists. Default False. + :param export_options: Additional options to pass to the file export functions. """ - template_fields = ("input_data", "output_file") + template_fields = ("input_data", "output_file", "export_options") def __init__( self, input_data: BaseTable | pd.DataFrame, output_file: File, if_exists: ExportExistsStrategy = "exception", + export_options: dict | None = None, **kwargs, ) -> None: self.output_file = output_file self.input_data = input_data self.if_exists = if_exists + self.export_options = export_options or {} self.kwargs = kwargs datasets = {"output_datasets": self.output_file} if isinstance(input_data, Table): @@ -57,7 +60,7 @@ def execute(self, context: Context) -> File: # skipcq PYL-W0613 raise ValueError(f"Expected input_table to be Table or dataframe. Got {type(self.input_data)}") # Write file if overwrite == True or if file doesn't exist. if self.if_exists == "replace" or not self.output_file.exists(): - self.output_file.create_from_dataframe(df, store_as_dataframe=False) + self.output_file.create_from_dataframe(df, store_as_dataframe=False, export_options=self.export_options) return self.output_file else: raise FileExistsError(f"{self.output_file.path} file already exists.") @@ -144,7 +147,8 @@ def export_to_file( output_file: File, if_exists: ExportExistsStrategy = "exception", task_id: str | None = None, - **kwargs: Any, + export_options: dict | None = None, + **kwargs, ) -> XComArg: """Convert ExportToFileOperator into a function. Returns XComArg. @@ -170,6 +174,7 @@ def export_to_file( :param input_data: Input table / dataframe :param if_exists: Overwrite file if exists. Default "exception" :param task_id: task id, optional + :param export_options: Additional options to pass to the file export functions. """ task_id = task_id or get_unique_task_id("export_to_file") @@ -179,5 +184,6 @@ def export_to_file( output_file=output_file, input_data=input_data, if_exists=if_exists, + export_options=export_options, **kwargs, ).output diff --git a/python-sdk/tests/sql/operators/test_export_file.py b/python-sdk/tests/sql/operators/test_export_file.py index b474197e5..cb97b231f 100644 --- a/python-sdk/tests/sql/operators/test_export_file.py +++ b/python-sdk/tests/sql/operators/test_export_file.py @@ -37,6 +37,25 @@ def make_df(): assert df.equals(pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]})) +def test_save_dataframe_to_local_with_options(sample_dag): + @aql.dataframe + def make_df(): + return pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]}) + + with sample_dag: + df = make_df() + aql.export_to_file( + input_data=df, + output_file=File(path="/tmp/saved_df.csv"), + if_exists="replace", + export_options={"header": None}, + ) + test_utils.run_dag(sample_dag) + + df = pd.read_csv("/tmp/saved_df.csv") + assert df.equals(pd.DataFrame(data={"0": [1, 2], "1": [3, 4]})) + + @pytest.mark.parametrize("database_table_fixture", [{"database": Database.SQLITE}], indirect=True) def test_save_temp_table_to_local(sample_dag, database_table_fixture): _, test_table = database_table_fixture