From 67036a86ab03953f392446b888b04a2ef3b3c096 Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 9 Jan 2024 14:43:29 +0800 Subject: [PATCH] Constrain generated signature for nondiff frule to need tuple first arg so no ambig with ruleconfig first arg --- src/rule_definition_tools.jl | 4 ++-- test/rule_definition_tools.jl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 170798d14..ef2db6b42 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -403,13 +403,13 @@ function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke) function (::Core.kwftype(typeof(ChainRulesCore.frule)))( @nospecialize($kwargs::Any), frule::typeof(ChainRulesCore.frule), - ::$RuleConfig, + ::Tuple, $(map(esc, primal_sig_parts)...), ) return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), NoTangent()) end function ChainRulesCore.frule( - ::$RuleConfig, $(map(esc, primal_sig_parts)...) + ::Tuple, $(map(esc, primal_sig_parts)...) ) $(__source__) # Julia functions always only have 1 output, so return a single NoTangent() diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 05c4d8389..43863a915 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -220,14 +220,14 @@ end foo_ndc1(x) = string(x) @non_differentiable foo_ndc1(x) - @test frule(AllConfig(), foo_ndc1, 2.0) == (string(2.0), NoTangent()) + @test frule(AllConfig(), (NoTangent(), NoTangent()), foo_ndc1, 2.0) == (string(2.0), NoTangent()) r1, pb1 = rrule(AllConfig(), foo_ndc1, 2.0) @test r1 == string(2.0) @test pb1(NoTangent()) == (NoTangent(), NoTangent()) foo_ndc2(x; y=0) = string(x + y) @non_differentiable foo_ndc2(x) - @test frule(AllConfig(), foo_ndc2, 2.0; y=4.0) == (string(6.0), NoTangent()) + @test frule(AllConfig(), (NoTangent(), NoTangent()), foo_ndc2, 2.0; y=4.0) == (string(6.0), NoTangent()) r2, pb2 = rrule(AllConfig(), foo_ndc2, 2.0; y=4.0) @test r2 == string(6.0) @test pb2(NoTangent()) == (NoTangent(), NoTangent())