Skip to content

Commit

Permalink
Rework self.
Browse files Browse the repository at this point in the history
  • Loading branch information
milesj committed Sep 12, 2023
1 parent a8ed6ba commit c2aa8e1
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 88 deletions.
144 changes: 64 additions & 80 deletions nextgen/pipeline/src/job.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::time::{Duration, Instant};
use tokio::task::JoinHandle;
use tokio::time::{sleep, timeout};
use tokio_util::sync::CancellationToken;
use tracing::{debug, trace};
use tracing::{debug, trace, warn};

#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
Expand All @@ -26,31 +26,29 @@ pub struct JobResult<T> {
pub struct Job<T: Send> {
pub batch_id: Option<String>,
pub id: String,
pub state: RunState,

/// Maximum seconds to run before it's cancelled.
pub timeout: Option<u64>,

/// Seconds to emit progress events on an interval.
pub interval: Option<u64>,

action: Option<Box<dyn JobAction<T>>>,
action: Box<dyn JobAction<T>>,
}

impl<T: 'static + Send> Job<T> {
pub fn new(id: String, action: impl JobAction<T> + 'static) -> Self {
Self {
action: Some(Box::new(action)),
action: Box::new(action),
batch_id: None,
id,
state: RunState::Pending,
timeout: None,
interval: Some(30),
}
}

pub async fn run(&mut self, context: Context<T>) -> miette::Result<RunState> {
let action_fn = self.action.take().expect("Missing job action!");
pub async fn run(self, context: Context<T>) -> miette::Result<RunState> {
let action_fn = self.action;

debug!(
batch = self.batch_id.as_ref(),
Expand All @@ -64,11 +62,18 @@ impl<T: 'static + Send> Job<T> {
let mut error = None;
let mut error_report = None;

self.update_state(&context, RunState::Running).await?;
context
.on_job_state_change
.emit(JobStateChangeEvent {
job: self.id.clone(),
state: RunState::Running,
prev_state: RunState::Pending,
})
.await?;

let timeout_token = CancellationToken::new();
let timeout_handle = self.track_timeout(timeout_token.clone());
let progress_handle = self.track_progress(context.clone());
let timeout_handle = track_timeout(self.timeout, timeout_token.clone());
let progress_handle = track_progress(self.interval, context.clone(), self.id.clone());

let final_state = tokio::select! {
// Abort if a sibling job has failed
Expand Down Expand Up @@ -133,7 +138,14 @@ impl<T: 'static + Send> Job<T> {
},
};

self.update_state(&context, final_state).await?;
context
.on_job_state_change
.emit(JobStateChangeEvent {
job: self.id.clone(),
state: final_state,
prev_state: RunState::Running,
})
.await?;

timeout_handle.abort();
progress_handle.abort();
Expand All @@ -146,7 +158,7 @@ impl<T: 'static + Send> Job<T> {
finished_at: Utc::now(),
id: self.id.clone(),
started_at,
state: self.state,
state: final_state,
};

debug!(
Expand All @@ -157,83 +169,55 @@ impl<T: 'static + Send> Job<T> {
"Ran job",
);

// context
// .on_job_finished
// .emit(JobFinishedEvent {
// job: id.clone(),
// result: result.clone(),
// })
// .await?;

// Send the result or abort pipeline on failure
if context.result_sender.send(result).await.is_err() {
context.abort_token.cancel();
}

Ok(self.state)
}

async fn update_state(
&mut self,
context: &Context<T>,
next_state: RunState,
) -> miette::Result<()> {
let prev_state = self.state;
let state = next_state;

context
.on_job_state_change
.emit(JobStateChangeEvent {
job: self.id.clone(),
state,
prev_state,
})
.await?;

self.state = state;

Ok(())
Ok(final_state)
}
}

fn track_progress(&self, context: Context<T>) -> JoinHandle<()> {
let duration = self.interval;
let id = self.id.clone();

tokio::spawn(async move {
if let Some(duration) = duration {
let mut secs = 0;

loop {
sleep(Duration::from_secs(duration)).await;
secs += duration;

let _ = context
.on_job_progress
.emit(JobProgressEvent {
job: id.clone(),
elapsed: secs as u32,
})
.await;
fn track_progress<T>(duration: Option<u64>, context: Context<T>, id: String) -> JoinHandle<()> {
tokio::spawn(async move {
if let Some(duration) = duration {
let mut secs = 0;

loop {
sleep(Duration::from_secs(duration)).await;
secs += duration;

if let Err(error) = context
.on_job_progress
.emit(JobProgressEvent {
job: id.clone(),
elapsed: secs as u32,
})
.await
{
warn!(
job = &id,
error = error.to_string(),
"Failed to emit job progress update event!"
);
}
}
})
}
}
})
}

fn track_timeout(&self, timeout_token: CancellationToken) -> JoinHandle<()> {
let duration = self.timeout;

tokio::spawn(async move {
if let Some(duration) = duration {
if timeout(
Duration::from_secs(duration),
sleep(Duration::from_secs(86400)), // 1 day
)
.await
.is_err()
{
timeout_token.cancel();
}
fn track_timeout(duration: Option<u64>, timeout_token: CancellationToken) -> JoinHandle<()> {
tokio::spawn(async move {
if let Some(duration) = duration {
if timeout(
Duration::from_secs(duration),
sleep(Duration::from_secs(86400)), // 1 day
)
.await
.is_err()
{
timeout_token.cancel();
}
})
}
}
})
}
13 changes: 5 additions & 8 deletions nextgen/pipeline/src/step.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@ use async_trait::async_trait;
use tokio::task::JoinHandle;
use tracing::debug;

async fn spawn_job<T: 'static + Send>(
mut job: Job<T>,
context: Context<T>,
) -> JoinHandle<RunState> {
async fn spawn_job<T: 'static + Send>(job: Job<T>, context: Context<T>) -> JoinHandle<RunState> {
let permit = context
.semaphore
.clone()
Expand Down Expand Up @@ -91,16 +88,16 @@ impl<T: 'static + Send> Step<T> for BatchedStep<T> {
batch.push(spawn_job(job, context.clone()).await);
}

for job in batch {
if job.is_finished() {
for handle in batch {
if handle.is_finished() {
continue;
}

if context.abort_token.is_cancelled() {
job.abort();
handle.abort();
}

if let Err(error) = job.await {
if let Err(error) = handle.await {
fail_count += 1;

if !error.is_cancelled() || error.is_panic() {
Expand Down

0 comments on commit c2aa8e1

Please sign in to comment.