Skip to content

Commit

Permalink
more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-alt committed Nov 28, 2022
1 parent ce22a2d commit 9aa5f98
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 246 deletions.
234 changes: 0 additions & 234 deletions docs/src/intro.ipynb

This file was deleted.

2 changes: 1 addition & 1 deletion docs/src/intro.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ using LaplaceRedux.Data: toy_data_regression
using Flux.Optimise: update!, Adam
# Data:
n = 120 # number of observations
n = 150 # number of observations
σtrue = 0.3 # true observational noise
x, y = toy_data_regression(100;noise=σtrue)
xs = [[x] for x in x]
Expand Down
14 changes: 5 additions & 9 deletions src/baselaplace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,10 @@ function get_params(la::BaseLaplace)
params = Flux.params(nn)
n_elements = length(params)
if la.subset_of_weights == :all
params =for θ params] # get all parameters and constants in logitbinarycrossentropy
params =for θ params] # get all parameters and constants in logitbinarycrossentropy
elseif la.subset_of_weights == :last_layer
params = [params[n_elements-1],params[n_elements]] # only get last parameters and constants
else
@error "`subset_of_weights` of weights should be one of the following: `[:all, :last_layer]`"
end
params = [params[n_elements-1],params[n_elements]] # only get last parameters and constants
end
return params
end

Expand Down Expand Up @@ -123,11 +121,9 @@ function log_marginal_likelihood(la::BaseLaplace; P₀::Union{Nothing,AbstractFl
if !isnothing(σ)
@assert (la.likelihood==:regression || la.σ == σ) "Can only change observational noise σ for regression."
la.σ = σ
end

mll = log_likelihood(la) - 0.5 * (log_det_ratio(la) + _weight_penalty(la))
end

return mll
return log_likelihood(la) - 0.5 * (log_det_ratio(la) + _weight_penalty(la))
end

"""
Expand Down
5 changes: 5 additions & 0 deletions src/laplace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@ function Laplace(model::Any; likelihood::Symbol, kwargs...)

# Load hyperparameters:
args = LaplaceParams(;kwargs...)

# Assertions:
@assert !(args.σ != 1.0 && likelihood != :regression) "Observation noise σ ≠ 1 only available for regression."
@assert args.subset_of_weights [:all, :last_layer] "`subset_of_weights` of weights should be one of the following: `[:all, :last_layer]`"

# Setup:
P₀ = isnothing(args.P₀) ? UniformScaling(args.λ) : args.P₀
nn = model
n_out = outdim(nn)
Expand Down
12 changes: 10 additions & 2 deletions test/laplace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ using Statistics

# One layer:
nn = Chain(Dense(2,1))

# Expected error
@test_throws AssertionError Laplace(nn; likelihood=:classification, subset_of_weights=:last)

# Correct:
la = Laplace(nn; likelihood=:classification)
@test la.n_params == 3

Expand Down Expand Up @@ -179,8 +184,11 @@ end

la = Laplace(nn; likelihood=likelihood, λ=λ, subset_of_weights=:last_layer)
fit!(la, data)
optimize_prior!(la)
plt = plot(la, X, y)
optimize_prior!(la; verbose=true)
plot(la, X, y) # standard
plot(la, X, y; xlims=(-5,5), ylims=(-5,5)) # lims
plot(la, X, y; link_approx=:plugin) # plugin approximation

end
end

Expand Down

2 comments on commit 9aa5f98

@pat-alt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: "Tag with name v0.1.1 already exists and points to a different commit"

Please sign in to comment.