diff --git a/Cargo.lock b/Cargo.lock index 328b86a..0b9bcaf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1531,6 +1531,8 @@ dependencies = [ "pin-utils", "serde", "serde_json", + "signal-hook", + "signal-hook-tokio", "streamstore", "thiserror", "tokio", @@ -1622,6 +1624,16 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "signal-hook" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8621587d4798caf8eb44879d42e56b9a93ea5dcd315a6487c357130095b62801" +dependencies = [ + "libc", + "signal-hook-registry", +] + [[package]] name = "signal-hook-registry" version = "1.4.2" @@ -1631,6 +1643,18 @@ dependencies = [ "libc", ] +[[package]] +name = "signal-hook-tokio" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "213241f76fb1e37e27de3b6aa1b068a2c333233b59cca6634f634b80a27ecf1e" +dependencies = [ + "futures-core", + "libc", + "signal-hook", + "tokio", +] + [[package]] name = "slab" version = "0.4.9" diff --git a/Cargo.toml b/Cargo.toml index 4a30c85..c2c020e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,8 @@ pin-project-lite = "0.2" pin-utils = "0.1.0" serde = { version = "1.0.214", features = ["derive"] } serde_json = "1.0.132" +signal-hook = "0.3.17" +signal-hook-tokio = { version = "0.3.1", features = ["futures-v0_3"] } streamstore = { git = "https://github.com/s2-streamstore/s2-sdk-rust.git", rev = "63b4964b66503f705e7c73ae07ba47f81019b79a" } thiserror = "1.0.67" tokio = { version = "*", features = ["full"] } diff --git a/src/main.rs b/src/main.rs index b81227a..7126739 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,6 +6,8 @@ use clap::{builder::styling, Parser, Subcommand}; use colored::*; use config::{config_path, create_config}; use error::{S2CliError, ServiceError, ServiceErrorContext}; +use signal_hook::consts::{SIGINT, SIGTERM, SIGTSTP}; +use signal_hook_tokio::Signals; use stream::{RecordStream, StreamService}; use streamstore::{ bytesize::ByteSize, @@ -16,6 +18,7 @@ use streamstore::{ use tokio::{ fs::{File, OpenOptions}, io::{self, AsyncBufRead, AsyncBufReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter}, + select, time::Instant, }; use tokio_stream::StreamExt; @@ -501,27 +504,51 @@ async fn run() -> Result<(), S2CliError> { .lines(), ); + let mut signals = + Signals::new([SIGTSTP, SIGINT, SIGTERM]).expect("valid signals"); + let mut append_output_stream = StreamService::new(stream_client) .append_session(append_input_stream) .await?; - while let Some(append_result) = append_output_stream.next().await { - append_result - .map(|append_result| { - eprintln!( - "{}", - format!( - "✓ [APPENDED] start: {}, end: {}, next: {}", - append_result.start_seq_num, - append_result.end_seq_num, - append_result.next_seq_num - ) - .green() - .bold() - ); - }) - .map_err(|e| { - ServiceError::new(ServiceErrorContext::AppendSession, e) - })?; + loop { + select! { + maybe_append_result = append_output_stream.next() => { + match maybe_append_result { + Some(append_result) => { + match append_result { + Ok(append_result) => { + eprintln!( + "{}", + format!( + "✓ [APPENDED] start: {}, end: {}, next: {}", + append_result.start_seq_num, + append_result.end_seq_num, + append_result.next_seq_num + ) + .green() + .bold() + ); + }, + Err(e) => { + return Err(ServiceError::new(ServiceErrorContext::AppendSession, e).into()); + } + } + } + None => break, + } + } + + Some(signal) = signals.next() => { + match signal { + SIGTSTP | SIGINT | SIGTERM => { + drop(append_output_stream); + eprintln!("{}", "■ [ABORTED]".red().bold()); + break; + } + _ => {} + } + } + } } } StreamActions::Read { @@ -531,6 +558,8 @@ async fn run() -> Result<(), S2CliError> { limit_bytes, } => { let stream_client = StreamClient::new(client_config, basin, stream); + let mut signals = + Signals::new([SIGTSTP, SIGINT, SIGTERM]).expect("valid signals"); let mut read_output_stream = StreamService::new(stream_client) .read_session(start_seq_num, limit_count, limit_bytes) .await?; @@ -539,65 +568,84 @@ async fn run() -> Result<(), S2CliError> { let mut start = None; let mut total_data_len = ByteSize::b(0); - while let Some(read_result) = read_output_stream.next().await { - if start.is_none() { - start = Some(Instant::now()); - } - - let read_result = read_result - .map_err(|e| ServiceError::new(ServiceErrorContext::ReadSession, e))?; - - match read_result { - ReadOutput::Batch(sequenced_record_batch) => { - let num_records = sequenced_record_batch.records.len(); - let mut batch_len = ByteSize::b(0); - - let seq_range = match ( - sequenced_record_batch.records.first(), - sequenced_record_batch.records.last(), - ) { - (Some(first), Some(last)) => first.seq_num..=last.seq_num, - _ => panic!("empty batch"), - }; - for sequenced_record in sequenced_record_batch.records { - let data = &sequenced_record.body; - batch_len += sequenced_record.metered_size(); - - writer - .write_all(data) - .await - .map_err(|e| S2CliError::RecordWrite(e.to_string()))?; - writer - .write_all(b"\n") - .await - .map_err(|e| S2CliError::RecordWrite(e.to_string()))?; + loop { + select! { + maybe_read_result = read_output_stream.next() => { + match maybe_read_result { + Some(read_result) => { + if start.is_none() { + start = Some(Instant::now()); + } + match read_result { + Ok(ReadOutput::Batch(sequenced_record_batch)) => { + let num_records = sequenced_record_batch.records.len(); + let mut batch_len = ByteSize::b(0); + + let seq_range = match ( + sequenced_record_batch.records.first(), + sequenced_record_batch.records.last(), + ) { + (Some(first), Some(last)) => first.seq_num..=last.seq_num, + _ => panic!("empty batch"), + }; + for sequenced_record in sequenced_record_batch.records { + let data = &sequenced_record.body; + batch_len += sequenced_record.metered_size(); + + writer + .write_all(data) + .await + .map_err(|e| S2CliError::RecordWrite(e.to_string()))?; + writer + .write_all(b"\n") + .await + .map_err(|e| S2CliError::RecordWrite(e.to_string()))?; + } + total_data_len += batch_len; + + let throughput_mibps = (total_data_len.0 as f64 + / start.unwrap().elapsed().as_secs_f64()) + / 1024.0 + / 1024.0; + + eprintln!( + "{}", + format!( + "⦿ {throughput_mibps:.2} MiB/s \ + ({num_records} records in range {seq_range:?})", + ) + .blue() + .bold() + ); + } + + Ok(ReadOutput::FirstSeqNum(seq_num)) => { + eprintln!("{}", format!("first_seq_num: {seq_num}").blue().bold()); + } + + Ok(ReadOutput::NextSeqNum(seq_num)) => { + eprintln!("{}", format!("next_seq_num: {seq_num}").blue().bold()); + } + + Err(e) => { + return Err(ServiceError::new(ServiceErrorContext::ReadSession, e).into()); + } + } + } + None => break, + } + }, + Some(signal) = signals.next() => { + match signal { + SIGTSTP | SIGINT | SIGTERM => { + drop(read_output_stream); + eprintln!("{}", "■ [ABORTED]".red().bold()); + break; + } + _ => {} } - total_data_len += batch_len; - - let throughput_mibps = (total_data_len.0 as f64 - / start.unwrap().elapsed().as_secs_f64()) - / 1024.0 - / 1024.0; - - eprintln!( - "{}", - format!( - "⦿ {throughput_mibps:.2} MiB/s \ - ({num_records} records in range {seq_range:?})", - ) - .blue() - .bold() - ); - } - // TODO: better message for these cases - ReadOutput::FirstSeqNum(seq_num) => { - eprintln!("{}", format!("first_seq_num: {seq_num}").blue().bold()); - } - ReadOutput::NextSeqNum(seq_num) => { - eprintln!("{}", format!("next_seq_num: {seq_num}").blue().bold()); } } - let total_elapsed_time = start.unwrap().elapsed().as_secs_f64(); let total_throughput_mibps = @@ -620,5 +668,5 @@ async fn run() -> Result<(), S2CliError> { } } } - Ok(()) + std::process::exit(0); }