Skip to content

Commit

Permalink
Constrain generated signature for nondiff frule to need tuple first a…
Browse files Browse the repository at this point in the history
…rg so no ambig with ruleconfig first arg
  • Loading branch information
oxinabox committed Jan 9, 2024
1 parent 6befc08 commit 67036a8
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -403,13 +403,13 @@ function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke)
function (::Core.kwftype(typeof(ChainRulesCore.frule)))(
@nospecialize($kwargs::Any),
frule::typeof(ChainRulesCore.frule),
::$RuleConfig,
::Tuple,
$(map(esc, primal_sig_parts)...),
)
return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), NoTangent())
end
function ChainRulesCore.frule(
::$RuleConfig, $(map(esc, primal_sig_parts)...)
::Tuple, $(map(esc, primal_sig_parts)...)
)
$(__source__)
# Julia functions always only have 1 output, so return a single NoTangent()
Expand Down
4 changes: 2 additions & 2 deletions test/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,14 +220,14 @@ end

foo_ndc1(x) = string(x)
@non_differentiable foo_ndc1(x)
@test frule(AllConfig(), foo_ndc1, 2.0) == (string(2.0), NoTangent())
@test frule(AllConfig(), (NoTangent(), NoTangent()), foo_ndc1, 2.0) == (string(2.0), NoTangent())
r1, pb1 = rrule(AllConfig(), foo_ndc1, 2.0)
@test r1 == string(2.0)
@test pb1(NoTangent()) == (NoTangent(), NoTangent())

foo_ndc2(x; y=0) = string(x + y)
@non_differentiable foo_ndc2(x)
@test frule(AllConfig(), foo_ndc2, 2.0; y=4.0) == (string(6.0), NoTangent())
@test frule(AllConfig(), (NoTangent(), NoTangent()), foo_ndc2, 2.0; y=4.0) == (string(6.0), NoTangent())
r2, pb2 = rrule(AllConfig(), foo_ndc2, 2.0; y=4.0)
@test r2 == string(6.0)
@test pb2(NoTangent()) == (NoTangent(), NoTangent())
Expand Down

0 comments on commit 67036a8

Please sign in to comment.