Skip to content

Commit

Permalink
Refactor StreamFile to use takecell
Browse files Browse the repository at this point in the history
  • Loading branch information
Desiders committed Oct 29, 2023
1 parent 93df40a commit e6197a6
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 42 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ dashmap = "5.4"
regex = "1.10"
backoff = "0.4"
bytes = "1.3"
triomphe = "0.1"
takecell = "0.1"
pathdiff = "0.2"
uuid = { version = "1.5", features = ["v4"] }

Expand Down
12 changes: 7 additions & 5 deletions src/client/session/reqwest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use reqwest::{
use serde::Serialize;
use std::{borrow::Cow, time::Duration};
use tracing::{event, field, instrument, Level, Span};
use triomphe::Arc as TriompheArc;

#[derive(Debug, Clone)]
pub struct Reqwest {
Expand Down Expand Up @@ -59,7 +58,7 @@ impl Reqwest {
return Ok(form);
};

for file in files {
for (index, file) in files.iter().enumerate() {
match file.kind() {
InputFileKind::FS(file) => {
let id = file.id().to_string();
Expand Down Expand Up @@ -89,11 +88,14 @@ impl Reqwest {
form = form.part(id, part);
}
InputFileKind::Stream(file) => {
let Some(stream) = file.take_stream() else {
return Err(SerializerError::Custom(Cow::Owned(format!(
"File stream with index `{index}` already taken. \
Read `StreamFile::take_stream` documentation for more information."
))));
};
let id = file.id().to_string();
let file_name = file.file_name();
let Ok(stream) = TriompheArc::try_unwrap(file.stream()) else {
panic!("Cannot unwrap a stream. `InputFile::stream` shouldn't have more than one strong reference");
};

let body = Body::wrap_stream(stream);
let part = if let Some(file_name) = file_name {
Expand Down
4 changes: 2 additions & 2 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ pub use inline_query_result_voice::InlineQueryResultVoice;
pub use inline_query_results_button::InlineQueryResultsButton;
pub use input_contact_message_content::InputContactMessageContent;
pub use input_file::{
FSFile as InputFSFile, FileId as InputFileId, FileKind as InputFileKind, InputFile,
UrlFile as InputUrlFile,
BufferedFile as InputBufferedFile, FSFile as InputFSFile, FileId as InputFileId,
FileKind as InputFileKind, InputFile, StreamFile as InputStreamFile, UrlFile as InputUrlFile,
};
pub use input_invoice_message_content::InputInvoiceMessageContent;
pub use input_location_message_content::InputLocationMessageContent;
Expand Down
152 changes: 118 additions & 34 deletions src/types/input_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,32 @@ use std::{
hash::{Hash, Hasher},
io,
path::{Path, PathBuf},
sync::Arc,
};
use takecell::TakeOwnCell;
use tokio_util::codec::{BytesCodec, FramedRead};
use triomphe::Arc as TriompheArc;
use uuid::Uuid;

const ATTACH_PREFIX: &str = "attach://";

pub const DEFAULT_CAPACITY: usize = 64 * 1024; // 64 KiB

/// This object represents the contents of a file to be uploaded.
/// Must be posted using `multipart/form-data` in the usual way that files are uploaded via the browser.
/// # Notes
/// You can use instead of [`InputFile`] any type that implements [`Into<InputFile>`]:
/// - [`FileId`] (for example [`FileId::new(file_id)`])
/// - [`UrlFile`] (for example [`UrlFile::new(url)`])
/// - [`FSFile`] (for example [`FSFile::new(path)`])
/// - [`BufferedFile`] (for example [`BufferedFile::new(bytes)`])
/// - [`StreamFile`] (for example [`StreamFile::new(stream)`])
/// This struct is useful for fast and easy creation of any of these types,
/// but if you want to use methods of specific type (for example [`FSFile::stream`] or [`StreamFile::set_stream`]),
/// you need to use specific type.
/// # Warning
/// If you [`Clone`] file, you will get a new file with the same ID for [`FileId`], [`UrlFile`], [`FSFile`], [`BufferedFile`].
/// So several parts will refer to the same data. It can be useful to minimize the amount of data uploaded.
/// If case of [`StreamFile`] you will get a new file with a new ID,
/// this is done to avoid problems when several parts with different streams refer to the first part and respectively to the first stream.
/// # Documentation
/// <https://core.telegram.org/bots/api#inputfile>
#[derive(Debug, Clone, Hash, PartialEq)]
Expand Down Expand Up @@ -319,7 +334,7 @@ impl Hash for FSFile<'_> {
}
}

#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Clone, PartialEq, Eq)]
pub struct BufferedFile<'a> {
id: Uuid,
bytes: Bytes,
Expand Down Expand Up @@ -388,45 +403,35 @@ impl<'a> BufferedFile<'a> {
}
}

impl Debug for BufferedFile<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("BufferedFile")
.field("id", &self.id)
.field("file_name", &self.file_name)
.field("bytes", &"...")
.field("str_to_file", &self.str_to_file)
.finish()
}
}

impl Hash for BufferedFile<'_> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.id.hash(state);
}
}

type SharedStream =
Arc<TakeOwnCell<Box<dyn Stream<Item = Result<Bytes, io::Error>> + Send + Sync + Unpin>>>;

/// # Warning
/// We use [`TriompheArc`] because to share [`Stream`] between threads without copying it
#[derive(Clone)]
pub struct StreamFile<'a> {
id: Uuid,
file_name: Option<Cow<'a, str>>,
stream: TriompheArc<Box<dyn Stream<Item = Result<Bytes, io::Error>> + Send + Sync + Unpin>>,
stream: SharedStream,
str_to_file: Box<str>,
}

impl Debug for StreamFile<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("StreamFile")
.field("id", &self.id)
.field("file_name", &self.file_name)
.field("stream", &"...")
.field("str_to_file", &self.str_to_file)
.finish()
}
}

impl PartialEq for StreamFile<'_> {
fn eq(&self, other: &Self) -> bool {
self.id == other.id
}
}

impl Hash for StreamFile<'_> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.id.hash(state);
}
}

impl<'a> StreamFile<'a> {
#[must_use]
pub fn new(
Expand All @@ -439,7 +444,7 @@ impl<'a> StreamFile<'a> {
Self {
id,
file_name: None,
stream: TriompheArc::new(Box::new(stream)),
stream: Arc::new(TakeOwnCell::new(Box::new(stream))),
str_to_file: str_to_file.into(),
}
}
Expand All @@ -456,7 +461,7 @@ impl<'a> StreamFile<'a> {
Self {
id,
file_name: Some(name.into()),
stream: TriompheArc::new(Box::new(stream)),
stream: Arc::new(TakeOwnCell::new(Box::new(stream))),
str_to_file: str_to_file.into(),
}
}
Expand All @@ -477,12 +482,55 @@ impl<'a> StreamFile<'a> {
self.file_name.as_deref()
}

/// Gets stream
/// Takes stream.
/// # Warning
/// If stream is taken, default client implementation raises an error,
/// so you need to use [`StreamFile::heal_stream`] to heal stream manually.
/// # Returns
/// After this function once returns `Some(_)` all consequtive calls before [`StreamFile::heal_stream`]
/// will return `None` as the value is already taken
#[must_use]
pub fn stream(
pub fn take_stream(
&self,
) -> TriompheArc<Box<dyn Stream<Item = Result<Bytes, io::Error>> + Send + Sync + Unpin>> {
self.stream.clone()
) -> Option<Box<dyn Stream<Item = Result<Bytes, io::Error>> + Send + Sync + Unpin>> {
self.stream.take()
}

/// Sets stream unconditionally.
/// You need to use this method if you want to use [`StreamFile`] again for another request,
/// because after [`StreamFile::take_stream`] was called, stream is taken and cannot be used again.
/// # Notes
/// If stream is taken, this method sets stream anyway.
///
/// If you want to set stream only if stream is taken, use [`StreamFile::set_stream_if_taken`].
/// # Returns
pub fn set_stream(
&mut self,
stream: impl Stream<Item = Result<Bytes, io::Error>> + Send + Sync + Unpin + 'static,
) {
self.stream = Arc::new(TakeOwnCell::new(Box::new(stream)));
}

/// Sets stream if stream is taken.
/// You need to use this method if you want to use [`StreamFile`] again for another request,
/// because after [`StreamFile::take_stream`] was called, stream is taken and cannot be used again.
/// # Notes
/// If stream is taken, this method does nothing.
///
/// If you want to set stream unconditionally, use [`StreamFile::set_stream`].
/// # Returns
/// If stream is taken returns `false`, otherwise returns `true` and sets stream.
pub fn set_stream_if_taken(
&mut self,
stream: impl Stream<Item = Result<Bytes, io::Error>> + Send + Sync + Unpin + 'static,
) -> bool {
if self.stream.is_taken() {
return false;
}

self.stream = Arc::new(TakeOwnCell::new(Box::new(stream)));

true
}

/// Gets string to file as path in format `attach://{id}`
Expand All @@ -491,3 +539,39 @@ impl<'a> StreamFile<'a> {
&self.str_to_file
}
}

impl Debug for StreamFile<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("StreamFile")
.field("id", &self.id)
.field("file_name", &self.file_name)
.field("stream", &"...")
.field("str_to_file", &self.str_to_file)
.finish()
}
}

impl Clone for StreamFile<'_> {
fn clone(&self) -> Self {
let id = Uuid::new_v4();

Self {
id,
file_name: self.file_name.clone(),
stream: self.stream.clone(),
str_to_file: format!("{ATTACH_PREFIX}{id}").into(),
}
}
}

impl Hash for StreamFile<'_> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.id.hash(state);
}
}

impl PartialEq for StreamFile<'_> {
fn eq(&self, other: &Self) -> bool {
self.id == other.id
}
}

0 comments on commit e6197a6

Please sign in to comment.