diff --git a/.gitignore b/.gitignore index 3a8cabc..945a65e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /target .idea +**/.DS_Store \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b6c682..32667e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.2.3](https://github.com/jBernavaPrah/azure-speech-sdk-rs/compare/v0.2.2...v0.2.3) - 2024-08-16 + +### Other +- Improve documentation ([#10](https://github.com/jBernavaPrah/azure-speech-sdk-rs/pull/10)) + ## [0.2.2](https://github.com/jBernavaPrah/azure-speech-sdk-rs/compare/v0.2.1...v0.2.2) - 2024-08-16 ### Other diff --git a/Cargo.lock b/Cargo.lock index 2c19a43..21f992b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -100,7 +100,7 @@ checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" [[package]] name = "azure-speech" -version = "0.2.2" +version = "0.2.3" dependencies = [ "async-channel", "async-trait", diff --git a/Cargo.toml b/Cargo.toml index 38b6322..18cbb38 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,8 +1,8 @@ [package] name = "azure-speech" -version = "0.2.2" +version = "0.2.3" authors = ["Jure Bernava Prah "] -description = "Pure Rust implementation for Microsoft Speech Service" +description = "Pure Rust SDK for Azure Speech Service" edition = "2021" rust-version = "1.71.0" license = "MIT" @@ -41,7 +41,7 @@ serde_json = "1.0.114" os_info = "3" ssml = "0.1.0" -async-channel = "1.9.0" +async-channel = "1.9.0" # needed for ezsockets 0.6 for call_with; [dev-dependencies] diff --git a/examples/recognize_callbacks.rs b/examples/recognize_callbacks.rs new file mode 100644 index 0000000..c7eae59 --- /dev/null +++ b/examples/recognize_callbacks.rs @@ -0,0 +1,79 @@ +use azure_speech::stream::Stream; +use azure_speech::Auth; +use azure_speech::{recognizer, StreamExt}; +use std::env; +use std::error::Error; +use std::path::Path; +use tokio::fs::File; +use tokio::io::{AsyncReadExt, BufReader}; +use tokio_stream::wrappers::ReceiverStream; + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + .init(); + + // Check on the example recognize_simple.rs for more details on how to set the recognizer. + let auth = Auth::from_subscription( + env::var("AZURE_REGION").expect("Region set on AZURE_REGION env"), + env::var("AZURE_SUBSCRIPTION_KEY").expect("Subscription set on AZURE_SUBSCRIPTION_KEY env"), + ); + let config = recognizer::Config::default(); + + let client = recognizer::Client::connect(auth, config) + .await + .expect("to connect to azure"); + + // Create a callbacks for the recognizer. + // The callbacks are used to get information about the recognition process. + let callbacks = recognizer::Callback::default() + .on_start_detected(|id, offset| async move { + tracing::info!("Start detected: {:?} - {:?}", id, offset); + }) + .on_recognized(|id, result, _offset, _duration, _raw| async move { + tracing::info!("Recognized: {:?} - {:?}", id, result); + }) + .on_session_end(|id| async move { + tracing::info!("Session end: {:?}", id); + }); + //.on_... // check the other callbacks available. + + client + .recognize( + create_audio_stream("tests/audios/examples_sample_files_turn_on_the_lamp.wav").await, // Try also the mp3 version of the file. + recognizer::ContentType::Wav, // Be sure to set it correctly. + recognizer::Details::file(), + ) + .await + .expect("to recognize") + // When you set the callbacks, the events will be sent to the callbacks and not to the stream. + .use_callbacks(callbacks) + .await; // it's important to await here. + + tracing::info!("Completed!"); + + Ok(()) +} + +async fn create_audio_stream(path: impl AsRef) -> impl Stream> { + let (tx, rx) = tokio::sync::mpsc::channel(1024); + let file = File::open(path).await.expect("Failed to open file"); + let mut reader = BufReader::new(file); + + tokio::spawn(async move { + let mut chunk = vec![0; 4096]; + while let Ok(n) = reader.read(&mut chunk).await { + if n == 0 { + break; + } + if tx.send(chunk.clone()).await.is_err() { + tracing::error!("Error sending data"); + break; + } + } + drop(tx); + }); + + ReceiverStream::new(rx) +} diff --git a/examples/recognize_from_bbc_word_radio.rs b/examples/recognize_from_bbc_word_radio.rs index 88ae3f1..42de04a 100644 --- a/examples/recognize_from_bbc_word_radio.rs +++ b/examples/recognize_from_bbc_word_radio.rs @@ -15,19 +15,20 @@ async fn main() { let client = recognizer::Client::connect( auth, - recognizer::Config::default().set_detect_languages( - vec![recognizer::Language::EnGb], - recognizer::LanguageDetectMode::Continuous, - ), + recognizer::Config::default() + // The BBC World Service stream is in English. + .set_language(recognizer::Language::EnGb), ) .await .expect("Failed to connect to Azure"); let mut events = client .recognize( + // The BBC World Service stream is a good example to test the recognizer. create_audio_stream("https://stream.live.vc.bbcmedia.co.uk/bbc_world_service").await, + // The content type is MPEG. recognizer::ContentType::Mpeg, - recognizer::Details::stream("mac", "stream"), + recognizer::Details::stream("unknown", "stream"), ) .await .expect("Failed to recognize"); diff --git a/examples/recognize_from_microphone.rs b/examples/recognize_from_microphone.rs index f20fb3e..2aa5abd 100644 --- a/examples/recognize_from_microphone.rs +++ b/examples/recognize_from_microphone.rs @@ -13,7 +13,7 @@ async fn main() -> Result<(), Box> { .with_max_level(tracing::Level::INFO) .init(); - // More information on the configuration can be found in the simple example. + // More information on the configuration can be found in the examples/recognize_simple.rs example. let auth = Auth::from_subscription( env::var("AZURE_REGION").expect("Region set on AZURE_REGION env"), @@ -31,6 +31,7 @@ async fn main() -> Result<(), Box> { // As the audio is raw, the WAV format is used. let (stream, microphone) = listen_from_default_input().await; + // Start the microphone. microphone.play().expect("play failed"); let mut events = client diff --git a/examples/recognize_simple.rs b/examples/recognize_simple.rs index ce1349e..7afb5d8 100644 --- a/examples/recognize_simple.rs +++ b/examples/recognize_simple.rs @@ -33,6 +33,8 @@ async fn main() -> Result<(), Box> { .await .expect("to connect to azure"); + // Here we are streaming the events from the synthesizer. + // But you can also use the callbacks (see: examples/recognize_callbacks.rs) if you prefer. let mut stream = client .recognize( // Here is your input audio stream. The audio headers needs to be present if required by the content type used. @@ -65,8 +67,9 @@ async fn main() -> Result<(), Box> { tracing::info!("Result: {:?}", result); tracing::info!("Offset: {:?}", offset); tracing::info!("Duration: {:?}", duration); - // the raw message is the raw json message from the service. - // You can use it to extract more information if needed. + + // the raw message is the json message received from the service. + // You can use it to extract more information when needed. tracing::info!("Raw message: {:?}", raw_message); } _ => { diff --git a/examples/synthesize_callbacks.rs b/examples/synthesize_callbacks.rs index a28fb96..2db7836 100644 --- a/examples/synthesize_callbacks.rs +++ b/examples/synthesize_callbacks.rs @@ -1,7 +1,6 @@ -use azure_speech::{synthesizer, Auth}; +use azure_speech::{synthesizer, Auth, StreamExt}; use std::env; use std::error::Error; -use tokio_stream::StreamExt; #[tokio::main] async fn main() -> Result<(), Box> { @@ -9,53 +8,55 @@ async fn main() -> Result<(), Box> { .with_max_level(tracing::Level::INFO) .init(); + // Check the examples/synthesize_simple.rs file for the full code. + let auth = Auth::from_subscription( env::var("AZURE_REGION").expect("Region set on AZURE_REGION env"), env::var("AZURE_SUBSCRIPTION_KEY").expect("Subscription set on AZURE_SUBSCRIPTION_KEY env"), ); - let config = synthesizer::Config::default() - .on_synthesising(|request_id, audio| { + let config = synthesizer::Config::default(); + let client = synthesizer::Client::connect(auth, config) + .await + .expect("to connect to azure"); + + // Create the callbacks for the synthesizer. + let callbacks = synthesizer::Callback::default() + .on_synthesising(|request_id, audio| async move { tracing::info!( "Callback - request: {:?}: Synthesising bytes {:?} ", request_id, audio.len() ); }) - .on_synthesised(|request_id| { + .on_synthesised(|request_id| async move { tracing::info!("Callback - request: {:?}: Synthesised", request_id); }) - .on_error(|request_id, error| { - tracing::info!("Callback - request: {:?}: Error {:?}", request_id, error); - }) - .on_audio_metadata(|request_id, metadata| { + .on_audio_metadata(|request_id, metadata| async move { tracing::info!( "Callback - request: {:?}: Audio metadata {:?}", request_id, metadata ); }) - .on_session_start(|request_id| { + .on_session_start(|request_id| async move { tracing::info!("Callback - request: {:?}: Session started", request_id); }) - .on_session_end(|request_id| { + .on_session_end(|request_id| async move { tracing::info!("Callback - request: {:?}: Session ended", request_id); + }) + .on_error(|request_id, error| async move { + tracing::info!("Callback - request: {:?}: Error {:?}", request_id, error); }); - let client = synthesizer::Client::connect(auth, config) - .await - .expect("to connect to azure"); - // you can use both the stream and callback in the same functions. - let mut stream = client + client // here you put your text to synthesize. .synthesize("Hello World!") .await - .expect("to synthesize"); - - while let Some(event) = stream.next().await { - tracing::info!("Synthesizer Event: {:?}", event); - } + .expect("to synthesize") + .use_callbacks(callbacks) + .await; Ok(()) } diff --git a/examples/synthesize_simple.rs b/examples/synthesize_simple.rs index e366e50..c67cd81 100644 --- a/examples/synthesize_simple.rs +++ b/examples/synthesize_simple.rs @@ -25,6 +25,8 @@ async fn main() -> Result<(), Box> { // It will understand the en-US language and will use the EnUsJennyNeural voice. // You can change it by using the Config struct and its methods. let config = synthesizer::Config::default(); + //.with_language(synthesizer::Language::EnGb) + //.with_voice(synthesizer::Voice::EnGbLibbyNeural) let client = synthesizer::Client::connect(auth, config) .await @@ -36,6 +38,8 @@ async fn main() -> Result<(), Box> { .await .expect("to synthesize"); + // Here we are streaming the events from the synthesizer. + // But you can also use the callbacks (see: examples/synthesize_callbacks.rs) if you prefer. while let Some(event) = stream.next().await { // Each event is a part of the synthesis process. match event { diff --git a/examples/synthesize_using_ssml.rs b/examples/synthesize_using_ssml.rs index 42a3e8c..e90a196 100644 --- a/examples/synthesize_using_ssml.rs +++ b/examples/synthesize_using_ssml.rs @@ -39,7 +39,7 @@ async fn main() -> Result<(), Box> { // this will print a lot of events to the console. // you can use the events to create your own audio output. - // check other examples to see how to create an audio output. + // check examples/synthesize_to_standard_output.rs to see how to create an audio output. tracing::info!("Synthesized: {:?}", event); } diff --git a/readme.md b/readme.md index aa13954..895a5d6 100644 --- a/readme.md +++ b/readme.md @@ -19,17 +19,20 @@ This library aims to provide an easy-to-install and straightforward interface fo The library currently supports the following features: -- [X] Speech-to-Text (Speech Recognition) +- [X] Speech Recognition (Speech-to-Text) [examples](examples/recognize_simple.rs) - [X] Real-time Speech Recognition - [X] Custom Speech Recognition -- [X] Text-to-Speech (Speech Synthesis) + - [X] Phrase List + - [ ] Conversation Transcriber - Real-time Diarization (Work in Progress) + - [ ] Pronunciation Assessment (Work in Progress) +- [X] Speech Synthesis (Text-to-Speech) [example](examples/synthesize_simple.rs) - [X] Real-time Speech Synthesis - [X] Custom Voice - [X] SSML Support -- [ ] Speech Translation -- [ ] Intent Recognition -- [ ] Speaker Recognition -- [ ] Keyword Recognition +- [ ] Speech Translation (Work in Progress) +- [ ] Intent Recognition (Work in Progress) +- [ ] Keyword Recognition (Work in Progress) + The library is currently in the early stages of development, and I am actively working on adding more features and improving the existing ones. @@ -48,6 +51,10 @@ Add this library to your project using the following command: cargo add azure_speech ``` +**And that's it!** + +You are now ready to use the Azure Speech SDK in your Rust project. + ## Usage For usage examples, please refer to the [examples folder](https://github.com/jBernavaPrah/azure-speech-sdk-rs/tree/master/examples) in the repository. Or check the [documentation](https://docs.rs/azure-speech). diff --git a/src/callback.rs b/src/callback.rs new file mode 100644 index 0000000..4ac43b4 --- /dev/null +++ b/src/callback.rs @@ -0,0 +1,15 @@ +// src/callback.rs +use crate::RequestId; +use std::future::Future; +use std::pin::Pin; + +pub(crate) type OnSessionStarted = Box BoxFuture>; +pub(crate) type OnSessionEnded = Box BoxFuture>; +pub(crate) type OnError = Box BoxFuture>; +pub(crate) type BoxFuture = Pin + Send + 'static>>; + +#[async_trait::async_trait] +pub trait Callback { + type Item; + fn on_event(&self, item: Self::Item) -> impl Future; +} diff --git a/src/lib.rs b/src/lib.rs index 9a630d9..0c6dfe8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,11 +8,11 @@ //! `tokio` runtime, it minimizes external dependencies wherever possible. //! //! ## Core Functionalities -//! - [X] Speech to Text -//! - [X] Text to Speech +//! - [X] Speech to Text [recognizer] +//! - [X] Text to Speech [synthesizer] //! //! For comprehensive information on Microsoft Speech Service, refer to the official -//! documentation [here](https://docs.microsoft.com/en-us/azure/cognitive-services/speech-service/speech-sdk?tabs=windows%2Cubuntu%2Cios-xcode%2Cmac-xcode%2Candroid-studio). +//! documentation [here](https://learn.microsoft.com/en-us/azure/ai-services/speech-service/speech-sdk). //! //! ## Notes //! This crate, in its current version, does not include some features available in the @@ -20,7 +20,7 @@ //! However, examples demonstrating these capabilities can be found in the `examples` directory. //! //! ## Usage and Examples -//! Detailed usage instructions and examples are provided in the `examples` directory. +//! Detailed usage instructions and examples are provided in the [examples](https://github.com/jBernavaPrah/azure-speech-sdk-rs/blob/master/examples) folder in the GitHub repository. //! mod auth; @@ -31,6 +31,7 @@ mod event; mod stream_ext; mod utils; +mod callback; pub mod recognizer; pub mod synthesizer; @@ -39,7 +40,7 @@ pub use connector::*; pub use error::*; pub use event::*; -pub use stream_ext::StreamExt; +pub use stream_ext::*; pub mod stream { //! Re-export of `tokio_stream` crate. diff --git a/src/recognizer/callback.rs b/src/recognizer/callback.rs new file mode 100644 index 0000000..1935a4e --- /dev/null +++ b/src/recognizer/callback.rs @@ -0,0 +1,195 @@ +use crate::callback::{BoxFuture, OnError, OnSessionEnded, OnSessionStarted}; +use crate::recognizer::{Duration, Event, Offset, RawMessage, Recognized}; +use crate::RequestId; +use std::future::Future; +use std::sync::Arc; + +pub(crate) type OnRecognizing = + Box BoxFuture>; +pub(crate) type OnRecognized = + Box BoxFuture>; +pub(crate) type OnUnMatch = Box BoxFuture>; +pub(crate) type OnStartDetected = Box BoxFuture>; +pub(crate) type OnEndDetected = Box BoxFuture>; + +#[derive(Default, Clone)] +pub struct Callback { + pub(crate) on_session_started: Option>, + pub(crate) on_error: Option>, + pub(crate) on_session_ended: Option>, + + pub(crate) on_recognizing: Option>, + pub(crate) on_recognized: Option>, + pub(crate) on_un_match: Option>, + pub(crate) on_start_detected: Option>, + pub(crate) on_end_detected: Option>, +} + +impl Callback { + pub fn on_session_start(mut self, func: F) -> Self + where + F: Fn(RequestId) -> Fut + 'static, + Fut: Future + Send + Sync + 'static, + { + self.on_session_started = Some(Arc::new(Box::new(move |str| Box::pin(func(str))))); + self + } + + pub fn on_session_end(mut self, func: F) -> Self + where + F: Fn(RequestId) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, + { + self.on_session_ended = Some(Arc::new(Box::new(move |str| Box::pin(func(str))))); + self + } + + pub fn on_error(mut self, func: F) -> Self + where + F: Fn(RequestId, crate::Error) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, + { + self.on_error = Some(Arc::new(Box::new(move |request, err| { + Box::pin(func(request, err)) + }))); + self + } + + pub fn on_recognizing(mut self, func: F) -> Self + where + F: Fn(RequestId, Recognized, Offset, Duration, RawMessage) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, + { + self.on_recognizing = Some(Arc::new(Box::new( + move |request_id, recognized, offset, duration, raw_message| { + Box::pin(func(request_id, recognized, offset, duration, raw_message)) + }, + ))); + self + } + + pub fn on_recognized(mut self, func: F) -> Self + where + F: Fn(RequestId, Recognized, Offset, Duration, RawMessage) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, + { + self.on_recognized = Some(Arc::new(Box::new( + move |request_id, recognized, offset, duration, raw_message| { + Box::pin(func(request_id, recognized, offset, duration, raw_message)) + }, + ))); + self + } + + pub fn on_un_match(mut self, func: F) -> Self + where + F: Fn(RequestId, Offset, Duration, RawMessage) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, + { + self.on_un_match = Some(Arc::new(Box::new( + move |request_id, offset, duration, raw_message| { + Box::pin(func(request_id, offset, duration, raw_message)) + }, + ))); + self + } + + pub fn on_start_detected(mut self, func: F) -> Self + where + F: Fn(RequestId, Offset) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, + { + self.on_start_detected = Some(Arc::new(Box::new(move |request, offset| { + Box::pin(func(request, offset)) + }))); + self + } + + pub fn on_end_detected(mut self, func: F) -> Self + where + F: Fn(RequestId, Offset) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, + { + self.on_end_detected = Some(Arc::new(Box::new(move |request, offset| { + Box::pin(func(request, offset)) + }))); + self + } +} + +#[async_trait::async_trait] +impl crate::callback::Callback for Callback { + type Item = crate::Result; + + #[allow(clippy::manual_async_fn)] + fn on_event(&self, item: Self::Item) -> impl Future { + async move { + match &item { + Ok(Event::SessionStarted(request_id)) => { + tracing::debug!("Session started"); + if let Some(f) = self.on_session_started.as_ref() { + f(*request_id).await + } + } + Ok(Event::SessionEnded(request_id)) => { + tracing::debug!("Session ended"); + if let Some(f) = self.on_session_ended.as_ref() { + f(*request_id).await + } + } + + Ok(Event::Recognizing(request_id, recognized, offset, duration, raw)) => { + if let Some(f) = self.on_recognizing.as_ref() { + f( + *request_id, + recognized.clone(), + *offset, + *duration, + raw.clone(), + ) + .await + } + } + + Ok(Event::Recognized(request_id, recognized, offset, duration, raw)) => { + if let Some(f) = self.on_recognized.as_ref() { + f( + *request_id, + recognized.clone(), + *offset, + *duration, + raw.clone(), + ) + .await + } + } + + Ok(Event::UnMatch(request_id, offset, duration, raw)) => { + if let Some(f) = self.on_un_match.as_ref() { + f(*request_id, *offset, *duration, raw.clone()).await + } + } + + Ok(Event::EndDetected(request_id, offset)) => { + if let Some(f) = self.on_end_detected.as_ref() { + f(*request_id, *offset).await + } + } + + Ok(Event::StartDetected(request_id, offset)) => { + if let Some(f) = self.on_start_detected.as_ref() { + f(*request_id, *offset).await + } + } + + Err(e) => { + tracing::error!("Error: {:?}", e); + if let Some(_f) = self.on_error.as_ref() { + // todo: improve the error with adding the request_id on it! + // f(session.request_id(), e.clone()) + } + } + } + } + } +} diff --git a/src/recognizer/client.rs b/src/recognizer/client.rs index ca6ff15..c4686d6 100644 --- a/src/recognizer/client.rs +++ b/src/recognizer/client.rs @@ -186,11 +186,6 @@ impl Client { .merge(tokio_stream::iter(vec![Ok(Event::SessionStarted( session3.request_id(), ))])) - // Handle the events and call the callbacks. - .map(move |event| { - // todo: implement the callbacks for events - event - }) // Stop the stream if there is an error or the session ended. .stop_after(move |event| event.is_err() || matches!(event, Ok(Event::SessionEnded(_))))) } diff --git a/src/recognizer/config.rs b/src/recognizer/config.rs index b344b92..ddff088 100644 --- a/src/recognizer/config.rs +++ b/src/recognizer/config.rs @@ -5,13 +5,14 @@ use serde::{Deserialize, Serialize}; /// The configuration for the recognizer. /// /// The configuration is used to set the parameters of the speech recognition. -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct Config { pub(crate) device: Device, pub(crate) languages: Vec, pub(crate) output_format: OutputFormat, + // todo: probably this will be removed and moved directly in the connection. pub(crate) mode: RecognitionMode, // todo: what is this? pub(crate) language_detect_mode: Option, @@ -25,6 +26,8 @@ pub struct Config { pub(crate) store_audio: bool, // todo: is this needed? pub(crate) profanity: Profanity, + // todo: check diarization https://learn.microsoft.com/en-us/azure/ai-services/speech-service/get-started-stt-diarization?tabs=macos&pivots=programming-language-javascript + // probably will be moved from here and added to a separate module. //pub(crate) recognize_speaker: bool, // todo add more detailed configuration from default: src/common.speech/ConnectionFactoryBase.ts @@ -43,7 +46,6 @@ impl Default for Config { store_audio: false, device: Device::default(), profanity: Profanity::Masked, - //recognize_speaker: false, } } } diff --git a/src/recognizer/event.rs b/src/recognizer/event.rs index 5f3857c..16675a8 100644 --- a/src/recognizer/event.rs +++ b/src/recognizer/event.rs @@ -53,6 +53,8 @@ pub struct Recognized { pub text: String, /// The primary language of the recognized text. pub primary_language: Option, + + // todo: Remove from here and add to a diarization module. /// The speaker id of the recognized text. /// This will be None if the detection of the speaker is not activated. pub speaker_id: Option, diff --git a/src/recognizer/mod.rs b/src/recognizer/mod.rs index 05c3e44..5ebe36d 100644 --- a/src/recognizer/mod.rs +++ b/src/recognizer/mod.rs @@ -40,6 +40,7 @@ //! } //! +mod callback; mod client; mod config; mod content_type; @@ -49,6 +50,7 @@ mod message; mod session; mod utils; +pub use callback::*; pub use client::*; pub use config::*; pub use content_type::*; diff --git a/src/stream_ext.rs b/src/stream_ext.rs index b414a2f..8738bf2 100644 --- a/src/stream_ext.rs +++ b/src/stream_ext.rs @@ -1,8 +1,11 @@ +use crate::callback::Callback; use core::fmt; use core::pin::Pin; use core::task::{Context, Poll}; use pin_project_lite::pin_project; -use tokio_stream::Stream; +use std::future::Future; +use std::pin::pin; +use tokio_stream::{Stream, StreamExt as _}; pin_project! { /// Stream for the [`stop_after`](stop_after) method. @@ -72,7 +75,10 @@ where } /// An extension trait for `Stream` that provides a variety of convenient combinator functions. -pub trait StreamExt: Stream { +pub trait StreamExt: Stream +where + Self: 'static, +{ /// Takes elements from this stream until the provided predicate resolves to `true`. /// /// This function operates similarly to `Iterator::take_while`, extracting elements from the @@ -110,6 +116,23 @@ pub trait StreamExt: Stream { { StopAfter::new(self, f) } + + /// Calls the provided callback for each item in the stream. + /// + /// + + fn use_callbacks(self, callback: C) -> impl Future + where + Self: Sized + Send + Sync, + C: Callback + 'static, + { + async move { + let mut _self = pin!(self); + while let Some(event) = _self.next().await { + callback.on_event(event).await; + } + } + } } -impl StreamExt for St where St: Stream {} +impl StreamExt for St where St: Stream {} diff --git a/src/synthesizer/callback.rs b/src/synthesizer/callback.rs new file mode 100644 index 0000000..cf42710 --- /dev/null +++ b/src/synthesizer/callback.rs @@ -0,0 +1,139 @@ +use crate::callback::{BoxFuture, OnError, OnSessionEnded, OnSessionStarted}; +use crate::synthesizer::Event; +use crate::RequestId; +use std::future::Future; +use std::sync::Arc; + +pub(crate) type OnSynthesising = Arc) -> BoxFuture>>; +pub(crate) type OnAudioMetadata = Arc BoxFuture>>; +pub(crate) type OnSynthesised = Arc BoxFuture>>; + +#[derive(Default, Clone)] + +pub struct Callback { + pub(crate) on_session_started: Option>, + pub(crate) on_error: Option>, + pub(crate) on_session_ended: Option>, + + pub(crate) on_synthesising: Option, + pub(crate) on_audio_metadata: Option, + pub(crate) on_synthesised: Option, +} + +impl Callback { + pub fn on_session_start(mut self, func: F) -> Self + where + F: Fn(RequestId) -> Fut + 'static, + Fut: Future + Send + Sync + 'static, + { + self.on_session_started = Some(Arc::new(Box::new(move |str| Box::pin(func(str))))); + self + } + + pub fn on_session_end(mut self, func: F) -> Self + where + F: Fn(RequestId) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, + { + self.on_session_ended = Some(Arc::new(Box::new(move |request_id| { + Box::pin(func(request_id)) + }))); + self + } + + pub fn on_error(mut self, func: F) -> Self + where + F: Fn(RequestId, crate::Error) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, + { + self.on_error = Some(Arc::new(Box::new(move |request, err| { + Box::pin(func(request, err)) + }))); + self + } + + pub fn on_synthesising(mut self, func: F) -> Self + where + F: Fn(RequestId, Vec) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, + { + self.on_synthesising = Some(Arc::new(Box::new(move |request_id, audio| { + Box::pin(func(request_id, audio)) + }))); + self + } + + pub fn on_audio_metadata(mut self, func: F) -> Self + where + F: Fn(RequestId, String) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, + { + self.on_audio_metadata = Some(Arc::new(Box::new(move |request_id, metadata| { + Box::pin(func(request_id, metadata)) + }))); + self + } + + pub fn on_synthesised(mut self, func: F) -> Self + where + F: Fn(RequestId) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, + { + self.on_synthesised = Some(Arc::new(Box::new(move |request_id| { + Box::pin(func(request_id)) + }))); + self + } +} +#[async_trait::async_trait] +impl crate::callback::Callback for Callback { + type Item = crate::Result; + + #[allow(clippy::manual_async_fn)] + fn on_event(&self, item: Self::Item) -> impl Future { + async move { + match &item { + Ok(Event::SessionStarted(request_id)) => { + tracing::debug!("Session started"); + if let Some(f) = self.on_session_started.as_ref() { + f(*request_id).await + } + } + Ok(Event::SessionEnded(request_id)) => { + tracing::debug!("Session ended"); + if let Some(f) = self.on_session_ended.as_ref() { + f(*request_id).await + } + } + + Ok(Event::Synthesising(request_id, audio)) => { + tracing::debug!("Synthesising audio: {:?}", audio.len()); + if let Some(f) = self.on_synthesising.as_ref() { + f(*request_id, audio.clone()).await + } + } + + Ok(Event::Synthesised(request_id)) => { + tracing::debug!("Synthesised"); + if let Some(f) = self.on_synthesised.as_ref() { + f(*request_id).await + } + } + + Ok(Event::AudioMetadata(request_id, metadata)) => { + tracing::debug!("Audio metadata: {:?}", metadata); + if let Some(f) = self.on_audio_metadata.as_ref() { + f(*request_id, metadata.clone()).await + } + } + + Err(e) => { + tracing::error!("Error: {:?}", e); + if let Some(_f) = self.on_error.as_ref() { + //f(session3.request_id(), e.clone()).await + } + } + } + } + } +} diff --git a/src/synthesizer/client.rs b/src/synthesizer/client.rs index 2b3c2dd..6dda336 100644 --- a/src/synthesizer/client.rs +++ b/src/synthesizer/client.rs @@ -82,7 +82,6 @@ impl Client { .send_text(create_ssml_message(request_id.to_string(), xml))?; let session2 = session.clone(); - let session3 = session.clone(); Ok(stream // Map errors. .map(move |message| match message { @@ -99,53 +98,6 @@ impl Client { Ok(message) => convert_message_to_event(message, session2.clone()), Err(e) => Some(Err(e)), }) - // Handle the events and call the callbacks. - .map(move |event| { - match &event { - Ok(Event::SessionEnded(request_id)) => { - tracing::debug!("Session ended"); - if let Some(f) = config.on_session_ended.as_ref() { - f(*request_id) - } - } - Ok(Event::SessionStarted(request_id)) => { - tracing::debug!("Session started"); - if let Some(f) = config.on_session_started.as_ref() { - f(*request_id) - } - } - - Ok(Event::Synthesising(request_id, audio)) => { - tracing::debug!("Synthesising audio: {:?}", audio.len()); - if let Some(f) = config.on_synthesising.as_ref() { - f(*request_id, audio.clone()) - } - } - - Ok(Event::Synthesised(request_id)) => { - tracing::debug!("Synthesised"); - if let Some(f) = config.on_synthesised.as_ref() { - f(*request_id) - } - } - - Ok(Event::AudioMetadata(request_id, metadata)) => { - tracing::debug!("Audio metadata: {:?}", metadata); - if let Some(f) = config.on_audio_metadata.as_ref() { - f(*request_id, metadata.clone()) - } - } - - Err(e) => { - tracing::error!("Error: {:?}", e); - if let Some(f) = config.on_error.as_ref() { - f(session3.request_id(), e.clone()) - } - } - } - - event - }) // Stop the stream if there is an error or the session ended. .stop_after(|event| event.is_err() || matches!(event, Ok(Event::SessionEnded(_))))) } diff --git a/src/synthesizer/config.rs b/src/synthesizer/config.rs index 25b5d99..899b377 100644 --- a/src/synthesizer/config.rs +++ b/src/synthesizer/config.rs @@ -1,7 +1,5 @@ use crate::config::Device; use crate::synthesizer::{AudioFormat, Language, Voice}; -use crate::RequestId; -use std::sync::Arc; #[derive(Clone, Default)] pub struct Config { @@ -20,22 +18,8 @@ pub struct Config { pub(crate) viseme_enabled: bool, pub(crate) auto_detect_language: bool, - - pub(crate) on_session_started: Option, - pub(crate) on_session_ended: Option, - pub(crate) on_synthesising: Option, - pub(crate) on_audio_metadata: Option, - pub(crate) on_synthesised: Option, - pub(crate) on_error: Option, } -pub type OnSessionStarted = Arc>; -pub type OnSessionEnded = Arc>; -pub type OnSynthesising = Arc) + Send + Sync + 'static>>; -pub type OnAudioMetadata = Arc>; -pub type OnSynthesised = Arc>; -pub type OnError = Arc>; - impl Config { pub fn new() -> Self { Self { @@ -99,52 +83,4 @@ impl Config { self.device = device; self } - - pub fn on_session_start(mut self, func: Func) -> Self - where - Func: Send + Sync + 'static + Fn(RequestId), - { - self.on_session_started = Some(Arc::new(Box::new(func))); - self - } - - pub fn on_session_end(mut self, func: Func) -> Self - where - Func: Send + Sync + 'static + Fn(RequestId), - { - self.on_session_ended = Some(Arc::new(Box::new(func))); - self - } - - pub fn on_synthesising(mut self, func: Func) -> Self - where - Func: Send + Sync + 'static + Fn(RequestId, Vec), - { - self.on_synthesising = Some(Arc::new(Box::new(func))); - self - } - - pub fn on_audio_metadata(mut self, func: Func) -> Self - where - Func: Send + Sync + 'static + Fn(RequestId, String), - { - self.on_audio_metadata = Some(Arc::new(Box::new(func))); - self - } - - pub fn on_synthesised(mut self, func: Func) -> Self - where - Func: Send + Sync + 'static + Fn(RequestId), - { - self.on_synthesised = Some(Arc::new(Box::new(func))); - self - } - - pub fn on_error(mut self, func: Func) -> Self - where - Func: Send + Sync + 'static + Fn(RequestId, crate::Error), - { - self.on_error = Some(Arc::new(Box::new(func))); - self - } } diff --git a/src/synthesizer/mod.rs b/src/synthesizer/mod.rs index 6c44051..e32f95c 100644 --- a/src/synthesizer/mod.rs +++ b/src/synthesizer/mod.rs @@ -45,9 +45,11 @@ mod session; mod utils; mod voice; +mod callback; pub mod ssml; pub use audio_format::*; +pub use callback::*; pub use client::*; pub use config::*; pub use event::*;