diff --git a/orbiter/__init__.py b/orbiter/__init__.py
index e23af30..01a019c 100644
--- a/orbiter/__init__.py
+++ b/orbiter/__init__.py
@@ -1,20 +1,13 @@
from __future__ import annotations
import re
-from enum import Enum
from typing import Any, Tuple
-__version__ = "1.1.0"
+__version__ = "1.2.0"
version = __version__
-class FileType(Enum):
- YAML = "YAML"
- XML = "XML"
- JSON = "JSON"
-
-
def clean_value(s: str):
"""Cleans a string to be a standard value, such as one that might be a python variable name
diff --git a/orbiter/file_types.py b/orbiter/file_types.py
new file mode 100644
index 0000000..efe4098
--- /dev/null
+++ b/orbiter/file_types.py
@@ -0,0 +1,117 @@
+from __future__ import annotations
+
+import json
+from functools import partial
+from typing import Callable, Set, ClassVar, Any
+
+import xmltodict
+import yaml
+from pydantic import (
+ BaseModel,
+)
+from pydantic.v1 import validator
+
+
+class FileType(BaseModel, arbitrary_types_allowed=True):
+ extension: ClassVar[Set[str]]
+ load_fn: ClassVar[Callable[[str], dict]]
+ dump_fn: ClassVar[Callable[[dict], str]]
+
+ def __hash__(self):
+ return hash(tuple(self.extension))
+
+ @validator("extension", pre=True)
+ @classmethod
+ def ext_validate(cls, v: Set[str]):
+ if not v:
+ raise ValueError("Extension cannot be an empty set")
+ for ext in v:
+ if not isinstance(ext, str):
+ raise ValueError("Extension should be a string")
+ if "." in v:
+ raise ValueError("Extension should not contain '.'")
+ return {ext.lower() for ext in v}
+
+
+class FileTypeJSON(FileType):
+ extension: ClassVar[Set[str]] = {"JSON"}
+ load_fn: ClassVar[Callable[[str], dict]] = json.loads
+ dump_fn: ClassVar[Callable[[dict], str]] = xmltodict.unparse
+
+
+# noinspection t
+def xmltodict_parse(input_str: str) -> Any:
+ """Calls `xmltodict.parse` and does post-processing fixes.
+
+ !!! note
+
+ The original [`xmltodict.parse`](https://pypi.org/project/xmltodict/) method returns EITHER:
+
+ - a dict (one child element of type)
+ - or a list of dict (many child element of type)
+
+ This behavior can be confusing, and is an issue with the original xml spec being referenced.
+
+ **This method deviates by standardizing to the latter case (always a `list[dict]`).**
+
+ **All XML elements will be a list of dictionaries, even if there's only one element.**
+
+ ```pycon
+ >>> xmltodict_parse("")
+ Traceback (most recent call last):
+ xml.parsers.expat.ExpatError: no element found: line 1, column 0
+ >>> xmltodict_parse("")
+ {'a': None}
+ >>> xmltodict_parse("")
+ {'a': [{'@foo': 'bar'}]}
+ >>> xmltodict_parse("") # Singleton - gets modified
+ {'a': [{'@foo': 'bar', 'foo': [{'@bar': 'baz'}]}]}
+ >>> xmltodict_parse("") # Nested Singletons - modified
+ {'a': [{'@foo': 'bar', 'foo': [{'@bar': 'baz', 'bar': [{'bop': None}]}]}]}
+ >>> xmltodict_parse("")
+ {'a': [{'@foo': 'bar', 'foo': [{'@bar': 'baz'}, {'@bing': 'bop'}]}]}
+
+ ```
+ :param input_str: The XML string to parse
+ :type input_str: str
+ :return: The parsed XML
+ :rtype: dict
+ """
+
+ # noinspection t
+ def _fix(d):
+ """fix the dict in place, recursively, standardizing on a list of dict even if there's only one entry."""
+ # if it's a dict, descend to fix
+ if isinstance(d, dict):
+ for k, v in d.items():
+ # @keys are properties of elements, non-@keys are elements
+ if not k.startswith("@"):
+ if isinstance(v, dict):
+ # THE FIX
+ # any non-@keys should be a list of dict, even if there's just one of the element
+ d[k] = [v]
+ _fix(v)
+ else:
+ _fix(v)
+ # if it's a list, descend to fix
+ if isinstance(d, list):
+ for v in d:
+ _fix(v)
+
+ output = xmltodict.parse(input_str)
+ _fix(output)
+ return output
+
+
+class FileTypeXML(FileType):
+ extension: ClassVar[Set[str]] = {"XML"}
+ load_fn: ClassVar[Callable[[str], dict]] = xmltodict_parse
+ dump_fn: ClassVar[Callable[[dict], str]] = partial(
+ json.dumps, default=str, indent=2
+ )
+
+
+class FileTypeYAML(FileType):
+ extension: ClassVar[Set[str]] = {"YAML", "YML"}
+ load_fn: ClassVar[Callable[[str], dict]] = yaml.safe_load
+ dump_fn: ClassVar[Callable[[dict], str]] = yaml.safe_dump
diff --git a/orbiter/rules/rulesets.py b/orbiter/rules/rulesets.py
index aba7cbe..58cc46f 100644
--- a/orbiter/rules/rulesets.py
+++ b/orbiter/rules/rulesets.py
@@ -9,13 +9,23 @@
from itertools import chain
from pathlib import Path
from tempfile import TemporaryDirectory
-from typing import List, Any, Collection, Annotated, Callable, Union, Generator
+from typing import (
+ List,
+ Any,
+ Collection,
+ Annotated,
+ Callable,
+ Union,
+ Generator,
+ Set,
+ Type,
+)
from loguru import logger
from pydantic import BaseModel, AfterValidator, validate_call
-from orbiter import FileType
from orbiter import import_from_qualname
+from orbiter.file_types import FileType, FileTypeJSON
from orbiter.objects.dag import OrbiterDAG
from orbiter.objects.project import OrbiterProject
from orbiter.objects.task import OrbiterOperator, OrbiterTaskDependency
@@ -31,6 +41,15 @@
EMPTY_RULE,
) # noqa: F401
+
+def _backport_walk(input_dir: Path):
+ """Path.walk() is only available in Python 3.12+, so, backport"""
+ import os
+
+ for result in os.walk(input_dir):
+ yield Path(result[0]), result[1], result[2]
+
+
qualname_validator_regex = r"^[\w.]+$"
qualname_validator = re.compile(qualname_validator_regex)
@@ -98,71 +117,6 @@ def validate_qualified_imports(qualified_imports: List[str]) -> List[str]:
]
-# noinspection t
-def xmltodict_parse(input_str: str) -> Any:
- """Calls `xmltodict.parse` and does post-processing fixes.
-
- !!! note
-
- The original [`xmltodict.parse`](https://pypi.org/project/xmltodict/) method returns EITHER:
-
- - a dict (one child element of type)
- - or a list of dict (many child element of type)
-
- This behavior can be confusing, and is an issue with the original xml spec being referenced.
-
- **This method deviates by standardizing to the latter case (always a `list[dict]`).**
-
- **All XML elements will be a list of dictionaries, even if there's only one element.**
-
- ```pycon
- >>> xmltodict_parse("")
- Traceback (most recent call last):
- xml.parsers.expat.ExpatError: no element found: line 1, column 0
- >>> xmltodict_parse("")
- {'a': None}
- >>> xmltodict_parse("")
- {'a': [{'@foo': 'bar'}]}
- >>> xmltodict_parse("") # Singleton - gets modified
- {'a': [{'@foo': 'bar', 'foo': [{'@bar': 'baz'}]}]}
- >>> xmltodict_parse("") # Nested Singletons - modified
- {'a': [{'@foo': 'bar', 'foo': [{'@bar': 'baz', 'bar': [{'bop': None}]}]}]}
- >>> xmltodict_parse("")
- {'a': [{'@foo': 'bar', 'foo': [{'@bar': 'baz'}, {'@bing': 'bop'}]}]}
-
- ```
- :param input_str: The XML string to parse
- :type input_str: str
- :return: The parsed XML
- :rtype: dict
- """
- import xmltodict
-
- # noinspection t
- def _fix(d):
- """fix the dict in place, recursively, standardizing on a list of dict even if there's only one entry."""
- # if it's a dict, descend to fix
- if isinstance(d, dict):
- for k, v in d.items():
- # @keys are properties of elements, non-@keys are elements
- if not k.startswith("@"):
- if isinstance(v, dict):
- # THE FIX
- # any non-@keys should be a list of dict, even if there's just one of the element
- d[k] = [v]
- _fix(v)
- else:
- _fix(v)
- # if it's a list, descend to fix
- if isinstance(d, list):
- for v in d:
- _fix(v)
-
- output = xmltodict.parse(input_str)
- _fix(output)
- return output
-
-
def _add_task_deduped(_task, _tasks, n=""):
"""
If this task_id doesn't already exist, add it as normal to the tasks dictionary.
@@ -621,7 +575,7 @@ class TranslationRuleset(BaseModel, ABC, extra="forbid"):
```pycon
>>> TranslationRuleset(
- ... file_type=FileType.JSON, # Has a file type
+ ... file_type={FileTypeJSON}, # Has a file type
... translate_fn=fake_translate, # and can have a callable
... # translate_fn="orbiter.rules.translate.fake_translate", # or a qualified name to a function
... dag_filter_ruleset={"ruleset": [{"rule": lambda x: None}]}, # Rulesets can be dict within dicts
@@ -635,8 +589,8 @@ class TranslationRuleset(BaseModel, ABC, extra="forbid"):
```
- :param file_type: FileType to translate (`.json`, `.xml`, `.yaml`, etc.)
- :type file_type: FileType
+ :param file_type: FileType to translate
+ :type file_type: Set[Type[FileType]]
:param dag_filter_ruleset: [`DAGFilterRuleset`][orbiter.rules.rulesets.DAGFilterRuleset]
(of [`DAGFilterRule`][orbiter.rules.DAGFilterRule])
:type dag_filter_ruleset: DAGFilterRuleset | dict
@@ -660,7 +614,7 @@ class TranslationRuleset(BaseModel, ABC, extra="forbid"):
:type translate_fn: Callable[[TranslationRuleset, Path], OrbiterProject] | str | TranslateFn
""" # noqa: E501
- file_type: FileType
+ file_type: Set[Type[FileType]]
dag_filter_ruleset: DAGFilterRuleset | dict
dag_ruleset: DAGRuleset | dict
task_filter_ruleset: TaskFilterRuleset | dict
@@ -669,67 +623,54 @@ class TranslationRuleset(BaseModel, ABC, extra="forbid"):
post_processing_ruleset: PostProcessingRuleset | dict
translate_fn: TranslateFn = translate
+ def get_ext(self) -> str:
+ """
+ Get the first file extension for this ruleset
+
+ ```pycon
+ >>> EMPTY_TRANSLATION_RULESET.get_ext()
+ 'JSON'
+
+ ```
+ """
+ return next(iter(next(iter(self.file_type)).extension))
+
@validate_call
- def loads(self, input_str: str) -> dict:
+ def loads(self, file: Path) -> dict:
"""
Converts all files of type into a Python dictionary "intermediate representation" form,
prior to any rulesets being applied.
- | FileType | Conversion Method |
- |----------|-------------------------------------------------------------|
- | `XML` | [`xmltodict_parse`][orbiter.rules.rulesets.xmltodict_parse] |
- | `YAML` | `yaml.safe_load` |
- | `JSON` | `json.loads` |
-
- :param input_str: The string to convert to a dictionary
- :type input_str: str
+ :param file: The file to load
+ :type file: Path
:return: The dictionary representation of the input_str
:rtype: dict
"""
-
- if self.file_type == FileType.JSON:
- import json
-
- return json.loads(input_str)
- elif self.file_type == FileType.YAML:
- import yaml
-
- return yaml.safe_load(input_str)
- elif self.file_type == FileType.XML:
- return xmltodict_parse(input_str)
- else:
- raise NotImplementedError(f"Cannot load file_type={self.file_type}")
+ for file_type in self.file_type:
+ if file.suffix.lower() in {
+ f".{ext.lower()}" for ext in file_type.extension
+ }:
+ return file_type.load_fn(file.read_text())
+ raise TypeError(
+ f"Invalid file_type={file.suffix}, does not match file_type={self.file_type}"
+ )
@validate_call
- def dumps(self, input_dict: dict) -> str:
+ def dumps(self, input_dict: dict, ext: str | None) -> str:
"""
Convert Python dictionary back to source string form, useful for testing
- | FileType | Conversion Method |
- |----------|---------------------|
- | `XML` | `xmltodict.unparse` |
- | `YAML` | `yaml.safe_dump` |
- | `JSON` | `json.dumps` |
-
:param input_dict: The dictionary to convert to a string
:type input_dict: dict
+ :param ext: The file type extension to dump as, defaults to first 'file_type' in the set
+ :type ext: str | None
:return str: The string representation of the input_dict, in the file_type format
:rtype: str
"""
- if self.file_type == FileType.JSON:
- import json
-
- return json.dumps(input_dict, indent=2)
- elif self.file_type == FileType.YAML:
- import yaml
-
- return yaml.safe_dump(input_dict)
- elif self.file_type == FileType.XML:
- import xmltodict
-
- return xmltodict.unparse(input_dict)
- else:
- raise NotImplementedError(f"Cannot dump file_type={self.file_type}")
+ for file_type in self.file_type:
+ if ext is None or ext.lower() in file_type.extension:
+ return file_type.dump_fn(input_dict)
+ raise TypeError(f"Invalid file_type={ext}")
def get_files_with_extension(self, input_dir: Path) -> Generator[Path, dict]:
"""
@@ -740,39 +681,28 @@ def get_files_with_extension(self, input_dir: Path) -> Generator[Path, dict]:
:return: Generator item of (Path, dict) for each file found
:rtype: Generator[Path, dict]
"""
- extension = f".{self.file_type.value.lower()}"
- extensions = [extension]
-
- # YAML and YML are both valid extensions
- extension_sub = {
- "yaml": "yml",
- }
- if other_extension := extension_sub.get(self.file_type.value.lower()):
- extensions.append(f".{other_extension}")
-
- logger.debug(f"Finding files with extension={extensions} in {input_dir}")
-
- def backport_walk(input_dir: Path):
- """Path.walk() is only available in Python 3.12+, so, backport"""
- import os
-
- for result in os.walk(input_dir):
- yield Path(result[0]), result[1], result[2]
-
for directory, _, files in (
- input_dir.walk() if hasattr(input_dir, "walk") else backport_walk(input_dir)
+ input_dir.walk()
+ if hasattr(input_dir, "walk")
+ else _backport_walk(input_dir)
):
logger.debug(f"Checking directory={directory}")
for file in files:
file = directory / file
- if file.suffix.lower() in extensions:
- logger.debug(f"File={file} matches extension={extensions}")
+ # noinspection PyBroadException
+ try:
yield (
# Return the file path
file,
# and load the file and convert it into a python dict
- self.loads(file.read_text()),
+ self.loads(file),
)
+ except TypeError:
+ logger.debug(f"File={file} not of correct type, skipping...")
+ continue
+ except Exception as e:
+ logger.exception(f"Error loading file={file}, {e}")
+ continue
def test(self, input_value: str | dict) -> OrbiterProject:
"""
@@ -788,7 +718,7 @@ def test(self, input_value: str | dict) -> OrbiterProject:
:rtype: OrbiterProject
"""
with TemporaryDirectory() as tempdir:
- file = Path(tempdir) / f"{uuid.uuid4()}.{self.file_type.value}"
+ file = Path(tempdir) / f"{uuid.uuid4()}.{self.get_ext()}"
file.write_text(
self.dumps(input_value)
if isinstance(input_value, dict)
@@ -797,6 +727,18 @@ def test(self, input_value: str | dict) -> OrbiterProject:
return self.translate_fn(translation_ruleset=self, input_dir=file.parent)
+EMPTY_TRANSLATION_RULESET = TranslationRuleset(
+ file_type={FileTypeJSON},
+ dag_filter_ruleset=EMPTY_RULESET,
+ dag_ruleset=EMPTY_RULESET,
+ task_filter_ruleset=EMPTY_RULESET,
+ task_ruleset=EMPTY_RULESET,
+ task_dependency_ruleset=EMPTY_RULESET,
+ post_processing_ruleset=EMPTY_RULESET,
+ translate_fn=fake_translate,
+)
+
+
if __name__ == "__main__":
import doctest
diff --git a/tests/orbiter/rules/rulesets_test.py b/tests/orbiter/rules/rulesets_test.py
index 622741e..f68a7d8 100644
--- a/tests/orbiter/rules/rulesets_test.py
+++ b/tests/orbiter/rules/rulesets_test.py
@@ -1,10 +1,23 @@
-from orbiter import FileType
+from orbiter.file_types import FileTypeYAML
from orbiter.rules.rulesets import TranslationRuleset, EMPTY_RULESET
+def test_loads(project_root):
+ actual = TranslationRuleset(
+ file_type={FileTypeYAML},
+ dag_ruleset=EMPTY_RULESET,
+ dag_filter_ruleset=EMPTY_RULESET,
+ task_filter_ruleset=EMPTY_RULESET,
+ task_ruleset=EMPTY_RULESET,
+ task_dependency_ruleset=EMPTY_RULESET,
+ post_processing_ruleset=EMPTY_RULESET,
+ ).loads(project_root / "tests/resources/test_get_files_with_extension/one.YAML")
+ assert actual == {"one": "foo"}
+
+
def test__get_files_with_extension(project_root):
translation_ruleset = TranslationRuleset(
- file_type=FileType.YAML,
+ file_type={FileTypeYAML},
dag_ruleset=EMPTY_RULESET,
dag_filter_ruleset=EMPTY_RULESET,
task_filter_ruleset=EMPTY_RULESET,
diff --git a/tests/resources/translation_template.py b/tests/resources/translation_template.py
index d8e7498..914521c 100644
--- a/tests/resources/translation_template.py
+++ b/tests/resources/translation_template.py
@@ -1,5 +1,5 @@
from __future__ import annotations
-from orbiter import FileType
+from orbiter.file_types import FileTypeJSON
from orbiter.objects.dag import OrbiterDAG
from orbiter.objects.operators.empty import OrbiterEmptyOperator
from orbiter.objects.project import OrbiterProject
@@ -88,7 +88,7 @@ def basic_post_processing_rule(val: OrbiterProject) -> None:
translation_ruleset = TranslationRuleset(
- file_type=FileType.JSON,
+ file_type={FileTypeJSON},
dag_filter_ruleset=DAGFilterRuleset(ruleset=[basic_dag_filter]),
dag_ruleset=DAGRuleset(ruleset=[basic_dag_rule]),
task_filter_ruleset=TaskFilterRuleset(ruleset=[basic_task_filter]),