From 6211c43cb2f59332696683bb824afd76af34c01e Mon Sep 17 00:00:00 2001 From: pat-alt Date: Thu, 21 Mar 2024 17:49:39 +0100 Subject: [PATCH 1/3] yup, this has done it --- Project.toml | 1 + src/baselaplace.jl | 2 +- src/curvature/utils.jl | 10 +- test/Manifest.toml | 491 +++++++++++++++++++++++++++- test/Project.toml | 24 +- test/counterfactual_explanations.jl | 18 + test/runtests.jl | 4 + 7 files changed, 523 insertions(+), 27 deletions(-) create mode 100644 test/counterfactual_explanations.jl diff --git a/Project.toml b/Project.toml index 0ad4bef9..ec740ef5 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Patrick Altmeyer"] version = "0.1.5" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" diff --git a/src/baselaplace.jl b/src/baselaplace.jl index e584e099..16c6b7b8 100644 --- a/src/baselaplace.jl +++ b/src/baselaplace.jl @@ -260,7 +260,7 @@ end # Posterior predictions: """ - predict(la::BaseLaplace, X::AbstractArray; link_approx=:probit) + predict(la::BaseLaplace, X::AbstractArray; link_approx=:probit, predict_proba::Bool=true) Computes predictions from Bayesian neural network. diff --git a/src/curvature/utils.jl b/src/curvature/utils.jl index 69ac90b1..16edab47 100644 --- a/src/curvature/utils.jl +++ b/src/curvature/utils.jl @@ -1,3 +1,5 @@ +using ChainRulesCore + """ jacobians(curvature::CurvatureInterface, X::AbstractArray; batched::Bool=false) @@ -25,7 +27,9 @@ function jacobians_unbatched(curvature::CurvatureInterface, X::AbstractArray) ŷ = vec(ŷ) # Jacobian: # Differentiate f with regards to the model parameters - 𝐉 = jacobian(() -> nn(X), Flux.params(nn)) + ChainRulesCore.ignore_derivatives() do + 𝐉 = jacobian(() -> nn(X), Flux.params(nn)) + end # Concatenate Jacobians for the selected parameters, to produce a matrix (K, P), where P is the total number of parameter scalars. 𝐉 = reduce(hcat, [𝐉[θ] for θ in curvature.params]) if curvature.subset_of_weights == :subnetwork @@ -46,7 +50,9 @@ function jacobians_batched(curvature::CurvatureInterface, X::AbstractArray) batch_size = size(X)[end] out_size = outdim(nn) # Jacobian: - grads = jacobian(() -> nn(X), Flux.params(nn)) + ChainRulesCore.ignore_derivatives() do + grads = jacobian(() -> nn(X), Flux.params(nn)) + end grads_joint = reduce(hcat, [grads[θ] for θ in curvature.params]) views = [ @view grads_joint[batch_start:(batch_start + out_size - 1), :] for diff --git a/test/Manifest.toml b/test/Manifest.toml index 54bd8dce..cf216271 100644 --- a/test/Manifest.toml +++ b/test/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.10.2" manifest_format = "2.0" -project_hash = "6abfc793ce90c622b1d64da61a41b1ec9660001e" +project_hash = "c67c13a7821c9510fe35d0b70abdf2bff3a78953" [[deps.ARFFFiles]] deps = ["CategoricalArrays", "Dates", "Parsers", "Tables"] @@ -21,6 +21,11 @@ weakdeps = ["ChainRulesCore", "Test"] AbstractFFTsChainRulesCoreExt = "ChainRulesCore" AbstractFFTsTestExt = "Test" +[[deps.AbstractTrees]] +git-tree-sha1 = "2d9c9a55f9c93e8887ad391fbae72f8ef55e1177" +uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +version = "0.4.5" + [[deps.Adapt]] deps = ["LinearAlgebra", "Requires"] git-tree-sha1 = "cea4ac3f5b4bc4b3000aa55afb6e5626518948fa" @@ -46,6 +51,18 @@ version = "2.3.0" uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" version = "1.1.1" +[[deps.Arpack]] +deps = ["Arpack_jll", "Libdl", "LinearAlgebra", "Logging"] +git-tree-sha1 = "9b9b347613394885fd1c8c7729bfc60528faa436" +uuid = "7d9fca2a-8960-54d3-9f78-7d1dccf2cb97" +version = "0.5.4" + +[[deps.Arpack_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "OpenBLAS_jll", "Pkg"] +git-tree-sha1 = "5ba6c757e8feccf03a1554dfaf3e26b3cfc7fd5e" +uuid = "68821587-b530-5797-8361-c406ea357684" +version = "3.5.1+1" + [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" @@ -55,6 +72,18 @@ git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be" uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" version = "0.1.0" +[[deps.AtomsBase]] +deps = ["LinearAlgebra", "PeriodicTable", "Printf", "Requires", "StaticArrays", "Unitful", "UnitfulAtomic"] +git-tree-sha1 = "995c2b6b17840cd87b722ce9c6cdd72f47bab545" +uuid = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a" +version = "0.3.5" + +[[deps.BFloat16s]] +deps = ["LinearAlgebra", "Printf", "Random", "Test"] +git-tree-sha1 = "dbf84058d0a8cbbadee18d25cf606934b22d7c66" +uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" +version = "0.4.2" + [[deps.BSON]] git-tree-sha1 = "4c3e506685c527ac6a54ccc0c8c76fd6f91b42fb" uuid = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" @@ -93,6 +122,11 @@ git-tree-sha1 = "2dc09997850d68179b69dafb58ae806167a32b1b" uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" version = "0.1.8" +[[deps.BufferedStreams]] +git-tree-sha1 = "4ae47f9a4b1dc19897d3743ff13685925c5202ec" +uuid = "e1450e63-4bb3-523b-b2a4-4ffa8c0fd77d" +version = "1.2.1" + [[deps.Bzip2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "9e2a6b69137e6969bab0152632dcb3bc108c8bdd" @@ -110,6 +144,41 @@ git-tree-sha1 = "a44910ceb69b0d44fe262dd451ab11ead3ed0be8" uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" version = "0.10.13" +[[deps.CUDA]] +deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "Crayons", "DataFrames", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LLVMLoopInfo", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "NVTX", "Preferences", "PrettyTables", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "StaticArrays", "Statistics"] +git-tree-sha1 = "baa8ea7a1ea63316fa3feb454635215773c9c845" +uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" +version = "5.2.0" +weakdeps = ["ChainRulesCore", "SpecialFunctions"] + + [deps.CUDA.extensions] + ChainRulesCoreExt = "ChainRulesCore" + SpecialFunctionsExt = "SpecialFunctions" + +[[deps.CUDA_Driver_jll]] +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] +git-tree-sha1 = "d01bfc999768f0a31ed36f5d22a76161fc63079c" +uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc" +version = "0.7.0+1" + +[[deps.CUDA_Runtime_Discovery]] +deps = ["Libdl"] +git-tree-sha1 = "2cb12f6b2209f40a4b8967697689a47c50485490" +uuid = "1af6417a-86b4-443c-805f-a4643ffb695f" +version = "0.2.3" + +[[deps.CUDA_Runtime_jll]] +deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "8e25c009d2bf16c2c31a70a6e9e8939f7325cc84" +uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" +version = "0.11.1+0" + +[[deps.CUDNN_jll]] +deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "75923dce4275ead3799b238e10178a68c07dbd3b" +uuid = "62b44479-cb7b-5706-934f-f13b2eb2e645" +version = "8.9.4+0" + [[deps.Cairo_jll]] deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"] git-tree-sha1 = "a4c43f59baa34011e303e76f5c8c91bf58415aaf" @@ -127,6 +196,7 @@ deps = ["DataAPI", "Future", "Missings", "Printf", "Requires", "Statistics", "Un git-tree-sha1 = "1568b28f91293458345dabba6a5ea3f183250a61" uuid = "324d7699-5711-5eae-9e2f-1d82baa6b597" version = "0.10.8" +weakdeps = ["JSON", "RecipesBase", "SentinelArrays", "StructTypes"] [deps.CategoricalArrays.extensions] CategoricalArraysJSONExt = "JSON" @@ -134,12 +204,6 @@ version = "0.10.8" CategoricalArraysSentinelArraysExt = "SentinelArrays" CategoricalArraysStructTypesExt = "StructTypes" - [deps.CategoricalArrays.weakdeps] - JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" - RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" - SentinelArrays = "91c51154-3ec4-41a3-a24f-3f23e20d615c" - StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" - [[deps.CategoricalDistributions]] deps = ["CategoricalArrays", "Distributions", "Missings", "OrderedCollections", "Random", "ScientificTypes"] git-tree-sha1 = "6d4569d555704cdf91b3417c0667769a4a7cbaa2" @@ -168,6 +232,18 @@ weakdeps = ["SparseArrays"] [deps.ChainRulesCore.extensions] ChainRulesCoreSparseArraysExt = "SparseArrays" +[[deps.Chemfiles]] +deps = ["AtomsBase", "Chemfiles_jll", "DocStringExtensions", "PeriodicTable", "Unitful", "UnitfulAtomic"] +git-tree-sha1 = "82fe5e341c793cb51149d993307da9543824b206" +uuid = "46823bd8-5fb3-5f92-9aa0-96921f3dd015" +version = "0.10.41" + +[[deps.Chemfiles_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "f3743181e30d87c23d9c8ebd493b77f43d8f1890" +uuid = "78a364fa-1a3c-552a-b4bb-8fa0f9c1fcca" +version = "0.10.4+0" + [[deps.CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] git-tree-sha1 = "59939d8a997469ee05c4b4944560a820f9ba0d73" @@ -275,6 +351,22 @@ git-tree-sha1 = "d05d9e7b7aedff4e5b51a029dced05cfb6125781" uuid = "d38c429a-6771-53c6-b99e-75d170b6e991" version = "0.6.2" +[[deps.CounterfactualExplanations]] +deps = ["CSV", "CUDA", "CategoricalArrays", "ChainRulesCore", "DataFrames", "DecisionTree", "Distributions", "EvoTrees", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "Logging", "MLDatasets", "MLJBase", "MLJDecisionTreeInterface", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "PackageExtensionCompat", "Parameters", "PrecompileTools", "ProgressMeter", "Random", "Serialization", "Statistics", "StatsBase", "Tables", "UUIDs", "cuDNN"] +git-tree-sha1 = "af4687806d81a3265173fad6250e3902eb659f37" +uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" +version = "0.1.31" + + [deps.CounterfactualExplanations.extensions] + MPIExt = "MPI" + PythonCallExt = "PythonCall" + RCallExt = "RCall" + + [deps.CounterfactualExplanations.weakdeps] + MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" + PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" + RCall = "6f49c342-dc21-5d91-9882-a32aef131414" + [[deps.Crayons]] git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" @@ -285,6 +377,12 @@ git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" version = "1.16.0" +[[deps.DataDeps]] +deps = ["HTTP", "Libdl", "Reexport", "SHA", "Scratch", "p7zip_jll"] +git-tree-sha1 = "8ae085b71c462c2cb1cfedcb10c3c877ec6cf03f" +uuid = "124859b0-ceae-595e-8997-d05f6a7a8dfe" +version = "0.7.13" + [[deps.DataFrames]] deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8" @@ -306,6 +404,12 @@ version = "1.0.0" deps = ["Printf"] uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" +[[deps.DecisionTree]] +deps = ["AbstractTrees", "DelimitedFiles", "LinearAlgebra", "Random", "ScikitLearnBase", "Statistics"] +git-tree-sha1 = "526ca14aaaf2d5a0e242f3a8a7966eb9065d7d78" +uuid = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb" +version = "0.12.4" + [[deps.DefineSingletons]] git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52" @@ -377,6 +481,12 @@ git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566" uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" version = "0.6.8" +[[deps.EarCut_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "e3290f2d49e661fbd94046d7e3726ffcb2d41053" +uuid = "5ae413db-bbd1-5e63-b57d-d24a61df00f5" +version = "2.2.4+0" + [[deps.EarlyStopping]] deps = ["Dates", "Statistics"] git-tree-sha1 = "98fdf08b707aaf69f524a6cd0a67858cefe0cfb6" @@ -389,6 +499,16 @@ git-tree-sha1 = "8e9441ee83492030ace98f9789a654a6d0b1f643" uuid = "2702e6a9-849d-5ed8-8c21-79e8b8f9ee43" version = "0.0.20230411+0" +[[deps.EvoTrees]] +deps = ["BSON", "CategoricalArrays", "Distributions", "MLJModelInterface", "NetworkLayout", "Random", "RecipesBase", "Statistics", "StatsBase", "Tables"] +git-tree-sha1 = "e1107e45d7fe1a3c5dd335376bb6333b42cf9d1c" +uuid = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" +version = "0.16.6" +weakdeps = ["CUDA"] + + [deps.EvoTrees.extensions] + EvoTreesCUDAExt = "CUDA" + [[deps.ExceptionUnwrapping]] deps = ["Test"] git-tree-sha1 = "dcb08a0d93ec0b1cdc4af184b26b591e9695423a" @@ -401,6 +521,16 @@ git-tree-sha1 = "4558ab818dcceaab612d1bb8c19cee87eda2b83c" uuid = "2e619515-83b5-522b-bb60-26c02a35a201" version = "2.5.0+0" +[[deps.ExprTools]] +git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" +uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +version = "0.1.10" + +[[deps.Extents]] +git-tree-sha1 = "2140cd04483da90b2da7f99b2add0750504fc39c" +uuid = "411431e0-e8b7-467b-b5e0-f676ba4f2910" +version = "0.1.2" + [[deps.FFMPEG]] deps = ["FFMPEG_jll"] git-tree-sha1 = "b57e3acbe22f8484b4b5ff66a7499717fe1a9cc8" @@ -537,6 +667,12 @@ git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950" uuid = "46192b85-c4d5-4398-a991-12ede77f4527" version = "0.1.6" +[[deps.GPUCompiler]] +deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"] +git-tree-sha1 = "a846f297ce9d09ccba02ead0cae70690e072a119" +uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" +version = "0.25.0" + [[deps.GR]] deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Pkg", "Preferences", "Printf", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "UUIDs", "p7zip_jll"] git-tree-sha1 = "3437ade7073682993e092ca570ad68a2aba26983" @@ -549,6 +685,24 @@ git-tree-sha1 = "a96d5c713e6aa28c242b0d25c1347e258d6541ab" uuid = "d2c73de3-f751-5644-a686-071e5b155ba9" version = "0.73.3+0" +[[deps.GZip]] +deps = ["Libdl", "Zlib_jll"] +git-tree-sha1 = "0085ccd5ec327c077ec5b91a5f937b759810ba62" +uuid = "92fee26a-97fe-5a0c-ad85-20a5f3185b63" +version = "0.6.2" + +[[deps.GeoInterface]] +deps = ["Extents"] +git-tree-sha1 = "d4f85701f569584f2cff7ba67a137d03f0cfb7d0" +uuid = "cf35fbd7-0cd7-5166-be24-54bfbe79505f" +version = "1.3.3" + +[[deps.GeometryBasics]] +deps = ["EarCut_jll", "Extents", "GeoInterface", "IterTools", "LinearAlgebra", "StaticArrays", "StructArrays", "Tables"] +git-tree-sha1 = "5694b56ccf9d15addedc35e9a4ba9c317721b788" +uuid = "5c1252a2-5f33-56bf-86c9-59e7332b4326" +version = "0.4.10" + [[deps.Gettext_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Libiconv_jll", "Pkg", "XML2_jll"] git-tree-sha1 = "9b02998aba7bf074d14de89f9d37ca24a1a0b046" @@ -561,6 +715,11 @@ git-tree-sha1 = "359a1ba2e320790ddbe4ee8b4d54a305c0ea2aff" uuid = "7746bdde-850d-59dc-9ae8-88ece973131d" version = "2.80.0+0" +[[deps.Glob]] +git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496" +uuid = "c27321d9-0574-5035-807b-f59d2c89b15c" +version = "1.3.1" + [[deps.Graphite2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "344bf40dcab1073aca04aa0df4fb092f920e4011" @@ -572,6 +731,24 @@ git-tree-sha1 = "53bb909d1151e57e2484c3d1b53e19552b887fb2" uuid = "42e2da0e-8278-4e71-bc24-59509adca0fe" version = "1.0.2" +[[deps.HDF5]] +deps = ["Compat", "HDF5_jll", "Libdl", "MPIPreferences", "Mmap", "Preferences", "Printf", "Random", "Requires", "UUIDs"] +git-tree-sha1 = "26407bd1c60129062cec9da63dc7d08251544d53" +uuid = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" +version = "0.17.1" + + [deps.HDF5.extensions] + MPIExt = "MPI" + + [deps.HDF5.weakdeps] + MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" + +[[deps.HDF5_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "LibCURL_jll", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "OpenSSL_jll", "TOML", "Zlib_jll", "libaec_jll"] +git-tree-sha1 = "e4591176488495bf44d7456bd73179d87d5e6eab" +uuid = "0234f1f7-429e-5d53-9886-15a909be8d59" +version = "1.14.3+1" + [[deps.HTTP]] deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] git-tree-sha1 = "995f762e0182ebc50548c434c171a5bb6635f8e4" @@ -584,6 +761,12 @@ git-tree-sha1 = "129acf094d168394e80ee1dc4bc06ec835e510a3" uuid = "2e76f6c2-a576-52d4-95c1-20adfe4de566" version = "2.8.1+1" +[[deps.Hwloc_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "ca0f6bf568b4bfc807e7537f081c81e35ceca114" +uuid = "e33a78d0-f292-5ffc-b300-72abe9b543c8" +version = "2.10.0+0" + [[deps.HypergeometricFunctions]] deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685" @@ -596,6 +779,24 @@ git-tree-sha1 = "5d8c5713f38f7bc029e26627b687710ba406d0dd" uuid = "7869d1d1-7146-5819-86e3-90919afe41df" version = "0.4.12" +[[deps.ImageBase]] +deps = ["ImageCore", "Reexport"] +git-tree-sha1 = "eb49b82c172811fd2c86759fa0553a2221feb909" +uuid = "c817782e-172a-44cc-b673-b171935fbb9e" +version = "0.1.7" + +[[deps.ImageCore]] +deps = ["ColorVectorSpace", "Colors", "FixedPointNumbers", "MappedArrays", "MosaicViews", "OffsetArrays", "PaddedViews", "PrecompileTools", "Reexport"] +git-tree-sha1 = "b2a7eaa169c13f5bcae8131a83bc30eff8f71be0" +uuid = "a09fc81d-aa75-5fe9-8630-4744c3626534" +version = "0.10.2" + +[[deps.ImageShow]] +deps = ["Base64", "ColorSchemes", "FileIO", "ImageBase", "ImageCore", "OffsetArrays", "StackViews"] +git-tree-sha1 = "3b5344bcdbdc11ad58f3b1956709b5b9345355de" +uuid = "4e3cecfd-b093-5904-9786-8bbb286a6a31" +version = "0.3.8" + [[deps.InitialValues]] git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" @@ -611,6 +812,12 @@ version = "1.4.0" deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +[[deps.InternedStrings]] +deps = ["Random", "Test"] +git-tree-sha1 = "eb05b5625bc5d821b8075a77e4c421933e20c76b" +uuid = "7d512f48-7fb1-5a58-b986-67e6dc259f01" +version = "0.7.0" + [[deps.InvertedIndices]] git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" @@ -621,6 +828,11 @@ git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" version = "0.2.2" +[[deps.IterTools]] +git-tree-sha1 = "42d5f897009e7ff2cf88db414a389e5ed1bdd023" +uuid = "c8e1da08-722c-5040-9ed9-7db0dc04731e" +version = "1.10.0" + [[deps.IterationControl]] deps = ["EarlyStopping", "InteractiveUtils"] git-tree-sha1 = "e663925ebc3d93c1150a7570d114f9ea2f664726" @@ -656,12 +868,30 @@ git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" version = "0.21.4" +[[deps.JSON3]] +deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"] +git-tree-sha1 = "eb3edce0ed4fa32f75a0a11217433c31d56bd48b" +uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" +version = "1.14.0" + + [deps.JSON3.extensions] + JSON3ArrowExt = ["ArrowTypes"] + + [deps.JSON3.weakdeps] + ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" + [[deps.JpegTurbo_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] git-tree-sha1 = "3336abae9a713d2210bb57ab484b1e065edd7d23" uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8" version = "3.0.2+0" +[[deps.JuliaNVTXCallbacks_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "af433a10f3942e882d3c671aacb203e006a5808f" +uuid = "9c1d0b0a-7046-5b2e-a33f-ea22f176ac7e" +version = "0.2.1+0" + [[deps.JuliaVariables]] deps = ["MLStyle", "NameResolution"] git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" @@ -697,19 +927,22 @@ deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", git-tree-sha1 = "7c6650580b4c3169d9905858160db895bff6d2e2" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" version = "6.6.1" +weakdeps = ["BFloat16s"] [deps.LLVM.extensions] BFloat16sExt = "BFloat16s" - [deps.LLVM.weakdeps] - BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" - [[deps.LLVMExtra_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] git-tree-sha1 = "88b916503aac4fb7f701bb625cd84ca5dd1677bc" uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" version = "0.0.29+0" +[[deps.LLVMLoopInfo]] +git-tree-sha1 = "2e5c102cfc41f48ae4740c7eca7743cc7e7b75ea" +uuid = "8b046642-f1f6-4319-8d3c-209ddc03c586" +version = "1.0.0" + [[deps.LLVMOpenMP_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] git-tree-sha1 = "d986ce2d884d49126836ea94ed5bfb0f12679713" @@ -727,6 +960,12 @@ git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" version = "1.3.1" +[[deps.LaplaceRedux]] +deps = ["Compat", "ComputationalResources", "Flux", "LinearAlgebra", "MLJFlux", "MLJModelInterface", "MLUtils", "Parameters", "ProgressMeter", "Random", "Statistics", "Tables", "Tullio", "Zygote", "cuDNN"] +git-tree-sha1 = "28b08415d15f8cad6bc2935203a3f99f00f5195a" +uuid = "c52c1a26-f7c5-402b-80be-ba1e638ad478" +version = "0.1.4" + [[deps.Latexify]] deps = ["Format", "InteractiveUtils", "LaTeXStrings", "MacroTools", "Markdown", "OrderedCollections", "Requires"] git-tree-sha1 = "cad560042a7cc108f5a4c24ea1431a9221f22c1b" @@ -751,6 +990,11 @@ version = "1.9.0" deps = ["Artifacts", "Pkg"] uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" +[[deps.LazyModules]] +git-tree-sha1 = "a560dd966b386ac9ae60bdd3a3d3a326062d3c3e" +uuid = "8cdb02fc-e678-4876-92c5-9defec4f444e" +version = "0.3.1" + [[deps.LearnAPI]] deps = ["InteractiveUtils", "Statistics"] git-tree-sha1 = "ec695822c1faaaa64cee32d0b21505e1977b4809" @@ -861,6 +1105,18 @@ git-tree-sha1 = "c1dd6d7978c12545b4179fb6153b9250c96b0075" uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" version = "1.0.3" +[[deps.MAT]] +deps = ["BufferedStreams", "CodecZlib", "HDF5", "SparseArrays"] +git-tree-sha1 = "ed1cf0a322d78cee07718bed5fd945e2218c35a1" +uuid = "23992714-dd62-5051-b70f-ba57cb901cac" +version = "0.10.6" + +[[deps.MLDatasets]] +deps = ["CSV", "Chemfiles", "DataDeps", "DataFrames", "DelimitedFiles", "FileIO", "FixedPointNumbers", "GZip", "Glob", "HDF5", "ImageShow", "JLD2", "JSON3", "LazyModules", "MAT", "MLUtils", "NPZ", "Pickle", "Printf", "Requires", "SparseArrays", "Statistics", "Tables"] +git-tree-sha1 = "aab72207b3c687086a400be710650a57494992bd" +uuid = "eb30cadb-4394-5ae3-aed4-317e484a6458" +version = "0.7.14" + [[deps.MLFlowClient]] deps = ["Dates", "FilePathsBase", "HTTP", "JSON", "ShowCases", "URIs", "UUIDs"] git-tree-sha1 = "049b39a208b052d020e18a0850ca9d228a11ef16" @@ -889,6 +1145,12 @@ weakdeps = ["StatisticalMeasures"] [deps.MLJBase.extensions] DefaultMeasuresExt = "StatisticalMeasures" +[[deps.MLJDecisionTreeInterface]] +deps = ["CategoricalArrays", "DecisionTree", "MLJModelInterface", "Random", "Tables"] +git-tree-sha1 = "1330eb4b8560bcc53d3878a2c9a08c75f99d530d" +uuid = "c6f25543-311c-4c74-83dc-3ea6d1015661" +version = "0.4.1" + [[deps.MLJEnsembles]] deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Distributed", "Distributions", "MLJModelInterface", "ProgressMeter", "Random", "ScientificTypesBase", "StatisticalMeasuresBase", "StatsBase"] git-tree-sha1 = "94403b2c8f692011df6731913376e0e37f6c0fe9" @@ -942,12 +1204,35 @@ git-tree-sha1 = "b45738c2e3d0d402dffa32b2c1654759a2ac35a4" uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" version = "0.4.4" +[[deps.MPICH_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] +git-tree-sha1 = "656036b9ed6f942d35e536e249600bc31d0f9df8" +uuid = "7cb0a576-ebde-5e09-9194-50597f1243b4" +version = "4.2.0+0" + +[[deps.MPIPreferences]] +deps = ["Libdl", "Preferences"] +git-tree-sha1 = "8f6af051b9e8ec597fa09d8885ed79fd582f33c9" +uuid = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267" +version = "0.1.10" + +[[deps.MPItrampoline_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] +git-tree-sha1 = "77c3bd69fdb024d75af38713e883d0f249ce19c2" +uuid = "f1f71cc9-e9ae-5b93-9b94-4fe0e1ad3748" +version = "5.3.2+0" + [[deps.MacroTools]] deps = ["Markdown", "Random"] git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" version = "0.5.13" +[[deps.MappedArrays]] +git-tree-sha1 = "2dab0221fe2b0f2cb6754eaa743cc266339f527e" +uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" +version = "0.4.2" + [[deps.Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" @@ -973,19 +1258,23 @@ deps = ["Artifacts", "BSON", "ChainRulesCore", "Flux", "Functors", "JLD2", "Lazy git-tree-sha1 = "5aac9a2b511afda7bf89df5044a2e0b429f83152" uuid = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" version = "0.9.3" +weakdeps = ["CUDA"] [deps.Metalhead.extensions] MetalheadCUDAExt = "CUDA" - [deps.Metalhead.weakdeps] - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - [[deps.MicroCollections]] deps = ["BangBang", "InitialValues", "Setfield"] git-tree-sha1 = "629afd7d10dbc6935ec59b32daeb33bc4460a42e" uuid = "128add7d-3638-4c79-886c-908ea0c25c34" version = "0.1.4" +[[deps.MicrosoftMPI_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "f12a29c4400ba812841c6ace3f4efbb6dbb3ba01" +uuid = "9237b28f-5490-5468-be7b-bb81f5f5e6cf" +version = "10.1.4+2" + [[deps.Missings]] deps = ["DataAPI"] git-tree-sha1 = "f66bdc5de519e8f8ae43bdc598782d35a25b1272" @@ -995,10 +1284,22 @@ version = "1.1.0" [[deps.Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" +[[deps.MosaicViews]] +deps = ["MappedArrays", "OffsetArrays", "PaddedViews", "StackViews"] +git-tree-sha1 = "7b86a5d4d70a9f5cdf2dacb3cbe6d251d1a61dbe" +uuid = "e94cdb99-869f-56ef-bcf0-1ae2bcbe0389" +version = "0.3.4" + [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" version = "2023.1.10" +[[deps.MultivariateStats]] +deps = ["Arpack", "LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI", "StatsBase"] +git-tree-sha1 = "68bf5103e002c44adfd71fea6bd770b3f0586843" +uuid = "6f286f6a-111f-5878-ab1e-185364afe411" +version = "0.10.2" + [[deps.NNlib]] deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] git-tree-sha1 = "877f15c331337d54cf24c797d5bcb2e48ce21221" @@ -1017,6 +1318,24 @@ version = "0.9.12" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" +[[deps.NPZ]] +deps = ["FileIO", "ZipFile"] +git-tree-sha1 = "60a8e272fe0c5079363b28b0953831e2dd7b7e6f" +uuid = "15e1cf62-19b3-5cfa-8e77-841668bca605" +version = "0.4.3" + +[[deps.NVTX]] +deps = ["Colors", "JuliaNVTXCallbacks_jll", "Libdl", "NVTX_jll"] +git-tree-sha1 = "53046f0483375e3ed78e49190f1154fa0a4083a1" +uuid = "5da4648a-3479-48b8-97b9-01cb529c0a1f" +version = "0.3.4" + +[[deps.NVTX_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "ce3269ed42816bf18d500c9f63418d4b0d9f5a3b" +uuid = "e98f9f5b-d649-5603-91fd-7774390e6439" +version = "3.1.0+2" + [[deps.NaNMath]] deps = ["OpenLibm_jll"] git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" @@ -1029,10 +1348,43 @@ git-tree-sha1 = "1a0fa0e9613f46c9b8c11eee38ebb4f590013c5e" uuid = "71a1bf82-56d0-4bbc-8a3c-48b961074391" version = "0.1.5" +[[deps.NearestNeighborModels]] +deps = ["Distances", "FillArrays", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "NearestNeighbors", "Statistics", "StatsBase", "Tables"] +git-tree-sha1 = "e411143a8362926e4284a54e745972e939fbab78" +uuid = "636a865e-7cf4-491e-846c-de09b730eb36" +version = "0.2.3" + +[[deps.NearestNeighbors]] +deps = ["Distances", "StaticArrays"] +git-tree-sha1 = "ded64ff6d4fdd1cb68dfcbb818c69e144a5b2e4c" +uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce" +version = "0.4.16" + +[[deps.NetworkLayout]] +deps = ["GeometryBasics", "LinearAlgebra", "Random", "Requires", "StaticArrays"] +git-tree-sha1 = "91bb2fedff8e43793650e7a677ccda6e6e6e166b" +uuid = "46757867-2c16-5918-afeb-47bfcb05e46a" +version = "0.4.6" + + [deps.NetworkLayout.extensions] + NetworkLayoutGraphsExt = "Graphs" + + [deps.NetworkLayout.weakdeps] + Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" + [[deps.NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" version = "1.2.0" +[[deps.OffsetArrays]] +git-tree-sha1 = "6a731f2b5c03157418a20c12195eb4b74c8f8621" +uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +version = "1.13.0" +weakdeps = ["Adapt"] + + [deps.OffsetArrays.extensions] + OffsetArraysAdaptExt = "Adapt" + [[deps.Ogg_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "887579a3eb005446d514ab7aeac5d1d027658b8f" @@ -1061,6 +1413,12 @@ git-tree-sha1 = "6efb039ae888699d5a74fb593f6f3e10c7193e33" uuid = "8b6db2d4-7670-4922-a472-f9537c81ab66" version = "0.3.1" +[[deps.OpenMPI_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] +git-tree-sha1 = "e25c1778a98e34219a00455d6e4384e017ea9762" +uuid = "fe0851c0-eecd-5654-98d4-656369965a5c" +version = "4.1.6+0" + [[deps.OpenSSL]] deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] git-tree-sha1 = "af81a32750ebc831ee28bdaaba6e1067decef51e" @@ -1107,6 +1465,18 @@ git-tree-sha1 = "949347156c25054de2db3b166c52ac4728cbad65" uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" version = "0.11.31" +[[deps.PackageExtensionCompat]] +git-tree-sha1 = "fb28e33b8a95c4cee25ce296c817d89cc2e53518" +uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930" +version = "1.0.2" +weakdeps = ["Requires", "TOML"] + +[[deps.PaddedViews]] +deps = ["OffsetArrays"] +git-tree-sha1 = "0fac6313486baae819364c52b4f483450a9d793f" +uuid = "5432bcbf-9aad-5242-b902-cca2824c8663" +version = "0.5.12" + [[deps.Parameters]] deps = ["OrderedCollections", "UnPack"] git-tree-sha1 = "34c0e9ad262e5f7fc75b10a9952ca7692cfc5fbe" @@ -1125,6 +1495,18 @@ git-tree-sha1 = "47b49a4dbc23b76682205c646252c0f9e1eb75af" uuid = "570af359-4316-4cb7-8c74-252c00c2016b" version = "1.2.0" +[[deps.PeriodicTable]] +deps = ["Base64", "Unitful"] +git-tree-sha1 = "238aa6298007565529f911b734e18addd56985e1" +uuid = "7b2266bf-644c-5ea3-82d8-af4bbd25a884" +version = "1.2.1" + +[[deps.Pickle]] +deps = ["BFloat16s", "DataStructures", "InternedStrings", "Serialization", "SparseArrays", "Strided", "StringEncodings", "ZipFile"] +git-tree-sha1 = "2e71d7dbcab8dc47306c0ed6ac6018fbc1a7070f" +uuid = "fbb45041-c46e-462f-888f-7c521cafbc2c" +version = "0.3.3" + [[deps.Pipe]] git-tree-sha1 = "6842804e7867b115ca9de748a0cf6b364523c16d" uuid = "b98c9c47-44ae-5843-9183-064241ee97a0" @@ -1243,6 +1625,18 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" deps = ["SHA"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +[[deps.Random123]] +deps = ["Random", "RandomNumbers"] +git-tree-sha1 = "4743b43e5a9c4a2ede372de7061eed81795b12e7" +uuid = "74087812-796a-5b5d-8853-05524746bad3" +version = "1.7.0" + +[[deps.RandomNumbers]] +deps = ["Random", "Requires"] +git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111" +uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143" +version = "1.5.3" + [[deps.RealDot]] deps = ["LinearAlgebra"] git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" @@ -1305,6 +1699,12 @@ git-tree-sha1 = "a8e18eb383b5ecf1b5e6fc237eb39255044fd92b" uuid = "30f210dd-8aff-4c5f-94ba-8e64358c1161" version = "3.0.0" +[[deps.ScikitLearnBase]] +deps = ["LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "7877e55c1523a4b336b433da39c8e8c08d2f221f" +uuid = "6e75b9c4-186b-50bd-896f-2d2496a4843e" +version = "0.5.0" + [[deps.Scratch]] deps = ["Dates"] git-tree-sha1 = "3bac05bc7e74a75fd9cba4295cde4045d9fe2386" @@ -1390,6 +1790,12 @@ git-tree-sha1 = "ddc1a7b85e760b5285b50b882fa91e40c603be47" uuid = "860ef19b-820b-49d6-a774-d7a799459cd3" version = "1.0.1" +[[deps.StackViews]] +deps = ["OffsetArrays"] +git-tree-sha1 = "46e589465204cd0c08b4bd97385e4fa79a0c770c" +uuid = "cae243ae-269e-4f55-b966-ac2d0dc13c15" +version = "0.1.1" + [[deps.StaticArrays]] deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] git-tree-sha1 = "bf074c045d3d5ffd956fa0a461da38a44685d6b2" @@ -1463,6 +1869,18 @@ version = "1.3.1" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" +[[deps.Strided]] +deps = ["LinearAlgebra", "TupleTools"] +git-tree-sha1 = "a7a664c91104329c88222aa20264e1a05b6ad138" +uuid = "5e0ebb24-38b0-5f93-81fe-25c709ecae67" +version = "1.2.3" + +[[deps.StringEncodings]] +deps = ["Libiconv_jll"] +git-tree-sha1 = "b765e46ba27ecf6b44faf70df40c57aa3a547dcb" +uuid = "69024149-9ee7-55f6-a4c4-859efe599b68" +version = "0.3.7" + [[deps.StringManipulation]] deps = ["PrecompileTools"] git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5" @@ -1482,6 +1900,12 @@ weakdeps = ["Adapt", "GPUArraysCore", "SparseArrays", "StaticArrays"] StructArraysSparseArraysExt = "SparseArrays" StructArraysStaticArraysExt = "StaticArrays" +[[deps.StructTypes]] +deps = ["Dates", "UUIDs"] +git-tree-sha1 = "ca4bccb03acf9faaf4137a9abc1881ed1841aa70" +uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" +version = "1.10.0" + [[deps.SuiteSparse]] deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" @@ -1508,6 +1932,12 @@ git-tree-sha1 = "cb76cf677714c095e535e3501ac7954732aeea2d" uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" version = "1.11.1" +[[deps.TaijaData]] +deps = ["CSV", "CounterfactualExplanations", "DataAPI", "DataFrames", "Flux", "LazyArtifacts", "MLDatasets", "MLJBase", "MLJModels", "Random", "StatsBase"] +git-tree-sha1 = "1b27d27767404cc3d55d15158e1701be3ce084b4" +uuid = "9d524318-b4e6-4a65-86d2-b2b72d07866c" +version = "0.1.0" + [[deps.Tar]] deps = ["ArgTools", "SHA"] uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" @@ -1523,6 +1953,12 @@ version = "0.1.1" deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[[deps.TimerOutputs]] +deps = ["ExprTools", "Printf"] +git-tree-sha1 = "f548a9e9c490030e545f72074a41edfd0e5bcdd7" +uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" +version = "0.5.23" + [[deps.TranscodingStreams]] git-tree-sha1 = "a09c933bebed12501890d8e92946bbab6a1690f1" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" @@ -1570,6 +2006,11 @@ version = "0.3.7" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +[[deps.TupleTools]] +git-tree-sha1 = "41d61b1c545b06279871ef1a4b5fcb2cac2191cd" +uuid = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" +version = "1.5.0" + [[deps.URIs]] git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b" uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" @@ -1607,6 +2048,12 @@ version = "1.19.0" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" +[[deps.UnitfulAtomic]] +deps = ["Unitful"] +git-tree-sha1 = "903be579194534af1c4b4778d1ace676ca042238" +uuid = "a7773ee8-282e-5fa2-be4e-bd808c38a91a" +version = "1.0.0" + [[deps.UnitfulLatexify]] deps = ["LaTeXStrings", "Latexify", "Unitful"] git-tree-sha1 = "e2d817cc500e960fdbafcf988ac8436ba3208bfd" @@ -1820,6 +2267,12 @@ git-tree-sha1 = "e92a1a012a10506618f10b7047e478403a046c77" uuid = "c5fb5394-a638-5e4d-96e5-b29de1b5cf10" version = "1.5.0+0" +[[deps.ZipFile]] +deps = ["Libdl", "Printf", "Zlib_jll"] +git-tree-sha1 = "f492b7fe1698e623024e873244f10d89c95c340a" +uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" +version = "0.10.1" + [[deps.Zlib_jll]] deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" @@ -1853,6 +2306,12 @@ git-tree-sha1 = "27798139afc0a2afa7b1824c206d5e87ea587a00" uuid = "700de1a5-db45-46bc-99cf-38207098b444" version = "0.2.5" +[[deps.cuDNN]] +deps = ["CEnum", "CUDA", "CUDA_Runtime_Discovery", "CUDNN_jll"] +git-tree-sha1 = "d433ec29756895512190cac9c96666d879f07b92" +uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" +version = "1.3.0" + [[deps.eudev_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "gperf_jll"] git-tree-sha1 = "431b678a28ebb559d224c0b6b6d01afce87c51ba" @@ -1871,6 +2330,12 @@ git-tree-sha1 = "3516a5630f741c9eecb3720b1ec9d8edc3ecc033" uuid = "1a1c6b14-54f6-533d-8383-74cd7377aa70" version = "3.1.1+0" +[[deps.libaec_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "46bf7be2917b59b761247be3f317ddf75e50e997" +uuid = "477f73a3-ac25-53e9-8cc3-50b2fa2566f0" +version = "1.1.2+0" + [[deps.libaom_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "3a2ea60308f0996d26f1e5354e10c24e9ef905d4" diff --git a/test/Project.toml b/test/Project.toml index 3023deb3..42d47090 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,23 +1,25 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" -Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" -CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" -DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +TaijaData = "9d524318-b4e6-4a65-86d2-b2b72d07866c" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc" -Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" -DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -Aqua = "0.8" \ No newline at end of file +Aqua = "0.8" diff --git a/test/counterfactual_explanations.jl b/test/counterfactual_explanations.jl new file mode 100644 index 00000000..d6703059 --- /dev/null +++ b/test/counterfactual_explanations.jl @@ -0,0 +1,18 @@ +using CounterfactualExplanations +using CounterfactualExplanations.Models +using TaijaData + +counterfactual_data = TaijaData.load_linearly_separable() |> + x -> (Float32.(x[1]), x[2]) |> + x -> CounterfactualData(x...) +M = Models.fit_model(counterfactual_data, :LaplaceRedux) + +# Select a factual instance: +target = 2 +factual = 1 +chosen = rand(findall(predict_label(M, counterfactual_data) .== factual)) +x = select_factual(counterfactual_data, chosen) + +# Search: +generator = GenericGenerator() +ce = generate_counterfactual(x, target, counterfactual_data, M, generator) \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 5fa82a88..92ac8bfc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,4 +35,8 @@ using Test @testset "MLJFlux" begin include("mlj_flux_interfacing.jl") end + + @testset "CounterfactualExplanations" begin + include("counterfactual_explanations.jl") + end end From 18070a9509e7ba7c1da9bb548835ac20f05673b6 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Thu, 21 Mar 2024 17:52:40 +0100 Subject: [PATCH 2/3] formatter --- Project.toml | 2 +- test/counterfactual_explanations.jl | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index ec740ef5..474d9ab5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LaplaceRedux" uuid = "c52c1a26-f7c5-402b-80be-ba1e638ad478" authors = ["Patrick Altmeyer"] -version = "0.1.5" +version = "0.1.6" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/test/counterfactual_explanations.jl b/test/counterfactual_explanations.jl index d6703059..b618b38a 100644 --- a/test/counterfactual_explanations.jl +++ b/test/counterfactual_explanations.jl @@ -2,9 +2,9 @@ using CounterfactualExplanations using CounterfactualExplanations.Models using TaijaData -counterfactual_data = TaijaData.load_linearly_separable() |> - x -> (Float32.(x[1]), x[2]) |> - x -> CounterfactualData(x...) +counterfactual_data = + TaijaData.load_linearly_separable() |> + x -> (Float32.(x[1]), x[2]) |> x -> CounterfactualData(x...) M = Models.fit_model(counterfactual_data, :LaplaceRedux) # Select a factual instance: @@ -15,4 +15,4 @@ x = select_factual(counterfactual_data, chosen) # Search: generator = GenericGenerator() -ce = generate_counterfactual(x, target, counterfactual_data, M, generator) \ No newline at end of file +ce = generate_counterfactual(x, target, counterfactual_data, M, generator) From e880e0cee57557bdc5576ece55a8f8f9b4de87d5 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Thu, 21 Mar 2024 18:06:38 +0100 Subject: [PATCH 3/3] hmm --- Project.toml | 1 + src/curvature/utils.jl | 8 +++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 474d9ab5..b0738c62 100644 --- a/Project.toml +++ b/Project.toml @@ -22,6 +22,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] Aqua = "0.8" +ChainRulesCore = "1.23.0" Compat = "4.7.0" ComputationalResources = "0.3.2" Flux = "0.12, 0.13, 0.14" diff --git a/src/curvature/utils.jl b/src/curvature/utils.jl index 16edab47..f71ffeac 100644 --- a/src/curvature/utils.jl +++ b/src/curvature/utils.jl @@ -27,9 +27,12 @@ function jacobians_unbatched(curvature::CurvatureInterface, X::AbstractArray) ŷ = vec(ŷ) # Jacobian: # Differentiate f with regards to the model parameters + J = [] ChainRulesCore.ignore_derivatives() do 𝐉 = jacobian(() -> nn(X), Flux.params(nn)) + push!(J, 𝐉) end + 𝐉 = J[1] # Concatenate Jacobians for the selected parameters, to produce a matrix (K, P), where P is the total number of parameter scalars. 𝐉 = reduce(hcat, [𝐉[θ] for θ in curvature.params]) if curvature.subset_of_weights == :subnetwork @@ -50,9 +53,12 @@ function jacobians_batched(curvature::CurvatureInterface, X::AbstractArray) batch_size = size(X)[end] out_size = outdim(nn) # Jacobian: + grads = [] ChainRulesCore.ignore_derivatives() do - grads = jacobian(() -> nn(X), Flux.params(nn)) + g = jacobian(() -> nn(X), Flux.params(nn)) + push!(grads, g) end + grads = grads[1] grads_joint = reduce(hcat, [grads[θ] for θ in curvature.params]) views = [ @view grads_joint[batch_start:(batch_start + out_size - 1), :] for