Skip to content

Commit

Permalink
Merge pull request #6 from JuliaTrustworthyAI/5-stop-relying-on-ce-fo…
Browse files Browse the repository at this point in the history
…r-functionality

trying
  • Loading branch information
pat-alt authored Apr 5, 2024
2 parents e2e431a + be8d227 commit 5ec77f4
Show file tree
Hide file tree
Showing 24 changed files with 364 additions and 371 deletions.
25 changes: 23 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,21 +1,42 @@
name = "TaijaParallel"
uuid = "bf1c2c22-5e42-4e78-8b6b-92e6c673eeb0"
authors = ["Patrick Altmeyer <[email protected]>"]
version = "0.1.0"
version = "1.0.0"

[deps]
CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
TaijaBase = "10284c91-9f28-4c9a-abbf-ee43576dfff6"

[weakdeps]
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"

[extensions]
MPIExt = "MPI"

[compat]
Aqua = "0.8"
CounterfactualExplanations = "0.1"
Logging = "1.7, 1.8, 1.9, 1.10"
MLUtils = "0.4.4"
MPI = "0.20"
PackageExtensionCompat = "1"
ProgressMeter = "1"
Reexport = "1"
Serialization = "1.7, 1.8, 1.9, 1.10"
TaijaBase = "1"
Test = "1.7, 1.8, 1.9, 1.10"
julia = "1.7, 1.8, 1.9, 1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["Aqua", "MPI", "Test"]
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
[![Build Status](https://github.com/JuliaTrustworthyAI/TaijaParallel.jl/actions/workflows/CI.yml/badge.svg?branch=master)](https://github.com/JuliaTrustworthyAI/TaijaParallel.jl/actions/workflows/CI.yml?query=branch%3Amaster)
[![Coverage](https://codecov.io/gh/JuliaTrustworthyAI/TaijaParallel.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/JuliaTrustworthyAI/TaijaParallel.jl)
[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle)
[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)

This package adds custom support for parallelization for certain [Taija](https://github.com/JuliaTrustworthyAI) packages.

Expand Down
29 changes: 12 additions & 17 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,20 @@
using TaijaParallel
using Documenter

DocMeta.setdocmeta!(TaijaParallel, :DocTestSetup, :(using TaijaParallel); recursive=true)
DocMeta.setdocmeta!(TaijaParallel, :DocTestSetup, :(using TaijaParallel); recursive = true)

makedocs(;
modules=[TaijaParallel],
authors="Patrick Altmeyer",
repo="https://github.com/JuliaTrustworthyAI/TaijaParallel.jl/blob/{commit}{path}#{line}",
sitename="TaijaParallel.jl",
format=Documenter.HTML(;
prettyurls=get(ENV, "CI", "false") == "true",
canonical="https://JuliaTrustworthyAI.github.io/TaijaParallel.jl",
edit_link="main",
assets=String[],
modules = [TaijaParallel],
authors = "Patrick Altmeyer",
repo = "https://github.com/JuliaTrustworthyAI/TaijaParallel.jl/blob/{commit}{path}#{line}",
sitename = "TaijaParallel.jl",
format = Documenter.HTML(;
prettyurls = get(ENV, "CI", "false") == "true",
canonical = "https://JuliaTrustworthyAI.github.io/TaijaParallel.jl",
edit_link = "main",
assets = String[],
),
pages=[
"Home" => "index.md",
],
pages = ["Home" => "index.md"],
)

deploydocs(;
repo="github.com/JuliaTrustworthyAI/TaijaParallel.jl",
devbranch="main",
)
deploydocs(; repo = "github.com/JuliaTrustworthyAI/TaijaParallel.jl", devbranch = "main")
9 changes: 6 additions & 3 deletions ext/MPIExt/MPIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ module MPIExt

export MPIParallelizer

using TaijaParallel
using Logging
using MPI
using ProgressMeter
using TaijaBase
using TaijaParallel

"The `MPIParallelizer` type is used to parallelize the evaluation of a function using `MPI.jl`."
struct MPIParallelizer <: AbstractParallelizer
struct MPIParallelizer <: TaijaParallel.AbstractParallelizer
comm::MPI.Comm
rank::Int
n_proc::Int
Expand All @@ -22,7 +23,9 @@ end
Create an `MPIParallelizer` object from an `MPI.Comm` object. Optionally, specify the number of observations to send to each process using `n_each`. If `n_each` is `nothing`, then all observations will be split into equally sized bins and sent to each process. If `threaded` is `true`, then the `MPIParallelizer` will use `Threads.@threads` to further parallelize the evaluation of a function.
"""
function TaijaParallel.MPIParallelizer(
comm::MPI.Comm; n_each::Union{Nothing,Int}=nothing, threaded::Bool=false
comm::MPI.Comm;
n_each::Union{Nothing,Int} = nothing,
threaded::Bool = false,
)
rank = MPI.Comm_rank(comm) # Rank of this process in the world 🌍
n_proc = MPI.Comm_size(comm) # Number of processes in the world 🌍
Expand Down
27 changes: 15 additions & 12 deletions ext/MPIExt/evaluate.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
parallelize(
TaijaBase.parallelize(
parallelizer::MPIParallelizer,
f::typeof(CounterfactualExplanations.Evaluation.evaluate),
args...;
Expand All @@ -8,19 +8,19 @@
Parallelizes the evaluation of the `CounterfactualExplanations.Evaluation.evaluate` function. This function is used to evaluate the performance of a counterfactual explanation method.
"""
function parallelize(
function TaijaBase.parallelize(
parallelizer::MPIParallelizer,
f::typeof(CounterfactualExplanations.Evaluation.evaluate),
args...;
verbose::Bool=false,
verbose::Bool = false,
kwargs...,
)

# Setup:
n_each = parallelizer.n_each

# Extract positional arguments:
counterfactuals = args[1] |> x -> CounterfactualExplanations.vectorize_collection(x)
counterfactuals = args[1] |> x -> TaijaBase.vectorize_collection(x)
# Get meta data if supplied:
if length(args) > 1
meta_data = args[2]
Expand All @@ -33,7 +33,7 @@ function parallelize(
# Break down into chunks:
args = zip(counterfactuals, meta_data)
if !isnothing(n_each)
chunks = Parallelization.chunk_obs(args, n_each, parallelizer.n_proc)
chunks = chunk_obs(args, n_each, parallelizer.n_proc)
else
chunks = [collect(args)]
end
Expand All @@ -43,15 +43,15 @@ function parallelize(

# For each chunk:
for (i, chunk) in enumerate(chunks)
worker_chunk = Parallelization.split_obs(chunk, parallelizer.n_proc)
worker_chunk = TaijaParallel.split_obs(chunk, parallelizer.n_proc)
worker_chunk = MPI.scatter(worker_chunk, parallelizer.comm)
worker_chunk = stack(worker_chunk; dims=1)
worker_chunk = stack(worker_chunk; dims = 1)
if !parallelizer.threaded
if parallelizer.rank == 0 && verbose
# Generating counterfactuals with progress bar:
output = []
@showprogress desc = "Evaluating counterfactuals ..." for x in zip(
eachcol(worker_chunk)...
eachcol(worker_chunk)...,
)
with_logger(NullLogger()) do
push!(output, f(x...; kwargs...))
Expand All @@ -66,8 +66,11 @@ function parallelize(
else
# Parallelize further with `Threads.@threads`:
second_parallelizer = ThreadsParallelizer()
output = parallelize(
second_parallelizer, f, eachcol(worker_chunk)...; kwargs...
output = TaijaBase.parallelize(
second_parallelizer,
f,
eachcol(worker_chunk)...;
kwargs...,
)
end
MPI.Barrier(parallelizer.comm)
Expand All @@ -84,7 +87,7 @@ function parallelize(
# Load output from rank 0:
if parallelizer.rank == 0
outputs = []
for i in 1:length(chunks)
for i = 1:length(chunks)
output = Serialization.deserialize(joinpath(storage_path, "output_$i.jls"))
push!(outputs, output)
end
Expand All @@ -95,7 +98,7 @@ function parallelize(
end

# Broadcast output to all processes:
final_output = MPI.bcast(output, parallelizer.comm; root=0)
final_output = MPI.bcast(output, parallelizer.comm; root = 0)
MPI.Barrier(parallelizer.comm)

return final_output
Expand Down
28 changes: 16 additions & 12 deletions ext/MPIExt/generate_counterfactual.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
using CounterfactualExplanations
using MLUtils: stack
using Serialization

"""
parallelize(
TaijaBase.parallelize(
parallelizer::MPIParallelizer,
f::typeof(CounterfactualExplanations.generate_counterfactual),
args...;
Expand All @@ -11,19 +12,19 @@ using Serialization
Parallelizes the `CounterfactualExplanations.generate_counterfactual` function using `MPI.jl`. This function is used to generate counterfactual explanations.
"""
function parallelize(
function TaijaBase.parallelize(
parallelizer::MPIParallelizer,
f::typeof(CounterfactualExplanations.generate_counterfactual),
args...;
verbose::Bool=false,
verbose::Bool = false,
kwargs...,
)

# Setup:
n_each = parallelizer.n_each

# Extract positional arguments:
counterfactuals = args[1] |> x -> CounterfactualExplanations.vectorize_collection(x)
counterfactuals = args[1] |> x -> TaijaBase.vectorize_collection(x)
target = args[2] |> x -> isa(x, AbstractArray) ? x : fill(x, length(counterfactuals))
data = args[3] |> x -> isa(x, AbstractArray) ? x : fill(x, length(counterfactuals))
M = args[4] |> x -> isa(x, AbstractArray) ? x : fill(x, length(counterfactuals))
Expand All @@ -32,7 +33,7 @@ function parallelize(
# Break down into chunks:
args = zip(counterfactuals, target, data, M, generator)
if !isnothing(n_each)
chunks = Parallelization.chunk_obs(args, n_each, parallelizer.n_proc)
chunks = chunk_obs(args, n_each, parallelizer.n_proc)
else
chunks = [collect(args)]
end
Expand All @@ -42,15 +43,15 @@ function parallelize(

# For each chunk:
for (i, chunk) in enumerate(chunks)
worker_chunk = Parallelization.split_obs(chunk, parallelizer.n_proc)
worker_chunk = TaijaParallel.split_obs(chunk, parallelizer.n_proc)
worker_chunk = MPI.scatter(worker_chunk, parallelizer.comm)
worker_chunk = stack(worker_chunk; dims=1)
worker_chunk = stack(worker_chunk; dims = 1)
if !parallelizer.threaded
if parallelizer.rank == 0 && verbose
# Generating counterfactuals with progress bar:
output = []
@showprogress desc = "Generating counterfactuals ..." for x in zip(
eachcol(worker_chunk)...
eachcol(worker_chunk)...,
)
with_logger(NullLogger()) do
push!(output, f(x...; kwargs...))
Expand All @@ -65,8 +66,11 @@ function parallelize(
else
# Parallelize further with `Threads.@threads`:
second_parallelizer = ThreadsParallelizer()
output = parallelize(
second_parallelizer, f, eachcol(worker_chunk)...; kwargs...
output = TaijaBase.parallelize(
second_parallelizer,
f,
eachcol(worker_chunk)...;
kwargs...,
)
end
MPI.Barrier(parallelizer.comm)
Expand All @@ -83,7 +87,7 @@ function parallelize(
# Load output from rank 0:
if parallelizer.rank == 0
outputs = []
for i in 1:length(chunks)
for i = 1:length(chunks)
output = Serialization.deserialize(joinpath(storage_path, "output_$i.jls"))
push!(outputs, output)
end
Expand All @@ -94,7 +98,7 @@ function parallelize(
end

# Broadcast output to all processes:
final_output = MPI.bcast(output, parallelizer.comm; root=0)
final_output = MPI.bcast(output, parallelizer.comm; root = 0)
MPI.Barrier(parallelizer.comm)

return final_output
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
using CounterfactualExplanations
using CounterfactualExplanations: generate_counterfactual
using CounterfactualExplanations.Evaluation: evaluate
import CounterfactualExplanations
using Logging
using ProgressMeter

include("assign_traits.jl")
include("threads/threads.jl")
include("threads/threads.jl")
5 changes: 3 additions & 2 deletions src/CounterfactualExplanations.jl/assign_traits.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"The `generate_counterfactual` method is parallelizable."
ProcessStyle(::Type{<:typeof(generate_counterfactual)}) = IsParallel()
ProcessStyle(::Type{<:typeof(CounterfactualExplanations.generate_counterfactual)}) =
IsParallel()

"The `evaluate` function is parallelizable."
function ProcessStyle(::Type{<:typeof(CounterfactualExplanations.Evaluation.evaluate)})
return IsParallel()
end
end
20 changes: 11 additions & 9 deletions src/CounterfactualExplanations.jl/threads/evaluate.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import TaijaBase

"""
parallelize(
TaijaBase.parallelize(
parallelizer::ThreadsParallelizer,
f::typeof(CounterfactualExplanations.Evaluation.evaluate),
args...;
Expand All @@ -8,16 +10,16 @@
Parallelizes the evaluation of `f` using `Threads.@threads`. This function is used to evaluate counterfactual explanations.
"""
function parallelize(
function TaijaBase.parallelize(
parallelizer::ThreadsParallelizer,
f::typeof(CounterfactualExplanations.Evaluation.evaluate),
args...;
verbose::Bool=true,
verbose::Bool = true,
kwargs...,
)

# Setup:
counterfactuals = args[1] |> x -> CounterfactualExplanations.vectorize_collection(x)
counterfactuals = args[1] |> x -> TaijaBase.vectorize_collection(x)

# Get meta data if supplied:
if length(args) > 1
Expand All @@ -28,22 +30,22 @@ function parallelize(

# Check meta data:
if typeof(meta_data) <: AbstractArray
meta_data = CounterfactualExplanations.vectorize_collection(meta_data)
meta_data = TaijaBase.vectorize_collection(meta_data)
@assert length(meta_data) == length(counterfactuals) "The number of meta data must match the number of counterfactuals."
else
meta_data = fill(meta_data, length(counterfactuals))
end

# Preallocate:
evaluations = [[] for _ in 1:Threads.nthreads()]
evaluations = [[] for _ = 1:Threads.nthreads()]

# Verbosity:
if verbose
prog = ProgressMeter.Progress(
length(counterfactuals);
desc="Evaluating counterfactuals ...",
showspeed=true,
color=:green,
desc = "Evaluating counterfactuals ...",
showspeed = true,
color = :green,
)
end

Expand Down
Loading

2 comments on commit 5ec77f4

@pat-alt
Copy link
Member Author

@pat-alt pat-alt commented on 5ec77f4 Apr 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

It turns out that previously this package actually relied on the existing functionality in CounterfactualExplanations.jl. This has been fixed now and all functionality related to parallelization has been removed from CounterfactualExplanations.jl (deprecated, technically). This package now depend on TaijaBase.jl, a new meta package for all packages in the Taija ecosystem.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/104314

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.0.0 -m "<description of version>" 5ec77f461a2e76cfe879e0ac35d9f6d3fe40f1ea
git push origin v1.0.0

Please sign in to comment.