diff --git a/orbiter/__init__.py b/orbiter/__init__.py index e23af30..01a019c 100644 --- a/orbiter/__init__.py +++ b/orbiter/__init__.py @@ -1,20 +1,13 @@ from __future__ import annotations import re -from enum import Enum from typing import Any, Tuple -__version__ = "1.1.0" +__version__ = "1.2.0" version = __version__ -class FileType(Enum): - YAML = "YAML" - XML = "XML" - JSON = "JSON" - - def clean_value(s: str): """Cleans a string to be a standard value, such as one that might be a python variable name diff --git a/orbiter/file_types.py b/orbiter/file_types.py new file mode 100644 index 0000000..efe4098 --- /dev/null +++ b/orbiter/file_types.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +import json +from functools import partial +from typing import Callable, Set, ClassVar, Any + +import xmltodict +import yaml +from pydantic import ( + BaseModel, +) +from pydantic.v1 import validator + + +class FileType(BaseModel, arbitrary_types_allowed=True): + extension: ClassVar[Set[str]] + load_fn: ClassVar[Callable[[str], dict]] + dump_fn: ClassVar[Callable[[dict], str]] + + def __hash__(self): + return hash(tuple(self.extension)) + + @validator("extension", pre=True) + @classmethod + def ext_validate(cls, v: Set[str]): + if not v: + raise ValueError("Extension cannot be an empty set") + for ext in v: + if not isinstance(ext, str): + raise ValueError("Extension should be a string") + if "." in v: + raise ValueError("Extension should not contain '.'") + return {ext.lower() for ext in v} + + +class FileTypeJSON(FileType): + extension: ClassVar[Set[str]] = {"JSON"} + load_fn: ClassVar[Callable[[str], dict]] = json.loads + dump_fn: ClassVar[Callable[[dict], str]] = xmltodict.unparse + + +# noinspection t +def xmltodict_parse(input_str: str) -> Any: + """Calls `xmltodict.parse` and does post-processing fixes. + + !!! note + + The original [`xmltodict.parse`](https://pypi.org/project/xmltodict/) method returns EITHER: + + - a dict (one child element of type) + - or a list of dict (many child element of type) + + This behavior can be confusing, and is an issue with the original xml spec being referenced. + + **This method deviates by standardizing to the latter case (always a `list[dict]`).** + + **All XML elements will be a list of dictionaries, even if there's only one element.** + + ```pycon + >>> xmltodict_parse("") + Traceback (most recent call last): + xml.parsers.expat.ExpatError: no element found: line 1, column 0 + >>> xmltodict_parse("") + {'a': None} + >>> xmltodict_parse("") + {'a': [{'@foo': 'bar'}]} + >>> xmltodict_parse("") # Singleton - gets modified + {'a': [{'@foo': 'bar', 'foo': [{'@bar': 'baz'}]}]} + >>> xmltodict_parse("") # Nested Singletons - modified + {'a': [{'@foo': 'bar', 'foo': [{'@bar': 'baz', 'bar': [{'bop': None}]}]}]} + >>> xmltodict_parse("") + {'a': [{'@foo': 'bar', 'foo': [{'@bar': 'baz'}, {'@bing': 'bop'}]}]} + + ``` + :param input_str: The XML string to parse + :type input_str: str + :return: The parsed XML + :rtype: dict + """ + + # noinspection t + def _fix(d): + """fix the dict in place, recursively, standardizing on a list of dict even if there's only one entry.""" + # if it's a dict, descend to fix + if isinstance(d, dict): + for k, v in d.items(): + # @keys are properties of elements, non-@keys are elements + if not k.startswith("@"): + if isinstance(v, dict): + # THE FIX + # any non-@keys should be a list of dict, even if there's just one of the element + d[k] = [v] + _fix(v) + else: + _fix(v) + # if it's a list, descend to fix + if isinstance(d, list): + for v in d: + _fix(v) + + output = xmltodict.parse(input_str) + _fix(output) + return output + + +class FileTypeXML(FileType): + extension: ClassVar[Set[str]] = {"XML"} + load_fn: ClassVar[Callable[[str], dict]] = xmltodict_parse + dump_fn: ClassVar[Callable[[dict], str]] = partial( + json.dumps, default=str, indent=2 + ) + + +class FileTypeYAML(FileType): + extension: ClassVar[Set[str]] = {"YAML", "YML"} + load_fn: ClassVar[Callable[[str], dict]] = yaml.safe_load + dump_fn: ClassVar[Callable[[dict], str]] = yaml.safe_dump diff --git a/orbiter/rules/rulesets.py b/orbiter/rules/rulesets.py index aba7cbe..58cc46f 100644 --- a/orbiter/rules/rulesets.py +++ b/orbiter/rules/rulesets.py @@ -9,13 +9,23 @@ from itertools import chain from pathlib import Path from tempfile import TemporaryDirectory -from typing import List, Any, Collection, Annotated, Callable, Union, Generator +from typing import ( + List, + Any, + Collection, + Annotated, + Callable, + Union, + Generator, + Set, + Type, +) from loguru import logger from pydantic import BaseModel, AfterValidator, validate_call -from orbiter import FileType from orbiter import import_from_qualname +from orbiter.file_types import FileType, FileTypeJSON from orbiter.objects.dag import OrbiterDAG from orbiter.objects.project import OrbiterProject from orbiter.objects.task import OrbiterOperator, OrbiterTaskDependency @@ -31,6 +41,15 @@ EMPTY_RULE, ) # noqa: F401 + +def _backport_walk(input_dir: Path): + """Path.walk() is only available in Python 3.12+, so, backport""" + import os + + for result in os.walk(input_dir): + yield Path(result[0]), result[1], result[2] + + qualname_validator_regex = r"^[\w.]+$" qualname_validator = re.compile(qualname_validator_regex) @@ -98,71 +117,6 @@ def validate_qualified_imports(qualified_imports: List[str]) -> List[str]: ] -# noinspection t -def xmltodict_parse(input_str: str) -> Any: - """Calls `xmltodict.parse` and does post-processing fixes. - - !!! note - - The original [`xmltodict.parse`](https://pypi.org/project/xmltodict/) method returns EITHER: - - - a dict (one child element of type) - - or a list of dict (many child element of type) - - This behavior can be confusing, and is an issue with the original xml spec being referenced. - - **This method deviates by standardizing to the latter case (always a `list[dict]`).** - - **All XML elements will be a list of dictionaries, even if there's only one element.** - - ```pycon - >>> xmltodict_parse("") - Traceback (most recent call last): - xml.parsers.expat.ExpatError: no element found: line 1, column 0 - >>> xmltodict_parse("") - {'a': None} - >>> xmltodict_parse("") - {'a': [{'@foo': 'bar'}]} - >>> xmltodict_parse("") # Singleton - gets modified - {'a': [{'@foo': 'bar', 'foo': [{'@bar': 'baz'}]}]} - >>> xmltodict_parse("") # Nested Singletons - modified - {'a': [{'@foo': 'bar', 'foo': [{'@bar': 'baz', 'bar': [{'bop': None}]}]}]} - >>> xmltodict_parse("") - {'a': [{'@foo': 'bar', 'foo': [{'@bar': 'baz'}, {'@bing': 'bop'}]}]} - - ``` - :param input_str: The XML string to parse - :type input_str: str - :return: The parsed XML - :rtype: dict - """ - import xmltodict - - # noinspection t - def _fix(d): - """fix the dict in place, recursively, standardizing on a list of dict even if there's only one entry.""" - # if it's a dict, descend to fix - if isinstance(d, dict): - for k, v in d.items(): - # @keys are properties of elements, non-@keys are elements - if not k.startswith("@"): - if isinstance(v, dict): - # THE FIX - # any non-@keys should be a list of dict, even if there's just one of the element - d[k] = [v] - _fix(v) - else: - _fix(v) - # if it's a list, descend to fix - if isinstance(d, list): - for v in d: - _fix(v) - - output = xmltodict.parse(input_str) - _fix(output) - return output - - def _add_task_deduped(_task, _tasks, n=""): """ If this task_id doesn't already exist, add it as normal to the tasks dictionary. @@ -621,7 +575,7 @@ class TranslationRuleset(BaseModel, ABC, extra="forbid"): ```pycon >>> TranslationRuleset( - ... file_type=FileType.JSON, # Has a file type + ... file_type={FileTypeJSON}, # Has a file type ... translate_fn=fake_translate, # and can have a callable ... # translate_fn="orbiter.rules.translate.fake_translate", # or a qualified name to a function ... dag_filter_ruleset={"ruleset": [{"rule": lambda x: None}]}, # Rulesets can be dict within dicts @@ -635,8 +589,8 @@ class TranslationRuleset(BaseModel, ABC, extra="forbid"): ``` - :param file_type: FileType to translate (`.json`, `.xml`, `.yaml`, etc.) - :type file_type: FileType + :param file_type: FileType to translate + :type file_type: Set[Type[FileType]] :param dag_filter_ruleset: [`DAGFilterRuleset`][orbiter.rules.rulesets.DAGFilterRuleset] (of [`DAGFilterRule`][orbiter.rules.DAGFilterRule]) :type dag_filter_ruleset: DAGFilterRuleset | dict @@ -660,7 +614,7 @@ class TranslationRuleset(BaseModel, ABC, extra="forbid"): :type translate_fn: Callable[[TranslationRuleset, Path], OrbiterProject] | str | TranslateFn """ # noqa: E501 - file_type: FileType + file_type: Set[Type[FileType]] dag_filter_ruleset: DAGFilterRuleset | dict dag_ruleset: DAGRuleset | dict task_filter_ruleset: TaskFilterRuleset | dict @@ -669,67 +623,54 @@ class TranslationRuleset(BaseModel, ABC, extra="forbid"): post_processing_ruleset: PostProcessingRuleset | dict translate_fn: TranslateFn = translate + def get_ext(self) -> str: + """ + Get the first file extension for this ruleset + + ```pycon + >>> EMPTY_TRANSLATION_RULESET.get_ext() + 'JSON' + + ``` + """ + return next(iter(next(iter(self.file_type)).extension)) + @validate_call - def loads(self, input_str: str) -> dict: + def loads(self, file: Path) -> dict: """ Converts all files of type into a Python dictionary "intermediate representation" form, prior to any rulesets being applied. - | FileType | Conversion Method | - |----------|-------------------------------------------------------------| - | `XML` | [`xmltodict_parse`][orbiter.rules.rulesets.xmltodict_parse] | - | `YAML` | `yaml.safe_load` | - | `JSON` | `json.loads` | - - :param input_str: The string to convert to a dictionary - :type input_str: str + :param file: The file to load + :type file: Path :return: The dictionary representation of the input_str :rtype: dict """ - - if self.file_type == FileType.JSON: - import json - - return json.loads(input_str) - elif self.file_type == FileType.YAML: - import yaml - - return yaml.safe_load(input_str) - elif self.file_type == FileType.XML: - return xmltodict_parse(input_str) - else: - raise NotImplementedError(f"Cannot load file_type={self.file_type}") + for file_type in self.file_type: + if file.suffix.lower() in { + f".{ext.lower()}" for ext in file_type.extension + }: + return file_type.load_fn(file.read_text()) + raise TypeError( + f"Invalid file_type={file.suffix}, does not match file_type={self.file_type}" + ) @validate_call - def dumps(self, input_dict: dict) -> str: + def dumps(self, input_dict: dict, ext: str | None) -> str: """ Convert Python dictionary back to source string form, useful for testing - | FileType | Conversion Method | - |----------|---------------------| - | `XML` | `xmltodict.unparse` | - | `YAML` | `yaml.safe_dump` | - | `JSON` | `json.dumps` | - :param input_dict: The dictionary to convert to a string :type input_dict: dict + :param ext: The file type extension to dump as, defaults to first 'file_type' in the set + :type ext: str | None :return str: The string representation of the input_dict, in the file_type format :rtype: str """ - if self.file_type == FileType.JSON: - import json - - return json.dumps(input_dict, indent=2) - elif self.file_type == FileType.YAML: - import yaml - - return yaml.safe_dump(input_dict) - elif self.file_type == FileType.XML: - import xmltodict - - return xmltodict.unparse(input_dict) - else: - raise NotImplementedError(f"Cannot dump file_type={self.file_type}") + for file_type in self.file_type: + if ext is None or ext.lower() in file_type.extension: + return file_type.dump_fn(input_dict) + raise TypeError(f"Invalid file_type={ext}") def get_files_with_extension(self, input_dir: Path) -> Generator[Path, dict]: """ @@ -740,39 +681,28 @@ def get_files_with_extension(self, input_dir: Path) -> Generator[Path, dict]: :return: Generator item of (Path, dict) for each file found :rtype: Generator[Path, dict] """ - extension = f".{self.file_type.value.lower()}" - extensions = [extension] - - # YAML and YML are both valid extensions - extension_sub = { - "yaml": "yml", - } - if other_extension := extension_sub.get(self.file_type.value.lower()): - extensions.append(f".{other_extension}") - - logger.debug(f"Finding files with extension={extensions} in {input_dir}") - - def backport_walk(input_dir: Path): - """Path.walk() is only available in Python 3.12+, so, backport""" - import os - - for result in os.walk(input_dir): - yield Path(result[0]), result[1], result[2] - for directory, _, files in ( - input_dir.walk() if hasattr(input_dir, "walk") else backport_walk(input_dir) + input_dir.walk() + if hasattr(input_dir, "walk") + else _backport_walk(input_dir) ): logger.debug(f"Checking directory={directory}") for file in files: file = directory / file - if file.suffix.lower() in extensions: - logger.debug(f"File={file} matches extension={extensions}") + # noinspection PyBroadException + try: yield ( # Return the file path file, # and load the file and convert it into a python dict - self.loads(file.read_text()), + self.loads(file), ) + except TypeError: + logger.debug(f"File={file} not of correct type, skipping...") + continue + except Exception as e: + logger.exception(f"Error loading file={file}, {e}") + continue def test(self, input_value: str | dict) -> OrbiterProject: """ @@ -788,7 +718,7 @@ def test(self, input_value: str | dict) -> OrbiterProject: :rtype: OrbiterProject """ with TemporaryDirectory() as tempdir: - file = Path(tempdir) / f"{uuid.uuid4()}.{self.file_type.value}" + file = Path(tempdir) / f"{uuid.uuid4()}.{self.get_ext()}" file.write_text( self.dumps(input_value) if isinstance(input_value, dict) @@ -797,6 +727,18 @@ def test(self, input_value: str | dict) -> OrbiterProject: return self.translate_fn(translation_ruleset=self, input_dir=file.parent) +EMPTY_TRANSLATION_RULESET = TranslationRuleset( + file_type={FileTypeJSON}, + dag_filter_ruleset=EMPTY_RULESET, + dag_ruleset=EMPTY_RULESET, + task_filter_ruleset=EMPTY_RULESET, + task_ruleset=EMPTY_RULESET, + task_dependency_ruleset=EMPTY_RULESET, + post_processing_ruleset=EMPTY_RULESET, + translate_fn=fake_translate, +) + + if __name__ == "__main__": import doctest diff --git a/tests/orbiter/rules/rulesets_test.py b/tests/orbiter/rules/rulesets_test.py index 622741e..f68a7d8 100644 --- a/tests/orbiter/rules/rulesets_test.py +++ b/tests/orbiter/rules/rulesets_test.py @@ -1,10 +1,23 @@ -from orbiter import FileType +from orbiter.file_types import FileTypeYAML from orbiter.rules.rulesets import TranslationRuleset, EMPTY_RULESET +def test_loads(project_root): + actual = TranslationRuleset( + file_type={FileTypeYAML}, + dag_ruleset=EMPTY_RULESET, + dag_filter_ruleset=EMPTY_RULESET, + task_filter_ruleset=EMPTY_RULESET, + task_ruleset=EMPTY_RULESET, + task_dependency_ruleset=EMPTY_RULESET, + post_processing_ruleset=EMPTY_RULESET, + ).loads(project_root / "tests/resources/test_get_files_with_extension/one.YAML") + assert actual == {"one": "foo"} + + def test__get_files_with_extension(project_root): translation_ruleset = TranslationRuleset( - file_type=FileType.YAML, + file_type={FileTypeYAML}, dag_ruleset=EMPTY_RULESET, dag_filter_ruleset=EMPTY_RULESET, task_filter_ruleset=EMPTY_RULESET, diff --git a/tests/resources/translation_template.py b/tests/resources/translation_template.py index d8e7498..914521c 100644 --- a/tests/resources/translation_template.py +++ b/tests/resources/translation_template.py @@ -1,5 +1,5 @@ from __future__ import annotations -from orbiter import FileType +from orbiter.file_types import FileTypeJSON from orbiter.objects.dag import OrbiterDAG from orbiter.objects.operators.empty import OrbiterEmptyOperator from orbiter.objects.project import OrbiterProject @@ -88,7 +88,7 @@ def basic_post_processing_rule(val: OrbiterProject) -> None: translation_ruleset = TranslationRuleset( - file_type=FileType.JSON, + file_type={FileTypeJSON}, dag_filter_ruleset=DAGFilterRuleset(ruleset=[basic_dag_filter]), dag_ruleset=DAGRuleset(ruleset=[basic_dag_rule]), task_filter_ruleset=TaskFilterRuleset(ruleset=[basic_task_filter]),