Skip to content

Commit

Permalink
Pass input file parameters to las writer
Browse files Browse the repository at this point in the history
  • Loading branch information
leavauchier committed Feb 13, 2024
1 parent d7b38bc commit 7511723
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 22 deletions.
10 changes: 6 additions & 4 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ dependencies:
- numpy
- h5py
# --------- geo --------- #
- pdal
- pdal>=2.6.0
- python-pdal
- pyproj
# --------- Visualization --------- #
- conda-forge:gdal
- conda_forge:pyproj
# --------- Visualization --------- #
- pandas
- matplotlib
# --------- loggers --------- #
Expand Down Expand Up @@ -68,4 +69,5 @@ dependencies:
# --------- Documentation --------- #
- myst_parser==0.17.*
- sphinxnotes-mock==1.0.0b0 # still a beta
- sphinx_paramlinks==0.5.*
- sphinx_paramlinks==0.5.*
- ign-pdal-tools>=1.5.2
14 changes: 9 additions & 5 deletions myria3d/models/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from torch.distributions import Categorical
from torch_scatter import scatter_sum

from pdaltools import las_info

from myria3d.pctl.dataset.utils import get_pdal_info_metadata, get_pdal_reader

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -83,7 +85,10 @@ def load_full_las_for_update(self, src_las: str, epsg: str) -> np.ndarray:
pipeline |= pdal.Filter.assign(value=f"{self.entropy_channel}=0")

pipeline.execute()
return pipeline.arrays[0]
writer_params = las_info.get_writer_parameters_from_reader_metadata(
pipeline.metadata, a_srs=f"EPSG:{epsg}" if str(epsg).isdigit() else epsg
)
return pipeline.arrays[0], writer_params

def store_predictions(self, logits, idx_in_original_cloud) -> None:
"""Keep a list of predictions made so far."""
Expand Down Expand Up @@ -143,7 +148,7 @@ def reduce_predictions_and_save(self, raw_path: str, output_dir: str, epsg: str)
del logits

# Read las after fetching all information to write into it
las = self.load_full_las_for_update(raw_path, epsg)
las, writer_params = self.load_full_las_for_update(raw_path, epsg)

for idx, class_name in enumerate(self.classification_dict.values()):
if class_name in self.probas_to_save:
Expand Down Expand Up @@ -173,9 +178,8 @@ def reduce_predictions_and_save(self, raw_path: str, output_dir: str, epsg: str)
out_f = os.path.abspath(out_f)
log.info(f"Updated LAS ({basename}) will be saved to: \n {output_dir}\n")
log.info("Saving...")
pipeline = pdal.Writer.las(
filename=out_f, extra_dims="all", minor_version=4, dataformat_id=8
).pipeline(las)
writer_params["extra_dims"] = "all"
pipeline = pdal.Writer.las(filename=out_f, **writer_params).pipeline(las)
pipeline.execute()
log.info("Saved.")

Expand Down
19 changes: 6 additions & 13 deletions myria3d/pctl/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,12 @@ def get_pdal_reader(las_path: str, epsg: str) -> pdal.Reader.las:

if epsg :
# if an epsg in provided, force pdal to read the lidar file with it
try : # epsg can be added as a number like "2154" or as a string like "EPSG:2154"
int(epsg)
return pdal.Reader.las(
filename=las_path,
nosrs=True,
override_srs=f"EPSG:{epsg}",
)
except ValueError:
return pdal.Reader.las(
filename=las_path,
nosrs=True,
override_srs=epsg,
)
# epsg can be added as a number like "2154" or as a string like "EPSG:2154"
return pdal.Reader.las(
filename=las_path,
nosrs=True,
override_srs=f"EPSG:{epsg}" if str(epsg).isdigit() else epsg,
)

try :
if get_metadata(las_path)['metadata']['readers.las']['srs']['compoundwkt']:
Expand Down
6 changes: 6 additions & 0 deletions tests/myria3d/test_train_and_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import numpy as np
import pytest
from lightning.pytorch.accelerators import find_usable_cuda_devices
from pathlib import Path
from pdaltools import las_info


from myria3d.pctl.dataset.toy_dataset import TOY_LAS_DATA
Expand Down Expand Up @@ -94,6 +96,10 @@ def test_predict_as_command(one_epoch_trained_RandLaNet_checkpoint, tmpdir):
"task.task_name=predict",
]
run_hydra_decorated_command(command)
output_path = Path(tmpdir) / Path(abs_path_to_toy_LAS).name
metadata = las_info.las_info_metadata(output_path)
out_pesg = las_info.get_epsg_from_header_info(metadata)
assert out_pesg == DEFAULT_EPSG


def test_command_without_epsg(one_epoch_trained_RandLaNet_checkpoint, tmpdir):
Expand Down

0 comments on commit 7511723

Please sign in to comment.