diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 1e725116..6d42a1e2 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -1,6 +1,6 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.10.3" +julia_version = "1.10.5" manifest_format = "2.0" project_hash = "0bd11d5fa58aad2714bf7893e520fc7c086ef3ca" @@ -85,9 +85,9 @@ version = "3.5.1+1" [[deps.ArrayInterface]] deps = ["Adapt", "LinearAlgebra"] -git-tree-sha1 = "f54c23a5d304fb87110de62bace7777d59088c34" +git-tree-sha1 = "3640d077b6dafd64ceb8fd5c1ec76f7ca53bcf76" uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "7.15.0" +version = "7.16.0" [deps.ArrayInterface.extensions] ArrayInterfaceBandedMatricesExt = "BandedMatrices" @@ -209,9 +209,9 @@ version = "0.9.2+0" [[deps.CUDA_Runtime_Discovery]] deps = ["Libdl"] -git-tree-sha1 = "f3b237289a5a77c759b2dd5d4c2ff641d67c4030" +git-tree-sha1 = "33576c7c1b2500f8e7e6baa082e04563203b3a45" uuid = "1af6417a-86b4-443c-805f-a4643ffb695f" -version = "0.3.4" +version = "0.3.5" [[deps.CUDA_Runtime_jll]] deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] @@ -359,17 +359,18 @@ uuid = "98bfc277-1877-43dc-819b-a3e38c30242f" version = "0.1.13" [[deps.ConstructionBase]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "a33b7ced222c6165f624a3f2b55945fac5a598d9" +git-tree-sha1 = "76219f1ed5771adbb096743bff43fb5fdd4c1157" uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -version = "1.5.7" +version = "1.5.8" [deps.ConstructionBase.extensions] ConstructionBaseIntervalSetsExt = "IntervalSets" + ConstructionBaseLinearAlgebraExt = "LinearAlgebra" ConstructionBaseStaticArraysExt = "StaticArrays" [deps.ConstructionBase.weakdeps] IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [[deps.ContextVariablesX]] @@ -569,19 +570,24 @@ uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" version = "1.16.3" [[deps.FilePathsBase]] -deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"] -git-tree-sha1 = "9f00e42f8d99fdde64d40c8ea5d14269a2e2c1aa" +deps = ["Compat", "Dates"] +git-tree-sha1 = "7878ff7172a8e6beedd1dea14bd27c3c6340d361" uuid = "48062228-2e41-5def-b9a4-89aafe57970f" -version = "0.9.21" +version = "0.9.22" +weakdeps = ["Mmap", "Test"] + + [deps.FilePathsBase.extensions] + FilePathsBaseMmapExt = "Mmap" + FilePathsBaseTestExt = "Test" [[deps.FileWatching]] uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" [[deps.FillArrays]] deps = ["LinearAlgebra"] -git-tree-sha1 = "fd0002c0b5362d7eb952450ad5eb742443340d6e" +git-tree-sha1 = "6a70198746448456524cb442b8af316927ff3e1a" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.12.0" +version = "1.13.0" weakdeps = ["PDMats", "SparseArrays", "Statistics"] [deps.FillArrays.extensions] @@ -841,10 +847,10 @@ uuid = "82899510-4779-5014-852e-03e436cf321d" version = "1.0.0" [[deps.JLD2]] -deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "PrecompileTools", "Reexport", "Requires", "TranscodingStreams", "UUIDs", "Unicode"] -git-tree-sha1 = "67d4690d32c22e28818a434b293a374cc78473d3" +deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "PrecompileTools", "Requires", "TranscodingStreams"] +git-tree-sha1 = "a0746c21bdc986d0dc293efa6b1faee112c37c28" uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" -version = "0.4.51" +version = "0.4.53" [[deps.JLFzf]] deps = ["Pipe", "REPL", "Random", "fzf_jll"] @@ -854,9 +860,9 @@ version = "0.1.8" [[deps.JLLWrappers]] deps = ["Artifacts", "Preferences"] -git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" +git-tree-sha1 = "f389674c99bfcde17dc57454011aa44d5a260a40" uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.5.0" +version = "1.6.0" [[deps.JSON]] deps = ["Dates", "Mmap", "Parsers", "Unicode"] @@ -884,9 +890,9 @@ version = "0.2.4" [[deps.KernelAbstractions]] deps = ["Adapt", "Atomix", "InteractiveUtils", "MacroTools", "PrecompileTools", "Requires", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "35ceea58aa34ad08b1ae00a52622c62d1cfb8ce2" +git-tree-sha1 = "cb1cff88ef2f3a157cbad75bbe6b229e1975e498" uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.24" +version = "0.9.25" [deps.KernelAbstractions.extensions] EnzymeExt = "EnzymeCore" @@ -1444,9 +1450,9 @@ version = "1.4.1" [[deps.Plots]] deps = ["Base64", "Contour", "Dates", "Downloads", "FFMPEG", "FixedPointNumbers", "GR", "JLFzf", "JSON", "LaTeXStrings", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "Pkg", "PlotThemes", "PlotUtils", "PrecompileTools", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "RelocatableFolders", "Requires", "Scratch", "Showoff", "SparseArrays", "Statistics", "StatsBase", "TOML", "UUIDs", "UnicodeFun", "UnitfulLatexify", "Unzip"] -git-tree-sha1 = "082f0c4b70c202c37784ce4bfbc33c9f437685bf" +git-tree-sha1 = "45470145863035bb124ca51b320ed35d071cc6c2" uuid = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" -version = "1.40.5" +version = "1.40.8" [deps.Plots.extensions] FileIOExt = "FileIO" @@ -1514,9 +1520,9 @@ uuid = "92933f4c-e287-5a05-a399-4b506db050ca" version = "1.10.2" [[deps.PtrArrays]] -git-tree-sha1 = "f011fbb92c4d401059b2212c05c0601b70f8b759" +git-tree-sha1 = "77a42d78b6a92df47ab37e177b2deac405e1c88f" uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d" -version = "1.2.0" +version = "1.2.1" [[deps.Qt6Base_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Fontconfig_jll", "Glib_jll", "JLLWrappers", "Libdl", "Libglvnd_jll", "OpenSSL_jll", "Vulkan_Loader_jll", "Xorg_libSM_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Xorg_libxcb_jll", "Xorg_xcb_util_cursor_jll", "Xorg_xcb_util_image_jll", "Xorg_xcb_util_keysyms_jll", "Xorg_xcb_util_renderutil_jll", "Xorg_xcb_util_wm_jll", "Zlib_jll", "libinput_jll", "xkbcommon_jll"] @@ -1544,9 +1550,15 @@ version = "6.7.1+1" [[deps.QuadGK]] deps = ["DataStructures", "LinearAlgebra"] -git-tree-sha1 = "e237232771fdafbae3db5c31275303e056afaa9f" +git-tree-sha1 = "1d587203cf851a51bf1ea31ad7ff89eff8d625ea" uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" -version = "2.10.1" +version = "2.11.0" + + [deps.QuadGK.extensions] + QuadGKEnzymeExt = "Enzyme" + + [deps.QuadGK.weakdeps] + Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" [[deps.RData]] deps = ["CategoricalArrays", "CodecZlib", "DataFrames", "Dates", "FileIO", "Requires", "TimeZones", "Unicode"] @@ -2274,7 +2286,7 @@ version = "0.15.2+0" [[deps.libblastrampoline_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.8.0+1" +version = "5.11.0+0" [[deps.libdecor_jll]] deps = ["Artifacts", "Dbus_jll", "JLLWrappers", "Libdl", "Libglvnd_jll", "Pango_jll", "Wayland_jll", "xkbcommon_jll"] diff --git a/src/baselaplace/predicting.jl b/src/baselaplace/predicting.jl index 810dc1dd..4a0a1d6c 100644 --- a/src/baselaplace/predicting.jl +++ b/src/baselaplace/predicting.jl @@ -23,19 +23,24 @@ function has_softmax_or_sigmoid_final_layer(model::Flux.Chain) return has_finaliser end -""" +@doc raw""" functional_variance(la::AbstractLaplace, 𝐉::AbstractArray) -Compute the functional variance for the GLM predictive. Dispatches to the appropriate method based on the Hessian structure. +Computes the functional variance for the GLM predictive as `map(j -> (j' * Σ * j), eachrow(𝐉))` which is a (output x output) predictive covariance matrix. Formally, we have ``{\mathbf{J}_{\hat\theta}}^\intercal\Sigma\mathbf{J}_{\hat\theta}`` where ``\mathbf{J}_{\hat\theta}=\nabla_{\theta}f(x;\theta)|\hat\theta`` is the Jacobian evaluated at the MAP estimate. + +Dispatches to the appropriate method based on the Hessian structure. """ function functional_variance(la, 𝐉) return functional_variance(la, la.est_params.hessian_structure, 𝐉) end -""" +@doc raw""" glm_predictive_distribution(la::AbstractLaplace, X::AbstractArray) -Computes the linearized GLM predictive. +Computes the linearized GLM predictive from neural network with a Laplace approximation to the posterior ``p(\theta|\mathcal{D})=\mathcal{N}(\hat\theta,\Sigma)``. +This is the distribution on network outputs given by ``p(f(x)|x,\mathcal{D})\approx \mathcal{N}(f(x;\hat\theta),{\mathbf{J}_{\hat\theta}}^\intercal\Sigma\mathbf{J}_{\hat\theta})``. +For the Bayesian predictive distribution, see [`predict`](@ref). + # Arguments @@ -49,7 +54,7 @@ Computes the linearized GLM predictive. # Examples -```julia-repl +```julia using Flux, LaplaceRedux using LaplaceRedux.Data: toy_data_linear x, y = toy_data_linear() @@ -58,6 +63,7 @@ nn = Chain(Dense(2,1)) la = Laplace(nn; likelihood=:classification) fit!(la, data) glm_predictive_distribution(la, hcat(x...)) +``` """ function glm_predictive_distribution(la::AbstractLaplace, X::AbstractArray) 𝐉, fμ = Curvature.jacobians(la.est_params.curvature, X) @@ -65,14 +71,20 @@ function glm_predictive_distribution(la::AbstractLaplace, X::AbstractArray) fvar = functional_variance(la, 𝐉) fvar = reshape(fvar, size(fμ)...) fstd = sqrt.(fvar) - normal_distr = [Normal(fμ[i], fstd[i]) for i in 1:size(fμ, 2)] + normal_distr = [Normal(fμ[i], fstd[i]) for i in axes(fμ, 2)] return (normal_distr, fμ, fvar) end -""" - predict(la::AbstractLaplace, X::AbstractArray; link_approx=:probit, predict_proba::Bool=true) +@doc raw""" + predict( + la::AbstractLaplace, + X::AbstractArray; + link_approx=:probit, + predict_proba::Bool=true, + ret_distr::Bool=false, + ) -Computes predictions from Bayesian neural network. +Computes the Bayesian predictivie distribution from a neural network with a Laplace approximation to the posterior ``p(\theta | \mathcal{D}) = \mathcal{N}(\hat\theta, \Sigma)``. # Arguments @@ -80,20 +92,26 @@ Computes predictions from Bayesian neural network. - `X::AbstractArray`: Input data. - `link_approx::Symbol=:probit`: Link function approximation. Options are `:probit` and `:plugin`. - `predict_proba::Bool=true`: If `true` (default) apply a sigmoid or a softmax function to the output of the Flux model. -- `return_distr::Bool=false`: if `false` (default), the function output either the direct output of the chain or pseudo-probabilities (if predict_proba= true). +- `return_distr::Bool=false`: if `false` (default), the function outputs either the direct output of the chain or pseudo-probabilities (if `predict_proba=true`). if `true` predict return a Bernoulli distribution in binary classification tasks and a categorical distribution in multiclassification tasks. # Returns -For classification tasks, LaplaceRedux provides different options: -if ret_distr is false: - - `fμ::AbstractArray`: Mean of the predictive distribution if link function is set to `:plugin`, otherwise the probit approximation. The output shape is column-major as in Flux. -if ret_distr is true: - - a Bernoulli distribution in binary classification tasks and a categorical distribution in multiclassification tasks. + +For classification tasks: + +1. If `ret_distr` is `false`, `predict` returns `fμ`, i.e. the mean of the predictive distribution, which corresponds to the MAP estimate if the link function is set to `:plugin`, otherwise the probit approximation. The output shape is column-major as in Flux. +2. If `ret_distr` is `true`, `predict` returns a Bernoulli distribution in binary classification tasks and a categorical distribution in multiclassification tasks. + For regression tasks: -- `normal_distr::Distributions.Normal`:the array of Normal distributions computed by glm_predictive_distribution. + +1. If `ret_distr` is `false`, `predict` returns the mean and the variance of the predictive distribution. The output shape is column-major as in Flux. +2. If `ret_distr` is `true`, `predict` returns the predictive posterior distribution, namely: + +``p(y|x,\mathcal{D})\approx \mathcal{N}(f(x;\hat\theta),{\mathbf{J}_{\hat\theta}}^\intercal\Sigma\mathbf{J}_{\hat\theta} + \sigma^2 \mathbf{I})`` + # Examples -```julia-repl +```julia using Flux, LaplaceRedux using LaplaceRedux.Data: toy_data_linear x, y = toy_data_linear() @@ -111,15 +129,22 @@ function predict( predict_proba::Bool=true, ret_distr::Bool=false, ) - normal_distr, fμ, fvar = glm_predictive_distribution(la, X) + _, fμ, fvar = glm_predictive_distribution(la, X) # Regression: if la.likelihood == :regression + + # Add observational noise: + pred_var = fvar .+ la.prior.σ^2 + fstd = sqrt.(pred_var) + pred_dist = [Normal(fμ[i], fstd[i]) for i in axes(fμ, 2)] + if ret_distr - return reshape(normal_distr, (:, 1)) + return reshape(pred_dist, (:, 1)) else - return fμ, fvar + return fμ, pred_var end + end # Classification: diff --git a/src/full.jl b/src/full.jl index f9173bb5..4720150f 100644 --- a/src/full.jl +++ b/src/full.jl @@ -50,10 +50,10 @@ function _fit!( return la.posterior.n_data = n_data end -""" -functional_variance(la::Laplace,𝐉) +@doc raw""" + functional_variance(la::Laplace, hessian_structure::FullHessian, 𝐉) -Compute the linearized GLM predictive variance as `𝐉ₙΣ𝐉ₙ'` where `𝐉=∇f(x;θ)|θ̂` is the Jacobian evaluated at the MAP estimate and `Σ = P⁻¹`. +Computes the functional variance for the GLM predictive as `map(j -> (j' * Σ * j), eachrow(𝐉))` which is a (output x output) predictive covariance matrix. Formally, we have ``{\mathbf{J}_{\hat\theta}}^\intercal\Sigma\mathbf{J}_{\hat\theta}`` where ``\mathbf{J}_{\hat\theta}=\nabla_{\theta}f(x;\theta)|\hat\theta`` is the Jacobian evaluated at the MAP estimate. """ function functional_variance(la::Laplace, hessian_structure::FullHessian, 𝐉) Σ = posterior_covariance(la) diff --git a/src/kronecker/kron.jl b/src/kronecker/kron.jl index 9d429efd..54ea0a8a 100644 --- a/src/kronecker/kron.jl +++ b/src/kronecker/kron.jl @@ -133,7 +133,7 @@ function _fit!( end """ -functional_variance(la::Laplace, hessian_structure::KronHessian, 𝐉::Matrix) + functional_variance(la::Laplace, hessian_structure::KronHessian, 𝐉::Matrix) Compute functional variance for the GLM predictive: as the diagonal of the K×K predictive output covariance matrix 𝐉𝐏⁻¹𝐉ᵀ, where K is the number of outputs, 𝐏 is the posterior precision, and 𝐉 is the Jacobian of model output `𝐉=∇f(x;θ)|θ̂`.