Skip to content

Commit

Permalink
fix reset_search_direction! failure when training with GPU (#1034)
Browse files Browse the repository at this point in the history
* fix reset_search_direction! error when training with GPU

* fix reduced codecov
  • Loading branch information
wei3li authored Aug 7, 2023
1 parent b706464 commit 934cee0
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions src/utilities/perform_linesearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@ _alphaguess(a::Number) = LineSearches.InitialStatic(alpha=a)
# project_tangent! here, because we already did that inplace on gradient(d) after
# the last evaluation (we basically just always do it)
function reset_search_direction!(state, d, method::BFGS)
n = length(state.x)
T = eltype(state.x)

if method.initial_invH === nothing
n = length(state.x)
T = typeof(state.invH)
if method.initial_stepnorm === nothing
state.invH .= Matrix{T}(I, n, n)
state.invH .= T(I, n, n)
else
initial_scale = method.initial_stepnorm * inv(norm(gradient(d), Inf))
state.invH.= Matrix{T}(initial_scale*I, n, n)
state.invH.= T(initial_scale*I, n, n)
end
else
state.invH .= method.initial_invH(state.x)
Expand Down

0 comments on commit 934cee0

Please sign in to comment.