diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4b214ce11..8f61804aa 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -293,13 +293,25 @@ jobs: PYTHONPATH: /home/runner/work/astronomer-cosmos/astronomer-cosmos/:$PYTHONPATH Run-Performance-Tests: + needs: Authorize runs-on: ubuntu-latest strategy: matrix: python-version: ["3.11"] airflow-version: ["2.7"] num-models: [1, 10, 50, 100] - + services: + postgres: + image: postgres + env: + POSTGRES_PASSWORD: postgres + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 steps: - uses: actions/checkout@v3 with: @@ -335,8 +347,14 @@ jobs: AIRFLOW_CONN_AIRFLOW_DB: postgres://postgres:postgres@0.0.0.0:5432/postgres AIRFLOW__CORE__DAGBAG_IMPORT_TIMEOUT: 90.0 PYTHONPATH: /home/runner/work/astronomer-cosmos/astronomer-cosmos/:$PYTHONPATH + COSMOS_CONN_POSTGRES_PASSWORD: ${{ secrets.COSMOS_CONN_POSTGRES_PASSWORD }} + POSTGRES_HOST: localhost + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: postgres + POSTGRES_SCHEMA: public + POSTGRES_PORT: 5432 MODEL_COUNT: ${{ matrix.num-models }} - env: AIRFLOW_HOME: /home/runner/work/astronomer-cosmos/astronomer-cosmos/ AIRFLOW_CONN_AIRFLOW_DB: postgres://postgres:postgres@0.0.0.0:5432/postgres diff --git a/cosmos/config.py b/cosmos/config.py index 52763536f..9dfb672be 100644 --- a/cosmos/config.py +++ b/cosmos/config.py @@ -10,7 +10,14 @@ import warnings from typing import Any, Iterator, Callable -from cosmos.constants import DbtResourceType, TestBehavior, ExecutionMode, LoadMode, TestIndirectSelection +from cosmos.constants import ( + DbtResourceType, + TestBehavior, + ExecutionMode, + LoadMode, + TestIndirectSelection, + InvocationMode, +) from cosmos.dbt.executable import get_system_dbt from cosmos.exceptions import CosmosValueError from cosmos.log import get_logger @@ -295,12 +302,14 @@ class ExecutionConfig: Contains configuration about how to execute dbt. :param execution_mode: The execution mode for dbt. Defaults to local + :param invocation_mode: The invocation mode for the dbt command. This is only configurable for ExecutionMode.LOCAL. :param test_indirect_selection: The mode to configure the test behavior when performing indirect selection. :param dbt_executable_path: The path to the dbt executable for runtime execution. Defaults to dbt if available on the path. :param dbt_project_path: Configures the DBT project location accessible at runtime for dag execution. This is the project path in a docker container for ExecutionMode.DOCKER or ExecutionMode.KUBERNETES. Mutually Exclusive with ProjectConfig.dbt_project_path """ execution_mode: ExecutionMode = ExecutionMode.LOCAL + invocation_mode: InvocationMode | None = None test_indirect_selection: TestIndirectSelection = TestIndirectSelection.EAGER dbt_executable_path: str | Path = field(default_factory=get_system_dbt) @@ -308,4 +317,6 @@ class ExecutionConfig: project_path: Path | None = field(init=False) def __post_init__(self, dbt_project_path: str | Path | None) -> None: + if self.invocation_mode and self.execution_mode != ExecutionMode.LOCAL: + raise CosmosValueError("ExecutionConfig.invocation_mode is only configurable for ExecutionMode.LOCAL.") self.project_path = Path(dbt_project_path) if dbt_project_path else None diff --git a/cosmos/constants.py b/cosmos/constants.py index b5a1f3daa..e8b9cff1d 100644 --- a/cosmos/constants.py +++ b/cosmos/constants.py @@ -54,6 +54,15 @@ class ExecutionMode(Enum): AZURE_CONTAINER_INSTANCE = "azure_container_instance" +class InvocationMode(Enum): + """ + How the dbt command should be invoked. + """ + + SUBPROCESS = "subprocess" + DBT_RUNNER = "dbt_runner" + + class TestIndirectSelection(Enum): """ Modes to configure the test behavior when performing indirect selection. diff --git a/cosmos/converter.py b/cosmos/converter.py index 1bd227a42..bafe094e8 100644 --- a/cosmos/converter.py +++ b/cosmos/converter.py @@ -254,6 +254,8 @@ def __init__( } if execution_config.dbt_executable_path: task_args["dbt_executable_path"] = execution_config.dbt_executable_path + if execution_config.invocation_mode: + task_args["invocation_mode"] = execution_config.invocation_mode validate_arguments( render_config.select, diff --git a/cosmos/dbt/parser/output.py b/cosmos/dbt/parser/output.py index 791c4b605..3690a8f60 100644 --- a/cosmos/dbt/parser/output.py +++ b/cosmos/dbt/parser/output.py @@ -1,33 +1,53 @@ +from __future__ import annotations + import logging import re -from typing import List, Tuple +from typing import List, Tuple, TYPE_CHECKING + +if TYPE_CHECKING: + from dbt.cli.main import dbtRunnerResult from cosmos.hooks.subprocess import FullOutputSubprocessResult -def parse_output(result: FullOutputSubprocessResult, keyword: str) -> int: +DBT_NO_TESTS_MSG = "Nothing to do" +DBT_WARN_MSG = "WARN" + + +def parse_number_of_warnings_subprocess(result: FullOutputSubprocessResult) -> int: """ - Parses the dbt test output message and returns the number of errors or warnings. + Parses the dbt test output message and returns the number of warnings. :param result: String containing the output to be parsed. - :param keyword: String representing the keyword to search for in the output (WARN, ERROR). :return: An integer value associated with the keyword, or 0 if parsing fails. Usage: ----- output_str = "Done. PASS=15 WARN=1 ERROR=0 SKIP=0 TOTAL=16" - keyword = "WARN" - num_warns = parse_output(output_str, keyword) + num_warns = parse_output(output_str) print(num_warns) # Output: 1 """ output = result.output - try: - num = int(output.split(f"{keyword}=")[1].split()[0]) - except ValueError: - logging.error( - f"Could not parse number of {keyword}s. Check your dbt/airflow version or if --quiet is not being used" - ) + num = 0 + if DBT_NO_TESTS_MSG not in result.output and DBT_WARN_MSG in result.output: + try: + num = int(output.split(f"{DBT_WARN_MSG}=")[1].split()[0]) + except ValueError: + logging.error( + f"Could not parse number of {DBT_WARN_MSG}s. Check your dbt/airflow version or if --quiet is not being used" + ) + return num + + +def parse_number_of_warnings_dbt_runner(result: dbtRunnerResult) -> int: + """Parses a dbt runner result and returns the number of warnings found. This only works for dbtRunnerResult + from invoking dbt build, compile, run, seed, snapshot, test, or run-operation. + """ + num = 0 + for run_result in result.result.results: # type: ignore + if run_result.status == "warn": + num += 1 return num @@ -67,3 +87,28 @@ def clean_line(line: str) -> str: test_results.append(test_result) return test_names, test_results + + +def extract_dbt_runner_issues( + result: dbtRunnerResult, status_levels: list[str] = ["warn"] +) -> Tuple[List[str], List[str]]: + """ + Extracts messages from the dbt runner result and returns them as a formatted string. + + This function iterates over dbtRunnerResult messages in dbt run. It extracts results that match the + status levels provided and appends them to a list of issues. + + :param result: dbtRunnerResult object containing the output to be parsed. + :param status_levels: List of strings, where each string is a result status level. Default is ["warn"]. + :return: two lists of strings, the first one containing the node names and the second one + containing the node result message. + """ + node_names = [] + node_results = [] + + for node_result in result.result.results: # type: ignore + if node_result.status in status_levels: + node_names.append(str(node_result.node.name)) + node_results.append(str(node_result.message)) + + return node_names, node_results diff --git a/cosmos/dbt/project.py b/cosmos/dbt/project.py index 889987b6d..144a1f6df 100644 --- a/cosmos/dbt/project.py +++ b/cosmos/dbt/project.py @@ -44,3 +44,16 @@ def environ(env_vars: dict[str, str]) -> Generator[None, None, None]: del os.environ[key] else: os.environ[key] = value + + +@contextmanager +def change_working_directory(path: str) -> Generator[None, None, None]: + """Temporarily changes the working directory to the given path, and then restores + back to the previous value on exit. + """ + previous_cwd = os.getcwd() + os.chdir(path) + try: + yield + finally: + os.chdir(previous_cwd) diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index c6e72ee74..fa4db144d 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -18,6 +18,7 @@ from airflow.models.taskinstance import TaskInstance from airflow.utils.context import Context from airflow.utils.session import NEW_SESSION, create_session, provide_session +from cosmos.constants import InvocationMode try: from openlineage.common.provider.dbt.local import DbtLocalArtifactProcessor @@ -31,6 +32,7 @@ if TYPE_CHECKING: from airflow.datasets import Dataset # noqa: F811 from openlineage.client.run import RunEvent + from dbt.cli.main import dbtRunner, dbtRunnerResult from sqlalchemy.orm import Session @@ -56,11 +58,14 @@ FullOutputSubprocessHook, FullOutputSubprocessResult, ) -from cosmos.dbt.parser.output import extract_log_issues, parse_output -from cosmos.dbt.project import create_symlinks, copy_msgpack_for_partial_parse +from cosmos.dbt.parser.output import ( + extract_dbt_runner_issues, + extract_log_issues, + parse_number_of_warnings_dbt_runner, + parse_number_of_warnings_subprocess, +) +from cosmos.dbt.project import create_symlinks, copy_msgpack_for_partial_parse, environ, change_working_directory -DBT_NO_TESTS_MSG = "Nothing to do" -DBT_WARN_MSG = "WARN" logger = get_logger(__name__) @@ -116,6 +121,7 @@ class DbtLocalBaseOperator(AbstractDbtBaseOperator): def __init__( self, profile_config: ProfileConfig, + invocation_mode: InvocationMode | None = None, install_deps: bool = False, callback: Callable[[str], None] | None = None, should_store_compiled_sql: bool = True, @@ -127,6 +133,12 @@ def __init__( self.compiled_sql = "" self.should_store_compiled_sql = should_store_compiled_sql self.openlineage_events_completes: list[RunEvent] = [] + self.invocation_mode = invocation_mode + self.invoke_dbt: Callable[..., FullOutputSubprocessResult | dbtRunnerResult] + self.handle_exception: Callable[..., None] + self._dbt_runner: dbtRunner | None = None + if self.invocation_mode: + self._set_invocation_methods() kwargs.pop("full_refresh", None) # usage of this param should be implemented in child classes super().__init__(**kwargs) @@ -135,7 +147,31 @@ def subprocess_hook(self) -> FullOutputSubprocessHook: """Returns hook for running the bash command.""" return FullOutputSubprocessHook() - def exception_handling(self, result: FullOutputSubprocessResult) -> None: + def _set_invocation_methods(self) -> None: + """Sets the associated run and exception handling methods based on the invocation mode.""" + if self.invocation_mode == InvocationMode.SUBPROCESS: + self.invoke_dbt = self.run_subprocess + self.handle_exception = self.handle_exception_subprocess + elif self.invocation_mode == InvocationMode.DBT_RUNNER: + self.invoke_dbt = self.run_dbt_runner + self.handle_exception = self.handle_exception_dbt_runner + + def _discover_invocation_mode(self) -> None: + """Discovers the invocation mode based on the availability of dbtRunner for import. If dbtRunner is available, it will + be used since it is faster than subprocess. If dbtRunner is not available, it will fall back to subprocess. + This method is called at runtime to work in the environment where the operator is running. + """ + try: + from dbt.cli.main import dbtRunner + except ImportError: + self.invocation_mode = InvocationMode.SUBPROCESS + logger.info("Could not import dbtRunner. Falling back to subprocess for invoking dbt.") + else: + self.invocation_mode = InvocationMode.DBT_RUNNER + logger.info("dbtRunner is available. Using dbtRunner for invoking dbt.") + self._set_invocation_methods() + + def handle_exception_subprocess(self, result: FullOutputSubprocessResult) -> None: if self.skip_exit_code is not None and result.exit_code == self.skip_exit_code: raise AirflowSkipException(f"dbt command returned exit code {self.skip_exit_code}. Skipping.") elif result.exit_code != 0: @@ -144,6 +180,16 @@ def exception_handling(self, result: FullOutputSubprocessResult) -> None: *result.full_output, ) + def handle_exception_dbt_runner(self, result: dbtRunnerResult) -> None: + """dbtRunnerResult has an attribute `success` that is False if the command failed.""" + if not result.success: + if result.exception: + raise AirflowException(f"dbt invocation did not complete with unhandled error: {result.exception}") + else: + node_names, node_results = extract_dbt_runner_issues(result, ["error", "fail", "runtime error"]) + error_message = "\n".join([f"{name}: {result}" for name, result in zip(node_names, node_results)]) + raise AirflowException(f"dbt invocation completed with errors: {error_message}") + @provide_session def store_compiled_sql(self, tmp_project_dir: str, context: Context, session: Session = NEW_SESSION) -> None: """ @@ -191,26 +237,58 @@ def store_compiled_sql(self, tmp_project_dir: str, context: Context, session: Se else: logger.info("Warning: ti is of type TaskInstancePydantic. Cannot update template_fields.") - def run_subprocess(self, *args: Any, **kwargs: Any) -> FullOutputSubprocessResult: - subprocess_result: FullOutputSubprocessResult = self.subprocess_hook.run_command(*args, **kwargs) + def run_subprocess(self, command: list[str], env: dict[str, str], cwd: str) -> FullOutputSubprocessResult: + logger.info("Trying to run the command:\n %s\nFrom %s", command, cwd) + subprocess_result: FullOutputSubprocessResult = self.subprocess_hook.run_command( + command=command, + env=env, + cwd=cwd, + output_encoding=self.output_encoding, + ) + logger.info(subprocess_result.output) return subprocess_result + def run_dbt_runner(self, command: list[str], env: dict[str, str], cwd: str) -> dbtRunnerResult: + """Invokes the dbt command programmatically.""" + try: + from dbt.cli.main import dbtRunner + except ImportError: + raise ImportError( + "Could not import dbt core. Ensure that dbt-core >= v1.5 is installed and available in the environment where the operator is running." + ) + + if self._dbt_runner is None: + self._dbt_runner = dbtRunner() + + # Exclude the dbt executable path from the command + cli_args = command[1:] + + logger.info("Trying to run dbtRunner with:\n %s\n in %s", cli_args, cwd) + + with change_working_directory(cwd), environ(env): + result = self._dbt_runner.invoke(cli_args) + + return result + def run_command( self, cmd: list[str], env: dict[str, str | bytes | os.PathLike[Any]], context: Context, - ) -> FullOutputSubprocessResult: + ) -> FullOutputSubprocessResult | dbtRunnerResult: """ Copies the dbt project to a temporary directory and runs the command. """ + if not self.invocation_mode: + self._discover_invocation_mode() + with tempfile.TemporaryDirectory() as tmp_project_dir: logger.info( "Cloning project to writable temp directory %s from %s", tmp_project_dir, self.project_dir, ) - + env = {k: str(v) for k, v in env.items()} create_symlinks(Path(self.project_dir), Path(tmp_project_dir), self.install_deps) if self.partial_parse: @@ -232,21 +310,18 @@ def run_command( if self.install_deps: deps_command = [self.dbt_executable_path, "deps"] deps_command.extend(flags) - self.run_subprocess( + self.invoke_dbt( command=deps_command, env=env, - output_encoding=self.output_encoding, cwd=tmp_project_dir, ) full_cmd = cmd + flags - logger.info("Trying to run the command:\n %s\nFrom %s", full_cmd, tmp_project_dir) logger.info("Using environment variables keys: %s", env.keys()) - result = self.run_subprocess( + result = self.invoke_dbt( command=full_cmd, env=env, - output_encoding=self.output_encoding, cwd=tmp_project_dir, ) if is_openlineage_available: @@ -263,7 +338,7 @@ def run_command( self.register_dataset(inlets, outlets) self.store_compiled_sql(tmp_project_dir, context) - self.exception_handling(result) + self.handle_exception(result) if self.callback: self.callback(tmp_project_dir) @@ -373,18 +448,20 @@ def get_openlineage_facets_on_complete(self, task_instance: TaskInstance) -> Ope job_facets=job_facets, ) - def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None) -> FullOutputSubprocessResult: + def build_and_run_cmd( + self, context: Context, cmd_flags: list[str] | None = None + ) -> FullOutputSubprocessResult | dbtRunnerResult: dbt_cmd, env = self.build_cmd(context=context, cmd_flags=cmd_flags) dbt_cmd = dbt_cmd or [] result = self.run_command(cmd=dbt_cmd, env=env, context=context) - logger.info(result.output) return result def on_kill(self) -> None: - if self.cancel_query_on_kill: - self.subprocess_hook.send_sigint() - else: - self.subprocess_hook.send_sigterm() + if self.invocation_mode == InvocationMode.SUBPROCESS: + if self.cancel_query_on_kill: + self.subprocess_hook.send_sigint() + else: + self.subprocess_hook.send_sigterm() class DbtBuildLocalOperator(DbtBuildMixin, DbtLocalBaseOperator): @@ -435,8 +512,10 @@ def __init__( ) -> None: super().__init__(**kwargs) self.on_warning_callback = on_warning_callback + self.extract_issues: Callable[..., tuple[list[str], list[str]]] + self.parse_number_of_warnings: Callable[..., int] - def _handle_warnings(self, result: FullOutputSubprocessResult, context: Context) -> None: + def _handle_warnings(self, result: FullOutputSubprocessResult | dbtRunnerResult, context: Context) -> None: """ Handles warnings by extracting log issues, creating additional context, and calling the on_warning_callback with the updated context. @@ -444,7 +523,7 @@ def _handle_warnings(self, result: FullOutputSubprocessResult, context: Context) :param result: The result object from the build and run command. :param context: The original airflow context in which the build and run command was executed. """ - test_names, test_results = extract_log_issues(result.full_output) + test_names, test_results = self.extract_issues(result) warning_context = dict(context) warning_context["test_names"] = test_names @@ -452,19 +531,21 @@ def _handle_warnings(self, result: FullOutputSubprocessResult, context: Context) self.on_warning_callback and self.on_warning_callback(warning_context) + def _set_test_result_parsing_methods(self) -> None: + """Sets the extract_issues and parse_number_of_warnings methods based on the invocation mode.""" + if self.invocation_mode == InvocationMode.SUBPROCESS: + self.extract_issues = lambda result: extract_log_issues(result.full_output) + self.parse_number_of_warnings = parse_number_of_warnings_subprocess + elif self.invocation_mode == InvocationMode.DBT_RUNNER: + self.extract_issues = extract_dbt_runner_issues + self.parse_number_of_warnings = parse_number_of_warnings_dbt_runner + def execute(self, context: Context) -> None: result = self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags()) - should_trigger_callback = all( - [ - self.on_warning_callback, - DBT_NO_TESTS_MSG not in result.output, - DBT_WARN_MSG in result.output, - ] - ) - if should_trigger_callback: - warnings = parse_output(result, "WARN") - if warnings > 0: - self._handle_warnings(result, context) + self._set_test_result_parsing_methods() + number_of_warnings = self.parse_number_of_warnings(result) # type: ignore + if self.on_warning_callback and number_of_warnings > 0: + self._handle_warnings(result, context) class DbtRunOperationLocalOperator(DbtRunOperationMixin, DbtLocalBaseOperator): diff --git a/cosmos/operators/virtualenv.py b/cosmos/operators/virtualenv.py index bad88e234..6612ab8b8 100644 --- a/cosmos/operators/virtualenv.py +++ b/cosmos/operators/virtualenv.py @@ -85,11 +85,16 @@ def venv_dbt_path( self.log.info("Using dbt version %s available at %s", dbt_version, dbt_binary) return str(dbt_binary) - def run_subprocess(self, *args: Any, command: list[str], **kwargs: Any) -> FullOutputSubprocessResult: + def run_subprocess(self, command: list[str], env: dict[str, str], cwd: str) -> FullOutputSubprocessResult: if self.py_requirements: command[0] = self.venv_dbt_path - subprocess_result: FullOutputSubprocessResult = self.subprocess_hook.run_command(command, *args, **kwargs) + subprocess_result: FullOutputSubprocessResult = self.subprocess_hook.run_command( + command=command, + env=env, + cwd=cwd, + output_encoding=self.output_encoding, + ) return subprocess_result def execute(self, context: Context) -> None: diff --git a/dev/dags/basic_cosmos_task_group.py b/dev/dags/basic_cosmos_task_group.py index 4b6aae71e..06b24f291 100644 --- a/dev/dags/basic_cosmos_task_group.py +++ b/dev/dags/basic_cosmos_task_group.py @@ -12,6 +12,7 @@ from cosmos import DbtTaskGroup, ProjectConfig, ProfileConfig, RenderConfig, ExecutionConfig from cosmos.profiles import PostgresUserPasswordProfileMapping +from cosmos.constants import InvocationMode DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt" DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) @@ -25,7 +26,7 @@ ), ) -shared_execution_config = ExecutionConfig() +shared_execution_config = ExecutionConfig(invocation_mode=InvocationMode.DBT_RUNNER) @dag( diff --git a/dev/dags/dbt/perf/profiles.yml b/dev/dags/dbt/perf/profiles.yml index 5b3cf175d..224f565f4 100644 --- a/dev/dags/dbt/perf/profiles.yml +++ b/dev/dags/dbt/perf/profiles.yml @@ -1,11 +1,12 @@ -simple: +default: target: dev outputs: dev: - type: sqlite - threads: 1 - database: "database" - schema: "main" - schemas_and_paths: - main: "{{ env_var('DBT_SQLITE_PATH') }}/imdb.db" - schema_directory: "{{ env_var('DBT_SQLITE_PATH') }}" + type: postgres + host: "{{ env_var('POSTGRES_HOST') }}" + user: "{{ env_var('POSTGRES_USER') }}" + password: "{{ env_var('POSTGRES_PASSWORD') }}" + port: "{{ env_var('POSTGRES_PORT') | int }}" + dbname: "{{ env_var('POSTGRES_DB') }}" + schema: "{{ env_var('POSTGRES_SCHEMA') }}" + threads: 4 diff --git a/dev/dags/performance_dag.py b/dev/dags/performance_dag.py index caf977817..fec5175c8 100644 --- a/dev/dags/performance_dag.py +++ b/dev/dags/performance_dag.py @@ -1,28 +1,31 @@ """ -A DAG that uses Cosmos to render a dbt project for performance testing. +An airflow DAG that uses Cosmos to render a dbt project for performance testing. """ -import airflow from datetime import datetime import os from pathlib import Path from cosmos import DbtDag, ProjectConfig, ProfileConfig, RenderConfig +from cosmos.profiles import PostgresUserPasswordProfileMapping + DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt" DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) -DBT_SQLITE_PATH = str(DEFAULT_DBT_ROOT_PATH / "data") + profile_config = ProfileConfig( - profile_name="simple", + profile_name="default", target_name="dev", - profiles_yml_filepath=(DBT_ROOT_PATH / "simple/profiles.yml"), + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="airflow_db", + profile_args={"schema": "public"}, + ), ) cosmos_perf_dag = DbtDag( project_config=ProjectConfig( DBT_ROOT_PATH / "perf", - env_vars={"DBT_SQLITE_PATH": DBT_SQLITE_PATH}, ), profile_config=profile_config, render_config=RenderConfig( diff --git a/docs/configuration/execution-config.rst b/docs/configuration/execution-config.rst index 23b511e37..dd9758d55 100644 --- a/docs/configuration/execution-config.rst +++ b/docs/configuration/execution-config.rst @@ -7,6 +7,7 @@ It does this by exposing a ``cosmos.config.ExecutionConfig`` class that you can The ``ExecutionConfig`` class takes the following arguments: - ``execution_mode``: The way dbt is run when executing within airflow. For more information, see the `execution modes <../getting_started/execution-modes.html>`_ page. +- ``invocation_mode`` (new in v1.4): The way dbt is invoked within the execution mode. This is only configurable for ``ExecutionMode.LOCAL``. For more information, see `invocation modes <../getting_started/execution-modes.html#invocation-modes>`_. - ``test_indirect_selection``: The mode to configure the test behavior when performing indirect selection. - ``dbt_executable_path``: The path to the dbt executable for dag generation. Defaults to dbt if available on the path. - ``dbt_project_path``: Configures the dbt project location accessible at runtime for dag execution. This is the project path in a docker container for ``ExecutionMode.DOCKER`` or ``ExecutionMode.KUBERNETES``. Mutually exclusive with ``ProjectConfig.dbt_project_path``. diff --git a/docs/getting_started/execution-modes.rst b/docs/getting_started/execution-modes.rst index 7c7417cc7..8f7013572 100644 --- a/docs/getting_started/execution-modes.rst +++ b/docs/getting_started/execution-modes.rst @@ -184,3 +184,33 @@ Each task will create a new container on Azure, giving full isolation. This, how "image": "dbt-jaffle-shop:1.0.0", }, ) + + +.. _invocation_modes: +Invocation Modes +================ +.. versionadded:: 1.4 + +For ``ExecutionMode.LOCAL`` execution mode, Cosmos supports two invocation modes for running dbt: + +1. ``InvocationMode.SUBPROCESS``: In this mode, Cosmos runs dbt cli commands using the Python ``subprocess`` module and parses the output to capture logs and to raise exceptions. + +2. ``InvocationMode.DBT_RUNNER``: In this mode, Cosmos uses the ``dbtRunner`` available for `dbt programmatic invocations `__ to run dbt commands. \ + In order to use this mode, dbt must be installed in the same local environment. This mode does not have the overhead of spawning new subprocesses or parsing the output of dbt commands and is faster than ``InvocationMode.SUBPROCESS``. \ + This mode requires dbt version 1.5.0 or higher. It is up to the user to resolve :ref:`execution-modes-local-conflicts` when using this mode. + +The invocation mode can be set in the ``ExecutionConfig`` as shown below: + +.. code-block:: python + + from cosmos.constants import InvocationMode + + dag = DbtDag( + # ... + execution_config=ExecutionConfig( + execution_mode=ExecutionMode.LOCAL, + invocation_mode=InvocationMode.DBT_RUNNER, + ), + ) + +If the invocation mode is not set, Cosmos will attempt to use ``InvocationMode.DBT_RUNNER`` if dbt is installed in the same environment as the worker, otherwise it will fall back to ``InvocationMode.SUBPROCESS``. diff --git a/scripts/test/performance-setup.sh b/scripts/test/performance-setup.sh index b8bce035c..7efb917c1 100644 --- a/scripts/test/performance-setup.sh +++ b/scripts/test/performance-setup.sh @@ -1,4 +1,4 @@ -pip uninstall -y dbt-core dbt-sqlite openlineage-airflow openlineage-integration-common; \ +pip uninstall -y dbt-core dbt-sqlite dbt-postgres openlineage-airflow openlineage-integration-common; \ rm -rf airflow.*; \ airflow db init; \ -pip install 'dbt-core==1.4' 'dbt-sqlite<=1.4' 'dbt-databricks<=1.4' 'dbt-postgres<=1.4' +pip install 'dbt-postgres' diff --git a/tests/dbt/parser/test_output.py b/tests/dbt/parser/test_output.py index 0f4ba56cd..9fae4d3b3 100644 --- a/tests/dbt/parser/test_output.py +++ b/tests/dbt/parser/test_output.py @@ -1,18 +1,52 @@ +import pytest +import logging +from unittest.mock import MagicMock from airflow.hooks.subprocess import SubprocessResult from cosmos.dbt.parser.output import ( + extract_dbt_runner_issues, extract_log_issues, - parse_output, + parse_number_of_warnings_subprocess, + parse_number_of_warnings_dbt_runner, ) -def test_parse_output() -> None: - for warnings in range(0, 3): - output_str = f"Done. PASS=15 WARN={warnings} ERROR=0 SKIP=0 TOTAL=16" - keyword = "WARN" +@pytest.mark.parametrize( + "output_str, expected_warnings", + [ + ("Done. PASS=15 WARN=1 ERROR=0 SKIP=0 TOTAL=16", 1), + ("Done. PASS=15 WARN=0 ERROR=0 SKIP=0 TOTAL=16", 0), + ("Done. PASS=15 WARN=2 ERROR=0 SKIP=0 TOTAL=16", 2), + ("Nothing to do. Exiting without running tests.", 0), + ], +) +def test_parse_number_of_warnings_subprocess(output_str: str, expected_warnings): + result = SubprocessResult(exit_code=0, output=output_str) + num_warns = parse_number_of_warnings_subprocess(result) + assert num_warns == expected_warnings + + +def test_parse_number_of_warnings_subprocess_error_logged(caplog): + output_str = "WARN= should log an error." + with caplog.at_level(logging.ERROR): result = SubprocessResult(exit_code=0, output=output_str) - num_warns = parse_output(result, keyword) - assert num_warns == warnings + parse_number_of_warnings_subprocess(result) + expected_error_log = ( + "Could not parse number of WARNs. Check your dbt/airflow version or if --quiet is not being used" + ) + assert expected_error_log in caplog.text + + +def test_parse_number_of_warnings_dbt_runner_with_warnings(): + runner_result = MagicMock() + runner_result.result.results = [ + MagicMock(status="pass"), + MagicMock(status="warn"), + MagicMock(status="pass"), + MagicMock(status="warn"), + ] + num_warns = parse_number_of_warnings_dbt_runner(runner_result) + assert num_warns == 2 def test_extract_log_issues() -> None: @@ -37,3 +71,43 @@ def test_extract_log_issues() -> None: test_names_no_warns, test_results_no_warns = extract_log_issues(log_list_no_warning) assert test_names_no_warns == [] assert test_results_no_warns == [] + + +def test_extract_dbt_runner_issues(): + """Tests that the function extracts the correct node names and messages from a dbt runner result + for warnings by default. + """ + runner_result = MagicMock() + runner_result.result.results = [ + MagicMock(status="pass"), + MagicMock(status="warn", message="A warning message", node=MagicMock()), + MagicMock(status="pass"), + MagicMock(status="warn", message="A different warning message", node=MagicMock()), + ] + runner_result.result.results[1].node.name = "a_test" + runner_result.result.results[3].node.name = "another_test" + + node_names, node_results = extract_dbt_runner_issues(runner_result) + + assert node_names == ["a_test", "another_test"] + assert node_results == ["A warning message", "A different warning message"] + + +def test_extract_dbt_runner_issues_with_status_levels(): + """Tests that the function extracts the correct test names and results from a dbt runner result + for status levels. + """ + runner_result = MagicMock() + runner_result.result.results = [ + MagicMock(status="pass"), + MagicMock(status="error", message="An error message", node=MagicMock()), + MagicMock(status="warn"), + MagicMock(status="fail", message="A failure message", node=MagicMock()), + ] + runner_result.result.results[1].node.name = "node1" + runner_result.result.results[3].node.name = "node2" + + node_names, node_results = extract_dbt_runner_issues(runner_result, status_levels=["error", "fail"]) + + assert node_names == ["node1", "node2"] + assert node_results == ["An error message", "A failure message"] diff --git a/tests/dbt/test_project.py b/tests/dbt/test_project.py index 85314b8e5..a3cd30819 100644 --- a/tests/dbt/test_project.py +++ b/tests/dbt/test_project.py @@ -4,7 +4,7 @@ import pytest -from cosmos.dbt.project import create_symlinks, copy_msgpack_for_partial_parse, environ +from cosmos.dbt.project import create_symlinks, copy_msgpack_for_partial_parse, environ, change_working_directory DBT_PROJECTS_ROOT_DIR = Path(__file__).parent.parent.parent / "dev/dags/dbt" @@ -60,3 +60,18 @@ def test_environ_context_manager(): # Check if the original environment variables are still set assert "value1" == os.environ.get("VAR1") assert "value2" == os.environ.get("VAR2") + + +@patch("os.chdir") +def test_change_working_directory(mock_chdir): + """Tests that the working directory is changed and then restored correctly.""" + # Define the path to change the working directory to + path = "/path/to/directory" + + # Use the change_working_directory context manager + with change_working_directory(path): + # Check if os.chdir is called with the correct path + mock_chdir.assert_called_once_with(path) + + # Check if os.chdir is called with the previous working directory + mock_chdir.assert_called_with(os.getcwd()) diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index 90585cc95..9a938bca9 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -31,8 +31,13 @@ DbtRunOperationLocalOperator, ) from cosmos.profiles import PostgresUserPasswordProfileMapping +from cosmos.constants import InvocationMode from tests.utils import test_dag as run_test_dag - +from cosmos.dbt.parser.output import ( + extract_dbt_runner_issues, + parse_number_of_warnings_subprocess, + parse_number_of_warnings_dbt_runner, +) DBT_PROJ_DIR = Path(__file__).parent.parent.parent / "dev/dags/dbt/jaffle_shop" MINI_DBT_PROJ_DIR = Path(__file__).parent.parent / "sample/mini" @@ -122,6 +127,51 @@ def test_dbt_base_operator_add_user_supplied_global_flags() -> None: assert cmd[-1] == "cmd" +@pytest.mark.parametrize( + "invocation_mode, invoke_dbt_method, handle_exception_method", + [ + (InvocationMode.SUBPROCESS, "run_subprocess", "handle_exception_subprocess"), + (InvocationMode.DBT_RUNNER, "run_dbt_runner", "handle_exception_dbt_runner"), + ], +) +def test_dbt_base_operator_set_invocation_methods(invocation_mode, invoke_dbt_method, handle_exception_method): + """Tests that the right methods are mapped to DbtLocalBaseOperator.invoke_dbt and + DbtLocalBaseOperator.handle_exception when a known invocation mode passed. + """ + dbt_base_operator = ConcreteDbtLocalBaseOperator( + profile_config=profile_config, task_id="my-task", project_dir="my/dir", invocation_mode=invocation_mode + ) + dbt_base_operator._set_invocation_methods() + assert dbt_base_operator.invoke_dbt.__name__ == invoke_dbt_method + assert dbt_base_operator.handle_exception.__name__ == handle_exception_method + + +@pytest.mark.parametrize( + "can_import_dbt, invoke_dbt_method, handle_exception_method", + [ + (False, "run_subprocess", "handle_exception_subprocess"), + (True, "run_dbt_runner", "handle_exception_dbt_runner"), + ], +) +def test_dbt_base_operator_discover_invocation_mode(can_import_dbt, invoke_dbt_method, handle_exception_method): + """Tests that the right methods are mapped to DbtLocalBaseOperator.invoke_dbt and + DbtLocalBaseOperator.handle_exception if dbt can be imported or not. + """ + dbt_base_operator = ConcreteDbtLocalBaseOperator( + profile_config=profile_config, task_id="my-task", project_dir="my/dir" + ) + with patch.dict(sys.modules, {"dbt.cli.main": MagicMock()} if can_import_dbt else {"dbt.cli.main": None}): + dbt_base_operator = ConcreteDbtLocalBaseOperator( + profile_config=profile_config, task_id="my-task", project_dir="my/dir" + ) + dbt_base_operator._discover_invocation_mode() + assert dbt_base_operator.invocation_mode == ( + InvocationMode.DBT_RUNNER if can_import_dbt else InvocationMode.SUBPROCESS + ) + assert dbt_base_operator.invoke_dbt.__name__ == invoke_dbt_method + assert dbt_base_operator.handle_exception.__name__ == handle_exception_method + + @pytest.mark.parametrize( "indirect_selection_type", [None, "cautious", "buildable", "empty"], @@ -145,6 +195,69 @@ def test_dbt_base_operator_use_indirect_selection(indirect_selection_type) -> No assert cmd[1] == "cmd" +def test_dbt_base_operator_run_dbt_runner_cannot_import(): + """Tests that the right error message is raised if dbtRunner cannot be imported.""" + dbt_base_operator = ConcreteDbtLocalBaseOperator( + profile_config=profile_config, + task_id="my-task", + project_dir="my/dir", + invocation_mode=InvocationMode.DBT_RUNNER, + ) + expected_error_message = "Could not import dbt core. Ensure that dbt-core >= v1.5 is installed and available in the environment where the operator is running." + with patch.dict(sys.modules, {"dbt.cli.main": None}): + with pytest.raises(ImportError, match=expected_error_message): + dbt_base_operator.run_dbt_runner(command=["cmd"], env={}, cwd="some-project") + + +@patch("cosmos.dbt.project.os.environ") +@patch("cosmos.dbt.project.os.chdir") +def test_dbt_base_operator_run_dbt_runner(mock_chdir, mock_environ): + """Tests that dbtRunner.invoke() is called with the expected cli args, that the + cwd is changed to the expected directory, and env variables are set.""" + dbt_base_operator = ConcreteDbtLocalBaseOperator( + profile_config=profile_config, + task_id="my-task", + project_dir="my/dir", + invocation_mode=InvocationMode.DBT_RUNNER, + ) + full_dbt_cmd = ["dbt", "run", "some_model"] + env_vars = {"VAR1": "value1", "VAR2": "value2"} + + mock_dbt = MagicMock() + with patch.dict(sys.modules, {"dbt.cli.main": mock_dbt}): + dbt_base_operator.run_dbt_runner(command=full_dbt_cmd, env=env_vars, cwd="some-dir") + + mock_dbt_runner = mock_dbt.dbtRunner.return_value + expected_cli_args = ["run", "some_model"] + # Assert dbtRunner.invoke was called with the expected cli args + assert mock_dbt_runner.invoke.call_count == 1 + assert mock_dbt_runner.invoke.call_args[0][0] == expected_cli_args + # Assert cwd was changed to the expected directory + assert mock_chdir.call_count == 2 + assert mock_chdir.call_args_list[0][0][0] == "some-dir" + # Assert env variables were updated + assert mock_environ.update.call_count == 1 + assert mock_environ.update.call_args[0][0] == env_vars + + +@patch("cosmos.dbt.project.os.chdir") +def test_dbt_base_operator_run_dbt_runner_is_cached(mock_chdir): + """Tests that if run_dbt_runner is called multiple times a cached runner is used.""" + dbt_base_operator = ConcreteDbtLocalBaseOperator( + profile_config=profile_config, + task_id="my-task", + project_dir="my/dir", + invocation_mode=InvocationMode.DBT_RUNNER, + ) + mock_dbt = MagicMock() + with patch.dict(sys.modules, {"dbt.cli.main": mock_dbt}): + for _ in range(3): + dbt_base_operator.run_dbt_runner(command=["cmd"], env={}, cwd="some-project") + mock_dbt_runner = mock_dbt.dbtRunner + assert mock_dbt_runner.call_count == 1 + assert dbt_base_operator._dbt_runner is not None + + @pytest.mark.parametrize( ["skip_exception", "exception_code_returned", "expected_exception"], [ @@ -158,17 +271,56 @@ def test_dbt_base_operator_use_indirect_selection(indirect_selection_type) -> No "No exception raised", ], ) -def test_dbt_base_operator_exception_handling(skip_exception, exception_code_returned, expected_exception) -> None: +def test_dbt_base_operator_exception_handling_subprocess( + skip_exception, exception_code_returned, expected_exception +) -> None: dbt_base_operator = ConcreteDbtLocalBaseOperator( profile_config=profile_config, task_id="my-task", project_dir="my/dir", + invocation_mode=InvocationMode.SUBPROCESS, ) if expected_exception: with pytest.raises(expected_exception): - dbt_base_operator.exception_handling(SubprocessResult(exception_code_returned, None)) + dbt_base_operator.handle_exception(SubprocessResult(exception_code_returned, None)) else: - dbt_base_operator.exception_handling(SubprocessResult(exception_code_returned, None)) + dbt_base_operator.handle_exception(SubprocessResult(exception_code_returned, None)) + + +def test_dbt_base_operator_handle_exception_dbt_runner_unhandled_error(): + """Tests that an AirflowException is raised if the dbtRunner result is not successful with an unhandled error.""" + operator = ConcreteDbtLocalBaseOperator( + profile_config=MagicMock(), + task_id="my-task", + project_dir="my/dir", + ) + result = MagicMock() + result.success = False + result.exception = "some exception" + expected_error_message = "dbt invocation did not complete with unhandled error: some exception" + + with pytest.raises(AirflowException, match=expected_error_message): + operator.handle_exception_dbt_runner(result) + + +@patch("cosmos.operators.local.extract_dbt_runner_issues", return_value=(["node1", "node2"], ["error1", "error2"])) +def test_dbt_base_operator_handle_exception_dbt_runner_handled_error(mock_extract_dbt_runner_issues): + """Tests that an AirflowException is raised if the dbtRunner result is not successful and with handled errors.""" + operator = ConcreteDbtLocalBaseOperator( + profile_config=MagicMock(), + task_id="my-task", + project_dir="my/dir", + ) + result = MagicMock() + result.success = False + result.exception = None + + expected_error_message = "dbt invocation completed with errors: node1: error1\nnode2: error2" + + with pytest.raises(AirflowException, match=expected_error_message): + operator.handle_exception_dbt_runner(result) + + mock_extract_dbt_runner_issues.assert_called_once() @patch("cosmos.operators.base.context_to_airflow_vars") @@ -201,6 +353,33 @@ def test_dbt_base_operator_get_env(p_context_to_airflow_vars: MagicMock) -> None assert env == expected_env +@patch("cosmos.operators.local.extract_log_issues") +def test_dbt_test_local_operator_invocation_mode_methods(mock_extract_log_issues): + # test subprocess invocation mode + operator = DbtTestLocalOperator( + profile_config=profile_config, + invocation_mode=InvocationMode.SUBPROCESS, + task_id="my-task", + project_dir="my/dir", + ) + operator._set_test_result_parsing_methods() + assert operator.parse_number_of_warnings == parse_number_of_warnings_subprocess + result = MagicMock(full_output="some output") + operator.extract_issues(result) + mock_extract_log_issues.assert_called_once_with("some output") + + # test dbt runner invocation mode + operator = DbtTestLocalOperator( + profile_config=profile_config, + invocation_mode=InvocationMode.DBT_RUNNER, + task_id="my-task", + project_dir="my/dir", + ) + operator._set_test_result_parsing_methods() + assert operator.extract_issues == extract_dbt_runner_issues + assert operator.parse_number_of_warnings == parse_number_of_warnings_dbt_runner + + @pytest.mark.skipif( version.parse(airflow_version) < version.parse("2.4"), reason="Airflow DAG did not have datasets until the 2.4 release", @@ -259,7 +438,8 @@ def test_dbt_base_operator_no_partial_parse() -> None: @pytest.mark.integration -def test_run_test_operator_with_callback(failing_test_dbt_project): +@pytest.mark.parametrize("invocation_mode", [InvocationMode.SUBPROCESS, InvocationMode.DBT_RUNNER]) +def test_run_test_operator_with_callback(invocation_mode, failing_test_dbt_project): on_warning_callback = MagicMock() with DAG("test-id-2", start_date=datetime(2022, 1, 1)) as dag: @@ -275,6 +455,7 @@ def test_run_test_operator_with_callback(failing_test_dbt_project): task_id="test", append_env=True, on_warning_callback=on_warning_callback, + invocation_mode=invocation_mode, ) run_operator >> test_operator run_test_dag(dag) @@ -282,7 +463,8 @@ def test_run_test_operator_with_callback(failing_test_dbt_project): @pytest.mark.integration -def test_run_test_operator_without_callback(): +@pytest.mark.parametrize("invocation_mode", [InvocationMode.SUBPROCESS, InvocationMode.DBT_RUNNER]) +def test_run_test_operator_without_callback(invocation_mode): on_warning_callback = MagicMock() with DAG("test-id-3", start_date=datetime(2022, 1, 1)) as dag: @@ -291,6 +473,7 @@ def test_run_test_operator_without_callback(): project_dir=MINI_DBT_PROJ_DIR, task_id="run", append_env=True, + invocation_mode=invocation_mode, ) test_operator = DbtTestLocalOperator( profile_config=mini_profile_config, @@ -298,6 +481,7 @@ def test_run_test_operator_without_callback(): task_id="test", append_env=True, on_warning_callback=on_warning_callback, + invocation_mode=invocation_mode, ) run_operator >> test_operator run_test_dag(dag) @@ -403,7 +587,13 @@ def test_store_compiled_sql() -> None: ) @patch("cosmos.operators.local.DbtLocalBaseOperator.build_and_run_cmd") def test_operator_execute_with_flags(mock_build_and_run_cmd, operator_class, kwargs, expected_call_kwargs): - task = operator_class(profile_config=profile_config, task_id="my-task", project_dir="my/dir", **kwargs) + task = operator_class( + profile_config=profile_config, + task_id="my-task", + project_dir="my/dir", + invocation_mode=InvocationMode.DBT_RUNNER, + **kwargs, + ) task.execute(context={}) mock_build_and_run_cmd.assert_called_once_with(**expected_call_kwargs) @@ -432,6 +622,7 @@ def test_operator_execute_without_flags(mock_build_and_run_cmd, operator_class): profile_config=profile_config, task_id="my-task", project_dir="my/dir", + invocation_mode=InvocationMode.DBT_RUNNER, **operator_class_kwargs.get(operator_class, {}), ) task.execute(context={}) @@ -497,11 +688,18 @@ def test_dbt_docs_gcs_local_operator(): @patch("cosmos.operators.local.DbtLocalBaseOperator.store_compiled_sql") -@patch("cosmos.operators.local.DbtLocalBaseOperator.exception_handling") +@patch("cosmos.operators.local.DbtLocalBaseOperator.handle_exception_subprocess") @patch("cosmos.config.ProfileConfig.ensure_profile") @patch("cosmos.operators.local.DbtLocalBaseOperator.run_subprocess") +@patch("cosmos.operators.local.DbtLocalBaseOperator.run_dbt_runner") +@pytest.mark.parametrize("invocation_mode", [InvocationMode.SUBPROCESS, InvocationMode.DBT_RUNNER]) def test_operator_execute_deps_parameters( - mock_build_and_run_cmd, mock_ensure_profile, mock_exception_handling, mock_store_compiled_sql + mock_dbt_runner, + mock_subprocess, + mock_ensure_profile, + mock_exception_handling, + mock_store_compiled_sql, + invocation_mode, ): expected_call_kwargs = [ "/usr/local/bin/dbt", @@ -520,10 +718,14 @@ def test_operator_execute_deps_parameters( install_deps=True, emit_datasets=False, dbt_executable_path="/usr/local/bin/dbt", + invocation_mode=invocation_mode, ) mock_ensure_profile.return_value.__enter__.return_value = (Path("/path/to/profile"), {"ENV_VAR": "value"}) task.execute(context={"task_instance": MagicMock()}) - assert mock_build_and_run_cmd.call_args_list[0].kwargs["command"] == expected_call_kwargs + if invocation_mode == InvocationMode.SUBPROCESS: + assert mock_subprocess.call_args_list[0].kwargs["command"] == expected_call_kwargs + elif invocation_mode == InvocationMode.DBT_RUNNER: + mock_dbt_runner.all_args_list[0].kwargs["command"] == expected_call_kwargs def test_dbt_docs_local_operator_with_static_flag(): @@ -541,7 +743,11 @@ def test_dbt_docs_local_operator_with_static_flag(): def test_dbt_local_operator_on_kill_sigint(mock_send_sigint) -> None: dbt_base_operator = ConcreteDbtLocalBaseOperator( - profile_config=profile_config, task_id="my-task", project_dir="my/dir", cancel_query_on_kill=True + profile_config=profile_config, + task_id="my-task", + project_dir="my/dir", + cancel_query_on_kill=True, + invocation_mode=InvocationMode.SUBPROCESS, ) dbt_base_operator.on_kill() @@ -553,7 +759,11 @@ def test_dbt_local_operator_on_kill_sigint(mock_send_sigint) -> None: def test_dbt_local_operator_on_kill_sigterm(mock_send_sigterm) -> None: dbt_base_operator = ConcreteDbtLocalBaseOperator( - profile_config=profile_config, task_id="my-task", project_dir="my/dir", cancel_query_on_kill=False + profile_config=profile_config, + task_id="my-task", + project_dir="my/dir", + cancel_query_on_kill=False, + invocation_mode=InvocationMode.SUBPROCESS, ) dbt_base_operator.on_kill() diff --git a/tests/operators/test_virtualenv.py b/tests/operators/test_virtualenv.py index 86796308b..036f162de 100644 --- a/tests/operators/test_virtualenv.py +++ b/tests/operators/test_virtualenv.py @@ -7,6 +7,7 @@ from cosmos.config import ProfileConfig from cosmos.profiles import PostgresUserPasswordProfileMapping +from cosmos.constants import InvocationMode profile_config = ProfileConfig( profile_name="default", @@ -25,7 +26,7 @@ class ConcreteDbtVirtualenvBaseOperator(DbtVirtualenvBaseOperator): @patch("airflow.utils.python_virtualenv.execute_in_subprocess") @patch("cosmos.operators.virtualenv.DbtLocalBaseOperator.calculate_openlineage_events_completes") @patch("cosmos.operators.virtualenv.DbtLocalBaseOperator.store_compiled_sql") -@patch("cosmos.operators.virtualenv.DbtLocalBaseOperator.exception_handling") +@patch("cosmos.operators.virtualenv.DbtLocalBaseOperator.handle_exception_subprocess") @patch("cosmos.operators.virtualenv.DbtLocalBaseOperator.subprocess_hook") @patch("airflow.hooks.base.BaseHook.get_connection") def test_run_command( @@ -53,6 +54,7 @@ def test_run_command( py_system_site_packages=False, py_requirements=["dbt-postgres==1.6.0b1"], emit_datasets=False, + invocation_mode=InvocationMode.SUBPROCESS, ) assert venv_operator._venv_tmp_dir is None # Otherwise we are creating empty directories during DAG parsing time # and not deleting them @@ -60,12 +62,12 @@ def test_run_command( run_command_args = mock_subprocess_hook.run_command.call_args_list assert len(run_command_args) == 3 python_cmd = run_command_args[0] - dbt_deps = run_command_args[1] - dbt_cmd = run_command_args[2] + dbt_deps = run_command_args[1].kwargs + dbt_cmd = run_command_args[2].kwargs assert python_cmd[0][0][0].endswith("/bin/python") assert python_cmd[0][-1][-1] == "from importlib.metadata import version; print(version('dbt-core'))" - assert dbt_deps[0][0][1] == "deps" - assert dbt_deps[0][0][0].endswith("/bin/dbt") - assert dbt_deps[0][0][0] == dbt_cmd[0][0][0] - assert dbt_cmd[0][0][1] == "do-something" + assert dbt_deps["command"][1] == "deps" + assert dbt_deps["command"][0].endswith("/bin/dbt") + assert dbt_deps["command"][0] == dbt_cmd["command"][0] + assert dbt_cmd["command"][1] == "do-something" assert mock_execute.call_count == 2 diff --git a/tests/perf/test_performance.py b/tests/perf/test_performance.py index acf5d3544..81b08d8bd 100644 --- a/tests/perf/test_performance.py +++ b/tests/perf/test_performance.py @@ -109,14 +109,18 @@ def test_perf_dag(): # measure the time before and after the dag is run start = time.time() - dag.test() + dag_run = dag.test() end = time.time() - print(f"Ran {num_models} models in {end - start} seconds") - print(f"NUM_MODELS={num_models}\nTIME={end - start}") - - # write the results to a file - with open("/tmp/performance_results.txt", "w") as f: - f.write( - f"NUM_MODELS={num_models}\nTIME={end - start}\nMODELS_PER_SECOND={num_models / (end - start)}\nDBT_VERSION={DBT_VERSION}" - ) + # assert the dag run was successful before writing the results + if dag_run.state == "success": + print(f"Ran {num_models} models in {end - start} seconds") + print(f"NUM_MODELS={num_models}\nTIME={end - start}") + + # write the results to a file + with open("/tmp/performance_results.txt", "w") as f: + f.write( + f"NUM_MODELS={num_models}\nTIME={end - start}\nMODELS_PER_SECOND={num_models / (end - start)}\nDBT_VERSION={DBT_VERSION}" + ) + else: + raise Exception("Performance DAG run failed.") diff --git a/tests/test_config.py b/tests/test_config.py index 795fcffb6..b93ad2627 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,10 +1,12 @@ from pathlib import Path from unittest.mock import patch from cosmos.profiles.postgres.user_pass import PostgresUserPasswordProfileMapping +from contextlib import nullcontext as does_not_raise import pytest -from cosmos.config import ProfileConfig, ProjectConfig, RenderConfig, CosmosConfigException +from cosmos.constants import ExecutionMode, InvocationMode +from cosmos.config import ExecutionConfig, ProfileConfig, ProjectConfig, RenderConfig, CosmosConfigException from cosmos.exceptions import CosmosValueError @@ -195,3 +197,18 @@ def test_render_config_env_vars_deprecated(): """RenderConfig.env_vars is deprecated since Cosmos 1.3, should warn user.""" with pytest.deprecated_call(): RenderConfig(env_vars={"VAR": "value"}) + + +@pytest.mark.parametrize( + "execution_mode, expectation", + [ + (ExecutionMode.LOCAL, does_not_raise()), + (ExecutionMode.VIRTUALENV, pytest.raises(CosmosValueError)), + (ExecutionMode.KUBERNETES, pytest.raises(CosmosValueError)), + (ExecutionMode.DOCKER, pytest.raises(CosmosValueError)), + (ExecutionMode.AZURE_CONTAINER_INSTANCE, pytest.raises(CosmosValueError)), + ], +) +def test_execution_config_with_invocation_option(execution_mode, expectation): + with expectation: + ExecutionConfig(execution_mode=execution_mode, invocation_mode=InvocationMode.DBT_RUNNER) diff --git a/tests/test_converter.py b/tests/test_converter.py index b0913acae..e66af468f 100644 --- a/tests/test_converter.py +++ b/tests/test_converter.py @@ -7,7 +7,7 @@ from airflow.models import DAG from cosmos.converter import DbtToAirflowConverter, validate_arguments, validate_initial_user_config -from cosmos.constants import DbtResourceType, ExecutionMode, LoadMode +from cosmos.constants import DbtResourceType, ExecutionMode, LoadMode, InvocationMode from cosmos.config import ProjectConfig, ProfileConfig, ExecutionConfig, RenderConfig, CosmosConfigException from cosmos.dbt.graph import DbtNode from cosmos.exceptions import CosmosValueError @@ -438,3 +438,34 @@ def test_converter_multiple_calls_same_operator_args( operator_args=operator_args, ) assert operator_args == original_operator_args + + +@pytest.mark.parametrize("invocation_mode", [None, InvocationMode.SUBPROCESS, InvocationMode.DBT_RUNNER]) +@patch("cosmos.config.ProjectConfig.validate_project") +@patch("cosmos.converter.validate_initial_user_config") +@patch("cosmos.converter.DbtGraph") +@patch("cosmos.converter.build_airflow_graph") +def test_converter_invocation_mode_added_to_task_args( + mock_build_airflow_graph, mock_user_config, mock_dbt_graph, mock_validate_project, invocation_mode +): + """Tests that the `task_args` passed to build_airflow_graph has invocation_mode if it is not None.""" + project_config = ProjectConfig(project_name="fake-project", dbt_project_path="/some/project/path") + execution_config = ExecutionConfig(invocation_mode=invocation_mode) + render_config = MagicMock() + profile_config = MagicMock() + + with DAG("test-id", start_date=datetime(2024, 1, 1)) as dag: + DbtToAirflowConverter( + dag=dag, + nodes=nodes, + project_config=project_config, + profile_config=profile_config, + execution_config=execution_config, + render_config=render_config, + operator_args={}, + ) + _, kwargs = mock_build_airflow_graph.call_args + if invocation_mode: + assert kwargs["task_args"]["invocation_mode"] == invocation_mode + else: + assert "invocation_mode" not in kwargs["task_args"]