diff --git a/src/io/geotiff.rs b/src/io/geotiff.rs index 6a68854..3c40eda 100644 --- a/src/io/geotiff.rs +++ b/src/io/geotiff.rs @@ -2,6 +2,7 @@ use std::io::{Read, Seek}; use geo::AffineTransform; use ndarray::{Array, Array1, Array3}; +use num_traits::FromPrimitive; use tiff::decoder::{Decoder, DecodingResult, Limits}; use tiff::tags::Tag; use tiff::{ColorType, TiffError, TiffFormatError, TiffResult, TiffUnsupportedError}; @@ -23,7 +24,7 @@ impl CogReader { } /// Decode GeoTIFF image to an [`ndarray::Array`] - pub fn ndarray(&mut self) -> TiffResult> { + pub fn ndarray(&mut self) -> TiffResult> { // Count number of bands let color_type = self.decoder.colortype()?; let num_bands: usize = match color_type { @@ -44,19 +45,45 @@ impl CogReader { // Get image pixel data let decode_result = self.decoder.read_image()?; - let image_data: Vec = match decode_result { - DecodingResult::F32(img_data) => img_data, - _ => { - return Err(TiffError::UnsupportedError( - TiffUnsupportedError::UnsupportedDataType, - )) + let image_data: Vec = match decode_result { + DecodingResult::U8(img_data) => { + img_data.iter().map(|v| T::from_u8(*v).unwrap()).collect() + } + DecodingResult::U16(img_data) => { + img_data.iter().map(|v| T::from_u16(*v).unwrap()).collect() + } + DecodingResult::U32(img_data) => { + img_data.iter().map(|v| T::from_u32(*v).unwrap()).collect() + } + DecodingResult::U64(img_data) => { + img_data.iter().map(|v| T::from_u64(*v).unwrap()).collect() + } + DecodingResult::I8(img_data) => { + img_data.iter().map(|v| T::from_i8(*v).unwrap()).collect() + } + DecodingResult::I16(img_data) => { + img_data.iter().map(|v| T::from_i16(*v).unwrap()).collect() + } + DecodingResult::I32(img_data) => { + img_data.iter().map(|v| T::from_i32(*v).unwrap()).collect() + } + DecodingResult::I64(img_data) => { + img_data.iter().map(|v| T::from_i64(*v).unwrap()).collect() + } + DecodingResult::F32(img_data) => { + img_data.iter().map(|v| T::from_f32(*v).unwrap()).collect() + } + DecodingResult::F64(img_data) => { + img_data.iter().map(|v| T::from_f64(*v).unwrap()).collect() } }; // Put image pixel data into an ndarray - let array_data = - Array3::from_shape_vec((num_bands, height as usize, width as usize), image_data) - .map_err(|_| TiffFormatError::InconsistentSizesEncountered)?; + let array_data: Array3 = Array3::from_shape_vec( + (num_bands, height as usize, width as usize), + image_data.into(), + ) + .map_err(|_| TiffFormatError::InconsistentSizesEncountered)?; Ok(array_data) } @@ -138,12 +165,14 @@ impl CogReader { } /// Synchronously read a GeoTIFF file into an [`ndarray::Array`] -pub fn read_geotiff(stream: R) -> TiffResult> { +pub fn read_geotiff( + stream: R, +) -> TiffResult> { // Open TIFF stream with decoder let mut reader = CogReader::new(stream)?; // Decode TIFF into ndarray - let array_data: Array3 = reader.ndarray()?; + let array_data: Array3 = reader.ndarray()?; Ok(array_data) } @@ -205,7 +234,25 @@ mod tests { let array = reader.ndarray().unwrap(); assert_eq!(array.dim(), (2, 512, 512)); - assert_eq!(array.mean(), Some(225.17654)); + assert_eq!(array.mean(), Some(225.17439122416545)); + } + + #[tokio::test] + async fn test_read_geotiff_uint16_dtype() { + let cog_url: &str = + "https://github.com/OSGeo/gdal/raw/v3.9.2/autotest/gcore/data/uint16.tif"; + let tif_url = Url::parse(cog_url).unwrap(); + let (store, location) = parse_url(&tif_url).unwrap(); + + let result = store.get(&location).await.unwrap(); + let bytes = result.bytes().await.unwrap(); + let stream = Cursor::new(bytes); + + let mut reader = CogReader::new(stream).unwrap(); + let array = reader.ndarray::().unwrap(); + + assert_eq!(array.dim(), (1, 20, 20)); + assert_eq!(array.mean(), Some(126)); } #[tokio::test] @@ -219,7 +266,7 @@ mod tests { let stream = Cursor::new(bytes); let mut reader = CogReader::new(stream).unwrap(); - let array = reader.ndarray().unwrap(); + let array = reader.ndarray::().unwrap(); assert_eq!(array.shape(), [1, 2, 3]); assert_eq!(array, array![[[1.41, 1.23, 0.78], [0.32, -0.23, -1.88]]])