diff --git a/Cargo.toml b/Cargo.toml index 353d57e..9af8600 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,6 +49,7 @@ url = "2.2" [dev-dependencies] # Diff view of test failures difference = "2.0" +futures = "0.3" futures-test = "0.3" reqwest = { version = "0.11", default-features = false, features = ["rustls-tls"] } tokio = { version = "1.0", features = ["macros"] } diff --git a/src/v1/objects/insert.rs b/src/v1/objects/insert.rs index 0a02dc8..99a7d9d 100644 --- a/src/v1/objects/insert.rs +++ b/src/v1/objects/insert.rs @@ -8,6 +8,7 @@ use crate::{ use futures_util::{ io::{AsyncRead, Result as FuturesResult}, task::{Context, Poll}, + Stream, }; #[cfg(feature = "async-multipart")] use pin_utils::unsafe_pinned; @@ -293,6 +294,29 @@ impl AsyncRead for Multipart { } } +#[cfg(feature = "async-multipart")] +impl Stream for Multipart { + type Item = bytes::Bytes; + + fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(match self.cursor.part { + MultipartPart::Prefix => { + self.cursor.part.next(); + Some(self.prefix.clone()) + } + MultipartPart::Body => { + self.cursor.part.next(); + Some(self.body.clone()) + } + MultipartPart::Suffix => { + self.cursor.part.next(); + Some(bytes::Bytes::from(MULTI_PART_SUFFIX)) + } + MultipartPart::End => None, + }) + } +} + impl super::Object { /// Stores a new object and metadata. /// diff --git a/tests/objects.rs b/tests/objects.rs index 1199ba1..6b73dfa 100644 --- a/tests/objects.rs +++ b/tests/objects.rs @@ -392,6 +392,63 @@ fn insert_multipart_async() { util::cmp_strings(&exp_body, &act_body); } +#[cfg(feature = "async-multipart")] +#[test] +fn insert_multipart_stream_bytes() { + use bytes::{BufMut, Bytes, BytesMut}; + + let metadata = Metadata { + name: Some("good_name".to_owned()), + content_type: Some("text/plain".to_owned()), + content_encoding: Some("gzip".to_owned()), + content_disposition: Some("attachment; filename=\"good name.jpg\"".to_owned()), + metadata: Some( + ["akey"] + .iter() + .map(|k| (String::from(*k), format!("{}value", k))) + .collect(), + ), + ..Default::default() + }; + + let insert_req = Object::insert_multipart( + &BucketName::non_validated("bucket"), + Bytes::from(TEST_CONTENT), + TEST_CONTENT.len() as u64, + &metadata, + None, + ) + .unwrap(); + + let exp_body = format!( + "--{b}\ncontent-type: application/json; charset=utf-8\n\n{}\n--{b}\ncontent-type: text/plain\n\n{}\n--{b}--", + serde_json::to_string(&metadata).unwrap(), + TEST_CONTENT, + b = "tame_gcs" + ); + + let expected = http::Request::builder() + .method(http::Method::POST) + .uri("https://www.googleapis.com/upload/storage/v1/b/bucket/o?uploadType=multipart&prettyPrint=false") + .header(http::header::CONTENT_TYPE, "multipart/related; boundary=tame_gcs") + .header(http::header::CONTENT_LENGTH, 5758) + .body(exp_body) + .unwrap(); + + let (exp_parts, exp_body) = expected.into_parts(); + let (act_parts, act_multipart) = insert_req.into_parts(); + + util::cmp_strings(&format!("{:#?}", exp_parts), &format!("{:#?}", act_parts)); + + let mut act_body = BytesMut::with_capacity(2 * 1024); + for chunk in futures::executor::block_on_stream(act_multipart) { + act_body.put(chunk); + } + let act_body = String::from_utf8_lossy(&act_body); + + util::cmp_strings(&exp_body, &act_body); +} + #[test] fn patches() { let mut md = std::collections::BTreeMap::new();