Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Airflow Scheduler #335

Open
wants to merge 3 commits into
base: lyft-stable-2.3.4
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
323 changes: 157 additions & 166 deletions airflow/jobs/scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from datetime import timedelta
from typing import Collection, DefaultDict, Dict, Iterator, List, Optional, Set, Tuple

from sqlalchemy import func, not_, or_, text, select
from sqlalchemy import func, and_, or_, text, select, desc
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import load_only, selectinload
from sqlalchemy.orm.session import Session, make_transient
Expand Down Expand Up @@ -253,7 +253,7 @@ def _get_starved_dags(self, session: Session = None) -> Set[str]:
TI.dag_id,
func.count().label('current_active_tasks')
)
.where(TI.state.in_(['running', 'queued']))
.where(TI.state.in_([TaskInstanceState.RUNNING, TaskInstanceState.QUEUED]))
.group_by(TI.dag_id)
.subquery()
)
Expand Down Expand Up @@ -330,198 +330,189 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session =
# dag and task ids that can't be queued because of concurrency limits
starved_dags: Set[str] = self._get_starved_dags(session=session)
starved_tasks: Set[Tuple[str, str]] = set()

pool_num_starving_tasks: DefaultDict[str, int] = defaultdict(int)

# Subquery to get the current active task count for each DAG
# Only considering tasks from running DAG runs
current_active_tasks = (
session.query(
TI.dag_id,
func.count().label('active_count')
)
.join(DR, and_(DR.dag_id == TI.dag_id, DR.run_id == TI.run_id))
.filter(DR.state == DagRunState.RUNNING)
.filter(TI.state.in_([TaskInstanceState.RUNNING, TaskInstanceState.QUEUED]))
.group_by(TI.dag_id)
.subquery()
)

for loop_count in itertools.count(start=1):

num_starved_pools = len(starved_pools)
num_starved_dags = len(starved_dags)
num_starved_tasks = len(starved_tasks)

# Get task instances associated with scheduled
# DagRuns which are not backfilled, in the given states,
# and the dag is not paused
query = (
session.query(TI)
.with_hint(TI, 'USE INDEX (ti_state)', dialect_name='mysql')
.join(TI.dag_run)
.filter(DR.run_type != DagRunType.BACKFILL_JOB, DR.state == DagRunState.RUNNING)
.join(TI.dag_model)
.filter(not_(DM.is_paused))
.filter(TI.state == TaskInstanceState.SCHEDULED)
.options(selectinload('dag_model'))
.order_by(-TI.priority_weight, DR.execution_date)
# Get the limit for each DAG
dag_limit_subquery = (
session.query(
DM.dag_id,
func.greatest(DM.max_active_tasks - func.coalesce(current_active_tasks.c.active_count, 0), 0).label('dag_limit')
)
.outerjoin(current_active_tasks, DM.dag_id == current_active_tasks.c.dag_id)
.subquery()
)

if starved_pools:
query = query.filter(not_(TI.pool.in_(starved_pools)))
# Subquery to rank tasks within each DAG
ranked_tis = (
session.query(
TI,
func.row_number().over(
partition_by=TI.dag_id,
order_by=[desc(TI.priority_weight), TI.start_date]
).label('row_number'),
dag_limit_subquery.c.dag_limit
)
.join(TI.dag_run)
.join(DM, TI.dag_id == DM.dag_id)
.join(dag_limit_subquery, TI.dag_id == dag_limit_subquery.c.dag_id)
.filter(
DR.state == DagRunState.RUNNING,
DR.run_type != DagRunType.BACKFILL_JOB,
~DM.is_paused,
~TI.dag_id.in_(starved_dags),
~TI.pool.in_(starved_pools),
TI.state == TaskInstanceState.SCHEDULED,
)
).subquery()

if starved_tasks:
ranked_tis = ranked_tis.filter(
~func.concat(TI.dag_id, TI.task_id).in_([f"{dag_id}{task_id}" for dag_id, task_id in starved_tasks])
)

final_query = (
session.query(TI)
.join(
ranked_tis,
and_(
TI.task_id == ranked_tis.c.task_id,
TI.dag_id == ranked_tis.c.dag_id,
TI.run_id == ranked_tis.c.run_id
)
)
.filter(ranked_tis.c.row_number <= ranked_tis.c.dag_limit)
.order_by(desc(ranked_tis.c.priority_weight), ranked_tis.c.start_date)
.limit(max_tis)
)

# Execute the query with row locks
task_instances_to_examine: List[TI] = with_row_locks(
final_query,
of=TI,
session=session,
**skip_locked(session=session),
).all()

if starved_dags:
query = query.filter(not_(TI.dag_id.in_(starved_dags)))

if len(task_instances_to_examine) == 0:
self.log.debug("No tasks to consider for execution.")
return []
# else:
# print("---dag_limit_subquery")
# print(str(dag_limit_subquery.select().compile(compile_kwargs={"literal_binds": True})))
# print("---ranked_tis-query")
# print(str(ranked_tis.select().compile(compile_kwargs={"literal_binds": True})))
# print("---FINAL QUERY")
# print(str(final_query.statement.compile(compile_kwargs={"literal_binds": True})))

# Put one task instance on each line
task_instance_str = "\n\t".join(repr(x) for x in task_instances_to_examine)
self.log.info(
"%s tasks up for execution:\n\t%s", len(task_instances_to_examine), task_instance_str
)

pool_slot_tracker = {pool_name: stats['open'] for pool_name, stats in pools.items()}

if starved_tasks:
task_filter = tuple_in_condition((TaskInstance.dag_id, TaskInstance.task_id), starved_tasks)
query = query.filter(not_(task_filter))
for task_instance in task_instances_to_examine:
pool_name = task_instance.pool

query = query.limit(max_tis)
pool_stats = pools.get(pool_name)
if not pool_stats:
self.log.warning("Tasks using non-existent pool '%s' will not be scheduled", pool_name)
starved_pools.add(pool_name)
continue

task_instances_to_examine: List[TI] = with_row_locks(
query,
of=TI,
session=session,
**skip_locked(session=session),
).all()
# TODO[HA]: This was wrong before anyway, as it only looked at a sub-set of dags, not everything.
# Stats.gauge('scheduler.tasks.pending', len(task_instances_to_examine))

# # Make sure to emit metrics if pool has no starving tasks
# # pool_num_starving_tasks.setdefault(pool_name, 0)
# pool_total = pool_stats["total"]
open_slots = pool_stats["open"]

if len(task_instances_to_examine) == 0:
self.log.debug("No tasks to consider for execution.")
break
# Check to make sure that the task max_active_tasks of the DAG hasn't been
# reached.
# This shoulnd't happen anymore but still leaving it here for debugging purposes
dag_id = task_instance.dag_id

# Put one task instance on each line
task_instance_str = "\n\t".join(repr(x) for x in task_instances_to_examine)
current_active_tasks_per_dag = dag_active_tasks_map[dag_id]
max_active_tasks_per_dag_limit = task_instance.dag_model.max_active_tasks
self.log.info(
"%s tasks up for execution:\n\t%s", len(task_instances_to_examine), task_instance_str
"DAG %s has %s/%s running and queued tasks",
dag_id,
current_active_tasks_per_dag,
max_active_tasks_per_dag_limit,
)

for task_instance in task_instances_to_examine:
pool_name = task_instance.pool

pool_stats = pools.get(pool_name)
if not pool_stats:
self.log.warning("Tasks using non-existent pool '%s' will not be scheduled", pool_name)
starved_pools.add(pool_name)
continue

# Make sure to emit metrics if pool has no starving tasks
pool_num_starving_tasks.setdefault(pool_name, 0)

pool_total = pool_stats["total"]
open_slots = pool_stats["open"]

if open_slots <= 0:
self.log.info(
"Not scheduling since there are %s open slots in pool %s", open_slots, pool_name
)
# Can't schedule any more since there are no more open slots.
pool_num_starving_tasks[pool_name] += 1
num_starving_tasks_total += 1
starved_pools.add(pool_name)
continue

if task_instance.pool_slots > pool_total:
self.log.warning(
"Not executing %s. Requested pool slots (%s) are greater than "
"total pool slots: '%s' for pool: %s.",
task_instance,
task_instance.pool_slots,
pool_total,
pool_name,
)

pool_num_starving_tasks[pool_name] += 1
num_starving_tasks_total += 1
starved_tasks.add((task_instance.dag_id, task_instance.task_id))
continue

if task_instance.pool_slots > open_slots:
self.log.info(
"Not executing %s since it requires %s slots "
"but there are %s open slots in the pool %s.",
task_instance,
task_instance.pool_slots,
open_slots,
pool_name,
)
pool_num_starving_tasks[pool_name] += 1
num_starving_tasks_total += 1
starved_tasks.add((task_instance.dag_id, task_instance.task_id))
# Though we can execute tasks with lower priority if there's enough room
continue

# Check to make sure that the task max_active_tasks of the DAG hasn't been
# reached.
dag_id = task_instance.dag_id

current_active_tasks_per_dag = dag_active_tasks_map[dag_id]
max_active_tasks_per_dag_limit = task_instance.dag_model.max_active_tasks
if current_active_tasks_per_dag >= max_active_tasks_per_dag_limit:
self.log.info(
"DAG %s has %s/%s running and queued tasks",
"Not executing %s since the number of tasks running or queued "
"from DAG %s is >= to the DAG's max_active_tasks limit of %s",
task_instance,
dag_id,
current_active_tasks_per_dag,
max_active_tasks_per_dag_limit,
)
if current_active_tasks_per_dag >= max_active_tasks_per_dag_limit:
self.log.info(
"Not executing %s since the number of tasks running or queued "
"from DAG %s is >= to the DAG's max_active_tasks limit of %s",
task_instance,
starved_dags.add(dag_id)

if task_instance.dag_model.has_task_concurrency_limits:
# Many dags don't have a task_concurrency, so where we can avoid loading the full
# serialized DAG the better.
serialized_dag = self.dagbag.get_dag(dag_id, session=session)
# If the dag is missing, fail the task and continue to the next task.
if not serialized_dag:
self.log.error(
"DAG '%s' for task instance %s not found in serialized_dag table",
dag_id,
max_active_tasks_per_dag_limit,
task_instance,
)
starved_dags.add(dag_id)
continue
session.query(TI).filter(TI.dag_id == dag_id, TI.state == State.SCHEDULED).update(
{TI.state: State.FAILED}, synchronize_session='fetch'
)
# continue

task_concurrency_limit: Optional[int] = None
if serialized_dag.has_task(task_instance.task_id):
task_concurrency_limit = serialized_dag.get_task(
task_instance.task_id
).max_active_tis_per_dag

if task_instance.dag_model.has_task_concurrency_limits:
# Many dags don't have a task_concurrency, so where we can avoid loading the full
# serialized DAG the better.
serialized_dag = self.dagbag.get_dag(dag_id, session=session)
# If the dag is missing, fail the task and continue to the next task.
if not serialized_dag:
self.log.error(
"DAG '%s' for task instance %s not found in serialized_dag table",
dag_id,
if task_concurrency_limit is not None:
current_task_concurrency = task_concurrency_map[
(task_instance.dag_id, task_instance.task_id)
]

if current_task_concurrency >= task_concurrency_limit:
self.log.info(
"Not executing %s since the task concurrency for"
" this task has been reached.",
task_instance,
)
session.query(TI).filter(TI.dag_id == dag_id, TI.state == State.SCHEDULED).update(
{TI.state: State.FAILED}, synchronize_session='fetch'
)
starved_tasks.add((task_instance.dag_id, task_instance.task_id))
continue

task_concurrency_limit: Optional[int] = None
if serialized_dag.has_task(task_instance.task_id):
task_concurrency_limit = serialized_dag.get_task(
task_instance.task_id
).max_active_tis_per_dag

if task_concurrency_limit is not None:
current_task_concurrency = task_concurrency_map[
(task_instance.dag_id, task_instance.task_id)
]

if current_task_concurrency >= task_concurrency_limit:
self.log.info(
"Not executing %s since the task concurrency for"
" this task has been reached.",
task_instance,
)
starved_tasks.add((task_instance.dag_id, task_instance.task_id))
continue


# Check pool-specific slot availability
if (pool_slot_tracker.get(pool_name, 0) >= task_instance.pool_slots):
executable_tis.append(task_instance)
open_slots -= task_instance.pool_slots
dag_active_tasks_map[dag_id] += 1
task_concurrency_map[(task_instance.dag_id, task_instance.task_id)] += 1

pool_stats["open"] = open_slots
else:
starved_tasks.add((task_instance.dag_id, task_instance.task_id))
pool_num_starving_tasks[pool_name] += 1
num_starving_tasks_total += 1

is_done = executable_tis or len(task_instances_to_examine) < max_tis
# Check this to avoid accidental infinite loops
found_new_filters = (
len(starved_pools) > num_starved_pools
or len(starved_dags) > num_starved_dags
or len(starved_tasks) > num_starved_tasks
)

if is_done or not found_new_filters:
break

self.log.debug(
"Found no task instances to queue on the %s. iteration "
"but there could be more candidate task instances to check.",
loop_count,
)

for pool_name, num_starving_tasks in pool_num_starving_tasks.items():
Stats.gauge(f'pool.starving_tasks.{pool_name}', num_starving_tasks)
Expand Down
5 changes: 3 additions & 2 deletions airflow/www/static/js/ti_log.js
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ function autoTailingLog(tryNumber, metadata = null, autoTailing = false) {
}
recurse().then(() => autoTailingLog(tryNumber, res.metadata, autoTailing));
}).catch((error) => {
console.error(`Error while retrieving log: ${error}`);
console.error(`Error while retrieving log`, error);

const externalLogUrl = getMetaValue('external_log_url');
const fullExternalUrl = `${externalLogUrl
Expand All @@ -151,7 +151,7 @@ function autoTailingLog(tryNumber, metadata = null, autoTailing = false) {

document.getElementById(`loading-${tryNumber}`).style.display = 'none';

const logBlockElementId = `try-${tryNumber}-${item[0]}`;
const logBlockElementId = `try-${tryNumber}-error`;
let logBlock = document.getElementById(logBlockElementId);
if (!logBlock) {
const logDivBlock = document.createElement('div');
Expand All @@ -164,6 +164,7 @@ function autoTailingLog(tryNumber, metadata = null, autoTailing = false) {

logBlock.innerHTML += "There was an error while retrieving the log from S3. Please use Kibana to view the logs.";
logBlock.innerHTML += `<a href="${fullExternalUrl}" target="_blank">View logs in Kibana</a>`;

});
}

Expand Down
Loading