Skip to content

Commit

Permalink
Add fast stack for sparse columns, catch incorrect kwarg usage in `…
Browse files Browse the repository at this point in the history
…learn_network`
  • Loading branch information
jtackm committed Jun 5, 2024
1 parent 9c1fe58 commit 759e2ee
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 12 deletions.
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
### general

- strongly speed-up contingency table computation for heterogeneous=true and max_k=0/1
- use the more compiler-friendly `stack` (introduced in Julia v1.9) instead of `hcat` for large numbers of columns (if available)
- use the more compiler-friendly `stack` (introduced in Julia v1.9) instead of `hcat` for large numbers of columns (if available), introduce fast method for sparse columns
- improve univariate pvalue filtering
- remove performance bottleneck in three-way `adjust_df`
- catch incorrect usage of 'meta_data_path' as keyword argument

# v0.19.2 (latest)

Expand Down
2 changes: 1 addition & 1 deletion src/io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ function load_data(data_path::AbstractString, meta_path::StrOrNoth=nothing; tran
end
ld_results = load_jld(data_path, otu_data_key, otu_header_key, meta_data_key, meta_header_key, transposed=transposed)
else
error("$(file_ext) not a valid output format. Choose one of $(valid_data_formats)")
error("$(file_ext) not a valid input format. Choose one of $(valid_data_formats)")
end

ld_results
Expand Down
38 changes: 33 additions & 5 deletions src/learning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ end
function make_table(data_path::AbstractString, meta_data_path::StrOrNoth=nothing;
otu_data_key::StrOrNoth="otu_data", otu_header_key::StrOrNoth="otu_header",
meta_data_key::StrOrNoth="meta_data", meta_header_key::StrOrNoth="meta_header",
verbose::Bool=true, transposed::Bool=false)
transposed::Bool=false)

data, header, meta_data, meta_header = load_data(data_path, meta_data_path, otu_data_key=otu_data_key,
otu_header_key=otu_header_key, meta_data_key=meta_data_key,
Expand All @@ -316,6 +316,24 @@ function make_table(data_path::AbstractString, meta_data_path::StrOrNoth=nothing
data, header, meta_mask
end

#learn_local_neighborhood(target_var::AbstractString, header, args...; kwargs...) =
# learn_local_neighborhood(findfirst(==(target_var), header), args...; kwargs...)
#
#function learn_local_neighborhood(target_var::Int, data_path::AbstractString, meta_data_path::StrOrNoth=nothing;
# otu_data_key::StrOrNoth="otu_data",
# otu_header_key::StrOrNoth="otu_header", meta_data_key::StrOrNoth="meta_data",
# meta_header_key::StrOrNoth="meta_header", verbose::Bool=true,
# transposed::Bool=false, kwargs...)
#
# verbose && println("\n### Loading data ###\n")
# data, header, meta_mask = make_table(data_path, meta_data_path, otu_data_key=otu_data_key,
# otu_header_key=otu_header_key, meta_data_key=meta_data_key,
# meta_header_key=meta_header_key, transposed=transposed)
#
#
#
#end

"""
learn_network(data_path::AbstractString, meta_data_path::AbstractString) -> FWResult{<:Integer}
Expand All @@ -339,6 +357,11 @@ function learn_network(data_path::AbstractString, meta_data_path::StrOrNoth=noth
meta_header_key::StrOrNoth="meta_header", verbose::Bool=true,
transposed::Bool=false, kwargs...)

# Catch incorrect usage of data_path/meta_data_path as keyword argument
for key in (:data_path, :meta_data_path)
@assert !(key in keys(kwargs)) "'$key' is a positional argument, please use 'learn_network(<data_path>, <meta_data_path>; <kwargs...>)'."
end

verbose && println("\n### Loading data ###\n")
data, header, meta_mask = make_table(data_path, meta_data_path, otu_data_key=otu_data_key,
otu_header_key=otu_header_key, meta_data_key=meta_data_key,
Expand All @@ -353,10 +376,12 @@ end
Works like learn_network(data_path::AbstractString, meta_data_path::AbstractString), but takes paths to multiple data sets (independent sequencing experiments (e.g. 16S + ITS) for the same biological samples) which are normalized independently.
"""
function learn_network(all_data_paths::AbstractVector{<:AbstractString}, meta_data_path::StrOrNoth=nothing;
otu_data_key::StrOrNoth="otu_data",
otu_header_key::StrOrNoth="otu_header", meta_data_key::StrOrNoth="meta_data",
meta_header_key::StrOrNoth="meta_header", verbose::Bool=true,
transposed::Bool=false, kwargs...)
otu_data_key::StrOrNoth="otu_data", otu_header_key::StrOrNoth="otu_header", transposed::Bool=false, kwargs...)

# Catch incorrect usage of data_path/meta_data_path as keyword argument
for key in (:all_data_paths, :meta_data_path)
@assert !(key in keys(kwargs)) "'$key' is a positional argument, please use 'learn_network(<all_data_paths>, <meta_data_path>; <kwargs...>)'."
end

data_path = all_data_paths[1]
if length(all_data_paths) > 1
Expand Down Expand Up @@ -447,6 +472,9 @@ function learn_network(data::AbstractMatrix; sensitive::Bool=true,
cache_pcor::Bool=false, time_limit::AbstractFloat=-1.0, update_interval::AbstractFloat=30.0, parallel_mode="auto",
extra_data::Union{AbstractVector,Nothing}=nothing, share_data::Bool=true, experimental_kwargs...)

@assert !(:meta_data_path in keys(experimental_kwargs)) "You provided a OTU matrix together with a meta data path, this is currently not supported.
Use either 'learn_network(<otu_table_path>, <meta_data_path>; <kwargs...>)' or 'learn_network(<otu_matrix>; <kwargs...>)'."

start_time = time()

cont_mode = sensitive ? "fz" : "mi"
Expand Down
36 changes: 32 additions & 4 deletions src/preprocessing.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,38 @@
function _fast_stack_sparse(vecs::Vector{SparseVector{T1, T2}}) where {T1 <: Real, T2 <: Integer}
"""Fast method for stacking sparse columns"""
n_rows = length(vecs[1])
@assert all(length(x) == n_rows for x in vecs)

rids, cids, nzvals = Int[], Int[], T1[]

for (col_i, v) in enumerate(vecs)
n_val = nnz(v)

if n_val > 0
append!(rids, rowvals(v))
append!(cids, repeat([col_i], n_val))
append!(nzvals, nonzeros(v))
end
end

n_cols = length(vecs)
return sparse(rids, cids, nzvals, n_rows, n_cols)
end

function stack_or_hcat(vecs::AbstractVector{<:AbstractArray})
# use more efficient stack (introduced in Julia v1.9) if available
if isdefined(Base, :stack)
return stack(vecs)
stacked_matrix = if isdefined(Base, :stack)
# use even faster custom implementation for sparse vectors
if isa(vecs, AbstractVector{<:SparseVector})
_fast_stack_sparse(vecs)
else
stack(vecs)
end
else
return hcat(vecs...)
hcat(vecs...)
end

return stacked_matrix
end


Expand Down Expand Up @@ -187,7 +215,7 @@ function discretize(X::AbstractMatrix{ElType}; n_bins::Integer=3, nz::Bool=true,
rank_method::String="tied", disc_method::String="median", nz_mask::BitMatrix=BitMatrix(undef, (0,0))) where ElType <: AbstractFloat
if nz
if issparse(X)
disc_vecs = SparseVector{Int}[]
disc_vecs = SparseVector{Int,Int}[]
for j in 1:size(X, 2)
push!(disc_vecs, discretize_nz(X[:, j], n_bins, rank_method=rank_method, disc_method=disc_method))
end
Expand Down
2 changes: 1 addition & 1 deletion src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ end
## RESULT TYPES ##
##################

const RejDict{T} = Dict{T,Tuple{Tuple{Int64,Vararg{Int64,N} where N},TestResult,Tuple{Int,Float64}}}
const RejDict{T} = Dict{T,Tuple{Tuple{Int64,Vararg{Int64,N}} where N,TestResult,Tuple{Int,Float64}}}

struct HitonState{T}
phase::Char
Expand Down

0 comments on commit 759e2ee

Please sign in to comment.