Skip to content

Commit

Permalink
Add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajastro committed Oct 17, 2024
1 parent 6f31846 commit c9de757
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 31 deletions.
35 changes: 5 additions & 30 deletions dagfactory/dagbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,7 @@
except ImportError:
from airflow.providers.common.sql.sensors.sql import SqlSensor

# python sensor was moved in Airflow 2.0.0
try:
from airflow.sensors.python import PythonSensor
except ImportError:
from airflow.contrib.sensors.python_sensor import PythonSensor

from airflow.sensors.python import PythonSensor

# k8s libraries are moved in v5.0.0
try:
Expand Down Expand Up @@ -69,29 +64,20 @@
)
from airflow.kubernetes.secret import Secret
from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator
except ImportError:
except ImportError: # pragma: no cover
from airflow.contrib.kubernetes.pod import Port
from airflow.contrib.kubernetes.pod_runtime_info_env import PodRuntimeInfoEnv
from airflow.contrib.kubernetes.secret import Secret
from airflow.contrib.kubernetes.volume import Volume
from airflow.contrib.kubernetes.volume_mount import VolumeMount
from airflow.contrib.operators.kubernetes_pod_operator import KubernetesPodOperator

from airflow.utils.task_group import TaskGroup
from kubernetes.client.models import V1Container, V1Pod

from dagfactory import utils
from dagfactory.exceptions import DagFactoryConfigException, DagFactoryException

# pylint: disable=ungrouped-imports,invalid-name
# Disabling pylint's ungrouped-imports warning because this is a
# conditional import and cannot be done within the import group above
# TaskGroup is introduced in Airflow 2.0.0
if version.parse(AIRFLOW_VERSION) >= version.parse("2.0.0"):
from airflow.utils.task_group import TaskGroup
else:
TaskGroup = None
# pylint: disable=ungrouped-imports,invalid-name

# TimeTable is introduced in Airflow 2.2.0
if version.parse(AIRFLOW_VERSION) >= version.parse("2.2.0"):
from airflow.timetables.base import Timetable
Expand All @@ -104,12 +90,7 @@
else:
MappedOperator = None

# XComArg is introduced in Airflow 2.0.0
if version.parse(AIRFLOW_VERSION) >= version.parse("2.0.0"):
from airflow.models.xcom_arg import XComArg
else:
XComArg = None
# pylint: disable=ungrouped-imports,invalid-name
from airflow.models.xcom_arg import XComArg

if version.parse(AIRFLOW_VERSION) >= version.parse("2.4.0"):
from airflow.datasets import Dataset
Expand Down Expand Up @@ -149,9 +130,6 @@ def get_dag_params(self) -> Dict[str, Any]:
raise DagFactoryConfigException("Failed to merge config with default config") from err
dag_params["dag_id"]: str = self.dag_name

if dag_params.get("task_groups") and version.parse(AIRFLOW_VERSION) < version.parse("2.0.0"):
raise DagFactoryConfigException("`task_groups` key can only be used with Airflow 2.x.x")

if utils.check_dict_key(dag_params, "schedule_interval") and dag_params["schedule_interval"] == "None":
dag_params["schedule_interval"] = None

Expand Down Expand Up @@ -691,10 +669,7 @@ def build(self) -> Dict[str, Union[str, DAG]]:
if not dag_params.get("timetable") and not utils.check_dict_key(dag_params, "schedule"):
dag_kwargs["schedule_interval"] = dag_params.get("schedule_interval", timedelta(days=1))

if version.parse(AIRFLOW_VERSION) >= version.parse("1.10.11"):
dag_kwargs["description"] = dag_params.get("description", None)
else:
dag_kwargs["description"] = dag_params.get("description", "")
dag_kwargs["description"] = dag_params.get("description", None)

if version.parse(AIRFLOW_VERSION) >= version.parse("2.2.0"):
dag_kwargs["max_active_tasks"] = dag_params.get(
Expand Down
30 changes: 29 additions & 1 deletion tests/test_dagbuilder.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import datetime
import os
from pathlib import Path
from unittest.mock import patch
from unittest.mock import mock_open, patch

import pendulum
import pytest
from airflow import DAG
from packaging import version

from dagfactory.dagbuilder import Dataset

try:
from airflow.providers.http.sensors.http import HttpSensor
except ImportError:
Expand Down Expand Up @@ -879,3 +881,29 @@ def test_replace_expand_string_with_xcom():
updated_task_conf_xcomarg = dagbuilder.DagBuilder.replace_expand_values(task_conf_xcomarg, tasks_dict)
assert updated_task_conf_output["expand"]["key_1"] == XComArg(tasks_dict["task_1"])
assert updated_task_conf_xcomarg["expand"]["key_1"] == XComArg(tasks_dict["task_1"])


@pytest.mark.parametrize(
"outlets,output",
[
(
{"datasets": "s3://test/test.txt", "file": "file://path/to/my_file.txt"},
["s3://test/test.txt", "file://path/to/my_file.txt"],
),
(["s3://test/test.txt"], ["s3://test/test.txt"]),
],
)
@patch("dagfactory.dagbuilder.utils.get_datasets_uri_yaml_file", new_callable=mock_open)
def test_make_task_outlets(mock_read_file, outlets, output):
td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG, DEFAULT_CONFIG)
task_params = {
"task_id": "process",
"python_callable_name": "dataset_task",
"python_callable_file": os.path.realpath(__file__),
"outlets": outlets,
}
mock_read_file.return_value = output
if version.parse(AIRFLOW_VERSION) > version.parse("2.4.0"):
operator = "airflow.operators.python_operator.PythonOperator"
actual = td.make_task(operator, task_params)
assert actual.outlets == [Dataset(uri) for uri in output]

0 comments on commit c9de757

Please sign in to comment.