Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Write operation #29

Merged
merged 4 commits into from
Jul 10, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# flake8 doesn't support pyproject.toml yet https://github.com/PyCQA/flake8/issues/234
[flake8]
max-line-length = 104
max-line-length = 120
5 changes: 3 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@ repos:
args: [--warn-unused-configs]
additional_dependencies:
# Type stubs
- types-setuptools
- boto3-stubs
- pytest
- dask
- deltalake
- pandas-stubs
- pytest
- types-setuptools
- repo: https://github.com/pycqa/flake8
rev: 6.0.0
hooks:
Expand Down
325 changes: 325 additions & 0 deletions dask_deltatable/_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,325 @@
from __future__ import annotations

"""
Most of this code was taken from

https://github.com/data-engineering-collective/plateau

https://github.com/data-engineering-collective/plateau/blob/d4c4522f5a829d43e3368fc82e1568c91fa352f3/plateau/core/common_metadata.py

and adapted to this project

under the original license

MIT License

Copyright (c) 2022 The plateau contributors.
Copyright (c) 2020-2021 The kartothek contributors.
Copyright (c) 2019 JDA Software, Inc

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

"""
import difflib
import json
import logging
import pprint
from copy import deepcopy
from typing import Iterable

import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq

_logger = logging.getLogger()


class SchemaWrapper:
def __init__(self, schema: pa.Schema):
self.schema = schema

def __hash__(self):
# FIXME: pyarrow raises a "cannot hash type dict" error
return hash(_schema2bytes(self.schema))


def _pandas_in_schemas(schemas):
"""Check if any schema contains pandas metadata."""
has_pandas = False
for schema in schemas:
if schema.metadata and b"pandas" in schema.metadata:
has_pandas = True
return has_pandas


def _determine_schemas_to_compare(
schemas: Iterable[pa.Schema], ignore_pandas: bool
) -> tuple[pa.Schema | None, list[tuple[pa.Schema, list[str]]]]:
"""Iterate over a list of `pyarrow.Schema` objects and prepares them for
comparison by picking a reference and determining all null columns.

.. note::

If pandas metadata exists, the version stored in the metadata is overwritten with the currently
installed version since we expect to stay backwards compatible

Returns
-------
reference: Schema
A reference schema which is picked from the input list. The reference schema is guaranteed
to be a schema having the least number of null columns of all input columns. The set of null
columns is guaranteed to be a true subset of all null columns of all input schemas. If no such
schema can be found, an Exception is raised
list_of_schemas: List[Tuple[Schema, List]]
A list holding pairs of (Schema, null_columns) where the null_columns are all columns which are null and
must be removed before comparing the schemas
"""
has_pandas = _pandas_in_schemas(schemas) and not ignore_pandas
schemas_to_evaluate: list[tuple[pa.Schema, list[str]]] = []
reference = None
null_cols_in_reference = set()
# Hashing the schemas is a very fast way to reduce the number of schemas to
# actually compare since in most circumstances this reduces to very few
# (which differ in e.g. null columns)
for schema_wrapped in set(map(SchemaWrapper, schemas)):
schema = schema_wrapped.schema
del schema_wrapped
if has_pandas:
metadata = schema.metadata
if metadata is None or b"pandas" not in metadata:
raise ValueError(
"Pandas and non-Pandas schemas are not comparable. "
"Use ignore_pandas=True if you only want to compare "
"on Arrow level."
)
pandas_metadata = json.loads(metadata[b"pandas"].decode("utf8"))

# we don't care about the pandas version, since we assume it's safe
# to read datasets that were written by older or newer versions.
pandas_metadata["pandas_version"] = f"{pd.__version__}"

metadata_clean = deepcopy(metadata)
metadata_clean[b"pandas"] = _dict_to_binary(pandas_metadata)
current = pa.schema(schema, metadata_clean)
else:
current = schema

# If a field is null we cannot compare it and must therefore reject it
null_columns = {field.name for field in current if field.type == pa.null()}

# Determine a valid reference schema. A valid reference schema is considered to be the schema
# of all input schemas with the least empty columns.
# The reference schema ought to be a schema whose empty columns are a true subset for all sets
# of empty columns. This ensures that the actual reference schema is the schema with the most
# information possible. A schema which doesn't fulfil this requirement would weaken the
# comparison and would allow for false positives

# Trivial case
if reference is None:
reference = current
null_cols_in_reference = null_columns
# The reference has enough information to validate against current schema.
# Append it to the list of schemas to be verified
elif null_cols_in_reference.issubset(null_columns):
schemas_to_evaluate.append((current, null_columns))
# current schema includes all information of reference and more.
# Add reference to schemas_to_evaluate and update reference
elif null_columns.issubset(null_cols_in_reference):
schemas_to_evaluate.append((reference, list(null_cols_in_reference)))
reference = current
null_cols_in_reference = null_columns
# If there is no clear subset available elect the schema with the least null columns as `reference`.
# Iterate over the null columns of `reference` and replace it with a non-null field of the `current`
# schema which recovers the loop invariant (null columns of `reference` is subset of `current`)
else:
if len(null_columns) < len(null_cols_in_reference):
reference, current = current, reference
null_cols_in_reference, null_columns = (
null_columns,
null_cols_in_reference,
)

for col in null_cols_in_reference - null_columns:
# Enrich the information in the reference by grabbing the missing fields
# from the current iteration. This assumes that we only check for global validity and
# isn't relevant where the reference comes from.
reference = _swap_fields_by_name(reference, current, col)
null_cols_in_reference.remove(col)
schemas_to_evaluate.append((current, null_columns))

assert (reference is not None) or (not schemas_to_evaluate)

return reference, schemas_to_evaluate


def _swap_fields_by_name(reference, current, field_name):
current_field = current.field(field_name)
reference_index = reference.get_field_index(field_name)
return reference.set(reference_index, current_field)


def _strip_columns_from_schema(schema, field_names):
stripped_schema = schema

for name in field_names:
ix = stripped_schema.get_field_index(name)
if ix >= 0:
stripped_schema = stripped_schema.remove(ix)
else:
# If the returned index is negative, the field doesn't exist in the schema.
# This is most likely an indicator for incompatible schemas and we refuse to strip the schema
# to not obfurscate the validation result
_logger.warning(
"Unexpected field `%s` encountered while trying to strip `null` columns.\n"
"Schema was:\n\n`%s`" % (name, schema)
)
return schema
return stripped_schema


def _schema2bytes(schema: SchemaWrapper) -> bytes:
buf = pa.BufferOutputStream()
pq.write_metadata(schema, buf, coerce_timestamps="us")
return buf.getvalue().to_pybytes()


def _remove_diff_header(diff):
diff = list(diff)
for ix, el in enumerate(diff):
# This marks the first actual entry of the diff
# e.g. @@ -1,5 + 2,5 @@
if el.startswith("@"):
return diff[ix:]
return diff


def _diff_schemas(first, second):
# see https://issues.apache.org/jira/browse/ARROW-4176

first_pyarrow_info = str(first.remove_metadata())
second_pyarrow_info = str(second.remove_metadata())
pyarrow_diff = _remove_diff_header(
difflib.unified_diff(
str(first_pyarrow_info).splitlines(), str(second_pyarrow_info).splitlines()
)
)

first_pandas_info = first.pandas_metadata
second_pandas_info = second.pandas_metadata
pandas_meta_diff = _remove_diff_header(
difflib.unified_diff(
pprint.pformat(first_pandas_info).splitlines(),
pprint.pformat(second_pandas_info).splitlines(),
)
)

diff_string = (
"Arrow schema:\n"
+ "\n".join(pyarrow_diff)
+ "\n\nPandas_metadata:\n"
+ "\n".join(pandas_meta_diff)
)

return diff_string


def validate_compatible(
schemas: Iterable[pa.Schema], ignore_pandas: bool = False
) -> pa.Schema:
"""Validate that all schemas in a given list are compatible.

Apart from the pandas version preserved in the schema metadata, schemas must be completely identical. That includes
a perfect match of the whole metadata (except the pandas version) and pyarrow types.

In the case that all schemas don't contain any pandas metadata, we will check the Arrow
schemas directly for compatibility.

Parameters
----------
schemas: List[Schema]
Schema information from multiple sources, e.g. multiple partitions. List may be empty.
ignore_pandas: bool
Ignore the schema information given by Pandas an always use the Arrow schema.

Returns
-------
schema: SchemaWrapper
The reference schema which was tested against

Raises
------
ValueError
At least two schemas are incompatible.
"""
reference, schemas_to_evaluate = _determine_schemas_to_compare(
schemas, ignore_pandas
)

for current, null_columns in schemas_to_evaluate:
# We have schemas so the reference schema should be non-none.
assert reference is not None
# Compare each schema to the reference but ignore the null_cols and the Pandas schema information.
reference_to_compare = _strip_columns_from_schema(
reference, null_columns
).remove_metadata()
current_to_compare = _strip_columns_from_schema(
current, null_columns
).remove_metadata()

def _fmt_origin(origin):
origin = sorted(origin)
# dask cuts of exception messages at 1k chars:
# https://github.com/dask/distributed/blob/6e0c0a6b90b1d3c/distributed/core.py#L964
# therefore, we cut the the maximum length
max_len = 200
inner_msg = ", ".join(origin)
ellipsis = "..."
if len(inner_msg) > max_len + len(ellipsis):
inner_msg = inner_msg[:max_len] + ellipsis
return f"{{{inner_msg}}}"

if reference_to_compare != current_to_compare:
schema_diff = _diff_schemas(reference, current)
exception_message = """Schema violation

Origin schema: {origin_schema}
Origin reference: {origin_reference}

Diff:
{schema_diff}

Reference schema:
{reference}""".format(
schema_diff=schema_diff,
reference=str(reference),
origin_schema=_fmt_origin(current.origin),
origin_reference=_fmt_origin(reference.origin),
)
raise ValueError(exception_message)

# add all origins to result AFTER error checking, otherwise the error message would be pretty misleading due to the
# reference containing all origins.
if reference is None:
return None
else:
return reference


def _dict_to_binary(dct):
return json.dumps(dct, sort_keys=True).encode("utf8")
4 changes: 2 additions & 2 deletions dask_deltatable/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

import dask
import dask.dataframe as dd
import pyarrow.parquet as pq # type: ignore[import]
import pyarrow.parquet as pq
from dask.base import tokenize
from dask.dataframe.utils import make_meta
from dask.delayed import delayed
from deltalake import DataCatalog, DeltaTable
from fsspec.core import get_fs_token_paths # type: ignore[import]
from fsspec.core import get_fs_token_paths
from pyarrow import dataset as pa_ds


Expand Down
Loading