Skip to content

Commit

Permalink
Fix slicing by stepped range
Browse files Browse the repository at this point in the history
  • Loading branch information
sabiwara committed Aug 28, 2024
1 parent 224f5da commit 50b41ee
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 28 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

## Dev

### Bug fixes

- Fix `Aja.Vector.slice/2` for ranges with step != 1
- Fix `Aja.Enum.count/1` for ranges with step != 1

## v0.7.0 (2024-08-28)

### Enhancements
Expand Down
8 changes: 4 additions & 4 deletions lib/enum.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1196,7 +1196,7 @@ defmodule Aja.Enum do
size = RawVector.size(vector)

if amount < size do
RawVector.slice(vector, 0, amount - 1)
RawVector.slice(vector, 0, amount - 1, 1)
else
RawVector.to_list(vector)
end
Expand All @@ -1207,7 +1207,7 @@ defmodule Aja.Enum do
start = amount + size

if start > 0 do
RawVector.slice(vector, start, size - 1)
RawVector.slice(vector, start, size - 1, 1)
else
RawVector.to_list(vector)
end
Expand All @@ -1233,7 +1233,7 @@ defmodule Aja.Enum do
size = RawVector.size(vector)

if amount < size do
RawVector.slice(vector, amount, size - 1)
RawVector.slice(vector, amount, size - 1, 1)
else
[]
end
Expand All @@ -1244,7 +1244,7 @@ defmodule Aja.Enum do
last = amount + size

if last > 0 do
RawVector.slice(vector, 0, last - 1)
RawVector.slice(vector, 0, last - 1, 1)
else
[]
end
Expand Down
21 changes: 13 additions & 8 deletions lib/vector.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1525,14 +1525,16 @@ defmodule Aja.Vector do
vec([80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90])
iex> Aja.Vector.new(0..100) |> Aja.Vector.slice(-40..-30//1)
vec([61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71])
iex> Aja.Vector.new(0..100) |> Aja.Vector.slice(20..50//10)
vec([20, 30, 40, 50])
iex> Aja.Vector.new([:only_one]) |> Aja.Vector.slice(0..1000)
vec([:only_one])
"""
@spec slice(t(val), Range.t()) :: t(val) when val: value
def slice(%__MODULE__{} = vector, first..last//1 = index_range) do
case first do
0 ->
def slice(%__MODULE__{} = vector, first..last//step = index_range) do
case {first, step} do
{0, 1} ->
amount = last + 1

if last < 0 do
Expand Down Expand Up @@ -1612,7 +1614,7 @@ defmodule Aja.Vector do
case size + amount do
start when start > 0 ->
internal
|> Raw.slice(start, size - 1)
|> Raw.slice(start, size - 1, 1)
|> Raw.from_list()

_ ->
Expand Down Expand Up @@ -1668,7 +1670,7 @@ defmodule Aja.Vector do
@empty_raw
else
internal
|> Raw.slice(amount, size - 1)
|> Raw.slice(amount, size - 1, 1)
|> Raw.from_list()
end
end
Expand Down Expand Up @@ -1763,7 +1765,7 @@ defmodule Aja.Vector do
size = Raw.size(internal)

internal
|> Raw.slice(index, size - 1)
|> Raw.slice(index, size - 1, 1)
|> from_list()
end
end
Expand Down Expand Up @@ -1804,7 +1806,7 @@ defmodule Aja.Vector do

dropped =
internal
|> Raw.slice(index, size - 1)
|> Raw.slice(index, size - 1, 1)
|> from_list()

{taken, dropped}
Expand Down Expand Up @@ -1995,7 +1997,10 @@ defmodule Aja.Vector do
size = Aja.Vector.Raw.size(internal)

{:ok, size,
fn start, length -> Aja.Vector.Raw.slice(internal, start, start + length - 1) end}
fn start, length, step ->
# dbg({start, length, step, size})
Aja.Vector.Raw.slice(internal, start, start + length - 1, step)
end}
end

def reduce(%Aja.Vector{__vector__: internal}, acc, fun) do
Expand Down
21 changes: 11 additions & 10 deletions lib/vector/raw.ex
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ defmodule Aja.Vector.Raw do

_ ->
left = take(vector, index)
[popped | right] = slice(vector, index, size - 1)
[popped | right] = slice(vector, index, size - 1, 1)
new_vector = concat_list(left, right)
{popped, new_vector}
end
Expand All @@ -452,7 +452,7 @@ defmodule Aja.Vector.Raw do

amount ->
left = take(vector, index)
right = slice(vector, amount, size - 1)
right = slice(vector, amount, size - 1, 1)
concat_list(left, right)
end
end
Expand Down Expand Up @@ -906,15 +906,15 @@ defmodule Aja.Vector.Raw do

def map(empty_pattern(), _fun), do: @empty

@compile {:inline, slice: 3}
@spec slice(t(val), non_neg_integer, non_neg_integer) :: [val] when val: value
def slice(vector, start, last)
@compile {:inline, slice: 4}
@spec slice(t(val), non_neg_integer, non_neg_integer, pos_integer) :: [val] when val: value
def slice(vector, start, last, step)

def slice(small(size, tail, _first), start, last) do
Tail.slice(tail, start, last, size)
def slice(small(size, tail, _first), start, last, step) do
Tail.slice(tail, start, last, size, step)
end

def slice(large(size, tail_offset, level, trie, tail, _first), start, last) do
def slice(large(size, tail_offset, level, trie, tail, _first), start, last, step) do
acc =
if last < tail_offset do
[]
Expand All @@ -923,7 +923,8 @@ defmodule Aja.Vector.Raw do
tail,
Kernel.max(0, start - tail_offset),
last - tail_offset,
size - tail_offset
size - tail_offset,
step
)
end

Expand All @@ -934,7 +935,7 @@ defmodule Aja.Vector.Raw do
end
end

def slice(empty_pattern(), _start, _last), do: []
def slice(empty_pattern(), _start, _last, _step), do: []

@compile {:inline, take: 2}
@spec take(t(val), non_neg_integer) :: t(val) when val: value
Expand Down
12 changes: 6 additions & 6 deletions lib/vector/tail.ex
Original file line number Diff line number Diff line change
Expand Up @@ -260,19 +260,19 @@ defmodule Aja.Vector.Tail do
end
end

def slice(tail, start, last, tail_size) do
def slice(tail, start, last, tail_size, step) do
offset = C.branch_factor() - tail_size
do_slice(tail, start + offset, last + offset + 1, [])
do_slice(tail, start + offset, last + offset + 1, step, [])
end

@compile {:inline, do_slice: 4}
defp do_slice(_tail, i, i, acc) do
@compile {:inline, do_slice: 5}
defp do_slice(_tail, i, i, _step, acc) do
acc
end

defp do_slice(tail, start, i, acc) do
defp do_slice(tail, start, i, step, acc) do
new_acc = [:erlang.element(i, tail) | acc]
do_slice(tail, start, i - 1, new_acc)
do_slice(tail, start, i - step, step, new_acc)
end

def partial_map_reduce(tail, _i = C.branch_factor(), acc, _fun), do: {tail, acc}
Expand Down
10 changes: 10 additions & 0 deletions test/vector_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,16 @@ defmodule Aja.VectorTest do
assert Aja.Vector.new([2, 3, 4, 5, 6]) == Aja.Vector.new(1..100) |> Aja.Vector.slice(1..5)
assert Aja.Vector.new([18, 19]) == Aja.Vector.new(1..20) |> Aja.Vector.slice(-3, 2)
assert Aja.Vector.new(2..99) == Aja.Vector.new(1..100) |> Aja.Vector.slice(1..98)

assert Aja.Vector.new([2, 4, 6]) ==
Aja.Vector.new(1..10) |> Aja.Vector.slice(1..5//2)

assert Aja.Vector.new([2, 4, 6]) ==
Aja.Vector.new(1..10) |> Aja.Vector.slice(1..6//2)

assert_raise ArgumentError, fn ->
Aja.Vector.new(1..10) |> Aja.Vector.slice(1..5//-2)
end
end

test "take/2" do
Expand Down

0 comments on commit 50b41ee

Please sign in to comment.