From f04b913cf2b185cc16127276f3eff14eb0608f1a Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 23 Jun 2023 16:55:49 +0200 Subject: [PATCH 1/4] WIP writer --- dask_deltatable/write.py | 239 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 239 insertions(+) create mode 100644 dask_deltatable/write.py diff --git a/dask_deltatable/write.py b/dask_deltatable/write.py new file mode 100644 index 0000000..1c5e3a7 --- /dev/null +++ b/dask_deltatable/write.py @@ -0,0 +1,239 @@ +from __future__ import annotations + +import json +from deltalake import DeltaTable +from typing import Any, Literal, Mapping +from pathlib import Path +import pyarrow as pa +import pyarrow.dataset as ds +import dask.dataframe as dd +import pyarrow.fs as pa_fs +from dask.highlevelgraph import HighLevelGraph +from dask.dataframe.core import Scalar +from deltalake.writer import ( + try_get_table_and_table_uri, + __enforce_append_only, + MAX_SUPPORTED_WRITER_VERSION, + DeltaStorageHandler, + DeltaProtocolError, + get_partitions_from_path, + PYARROW_MAJOR_VERSION, + get_file_stats_from_metadata, + AddAction, + DeltaJSONEncoder, + _write_new_deltalake, +) +from datetime import datetime + +import uuid +from pyarrow.lib import RecordBatchReader +from dask.core import flatten + + +def write_deltalake( + table_or_uri: str | Path | DeltaTable, + df: dd.DataFrame, + *, + schema: pa.Schema | None = None, + partition_by: list[str] | str | None = None, + filesystem: pa_fs.FileSystem | None = None, + mode: Literal["error", "append", "overwrite", "ignore"] = "error", + file_options: ds.ParquetFileWriteOptions | None = None, + max_partitions: int | None = None, + max_open_files: int = 1024, + max_rows_per_file: int = 10 * 1024 * 1024, + min_rows_per_group: int = 64 * 1024, + max_rows_per_group: int = 128 * 1024, + name: str | None = None, + description: str | None = None, + configuration: Mapping[str, str | None] | None = None, + overwrite_schema: bool = False, + storage_options: dict[str, str] | None = None, + partition_filters: list[tuple[str, str, Any]] | None = None, +): + table, table_uri = try_get_table_and_table_uri(table_or_uri, storage_options) + + # We need to write against the latest table version + if table: + table.update_incremental() + + __enforce_append_only(table=table, configuration=configuration, mode=mode) + + if filesystem is None: + if table is not None: + storage_options = table._storage_options or {} + storage_options.update(storage_options or {}) + + filesystem = pa_fs.PyFileSystem(DeltaStorageHandler(table_uri, storage_options)) + + if isinstance(partition_by, str): + partition_by = [partition_by] + + if table: # already exists + if schema != table.schema().to_pyarrow() and not ( + mode == "overwrite" and overwrite_schema + ): + raise ValueError( + "Schema of data does not match table schema\n" + f"Table schema:\n{schema}\nData Schema:\n{table.schema().to_pyarrow()}" + ) + + if mode == "error": + raise AssertionError("DeltaTable already exists.") + elif mode == "ignore": + return + + current_version = table.version() + + if partition_by: + assert partition_by == table.metadata().partition_columns + else: + partition_by = table.metadata().partition_columns + + if table.protocol().min_writer_version > MAX_SUPPORTED_WRITER_VERSION: + raise DeltaProtocolError( + "This table's min_writer_version is " + f"{table.protocol().min_writer_version}, " + "but this method only supports version 2." + ) + else: # creating a new table + current_version = -1 + + if partition_by: + partition_schema = pa.schema([schema.field(name) for name in partition_by]) + partitioning = ds.partitioning(partition_schema, flavor="hive") + else: + partitioning = None + if mode == "overwrite": + # FIXME: There are a couple of checks that are not migrated yet + raise NotImplementedError() + + written = df.map_partitions( + _write_partition, + schema=schema, + partitioning=partitioning, + current_version=current_version, + file_options=file_options, + max_open_files=max_open_files, + max_rows_per_file=max_rows_per_file, + min_rows_per_group=min_rows_per_group, + max_rows_per_group=max_rows_per_group, + filesystem=filesystem, + max_partitions=max_partitions, + ) + final_name = "delta-commit" + dsk = { + (final_name, 0): ( + _commit, + table, + written.__dask_keys__(), + table_uri, + schema, + mode, + partition_by, + name, + description, + configuration, + storage_options, + partition_filters, + ) + } + graph = HighLevelGraph.from_collections(final_name, dsk, dependencies=(written,)) + return Scalar(graph, final_name, "") + + +def _commit( + table, + add_actions_nested, + table_uri, + schema, + mode, + partition_by, + name, + description, + configuration, + storage_options, + partition_filters, +): + add_actions = flatten(add_actions_nested) + if table is None: + _write_new_deltalake( + table_uri, + schema, + add_actions, + mode, + partition_by or [], + name, + description, + configuration, + storage_options, + ) + else: + table._table.create_write_transaction( + add_actions, + mode, + partition_by or [], + schema, + partition_filters, + ) + table.update_incremental() + + +def _write_partition( + df, + *, + schema, + partitioning, + current_version, + file_options, + max_open_files, + max_rows_per_file, + min_rows_per_group, + max_rows_per_group, + filesystem, + max_partitions, +): + data = pa.Table.from_pandas(df) + + add_actions: list[AddAction] = [] + + def visitor(written_file: Any) -> None: + path, partition_values = get_partitions_from_path(written_file.path) + stats = get_file_stats_from_metadata(written_file.metadata) + + # PyArrow added support for written_file.size in 9.0.0 + if PYARROW_MAJOR_VERSION >= 9: + size = written_file.size + else: + size = filesystem.get_file_info([path])[0].size # type: ignore + + add_actions.append( + AddAction( + path, + size, + partition_values, + int(datetime.now().timestamp() * 1000), + True, + json.dumps(stats, cls=DeltaJSONEncoder), + ) + ) + + ds.write_dataset( + data, + base_dir="/", + basename_template=f"{current_version + 1}-{uuid.uuid4()}-{{i}}.parquet", + format="parquet", + partitioning=partitioning, + # It will not accept a schema if using a RBR + schema=schema if not isinstance(data, RecordBatchReader) else None, + existing_data_behavior="overwrite_or_ignore", + file_options=file_options, + max_open_files=max_open_files, + file_visitor=visitor, + max_rows_per_file=max_rows_per_file, + min_rows_per_group=min_rows_per_group, + max_rows_per_group=max_rows_per_group, + filesystem=filesystem, + max_partitions=max_partitions, + ) + return add_actions From e36b34c6f55e5861b01326f7476a1c3157650e0c Mon Sep 17 00:00:00 2001 From: fjetter Date: Wed, 28 Jun 2023 16:05:06 +0200 Subject: [PATCH 2/4] Finish simple roundtrip --- .flake8 | 2 +- .pre-commit-config.yaml | 5 +- dask_deltatable/_schema.py | 325 +++++++++++++++++++++++++++++++++++++ dask_deltatable/core.py | 4 +- dask_deltatable/write.py | 65 +++++--- pyproject.toml | 1 + tests/test_write.py | 92 +++++++++++ 7 files changed, 468 insertions(+), 26 deletions(-) create mode 100644 dask_deltatable/_schema.py create mode 100644 tests/test_write.py diff --git a/.flake8 b/.flake8 index 7f39e8d..7a214b1 100644 --- a/.flake8 +++ b/.flake8 @@ -1,3 +1,3 @@ # flake8 doesn't support pyproject.toml yet https://github.com/PyCQA/flake8/issues/234 [flake8] -max-line-length = 104 +max-line-length = 120 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c691a09..9f691b6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,11 +30,12 @@ repos: args: [--warn-unused-configs] additional_dependencies: # Type stubs - - types-setuptools - boto3-stubs - - pytest - dask - deltalake + - pandas-stubs + - pytest + - types-setuptools - repo: https://github.com/pycqa/flake8 rev: 6.0.0 hooks: diff --git a/dask_deltatable/_schema.py b/dask_deltatable/_schema.py new file mode 100644 index 0000000..c85d7d5 --- /dev/null +++ b/dask_deltatable/_schema.py @@ -0,0 +1,325 @@ +from __future__ import annotations + +""" +Most of this code was taken from + +https://github.com/data-engineering-collective/plateau + +https://github.com/data-engineering-collective/plateau/blob/d4c4522f5a829d43e3368fc82e1568c91fa352f3/plateau/core/common_metadata.py + +and adapted to this project + +under the original license + +MIT License + +Copyright (c) 2022 The plateau contributors. +Copyright (c) 2020-2021 The kartothek contributors. +Copyright (c) 2019 JDA Software, Inc + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +""" +import difflib +import json +import logging +import pprint +from copy import deepcopy +from typing import Iterable + +import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq + +_logger = logging.getLogger() + + +class SchemaWrapper: + def __init__(self, schema: pa.Schema): + self.schema = schema + + def __hash__(self): + # FIXME: pyarrow raises a "cannot hash type dict" error + return hash(_schema2bytes(self.schema)) + + +def _pandas_in_schemas(schemas): + """Check if any schema contains pandas metadata.""" + has_pandas = False + for schema in schemas: + if schema.metadata and b"pandas" in schema.metadata: + has_pandas = True + return has_pandas + + +def _determine_schemas_to_compare( + schemas: Iterable[pa.Schema], ignore_pandas: bool +) -> tuple[pa.Schema | None, list[tuple[pa.Schema, list[str]]]]: + """Iterate over a list of `pyarrow.Schema` objects and prepares them for + comparison by picking a reference and determining all null columns. + + .. note:: + + If pandas metadata exists, the version stored in the metadata is overwritten with the currently + installed version since we expect to stay backwards compatible + + Returns + ------- + reference: Schema + A reference schema which is picked from the input list. The reference schema is guaranteed + to be a schema having the least number of null columns of all input columns. The set of null + columns is guaranteed to be a true subset of all null columns of all input schemas. If no such + schema can be found, an Exception is raised + list_of_schemas: List[Tuple[Schema, List]] + A list holding pairs of (Schema, null_columns) where the null_columns are all columns which are null and + must be removed before comparing the schemas + """ + has_pandas = _pandas_in_schemas(schemas) and not ignore_pandas + schemas_to_evaluate: list[tuple[pa.Schema, list[str]]] = [] + reference = None + null_cols_in_reference = set() + # Hashing the schemas is a very fast way to reduce the number of schemas to + # actually compare since in most circumstances this reduces to very few + # (which differ in e.g. null columns) + for schema_wrapped in set(map(SchemaWrapper, schemas)): + schema = schema_wrapped.schema + del schema_wrapped + if has_pandas: + metadata = schema.metadata + if metadata is None or b"pandas" not in metadata: + raise ValueError( + "Pandas and non-Pandas schemas are not comparable. " + "Use ignore_pandas=True if you only want to compare " + "on Arrow level." + ) + pandas_metadata = json.loads(metadata[b"pandas"].decode("utf8")) + + # we don't care about the pandas version, since we assume it's safe + # to read datasets that were written by older or newer versions. + pandas_metadata["pandas_version"] = f"{pd.__version__}" + + metadata_clean = deepcopy(metadata) + metadata_clean[b"pandas"] = _dict_to_binary(pandas_metadata) + current = pa.schema(schema, metadata_clean) + else: + current = schema + + # If a field is null we cannot compare it and must therefore reject it + null_columns = {field.name for field in current if field.type == pa.null()} + + # Determine a valid reference schema. A valid reference schema is considered to be the schema + # of all input schemas with the least empty columns. + # The reference schema ought to be a schema whose empty columns are a true subset for all sets + # of empty columns. This ensures that the actual reference schema is the schema with the most + # information possible. A schema which doesn't fulfil this requirement would weaken the + # comparison and would allow for false positives + + # Trivial case + if reference is None: + reference = current + null_cols_in_reference = null_columns + # The reference has enough information to validate against current schema. + # Append it to the list of schemas to be verified + elif null_cols_in_reference.issubset(null_columns): + schemas_to_evaluate.append((current, null_columns)) + # current schema includes all information of reference and more. + # Add reference to schemas_to_evaluate and update reference + elif null_columns.issubset(null_cols_in_reference): + schemas_to_evaluate.append((reference, list(null_cols_in_reference))) + reference = current + null_cols_in_reference = null_columns + # If there is no clear subset available elect the schema with the least null columns as `reference`. + # Iterate over the null columns of `reference` and replace it with a non-null field of the `current` + # schema which recovers the loop invariant (null columns of `reference` is subset of `current`) + else: + if len(null_columns) < len(null_cols_in_reference): + reference, current = current, reference + null_cols_in_reference, null_columns = ( + null_columns, + null_cols_in_reference, + ) + + for col in null_cols_in_reference - null_columns: + # Enrich the information in the reference by grabbing the missing fields + # from the current iteration. This assumes that we only check for global validity and + # isn't relevant where the reference comes from. + reference = _swap_fields_by_name(reference, current, col) + null_cols_in_reference.remove(col) + schemas_to_evaluate.append((current, null_columns)) + + assert (reference is not None) or (not schemas_to_evaluate) + + return reference, schemas_to_evaluate + + +def _swap_fields_by_name(reference, current, field_name): + current_field = current.field(field_name) + reference_index = reference.get_field_index(field_name) + return reference.set(reference_index, current_field) + + +def _strip_columns_from_schema(schema, field_names): + stripped_schema = schema + + for name in field_names: + ix = stripped_schema.get_field_index(name) + if ix >= 0: + stripped_schema = stripped_schema.remove(ix) + else: + # If the returned index is negative, the field doesn't exist in the schema. + # This is most likely an indicator for incompatible schemas and we refuse to strip the schema + # to not obfurscate the validation result + _logger.warning( + "Unexpected field `%s` encountered while trying to strip `null` columns.\n" + "Schema was:\n\n`%s`" % (name, schema) + ) + return schema + return stripped_schema + + +def _schema2bytes(schema: SchemaWrapper) -> bytes: + buf = pa.BufferOutputStream() + pq.write_metadata(schema, buf, coerce_timestamps="us") + return buf.getvalue().to_pybytes() + + +def _remove_diff_header(diff): + diff = list(diff) + for ix, el in enumerate(diff): + # This marks the first actual entry of the diff + # e.g. @@ -1,5 + 2,5 @@ + if el.startswith("@"): + return diff[ix:] + return diff + + +def _diff_schemas(first, second): + # see https://issues.apache.org/jira/browse/ARROW-4176 + + first_pyarrow_info = str(first.remove_metadata()) + second_pyarrow_info = str(second.remove_metadata()) + pyarrow_diff = _remove_diff_header( + difflib.unified_diff( + str(first_pyarrow_info).splitlines(), str(second_pyarrow_info).splitlines() + ) + ) + + first_pandas_info = first.pandas_metadata + second_pandas_info = second.pandas_metadata + pandas_meta_diff = _remove_diff_header( + difflib.unified_diff( + pprint.pformat(first_pandas_info).splitlines(), + pprint.pformat(second_pandas_info).splitlines(), + ) + ) + + diff_string = ( + "Arrow schema:\n" + + "\n".join(pyarrow_diff) + + "\n\nPandas_metadata:\n" + + "\n".join(pandas_meta_diff) + ) + + return diff_string + + +def validate_compatible( + schemas: Iterable[pa.Schema], ignore_pandas: bool = False +) -> pa.Schema: + """Validate that all schemas in a given list are compatible. + + Apart from the pandas version preserved in the schema metadata, schemas must be completely identical. That includes + a perfect match of the whole metadata (except the pandas version) and pyarrow types. + + In the case that all schemas don't contain any pandas metadata, we will check the Arrow + schemas directly for compatibility. + + Parameters + ---------- + schemas: List[Schema] + Schema information from multiple sources, e.g. multiple partitions. List may be empty. + ignore_pandas: bool + Ignore the schema information given by Pandas an always use the Arrow schema. + + Returns + ------- + schema: SchemaWrapper + The reference schema which was tested against + + Raises + ------ + ValueError + At least two schemas are incompatible. + """ + reference, schemas_to_evaluate = _determine_schemas_to_compare( + schemas, ignore_pandas + ) + + for current, null_columns in schemas_to_evaluate: + # We have schemas so the reference schema should be non-none. + assert reference is not None + # Compare each schema to the reference but ignore the null_cols and the Pandas schema information. + reference_to_compare = _strip_columns_from_schema( + reference, null_columns + ).remove_metadata() + current_to_compare = _strip_columns_from_schema( + current, null_columns + ).remove_metadata() + + def _fmt_origin(origin): + origin = sorted(origin) + # dask cuts of exception messages at 1k chars: + # https://github.com/dask/distributed/blob/6e0c0a6b90b1d3c/distributed/core.py#L964 + # therefore, we cut the the maximum length + max_len = 200 + inner_msg = ", ".join(origin) + ellipsis = "..." + if len(inner_msg) > max_len + len(ellipsis): + inner_msg = inner_msg[:max_len] + ellipsis + return f"{{{inner_msg}}}" + + if reference_to_compare != current_to_compare: + schema_diff = _diff_schemas(reference, current) + exception_message = """Schema violation + +Origin schema: {origin_schema} +Origin reference: {origin_reference} + +Diff: +{schema_diff} + +Reference schema: +{reference}""".format( + schema_diff=schema_diff, + reference=str(reference), + origin_schema=_fmt_origin(current.origin), + origin_reference=_fmt_origin(reference.origin), + ) + raise ValueError(exception_message) + + # add all origins to result AFTER error checking, otherwise the error message would be pretty misleading due to the + # reference containing all origins. + if reference is None: + return None + else: + return reference + + +def _dict_to_binary(dct): + return json.dumps(dct, sort_keys=True).encode("utf8") diff --git a/dask_deltatable/core.py b/dask_deltatable/core.py index ea875a9..7f080bb 100644 --- a/dask_deltatable/core.py +++ b/dask_deltatable/core.py @@ -8,12 +8,12 @@ import dask import dask.dataframe as dd -import pyarrow.parquet as pq # type: ignore[import] +import pyarrow.parquet as pq from dask.base import tokenize from dask.dataframe.utils import make_meta from dask.delayed import delayed from deltalake import DataCatalog, DeltaTable -from fsspec.core import get_fs_token_paths # type: ignore[import] +from fsspec.core import get_fs_token_paths from pyarrow import dataset as pa_ds diff --git a/dask_deltatable/write.py b/dask_deltatable/write.py index 1c5e3a7..58b945a 100644 --- a/dask_deltatable/write.py +++ b/dask_deltatable/write.py @@ -1,36 +1,39 @@ from __future__ import annotations import json -from deltalake import DeltaTable -from typing import Any, Literal, Mapping +import uuid +from datetime import datetime from pathlib import Path +from typing import Any, Literal, Mapping + +import dask.dataframe as dd import pyarrow as pa import pyarrow.dataset as ds -import dask.dataframe as dd import pyarrow.fs as pa_fs -from dask.highlevelgraph import HighLevelGraph +from dask.core import flatten from dask.dataframe.core import Scalar +from dask.highlevelgraph import HighLevelGraph +from deltalake import DeltaTable from deltalake.writer import ( - try_get_table_and_table_uri, - __enforce_append_only, MAX_SUPPORTED_WRITER_VERSION, - DeltaStorageHandler, - DeltaProtocolError, - get_partitions_from_path, PYARROW_MAJOR_VERSION, - get_file_stats_from_metadata, AddAction, DeltaJSONEncoder, + DeltaProtocolError, + DeltaStorageHandler, + __enforce_append_only, _write_new_deltalake, + get_file_stats_from_metadata, + get_partitions_from_path, + try_get_table_and_table_uri, ) -from datetime import datetime - -import uuid from pyarrow.lib import RecordBatchReader -from dask.core import flatten +from toolz.itertoolz import pluck +from ._schema import validate_compatible -def write_deltalake( + +def to_deltalake( table_or_uri: str | Path | DeltaTable, df: dd.DataFrame, *, @@ -51,6 +54,10 @@ def write_deltalake( storage_options: dict[str, str] | None = None, partition_filters: list[tuple[str, str, Any]] | None = None, ): + """Write a given dask.DataFrame to a delta table + + TODO: + """ table, table_uri = try_get_table_and_table_uri(table_or_uri, storage_options) # We need to write against the latest table version @@ -99,10 +106,13 @@ def write_deltalake( else: # creating a new table current_version = -1 - if partition_by: + # FIXME: schema is only known at this point if provided by the user + if partition_by and schema: partition_schema = pa.schema([schema.field(name) for name in partition_by]) partitioning = ds.partitioning(partition_schema, flavor="hive") else: + if partition_by: + raise NotImplementedError("Have to provide schema when using partition_by") partitioning = None if mode == "overwrite": # FIXME: There are a couple of checks that are not migrated yet @@ -120,6 +130,7 @@ def write_deltalake( max_rows_per_group=max_rows_per_group, filesystem=filesystem, max_partitions=max_partitions, + meta=(None, object), ) final_name = "delta-commit" dsk = { @@ -144,7 +155,7 @@ def write_deltalake( def _commit( table, - add_actions_nested, + schemas_add_actions_nested, table_uri, schema, mode, @@ -155,7 +166,17 @@ def _commit( storage_options, partition_filters, ): - add_actions = flatten(add_actions_nested) + schemas = list(flatten(pluck(0, schemas_add_actions_nested))) + add_actions = list(flatten(pluck(1, schemas_add_actions_nested))) + # TODO: What should the behavior be if the schema is provided? Cast the + # data? + if schema: + schemas.append(schema) + + # TODO: This is applying a potentially stricted schema control than what + # Delta requires but if this passes, it should be good to go + schema = validate_compatible(schemas) + assert schema if table is None: _write_new_deltalake( table_uri, @@ -192,8 +213,10 @@ def _write_partition( max_rows_per_group, filesystem, max_partitions, -): +) -> tuple[pa.Schema, list[AddAction]]: + # TODO: what to do with the schema, if provided data = pa.Table.from_pandas(df) + schema = schema or data.schema add_actions: list[AddAction] = [] @@ -205,7 +228,7 @@ def visitor(written_file: Any) -> None: if PYARROW_MAJOR_VERSION >= 9: size = written_file.size else: - size = filesystem.get_file_info([path])[0].size # type: ignore + size = filesystem.get_file_info([path])[0].size add_actions.append( AddAction( @@ -236,4 +259,4 @@ def visitor(written_file: Any) -> None: filesystem=filesystem, max_partitions=max_partitions, ) - return add_actions + return schema, add_actions diff --git a/pyproject.toml b/pyproject.toml index 7cd19ee..2565cd9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,7 @@ allow_incomplete_defs = true allow_untyped_defs = true warn_return_any = false disallow_untyped_calls = false +ignore_missing_imports = true [tool.isort] profile = "black" diff --git a/tests/test_write.py b/tests/test_write.py new file mode 100644 index 0000000..0c33938 --- /dev/null +++ b/tests/test_write.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import os +import uuid +from datetime import date + +import numpy as np +import pandas as pd +import pytest +from dask.dataframe.utils import assert_eq +from dask.datasets import timeseries + +from dask_deltatable import read_delta_table +from dask_deltatable.write import to_deltalake + + +def get_dataframe_not_nested(): + return pd.DataFrame( + { + "bool": pd.Series([1], dtype=np.bool_), + "int8": pd.Series([1], dtype=np.int8), + "int16": pd.Series([1], dtype=np.int16), + "int32": pd.Series([1], dtype=np.int32), + "int64": pd.Series([1], dtype=np.int64), + "uint8": pd.Series([1], dtype=np.uint8), + "uint16": pd.Series([1], dtype=np.uint16), + "uint32": pd.Series([1], dtype=np.uint32), + "uint64": pd.Series([1], dtype=np.uint64), + "float32": pd.Series([1.0], dtype=np.float32), + "float64": pd.Series([1.0], dtype=np.float64), + "date": pd.Series([date(2018, 1, 1)], dtype=object), + "datetime64": pd.Series(["2018-01-01"], dtype="datetime64[ns]"), + "unicode": pd.Series(["Ö"], dtype=str), + "null": pd.Series([None], dtype=object), + # Adding a byte type with value as byte sequence which can not be encoded as UTF8 + "byte": pd.Series([uuid.uuid4().bytes], dtype=object), + } + ).sort_index(axis=1) + + +@pytest.mark.parametrize( + "with_index", + [ + pytest.param( + True, + marks=[ + pytest.mark.xfail( + reason="TS index is always ns resolution but delta can only handle us" + ) + ], + ), + False, + ], +) +def test_roundtrip(tmpdir, with_index): + dtypes = { + "str": object, + # FIXME: Categorical data does not work + # "category": "category", + "float": float, + "int": int, + } + tmpdir = str(tmpdir) + ddf = timeseries( + start="2023-01-01", + end="2023-01-15", + # FIXME: Setting the partition frequency destroys the roundtrip for some + # reason + # partition_freq="1w", + dtypes=dtypes, + ) + # FIXME: us is the only precision delta supports. This lib should likely + # case this itself + + ddf = ddf.reset_index() + ddf.timestamp = ddf.timestamp.astype("datetime64[us]") + if with_index: + ddf = ddf.set_index("timestamp") + + out = to_deltalake(tmpdir, ddf) + assert not os.listdir(tmpdir) + out.compute() + assert len(os.listdir(tmpdir)) > 0 + + ddf_read = read_delta_table(tmpdir) + # FIXME: The index is not recovered + if with_index: + ddf = ddf.reset_index() + + # By default, arrow reads with ns resolution + ddf.timestamp = ddf.timestamp.astype("datetime64[ns]") + assert_eq(ddf, ddf_read) From bf1e061211331909625bf07a228d66e4c9ed2f23 Mon Sep 17 00:00:00 2001 From: fjetter Date: Wed, 28 Jun 2023 16:31:32 +0200 Subject: [PATCH 3/4] drop get_dataframe_not_nested --- tests/test_write.py | 28 ---------------------------- 1 file changed, 28 deletions(-) diff --git a/tests/test_write.py b/tests/test_write.py index 0c33938..76234e2 100644 --- a/tests/test_write.py +++ b/tests/test_write.py @@ -1,11 +1,7 @@ from __future__ import annotations import os -import uuid -from datetime import date -import numpy as np -import pandas as pd import pytest from dask.dataframe.utils import assert_eq from dask.datasets import timeseries @@ -14,30 +10,6 @@ from dask_deltatable.write import to_deltalake -def get_dataframe_not_nested(): - return pd.DataFrame( - { - "bool": pd.Series([1], dtype=np.bool_), - "int8": pd.Series([1], dtype=np.int8), - "int16": pd.Series([1], dtype=np.int16), - "int32": pd.Series([1], dtype=np.int32), - "int64": pd.Series([1], dtype=np.int64), - "uint8": pd.Series([1], dtype=np.uint8), - "uint16": pd.Series([1], dtype=np.uint16), - "uint32": pd.Series([1], dtype=np.uint32), - "uint64": pd.Series([1], dtype=np.uint64), - "float32": pd.Series([1.0], dtype=np.float32), - "float64": pd.Series([1.0], dtype=np.float64), - "date": pd.Series([date(2018, 1, 1)], dtype=object), - "datetime64": pd.Series(["2018-01-01"], dtype="datetime64[ns]"), - "unicode": pd.Series(["Ö"], dtype=str), - "null": pd.Series([None], dtype=object), - # Adding a byte type with value as byte sequence which can not be encoded as UTF8 - "byte": pd.Series([uuid.uuid4().bytes], dtype=object), - } - ).sort_index(axis=1) - - @pytest.mark.parametrize( "with_index", [ From 3a17af6d2a17ff01bd6f0716befb129c24047ccb Mon Sep 17 00:00:00 2001 From: fjetter Date: Mon, 10 Jul 2023 10:28:12 +0200 Subject: [PATCH 4/4] nits --- dask_deltatable/write.py | 7 +++---- tests/test_write.py | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/dask_deltatable/write.py b/dask_deltatable/write.py index 58b945a..80bdeeb 100644 --- a/dask_deltatable/write.py +++ b/dask_deltatable/write.py @@ -27,7 +27,6 @@ get_partitions_from_path, try_get_table_and_table_uri, ) -from pyarrow.lib import RecordBatchReader from toolz.itertoolz import pluck from ._schema import validate_compatible @@ -101,7 +100,7 @@ def to_deltalake( raise DeltaProtocolError( "This table's min_writer_version is " f"{table.protocol().min_writer_version}, " - "but this method only supports version 2." + f"but this method only supports version {MAX_SUPPORTED_WRITER_VERSION}." ) else: # creating a new table current_version = -1 @@ -116,7 +115,7 @@ def to_deltalake( partitioning = None if mode == "overwrite": # FIXME: There are a couple of checks that are not migrated yet - raise NotImplementedError() + raise NotImplementedError("mode='overwrite' is not implemented") written = df.map_partitions( _write_partition, @@ -248,7 +247,7 @@ def visitor(written_file: Any) -> None: format="parquet", partitioning=partitioning, # It will not accept a schema if using a RBR - schema=schema if not isinstance(data, RecordBatchReader) else None, + schema=schema, existing_data_behavior="overwrite_or_ignore", file_options=file_options, max_open_files=max_open_files, diff --git a/tests/test_write.py b/tests/test_write.py index 76234e2..314092b 100644 --- a/tests/test_write.py +++ b/tests/test_write.py @@ -42,7 +42,7 @@ def test_roundtrip(tmpdir, with_index): dtypes=dtypes, ) # FIXME: us is the only precision delta supports. This lib should likely - # case this itself + # cast this itself ddf = ddf.reset_index() ddf.timestamp = ddf.timestamp.astype("datetime64[us]")