diff --git a/src/interleaved.jl b/src/interleaved.jl index b07f552..be13094 100644 --- a/src/interleaved.jl +++ b/src/interleaved.jl @@ -45,7 +45,7 @@ function interleaved_backend(target_vars::AbstractVector{Int}, data::AbstractMat convergence_threshold::AbstractFloat=0.01, conv_check_start::AbstractFloat=0.1, conv_time_step::AbstractFloat=0.1, parallel::String="multi_il", edge_rule::String="OR", edge_merge_fun=maxweight, nonsparse_cond::Bool=false, verbose::Bool=true, - workers_local::Bool=true, feed_forward::Bool=true) where {ElType<:Real, DiscType<:Integer, ContType<:AbstractFloat} + workers_local::Bool=true, feed_forward::Bool=true, kill_remote_workers::Bool=true) where {ElType<:Real, DiscType<:Integer, ContType<:AbstractFloat} test_name = GLL_args[:test_name] weight_type = GLL_args[:weight_type] @@ -235,7 +235,7 @@ function interleaved_backend(target_vars::AbstractVector{Int}, data::AbstractMat end end - if !workers_local + if !workers_local && kill_remote_workers rmprocs(workers()) else wait.(worker_returns) diff --git a/src/learning.jl b/src/learning.jl index cb097bc..d313dc2 100644 --- a/src/learning.jl +++ b/src/learning.jl @@ -204,7 +204,7 @@ function LGL(data::AbstractMatrix; test_name::String="mi", max_k::Integer=3, tmp_folder::AbstractString="", debug::Integer=0, time_limit::AbstractFloat=-1.0, header=nothing, meta_variable_mask=nothing, dense_cor::Bool=true, recursive_pcor::Bool=true, cache_pcor::Bool=false, correct_reliable_only::Bool=true, feed_forward::Bool=true, - track_rejections::Bool=false, all_univar_nbrs=nothing, kwargs...) + track_rejections::Bool=false, all_univar_nbrs=nothing, kill_remote_workers::Bool=true, kwargs...) """ time_limit: -1.0 set heuristically, 0.0 no time_limit, otherwise time limit in seconds parallel: 'single', 'single_il', 'multi_il' @@ -239,7 +239,8 @@ function LGL(data::AbstractMatrix; test_name::String="mi", max_k::Integer=3, :feed_forward => feed_forward, :edge_rule => edge_rule, :edge_merge_fun => edge_merge_fun, :workers_local => workers_all_local(), - :variable_ids => header, :meta_variable_mask => meta_variable_mask) + :variable_ids => header, :meta_variable_mask => meta_variable_mask, + :kill_remote_workers => kill_remote_workers) nbr_dict, unfinished_state_dict, rej_dict = learn_graph_structure(target_vars, data, all_univar_nbrs, levels,