From 751172373cc4386faebd52a8d4193c59ff24dfcc Mon Sep 17 00:00:00 2001 From: Lea Vauchier Date: Tue, 13 Feb 2024 17:04:12 +0100 Subject: [PATCH] Pass input file parameters to las writer --- environment.yml | 10 ++++++---- myria3d/models/interpolation.py | 14 +++++++++----- myria3d/pctl/dataset/utils.py | 19 ++++++------------- tests/myria3d/test_train_and_predict.py | 6 ++++++ 4 files changed, 27 insertions(+), 22 deletions(-) diff --git a/environment.yml b/environment.yml index 553bccf0..0d582df4 100644 --- a/environment.yml +++ b/environment.yml @@ -29,10 +29,11 @@ dependencies: - numpy - h5py # --------- geo --------- # - - pdal + - pdal>=2.6.0 - python-pdal - - pyproj - # --------- Visualization --------- # + - conda-forge:gdal + - conda_forge:pyproj + # --------- Visualization --------- # - pandas - matplotlib # --------- loggers --------- # @@ -68,4 +69,5 @@ dependencies: # --------- Documentation --------- # - myst_parser==0.17.* - sphinxnotes-mock==1.0.0b0 # still a beta - - sphinx_paramlinks==0.5.* \ No newline at end of file + - sphinx_paramlinks==0.5.* + - ign-pdal-tools>=1.5.2 \ No newline at end of file diff --git a/myria3d/models/interpolation.py b/myria3d/models/interpolation.py index dbefbc55..7998d85c 100644 --- a/myria3d/models/interpolation.py +++ b/myria3d/models/interpolation.py @@ -8,6 +8,8 @@ from torch.distributions import Categorical from torch_scatter import scatter_sum +from pdaltools import las_info + from myria3d.pctl.dataset.utils import get_pdal_info_metadata, get_pdal_reader log = logging.getLogger(__name__) @@ -83,7 +85,10 @@ def load_full_las_for_update(self, src_las: str, epsg: str) -> np.ndarray: pipeline |= pdal.Filter.assign(value=f"{self.entropy_channel}=0") pipeline.execute() - return pipeline.arrays[0] + writer_params = las_info.get_writer_parameters_from_reader_metadata( + pipeline.metadata, a_srs=f"EPSG:{epsg}" if str(epsg).isdigit() else epsg + ) + return pipeline.arrays[0], writer_params def store_predictions(self, logits, idx_in_original_cloud) -> None: """Keep a list of predictions made so far.""" @@ -143,7 +148,7 @@ def reduce_predictions_and_save(self, raw_path: str, output_dir: str, epsg: str) del logits # Read las after fetching all information to write into it - las = self.load_full_las_for_update(raw_path, epsg) + las, writer_params = self.load_full_las_for_update(raw_path, epsg) for idx, class_name in enumerate(self.classification_dict.values()): if class_name in self.probas_to_save: @@ -173,9 +178,8 @@ def reduce_predictions_and_save(self, raw_path: str, output_dir: str, epsg: str) out_f = os.path.abspath(out_f) log.info(f"Updated LAS ({basename}) will be saved to: \n {output_dir}\n") log.info("Saving...") - pipeline = pdal.Writer.las( - filename=out_f, extra_dims="all", minor_version=4, dataformat_id=8 - ).pipeline(las) + writer_params["extra_dims"] = "all" + pipeline = pdal.Writer.las(filename=out_f, **writer_params).pipeline(las) pipeline.execute() log.info("Saved.") diff --git a/myria3d/pctl/dataset/utils.py b/myria3d/pctl/dataset/utils.py index 9d1b269f..73ede8e9 100644 --- a/myria3d/pctl/dataset/utils.py +++ b/myria3d/pctl/dataset/utils.py @@ -85,19 +85,12 @@ def get_pdal_reader(las_path: str, epsg: str) -> pdal.Reader.las: if epsg : # if an epsg in provided, force pdal to read the lidar file with it - try : # epsg can be added as a number like "2154" or as a string like "EPSG:2154" - int(epsg) - return pdal.Reader.las( - filename=las_path, - nosrs=True, - override_srs=f"EPSG:{epsg}", - ) - except ValueError: - return pdal.Reader.las( - filename=las_path, - nosrs=True, - override_srs=epsg, - ) + # epsg can be added as a number like "2154" or as a string like "EPSG:2154" + return pdal.Reader.las( + filename=las_path, + nosrs=True, + override_srs=f"EPSG:{epsg}" if str(epsg).isdigit() else epsg, + ) try : if get_metadata(las_path)['metadata']['readers.las']['srs']['compoundwkt']: diff --git a/tests/myria3d/test_train_and_predict.py b/tests/myria3d/test_train_and_predict.py index 0ef0c388..29818f22 100644 --- a/tests/myria3d/test_train_and_predict.py +++ b/tests/myria3d/test_train_and_predict.py @@ -4,6 +4,8 @@ import numpy as np import pytest from lightning.pytorch.accelerators import find_usable_cuda_devices +from pathlib import Path +from pdaltools import las_info from myria3d.pctl.dataset.toy_dataset import TOY_LAS_DATA @@ -94,6 +96,10 @@ def test_predict_as_command(one_epoch_trained_RandLaNet_checkpoint, tmpdir): "task.task_name=predict", ] run_hydra_decorated_command(command) + output_path = Path(tmpdir) / Path(abs_path_to_toy_LAS).name + metadata = las_info.las_info_metadata(output_path) + out_pesg = las_info.get_epsg_from_header_info(metadata) + assert out_pesg == DEFAULT_EPSG def test_command_without_epsg(one_epoch_trained_RandLaNet_checkpoint, tmpdir):