Skip to content

Commit

Permalink
try modifying @non_differentiable instead
Browse files Browse the repository at this point in the history
  • Loading branch information
MasonProtter authored Jan 16, 2024
1 parent ee4e1c8 commit 52c8bc6
Showing 1 changed file with 1 addition and 10 deletions.
11 changes: 1 addition & 10 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,15 +209,6 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials)
return @strip_linenos quote
# _ is the input derivative w.r.t. function internals. since we do not
# allow closures/functors with @scalar_rule, it is always ignored
function ChainRulesCore.frule((_, $(Δs...))::Tuple, ::Core.Typeof($f), $(inputs...))
$(__source__)
$(esc()) = $call
$(setup_stmts...)
$(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output(
$(esc()), $f, $(inputs...)
)
return $(esc()), $pushforward_returns
end
function ChainRulesCore.frule((_, $(Δs...)), ::Core.Typeof($f), $(inputs...))
$(__source__)
$(esc()) = $call
Expand Down Expand Up @@ -418,7 +409,7 @@ function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke)
return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), NoTangent())
end
function ChainRulesCore.frule(
@nospecialize(::Tuple), $(map(esc, primal_sig_parts)...)
@nospecialize(::Any), $(map(esc, primal_sig_parts)...)
)
$(__source__)
# Julia functions always only have 1 output, so return a single NoTangent()
Expand Down

0 comments on commit 52c8bc6

Please sign in to comment.