Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-threading #478

Closed
wants to merge 8 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 46 additions & 34 deletions src/Radials.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using LinearAlgebra
using ExtendableSparse
using Base.Threads
using Distributed

_copy(t::Tuple) = t
_copy(t) = copy(t)
Expand Down Expand Up @@ -43,8 +45,15 @@ of a polynomial term.
References:
https://en.wikipedia.org/wiki/Polyharmonic_spline
"""
function RadialBasis(x, y, lb, ub; rad::RadialFunction = linearRadial(),
scale_factor::Real = 0.5, sparse = false)
function RadialBasis(
x,
y,
lb,
ub;
rad::RadialFunction = linearRadial(),
scale_factor::Real = 0.5,
sparse = false
)
q = rad.q
phi = rad.phi
coeff = _calc_coeffs(x, y, lb, ub, phi, q, scale_factor, sparse)
Expand Down Expand Up @@ -109,10 +118,9 @@ using ChainRulesCore: @non_differentiable

function _make_combination(n, d, ix)
exponents_combinations = [e
for e
in collect(Iterators.product(Iterators.repeated(0:n,
d)...))[:]
if sum(e) <= n]
for
e in collect(Iterators.product(Iterators.repeated(
0:n, d)...))[:] if sum(e) <= n]

return exponents_combinations[ix + 1]
end
Expand Down Expand Up @@ -145,9 +153,7 @@ function multivar_poly_basis(x, ix, d, n)
if n == 0
return one(eltype(x))
else
prod(a^d
for (a, d)
in zip(x, _make_combination(n, d, ix)))
prod(a^d for (a, d) in zip(x, _make_combination(n, d, ix)))
end
end

Expand Down Expand Up @@ -181,13 +187,17 @@ function _add_tmp_to_approx!(approx, i, tmp, rad::RadialBasis; f = identity)
end
end
# specialise when only single output dimension
function _make_approx(val,
::RadialBasis{F, Q, X, <:AbstractArray{<:Number}}) where {F, Q, X}
function _make_approx(
val, ::RadialBasis{F, Q, X, <:AbstractArray{<:Number}}) where {F, Q, X}
return Ref(zero(eltype(val)))
end
function _add_tmp_to_approx!(approx::Base.RefValue, i, tmp,
function _add_tmp_to_approx!(
approx::Base.RefValue,
i,
tmp,
rad::RadialBasis{F, Q, X, <:AbstractArray{<:Number}};
f = identity) where {F, Q, X}
f = identity
) where {F, Q, X}
@inbounds @simd ivdep for j in 1:size(rad.coeff, 1)
approx[] += rad.coeff[j, i] * f(tmp)
end
Expand All @@ -198,29 +208,23 @@ _ret_copy(v) = copy(v)

function _approx_rbf(val, rad::RadialBasis)
n = length(rad.x)
approx = _make_approx(val, rad)

# make sure @inbounds is safe
if n > size(rad.coeff, 2)
throw("Length of model's x vector exceeds number of calculated coefficients ($n != $(size(rad.coeff, 2))).")
# Define a function to compute tmp for a single index i
function compute_tmp(i)
tmp = zero(eltype(val))
@inbounds @simd ivdep for j in eachindex(val)
tmp += ((val[j] - rad.x[i][j]) / rad.scale_factor)^2
end
return sqrt(tmp)
end

approx = _make_approx(val, rad)
# Use pmap to parallelize the computation of tmp
tmp_values = pmap(compute_tmp, 1:n)

if rad.phi === linearRadial().phi
for i in 1:n
tmp = zero(eltype(val))
@inbounds @simd ivdep for j in 1:length(val)
tmp += ((val[j] - rad.x[i][j]) / rad.scale_factor)^2
end
tmp = sqrt(tmp)
_add_tmp_to_approx!(approx, i, tmp, rad)
end
else
tmp = collect(val)
@inbounds for i in 1:n
tmp = (val .- rad.x[i]) ./ rad.scale_factor
_add_tmp_to_approx!(approx, i, tmp, rad; f = rad.phi)
end
# Update the approx using the computed tmp values
for (i, tmp) in enumerate(tmp_values)
_add_tmp_to_approx!(approx, i, tmp, rad)
end

return _ret_copy(approx)
Expand All @@ -244,7 +248,15 @@ function add_point!(rad::RadialBasis, new_x, new_y)
append!(rad.x, new_x)
append!(rad.y, new_y)
end
rad.coeff = _calc_coeffs(rad.x, rad.y, rad.lb, rad.ub, rad.phi, rad.dim_poly,
rad.scale_factor, rad.sparse)
rad.coeff = _calc_coeffs(
rad.x,
rad.y,
rad.lb,
rad.ub,
rad.phi,
rad.dim_poly,
rad.scale_factor,
rad.sparse
)
nothing
end
Loading