Skip to content

Commit

Permalink
Vendor take_random and shuffle for deterministic CI
Browse files Browse the repository at this point in the history
  • Loading branch information
sabiwara committed Aug 28, 2024
1 parent 00e84da commit d8d10c7
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 7 deletions.
5 changes: 3 additions & 2 deletions lib/enum.ex
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ defmodule Aja.Enum do

require Aja.Vector.Raw, as: RawVector
alias Aja.EnumHelper, as: H
alias Aja.RandomHelper

@compile :inline_list_funcs

Expand Down Expand Up @@ -826,7 +827,7 @@ defmodule Aja.Enum do
def take_random(enumerable, count) do
enumerable
|> H.to_list()
|> Enum.take_random(count)
|> RandomHelper.take_random(count)
end

@doc """
Expand All @@ -838,7 +839,7 @@ defmodule Aja.Enum do
def shuffle(enumerable) do
enumerable
|> H.to_list()
|> Enum.shuffle()
|> RandomHelper.shuffle()
end

# UNIQ
Expand Down
124 changes: 124 additions & 0 deletions lib/helpers/random_helper.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
defmodule Aja.RandomHelper do
@moduledoc false

# TODO Remove this module when dropping support for Elixir 1.16

# vendoring Elixir's 1.17 implementation of take_random/2 and shuffle/1 which are:
# - faster
# - necessary to keep CI deterministic across multiple Elixir versions

def shuffle(enumerable) do
randomized =
Enum.reduce(enumerable, [], fn x, acc ->
[{:rand.uniform(), x} | acc]
end)

shuffle_unwrap(:lists.keysort(1, randomized))
end

defp shuffle_unwrap([{_, h} | rest]), do: [h | shuffle_unwrap(rest)]
defp shuffle_unwrap([]), do: []

def take_random(enumerable, count)
def take_random(_enumerable, 0), do: []
def take_random([], _), do: []

def take_random(enumerable, 1) do
enumerable
|> Enum.reduce({0, 0, 1.0, nil}, fn
elem, {idx, idx, w, _current} ->
{jdx, w} = take_jdx_w(idx, w, 1)
{idx + 1, jdx, w, elem}

_elem, {idx, jdx, w, current} ->
{idx + 1, jdx, w, current}
end)
|> case do
{0, 0, 1.0, nil} -> []
{_idx, _jdx, _w, current} -> [current]
end
end

def take_random(enumerable, count) when count in 0..128 do
sample = Tuple.duplicate(nil, count)

reducer = fn
elem, {idx, jdx, w, sample} when idx < count ->
rand = take_index(idx)
sample = sample |> put_elem(idx, elem(sample, rand)) |> put_elem(rand, elem)

if idx == jdx do
{jdx, w} = take_jdx_w(idx, w, count)
{idx + 1, jdx, w, sample}
else
{idx + 1, jdx, w, sample}
end

elem, {idx, idx, w, sample} ->
pos = :rand.uniform(count) - 1
{jdx, w} = take_jdx_w(idx, w, count)
{idx + 1, jdx, w, put_elem(sample, pos, elem)}

_elem, {idx, jdx, w, sample} ->
{idx + 1, jdx, w, sample}
end

{size, _, _, sample} = Enum.reduce(enumerable, {0, count - 1, 1.0, sample}, reducer)

if count < size do
Tuple.to_list(sample)
else
take_tupled(sample, size, [])
end
end

def take_random(enumerable, count) when is_integer(count) and count >= 0 do
reducer = fn
elem, {idx, jdx, w, sample} when idx < count ->
rand = take_index(idx)
sample = sample |> Map.put(idx, Map.get(sample, rand)) |> Map.put(rand, elem)

if idx == jdx do
{jdx, w} = take_jdx_w(idx, w, count)
{idx + 1, jdx, w, sample}
else
{idx + 1, jdx, w, sample}
end

elem, {idx, idx, w, sample} ->
pos = :rand.uniform(count) - 1
{jdx, w} = take_jdx_w(idx, w, count)
{idx + 1, jdx, w, %{sample | pos => elem}}

_elem, {idx, jdx, w, sample} ->
{idx + 1, jdx, w, sample}
end

{size, _, _, sample} = Enum.reduce(enumerable, {0, count - 1, 1.0, %{}}, reducer)
take_mapped(sample, Kernel.min(count, size), [])
end

@compile {:inline, take_jdx_w: 3, take_index: 1}
defp take_jdx_w(idx, w, count) do
w = w * :math.exp(:math.log(:rand.uniform()) / count)
jdx = idx + floor(:math.log(:rand.uniform()) / :math.log(1 - w)) + 1
{jdx, w}
end

defp take_index(0), do: 0
defp take_index(idx), do: :rand.uniform(idx + 1) - 1

defp take_tupled(_sample, 0, acc), do: acc

defp take_tupled(sample, position, acc) do
position = position - 1
take_tupled(sample, position, [elem(sample, position) | acc])
end

defp take_mapped(_sample, 0, acc), do: acc

defp take_mapped(sample, position, acc) do
position = position - 1
take_mapped(sample, position, [Map.fetch!(sample, position) | acc])
end
end
3 changes: 2 additions & 1 deletion lib/vector.ex
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ defmodule Aja.Vector do
"""

alias Aja.RandomHelper
alias Aja.Vector.{EmptyError, IndexError, Raw}
require Raw

Expand Down Expand Up @@ -1895,7 +1896,7 @@ defmodule Aja.Vector do
# Note: benchmarks suggest that this is already fast without further optimization
internal
|> Raw.to_list()
|> Enum.shuffle()
|> RandomHelper.shuffle()
|> from_list()
end

Expand Down
5 changes: 3 additions & 2 deletions lib/vector/raw.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ defmodule Aja.Vector.Raw do

require Aja.Vector.CodeGen, as: C

alias Aja.RandomHelper
alias Aja.Vector.{Builder, Node, Tail, Trie}

@empty {0}
Expand Down Expand Up @@ -1049,11 +1050,11 @@ defmodule Aja.Vector.Raw do
end

def take_random(vector, amount) when amount >= size(vector) do
vector |> to_list() |> Enum.shuffle() |> from_list()
vector |> to_list() |> RandomHelper.shuffle() |> from_list()
end

def take_random(vector, amount) do
vector |> to_list() |> Enum.take_random(amount) |> from_list()
vector |> to_list() |> RandomHelper.take_random(amount) |> from_list()
end

def scan(vector, fun) do
Expand Down
8 changes: 6 additions & 2 deletions test/enum_prop_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ defmodule Aja.Enum.PropTest do
import Aja.TestHelpers
import Aja.TestDataGenerators

alias Aja.RandomHelper

@moduletag timeout: :infinity
@moduletag :property

Expand Down Expand Up @@ -410,7 +412,8 @@ defmodule Aja.Enum.PropTest do
assert ^sorted_by = Aja.Enum.sort_by(stream, fun)
assert Enum.sort_by(map_set, fun) === Aja.Enum.sort_by(map_set, fun)

shuffled = with_seed(fn -> Enum.shuffle(list) end)
# TODO Replace RandomHelper by Enum when dropping support for Elixir 1.16
shuffled = with_seed(fn -> RandomHelper.shuffle(list) end)
assert ^shuffled = with_seed(fn -> Aja.Enum.shuffle(list) end)
assert ^shuffled = with_seed(fn -> Aja.Enum.shuffle(vector) end)
assert ^shuffled = with_seed(fn -> Aja.Enum.shuffle(stream) end)
Expand All @@ -425,7 +428,8 @@ defmodule Aja.Enum.PropTest do
assert with_seed(fn -> Enum.random(map_set) |> capture_error() end) ===
with_seed(fn -> Aja.Enum.random(map_set) |> capture_error() end)

rand_taken = with_seed(fn -> Enum.take_random(list, amount) end)
# TODO Replace RandomHelper by Enum when dropping support for Elixir 1.16
rand_taken = with_seed(fn -> RandomHelper.take_random(list, amount) end)
assert ^rand_taken = with_seed(fn -> Aja.Enum.take_random(list, amount) end)
assert ^rand_taken = with_seed(fn -> Aja.Enum.take_random(vector, amount) end)
assert ^rand_taken = with_seed(fn -> Aja.Enum.take_random(stream, amount) end)
Expand Down

0 comments on commit d8d10c7

Please sign in to comment.