Skip to content

Commit

Permalink
Add multi-threading
Browse files Browse the repository at this point in the history
  • Loading branch information
Spinachboul authored Mar 24, 2024
1 parent 43a4b3e commit f443843
Showing 1 changed file with 11 additions and 23 deletions.
34 changes: 11 additions & 23 deletions src/Radials.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using LinearAlgebra
using ExtendableSparse
using Surrogates
using Base.Threads

_copy(t::Tuple) = t
_copy(t) = copy(t)
Expand Down Expand Up @@ -196,33 +198,19 @@ end
_ret_copy(v::Base.RefValue) = v[]
_ret_copy(v) = copy(v)

function _approx_rbf(val, rad::RadialBasis)
function _approx_rbf_threaded(val, rad::RadialBasis)
n = length(rad.x)

# make sure @inbounds is safe
if n > size(rad.coeff, 2)
throw("Length of model's x vector exceeds number of calculated coefficients ($n != $(size(rad.coeff, 2))).")
end

approx = _make_approx(val, rad)

if rad.phi === linearRadial().phi
for i in 1:n
tmp = zero(eltype(val))
@inbounds @simd ivdep for j in 1:length(val)
tmp += ((val[j] - rad.x[i][j]) / rad.scale_factor)^2
end
tmp = sqrt(tmp)
_add_tmp_to_approx!(approx, i, tmp, rad)
end
else
tmp = collect(val)
@inbounds for i in 1:n
tmp = (val .- rad.x[i]) ./ rad.scale_factor
_add_tmp_to_approx!(approx, i, tmp, rad; f = rad.phi)

Threads.@threads for i in 1:n
tmp = zero(eltype(val))
@inbounds @simd ivdep for j in 1:length(val)
tmp += ((val[j] - rad.x[i][j]) / rad.scale_factor)^2
end
tmp = sqrt(tmp)
_add_tmp_to_approx!(approx, i, tmp, rad)
end

return _ret_copy(approx)
end

Expand Down

0 comments on commit f443843

Please sign in to comment.