Skip to content

Commit

Permalink
commented out unused GGN
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-alt committed Nov 25, 2022
1 parent ab5efd4 commit ce22a2d
Showing 1 changed file with 32 additions and 32 deletions.
64 changes: 32 additions & 32 deletions src/curvature/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,49 +32,49 @@ function gradients(curvature::CurvatureInterface, X::AbstractArray, y::Union{Num
return 𝐠
end

"Constructor for Generalized Gauss Newton."
struct GGN <: CurvatureInterface
model::Any
likelihood::Symbol
loss_fun::Function
params::AbstractArray
factor::Union{Nothing,Real}
end
# "Constructor for Generalized Gauss Newton."
# struct GGN <: CurvatureInterface
# model::Any
# likelihood::Symbol
# loss_fun::Function
# params::AbstractArray
# factor::Union{Nothing,Real}
# end

function GGN(model::Any, likelihood::Symbol, params::AbstractArray)
# function GGN(model::Any, likelihood::Symbol, params::AbstractArray)

@error "GGN not yet implemented."
# @error "GGN not yet implemented."

# Define loss function:
loss_fun = get_loss_fun(likelihood, model)
factor = likelihood == :regression ? 0.5 : 1.0
# # Define loss function:
# loss_fun = get_loss_fun(likelihood, model)
# factor = likelihood == :regression ? 0.5 : 1.0

GGN(model, likelihood, loss_fun, params, factor)
end
# GGN(model, likelihood, loss_fun, params, factor)
# end

"""
full(curvature::GGN, d::Union{Tuple,NamedTuple})
# """
# full(curvature::GGN, d::Union{Tuple,NamedTuple})

Compute the full GGN.
"""
function full(curvature::GGN, d::Tuple)
x, y = d
# Compute the full GGN.
# """
# function full(curvature::GGN, d::Tuple)
# x, y = d

loss = curvature.factor * curvature.loss_fun(x, y)
# loss = curvature.factor * curvature.loss_fun(x, y)

𝐉, fμ = jacobians(curvature, x)
# 𝐉, fμ = jacobians(curvature, x)

if curvature.likelihood == :regression
H = 𝐉 * 𝐉'
else
p = outdim(curvature.model) > 1 ? softmax(fμ) : sigmoid(fμ)
H = map(j -> j * (diagm(p) - p * p') * j', eachcol(𝐉))
println(H)
end
# if curvature.likelihood == :regression
# H = 𝐉 * 𝐉'
# else
# p = outdim(curvature.model) > 1 ? softmax(fμ) : sigmoid(fμ)
# H = map(j -> j * (diagm(p) - p * p') * j', eachcol(𝐉))
# println(H)
# end

return loss, H
# return loss, H

end
# end

"Constructor for Empirical Fisher."
struct EmpiricalFisher <: CurvatureInterface
Expand Down

0 comments on commit ce22a2d

Please sign in to comment.