Skip to content

Commit

Permalink
✨ Support reading uint/int/float dtypes
Browse files Browse the repository at this point in the history
Add support on the Rust side for reading u8/u16/u32/u64/i8/i16/i32/i64/f32/f64 dtypes via a num_traits::FromPrimitive bound. Different dtypes can be selected via the turbofish operator e.g. by calling `.ndarray::<u16>()`. Added a unit test to check that reading a uint16 tif file works.
  • Loading branch information
weiji14 committed Sep 9, 2024
1 parent ddf179b commit d152e51
Showing 1 changed file with 61 additions and 14 deletions.
75 changes: 61 additions & 14 deletions src/io/geotiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::io::{Read, Seek};

use geo::AffineTransform;
use ndarray::{Array, Array1, Array3};
use num_traits::FromPrimitive;
use tiff::decoder::{Decoder, DecodingResult, Limits};
use tiff::tags::Tag;
use tiff::{ColorType, TiffError, TiffFormatError, TiffResult, TiffUnsupportedError};
Expand All @@ -23,7 +24,7 @@ impl<R: Read + Seek> CogReader<R> {
}

/// Decode GeoTIFF image to an [`ndarray::Array`]
pub fn ndarray(&mut self) -> TiffResult<Array3<f32>> {
pub fn ndarray<T: FromPrimitive + 'static>(&mut self) -> TiffResult<Array3<T>> {
// Count number of bands
let color_type = self.decoder.colortype()?;
let num_bands: usize = match color_type {
Expand All @@ -44,19 +45,45 @@ impl<R: Read + Seek> CogReader<R> {

// Get image pixel data
let decode_result = self.decoder.read_image()?;
let image_data: Vec<f32> = match decode_result {
DecodingResult::F32(img_data) => img_data,
_ => {
return Err(TiffError::UnsupportedError(
TiffUnsupportedError::UnsupportedDataType,
))
let image_data: Vec<T> = match decode_result {
DecodingResult::U8(img_data) => {
img_data.iter().map(|v| T::from_u8(*v).unwrap()).collect()
}
DecodingResult::U16(img_data) => {
img_data.iter().map(|v| T::from_u16(*v).unwrap()).collect()
}
DecodingResult::U32(img_data) => {
img_data.iter().map(|v| T::from_u32(*v).unwrap()).collect()
}
DecodingResult::U64(img_data) => {
img_data.iter().map(|v| T::from_u64(*v).unwrap()).collect()
}
DecodingResult::I8(img_data) => {
img_data.iter().map(|v| T::from_i8(*v).unwrap()).collect()
}
DecodingResult::I16(img_data) => {
img_data.iter().map(|v| T::from_i16(*v).unwrap()).collect()
}
DecodingResult::I32(img_data) => {
img_data.iter().map(|v| T::from_i32(*v).unwrap()).collect()
}
DecodingResult::I64(img_data) => {
img_data.iter().map(|v| T::from_i64(*v).unwrap()).collect()
}
DecodingResult::F32(img_data) => {
img_data.iter().map(|v| T::from_f32(*v).unwrap()).collect()
}
DecodingResult::F64(img_data) => {
img_data.iter().map(|v| T::from_f64(*v).unwrap()).collect()
}
};

// Put image pixel data into an ndarray
let array_data =
Array3::from_shape_vec((num_bands, height as usize, width as usize), image_data)
.map_err(|_| TiffFormatError::InconsistentSizesEncountered)?;
let array_data: Array3<T> = Array3::from_shape_vec(
(num_bands, height as usize, width as usize),
image_data.into(),
)
.map_err(|_| TiffFormatError::InconsistentSizesEncountered)?;

Ok(array_data)
}
Expand Down Expand Up @@ -138,12 +165,14 @@ impl<R: Read + Seek> CogReader<R> {
}

/// Synchronously read a GeoTIFF file into an [`ndarray::Array`]
pub fn read_geotiff<R: Read + Seek>(stream: R) -> TiffResult<Array3<f32>> {
pub fn read_geotiff<T: FromPrimitive + 'static, R: Read + Seek>(
stream: R,
) -> TiffResult<Array3<T>> {
// Open TIFF stream with decoder
let mut reader = CogReader::new(stream)?;

// Decode TIFF into ndarray
let array_data: Array3<f32> = reader.ndarray()?;
let array_data: Array3<T> = reader.ndarray()?;

Ok(array_data)
}
Expand Down Expand Up @@ -205,7 +234,25 @@ mod tests {
let array = reader.ndarray().unwrap();

assert_eq!(array.dim(), (2, 512, 512));
assert_eq!(array.mean(), Some(225.17654));
assert_eq!(array.mean(), Some(225.17439122416545));
}

#[tokio::test]
async fn test_read_geotiff_uint16_dtype() {
let cog_url: &str =
"https://github.com/OSGeo/gdal/raw/v3.9.2/autotest/gcore/data/uint16.tif";
let tif_url = Url::parse(cog_url).unwrap();
let (store, location) = parse_url(&tif_url).unwrap();

let result = store.get(&location).await.unwrap();
let bytes = result.bytes().await.unwrap();
let stream = Cursor::new(bytes);

let mut reader = CogReader::new(stream).unwrap();
let array = reader.ndarray::<u16>().unwrap();

assert_eq!(array.dim(), (1, 20, 20));
assert_eq!(array.mean(), Some(126));
}

#[tokio::test]
Expand All @@ -219,7 +266,7 @@ mod tests {
let stream = Cursor::new(bytes);

let mut reader = CogReader::new(stream).unwrap();
let array = reader.ndarray().unwrap();
let array = reader.ndarray::<f32>().unwrap();

assert_eq!(array.shape(), [1, 2, 3]);
assert_eq!(array, array![[[1.41, 1.23, 0.78], [0.32, -0.23, -1.88]]])
Expand Down

0 comments on commit d152e51

Please sign in to comment.