From 50b41ee6ade5ced0ae7e2af760ac50b94c5ad561 Mon Sep 17 00:00:00 2001 From: sabiwara Date: Sat, 25 May 2024 11:18:17 +0900 Subject: [PATCH] Fix slicing by stepped range --- CHANGELOG.md | 5 +++++ lib/enum.ex | 8 ++++---- lib/vector.ex | 21 +++++++++++++-------- lib/vector/raw.ex | 21 +++++++++++---------- lib/vector/tail.ex | 12 ++++++------ test/vector_test.exs | 10 ++++++++++ 6 files changed, 49 insertions(+), 28 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 12ea233..c5f3794 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,11 @@ ## Dev +### Bug fixes + +- Fix `Aja.Vector.slice/2` for ranges with step != 1 +- Fix `Aja.Enum.count/1` for ranges with step != 1 + ## v0.7.0 (2024-08-28) ### Enhancements diff --git a/lib/enum.ex b/lib/enum.ex index 8139d2f..7cbdc04 100644 --- a/lib/enum.ex +++ b/lib/enum.ex @@ -1196,7 +1196,7 @@ defmodule Aja.Enum do size = RawVector.size(vector) if amount < size do - RawVector.slice(vector, 0, amount - 1) + RawVector.slice(vector, 0, amount - 1, 1) else RawVector.to_list(vector) end @@ -1207,7 +1207,7 @@ defmodule Aja.Enum do start = amount + size if start > 0 do - RawVector.slice(vector, start, size - 1) + RawVector.slice(vector, start, size - 1, 1) else RawVector.to_list(vector) end @@ -1233,7 +1233,7 @@ defmodule Aja.Enum do size = RawVector.size(vector) if amount < size do - RawVector.slice(vector, amount, size - 1) + RawVector.slice(vector, amount, size - 1, 1) else [] end @@ -1244,7 +1244,7 @@ defmodule Aja.Enum do last = amount + size if last > 0 do - RawVector.slice(vector, 0, last - 1) + RawVector.slice(vector, 0, last - 1, 1) else [] end diff --git a/lib/vector.ex b/lib/vector.ex index b7a624f..e0892d0 100644 --- a/lib/vector.ex +++ b/lib/vector.ex @@ -1525,14 +1525,16 @@ defmodule Aja.Vector do vec([80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90]) iex> Aja.Vector.new(0..100) |> Aja.Vector.slice(-40..-30//1) vec([61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71]) + iex> Aja.Vector.new(0..100) |> Aja.Vector.slice(20..50//10) + vec([20, 30, 40, 50]) iex> Aja.Vector.new([:only_one]) |> Aja.Vector.slice(0..1000) vec([:only_one]) """ @spec slice(t(val), Range.t()) :: t(val) when val: value - def slice(%__MODULE__{} = vector, first..last//1 = index_range) do - case first do - 0 -> + def slice(%__MODULE__{} = vector, first..last//step = index_range) do + case {first, step} do + {0, 1} -> amount = last + 1 if last < 0 do @@ -1612,7 +1614,7 @@ defmodule Aja.Vector do case size + amount do start when start > 0 -> internal - |> Raw.slice(start, size - 1) + |> Raw.slice(start, size - 1, 1) |> Raw.from_list() _ -> @@ -1668,7 +1670,7 @@ defmodule Aja.Vector do @empty_raw else internal - |> Raw.slice(amount, size - 1) + |> Raw.slice(amount, size - 1, 1) |> Raw.from_list() end end @@ -1763,7 +1765,7 @@ defmodule Aja.Vector do size = Raw.size(internal) internal - |> Raw.slice(index, size - 1) + |> Raw.slice(index, size - 1, 1) |> from_list() end end @@ -1804,7 +1806,7 @@ defmodule Aja.Vector do dropped = internal - |> Raw.slice(index, size - 1) + |> Raw.slice(index, size - 1, 1) |> from_list() {taken, dropped} @@ -1995,7 +1997,10 @@ defmodule Aja.Vector do size = Aja.Vector.Raw.size(internal) {:ok, size, - fn start, length -> Aja.Vector.Raw.slice(internal, start, start + length - 1) end} + fn start, length, step -> + # dbg({start, length, step, size}) + Aja.Vector.Raw.slice(internal, start, start + length - 1, step) + end} end def reduce(%Aja.Vector{__vector__: internal}, acc, fun) do diff --git a/lib/vector/raw.ex b/lib/vector/raw.ex index 96e7eab..57bb07f 100644 --- a/lib/vector/raw.ex +++ b/lib/vector/raw.ex @@ -439,7 +439,7 @@ defmodule Aja.Vector.Raw do _ -> left = take(vector, index) - [popped | right] = slice(vector, index, size - 1) + [popped | right] = slice(vector, index, size - 1, 1) new_vector = concat_list(left, right) {popped, new_vector} end @@ -452,7 +452,7 @@ defmodule Aja.Vector.Raw do amount -> left = take(vector, index) - right = slice(vector, amount, size - 1) + right = slice(vector, amount, size - 1, 1) concat_list(left, right) end end @@ -906,15 +906,15 @@ defmodule Aja.Vector.Raw do def map(empty_pattern(), _fun), do: @empty - @compile {:inline, slice: 3} - @spec slice(t(val), non_neg_integer, non_neg_integer) :: [val] when val: value - def slice(vector, start, last) + @compile {:inline, slice: 4} + @spec slice(t(val), non_neg_integer, non_neg_integer, pos_integer) :: [val] when val: value + def slice(vector, start, last, step) - def slice(small(size, tail, _first), start, last) do - Tail.slice(tail, start, last, size) + def slice(small(size, tail, _first), start, last, step) do + Tail.slice(tail, start, last, size, step) end - def slice(large(size, tail_offset, level, trie, tail, _first), start, last) do + def slice(large(size, tail_offset, level, trie, tail, _first), start, last, step) do acc = if last < tail_offset do [] @@ -923,7 +923,8 @@ defmodule Aja.Vector.Raw do tail, Kernel.max(0, start - tail_offset), last - tail_offset, - size - tail_offset + size - tail_offset, + step ) end @@ -934,7 +935,7 @@ defmodule Aja.Vector.Raw do end end - def slice(empty_pattern(), _start, _last), do: [] + def slice(empty_pattern(), _start, _last, _step), do: [] @compile {:inline, take: 2} @spec take(t(val), non_neg_integer) :: t(val) when val: value diff --git a/lib/vector/tail.ex b/lib/vector/tail.ex index 3602056..3fac770 100644 --- a/lib/vector/tail.ex +++ b/lib/vector/tail.ex @@ -260,19 +260,19 @@ defmodule Aja.Vector.Tail do end end - def slice(tail, start, last, tail_size) do + def slice(tail, start, last, tail_size, step) do offset = C.branch_factor() - tail_size - do_slice(tail, start + offset, last + offset + 1, []) + do_slice(tail, start + offset, last + offset + 1, step, []) end - @compile {:inline, do_slice: 4} - defp do_slice(_tail, i, i, acc) do + @compile {:inline, do_slice: 5} + defp do_slice(_tail, i, i, _step, acc) do acc end - defp do_slice(tail, start, i, acc) do + defp do_slice(tail, start, i, step, acc) do new_acc = [:erlang.element(i, tail) | acc] - do_slice(tail, start, i - 1, new_acc) + do_slice(tail, start, i - step, step, new_acc) end def partial_map_reduce(tail, _i = C.branch_factor(), acc, _fun), do: {tail, acc} diff --git a/test/vector_test.exs b/test/vector_test.exs index ac93b4c..c7e1246 100644 --- a/test/vector_test.exs +++ b/test/vector_test.exs @@ -443,6 +443,16 @@ defmodule Aja.VectorTest do assert Aja.Vector.new([2, 3, 4, 5, 6]) == Aja.Vector.new(1..100) |> Aja.Vector.slice(1..5) assert Aja.Vector.new([18, 19]) == Aja.Vector.new(1..20) |> Aja.Vector.slice(-3, 2) assert Aja.Vector.new(2..99) == Aja.Vector.new(1..100) |> Aja.Vector.slice(1..98) + + assert Aja.Vector.new([2, 4, 6]) == + Aja.Vector.new(1..10) |> Aja.Vector.slice(1..5//2) + + assert Aja.Vector.new([2, 4, 6]) == + Aja.Vector.new(1..10) |> Aja.Vector.slice(1..6//2) + + assert_raise ArgumentError, fn -> + Aja.Vector.new(1..10) |> Aja.Vector.slice(1..5//-2) + end end test "take/2" do