Skip to content

Commit

Permalink
Fix edgecase with 1 input value that is not selected in subsample_data
Browse files Browse the repository at this point in the history
  • Loading branch information
leavauchier committed Mar 12, 2024
1 parent 83b6b21 commit ef43f34
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# CHANGELOG
### 3.8.2
- fix: type error in edge case when dropping points in transforms
- fix: points not dropped case when dropping points in transforms

### 3.8.1
- fix: propagate input las format to output las (in particular epsg which comes either from input or config)
Expand Down
6 changes: 3 additions & 3 deletions myria3d/pctl/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ def subsample_data(data, num_nodes, choice):
data.num_nodes = choice.size(0)
elif bool(re.search("edge", key)):
continue
elif torch.is_tensor(item) and item.size(0) == num_nodes and item.size(0) != 1:
data[key] = item[choice]
elif isinstance(item, np.ndarray) and item.shape[0] == num_nodes and item.shape[0] != 1:
elif torch.is_tensor(item) and item.size(0) == num_nodes:
data[key] = item[choice]
elif isinstance(item, np.ndarray) and item.shape[0] == num_nodes:
data[key] = item[np.array(choice)]

return data

Expand Down
46 changes: 45 additions & 1 deletion tests/myria3d/pctl/transforms/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,51 @@
import torch
import torch_geometric

from myria3d.pctl.transforms.transforms import DropPointsByClass, TargetTransform
from myria3d.pctl.transforms.transforms import (
DropPointsByClass,
TargetTransform,
subsample_data,
)


@pytest.mark.parametrize(
"x,idx,choice",
[
# Standard use case
(
torch.Tensor([10, 11, 12, 13, 14]),
np.array([20, 21, 22, 23, 24]),
torch.IntTensor([0, 1, 4]),
),
# Edge cases
(
torch.Tensor([10, 11, 12, 13, 14]),
np.array([20, 21, 22, 23, 24]),
torch.IntTensor([]),
),
(
torch.Tensor([10]),
np.array([20]),
torch.IntTensor([0]),
),
(
torch.Tensor([10]),
np.array([20]),
torch.IntTensor([]),
),
],
)
def test_subsample_data(x, idx, choice):
# points w.
num_nodes = x.size(0)
data = torch_geometric.data.Data(x=x, idx=idx, num_nodes=num_nodes)
transformed_data = subsample_data(data, num_nodes, choice)
out_num_nodes = choice.size(0)
assert transformed_data.num_nodes == out_num_nodes
assert isinstance(transformed_data.x, torch.Tensor)
assert transformed_data.x.size(0) == out_num_nodes
assert isinstance(transformed_data.idx, np.ndarray)
assert transformed_data.idx.shape[0] == out_num_nodes


def test_TargetTransform_with_valid_config():
Expand Down

0 comments on commit ef43f34

Please sign in to comment.