Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix classes removal in standardisation #63

Merged
merged 2 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .github/workflows/cicd_light.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@ on:


jobs:
docker_build_and_test:
runs-on: ubuntu-latest
test_light:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, macos-latest]
permissions:
contents: read
packages: write
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def compute_count_one_file(filepath: str, attribute: str = "Classification") ->
pipeline |= pdal.Filter.stats(dimensions=attribute, count=attribute)
pipeline.execute()
# List of "class/count" on the only dimension that is counted
raw_counts = pipeline.metadata["metadata"]["filters.stats"]["statistic"][0]["counts"]
raw_counts = pipeline.metadata["metadata"]["filters.stats"]["statistic"][0].get("counts", [])
split_counts = [c.split("/") for c in raw_counts]
try:
# Try to prettify the value by converting it to an integer (eg. for Classification that
Expand Down
42 changes: 10 additions & 32 deletions pdaltools/standardize_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,14 @@

import argparse
import os
import platform
leavauchier marked this conversation as resolved.
Show resolved Hide resolved
import subprocess as sp
import tempfile
import platform
import numpy as np
from typing import Dict
from typing import Dict, List

import pdal

from pdaltools.unlock_file import copy_and_hack_decorator
from pdaltools.las_info import get_writer_parameters_from_reader_metadata

STANDARD_PARAMETERS = dict(
major_version="1",
Expand Down Expand Up @@ -74,37 +72,17 @@ def get_writer_parameters(new_parameters: Dict) -> Dict:
return params


def remove_points_from_class(points, class_points_removed: []) :
input_dimensions = list(points.dtype.fields.keys())
dim_class = input_dimensions.index("Classification")

indice_pts_delete = [id for id in range(0, len(points)) if points[id][dim_class] in class_points_removed]
points_preserved = np.delete(points, indice_pts_delete)

if len(points_preserved) == 0:
raise Exception("All points removed !")

return points_preserved


def rewrite_with_pdal(input_file: str, output_file: str, params_from_parser: Dict, class_points_removed: []) -> None:
# Update parameters with command line values
def rewrite_with_pdal(
input_file: str, output_file: str, params_from_parser: Dict, classes_to_remove: List = []
) -> None:
params = get_writer_parameters(params_from_parser)
pipeline = pdal.Pipeline()
pipeline |= pdal.Reader.las(input_file)
if classes_to_remove:
expression = "&&".join([f"Classification != {c}" for c in classes_to_remove])
pipeline |= pdal.Filter.expression(expression=expression)
pipeline |= pdal.Writer(filename=output_file, forward="all", **params)
pipeline.execute()
points = pipeline.arrays[0]

if class_points_removed:
points = remove_points_from_class(points, class_points_removed)

#ToDo : it seems that the forward="all" doesn't work because we use a new pipeline
# since we create a new pipeline, the 2 metadatas creation_doy and creation_year are update
# to current date instead of forwarded from input LAS

params = get_writer_parameters(params_from_parser)
pipeline_end = pdal.Pipeline(arrays=[points])
pipeline_end |= pdal.Writer.las(output_file, forward="all", **params)
pipeline_end.execute()


def exec_las2las(input_file: str, output_file: str):
Expand Down
112 changes: 45 additions & 67 deletions test/test_standardize_format.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
import logging
import os
import platform
import shutil
import subprocess as sp
import platform
import json
from test.utils import EXPECTED_DIMS_BY_DATAFORMAT, get_pdal_infos_summary

import pdal
import pytest

from pdaltools.standardize_format import exec_las2las, rewrite_with_pdal, standardize, remove_points_from_class
from pdaltools.count_occurences.count_occurences_for_attribute import (
compute_count_one_file,
)
from pdaltools.standardize_format import exec_las2las, rewrite_with_pdal, standardize

TEST_PATH = os.path.dirname(os.path.abspath(__file__))
TMP_PATH = os.path.join(TEST_PATH, "tmp")
INPUT_DIR = os.path.join(TEST_PATH, "data")

DEFAULT_PARAMS = {"dataformat_id": 6, "a_srs": "EPSG:2154", "extra_dims": []}

MUTLIPLE_PARAMS = [
{"dataformat_id": 6, "a_srs": "EPSG:2154", "extra_dims": []},
DEFAULT_PARAMS,
{"dataformat_id": 8, "a_srs": "EPSG:4326", "extra_dims": []},
{"dataformat_id": 8, "a_srs": "EPSG:2154", "extra_dims": ["dtm_marker=double", "dsm_marker=double"]},
{"dataformat_id": 8, "a_srs": "EPSG:2154", "extra_dims": "all"},
Expand All @@ -32,7 +36,18 @@ def setup_module(module):
os.mkdir(TMP_PATH)


def _test_standardize_format_one_params_set(input_file, output_file, params):
@pytest.mark.parametrize(
"params",
[
DEFAULT_PARAMS,
{"dataformat_id": 8, "a_srs": "EPSG:4326", "extra_dims": []},
{"dataformat_id": 8, "a_srs": "EPSG:2154", "extra_dims": ["dtm_marker=double", "dsm_marker=double"]},
{"dataformat_id": 8, "a_srs": "EPSG:2154", "extra_dims": "all"},
],
)
def test_standardize_format(params):
input_file = os.path.join(INPUT_DIR, "test_data_77055_627755_LA93_IGN69_extra_dims.laz")
output_file = os.path.join(TMP_PATH, "formatted.laz")
rewrite_with_pdal(input_file, output_file, params, [])
# check file exists
assert os.path.isfile(output_file)
Expand All @@ -56,15 +71,35 @@ def _test_standardize_format_one_params_set(input_file, output_file, params):
extra_dims_names = [dim.split("=")[0] for dim in params["extra_dims"]]
assert dimensions == EXPECTED_DIMS_BY_DATAFORMAT[params["dataformat_id"]].union(extra_dims_names)

# Check that there is the expected number of points for each class
expected_points_counts = compute_count_one_file(input_file)

output_points_counts = compute_count_one_file(output_file)
assert output_points_counts == expected_points_counts

# TODO: Check srs
# TODO: check precision


def test_standardize_format():
@pytest.mark.parametrize(
"classes_to_remove",
[
[],
[2, 3],
[1, 2, 3, 4, 5, 6, 64], # remove all classes
],
)
def test_standardize_classes(classes_to_remove):
input_file = os.path.join(INPUT_DIR, "test_data_77055_627755_LA93_IGN69_extra_dims.laz")
output_file = os.path.join(TMP_PATH, "formatted.laz")
for params in MUTLIPLE_PARAMS:
_test_standardize_format_one_params_set(input_file, output_file, params)
rewrite_with_pdal(input_file, output_file, DEFAULT_PARAMS, classes_to_remove)
# Check that there is the expected number of points for each class
expected_points_counts = compute_count_one_file(input_file)
for cl in classes_to_remove:
expected_points_counts.pop(str(cl))

output_points_counts = compute_count_one_file(output_file)
assert output_points_counts == expected_points_counts


def exec_lasinfo(input_file: str):
Expand Down Expand Up @@ -108,74 +143,17 @@ def test_standardize_does_NOT_produce_any_warning_with_Lasinfo():
# if you want to see input_file warnings
# assert_lasinfo_no_warning(input_file)

standardize(input_file, output_file, MUTLIPLE_PARAMS[0], [])
standardize(input_file, output_file, DEFAULT_PARAMS, [])
assert_lasinfo_no_warning(output_file)


def test_standardize_malformed_laz():
input_file = os.path.join(TEST_PATH, "data/test_pdalfail_0643_6319_LA93_IGN69.laz")
output_file = os.path.join(TMP_PATH, "standardize_pdalfail_0643_6319_LA93_IGN69.laz")
standardize(input_file, output_file, MUTLIPLE_PARAMS[0], [])
standardize(input_file, output_file, DEFAULT_PARAMS, [])
assert os.path.isfile(output_file)


def get_pipeline_metadata_cross_plateform(pipeline):
try:
metadata = json.loads(pipeline.metadata)
except TypeError:
d_metadata = json.dumps(pipeline.metadata)
metadata = json.loads(d_metadata)
return metadata

def get_statistics_from_las_points(points):
pipeline = pdal.Pipeline(arrays=[points])
pipeline |= pdal.Filter.stats(dimensions="Classification", enumerate="Classification")
pipeline.execute()
metadata = get_pipeline_metadata_cross_plateform(pipeline)
statistic = metadata["metadata"]["filters.stats"]["statistic"]
return statistic[0]["count"], statistic[0]["values"]

@pytest.mark.parametrize(
"classes_to_remove",
[
[2, 3],
[2, 3, 4],
[0, 1, 2, 3, 4, 5, 6],
],
)
def test_remove_points_from_class(classes_to_remove):
input_file = os.path.join(TEST_PATH, "data/classified_laz/test_data_77050_627755_LA93_IGN69.laz")
output_file = os.path.join(TMP_PATH, "test_remove_points_from_class.laz")

# count points of class not in classes_to_remove (get the point we should have in fine)
pipeline = pdal.Pipeline() | pdal.Reader.las(input_file)

where = ' && '.join(["CLassification != " + str(cl) for cl in classes_to_remove])
pipeline |= pdal.Filter.stats(dimensions="Classification", enumerate="Classification", where=where)
pipeline.execute()

points = pipeline.arrays[0]
nb_points_before, class_before = get_statistics_from_las_points(points)

metadata = get_pipeline_metadata_cross_plateform(pipeline)
statistic = metadata["metadata"]["filters.stats"]["statistic"]
nb_points_to_get = statistic[0]["count"]

try:
points = remove_points_from_class(points, classes_to_remove)
except Exception as error: # error because all points are removed
assert nb_points_to_get == 0
return

nb_points_after, class_after = get_statistics_from_las_points(points)

assert nb_points_before > 0
assert nb_points_before > nb_points_after
assert set(classes_to_remove).issubset(set(class_before))
assert not set(classes_to_remove).issubset(set(class_after))
assert nb_points_after == nb_points_to_get


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
test_standardize_format()