Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make @non_differentiable use identical pullbacks when possible #679

Merged
merged 3 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module ChainRulesCore
using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize!
using Base.Meta
using LinearAlgebra
using Compat: hasfield, hasproperty, ismutabletype
using Compat: hasfield, hasproperty, ismutabletype, Returns

export frule, rrule # core function
# rule configurations
Expand Down
4 changes: 1 addition & 3 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -436,9 +436,7 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke)
tup_expr = tuple_expression(primal_sig_parts)
primal_name = first(primal_invoke.args)
pullback_expr = @strip_linenos quote
function $(esc(propagator_name(primal_name, :pullback)))(@nospecialize(_))
return $(tup_expr)
end
Returns($(tup_expr))
end

@gensym kwargs
Expand Down
14 changes: 14 additions & 0 deletions test/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,20 @@ end

@testset "rule_definition_tools.jl" begin
@testset "@non_differentiable" begin
@testset "issue #678: identical pullback objects" begin
issue_678_f(::Any) = nothing
issue_678_g(::Any) = nothing
issue_678_h(::Any...) = nothing
@non_differentiable issue_678_f(::Any)
@non_differentiable issue_678_g(::Any)
@non_differentiable issue_678_h(::Any...)
@test (
last(rrule(issue_678_f, 0.1)) ===
last(rrule(issue_678_g, 0.2)) ===
last(rrule(issue_678_h, 0.3))
)
end

@testset "two input one output function" begin
nondiff_2_1(x, y) = fill(7.5, 100)[x + y]
@non_differentiable nondiff_2_1(::Any, ::Any)
Expand Down
Loading