Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Investigate using a different AD for tests #96

Open
ToucheSir opened this issue Jul 3, 2022 · 5 comments · Fixed by #105
Open

Investigate using a different AD for tests #96

ToucheSir opened this issue Jul 3, 2022 · 5 comments · Fixed by #105

Comments

@ToucheSir
Copy link
Member

Zygote compile times make what ordinarily should be a pretty fast test suite rather sluggish. Perhaps we could borrow some functionality from CRTU: https://github.com/JuliaDiff/ChainRulesTestUtils.jl/blob/v1.8.1/src/testers.jl#L224.

@mcabbott
Copy link
Member

mcabbott commented Jul 3, 2022

It would be nice to be quicker. But the tests aren't just about mathematical correctness, some are also checking that missing branches with nothing are handled correctly, etc. If JuliaDiff/Diffractor.jl#66 ever happens, these should ideally be supplemented by checks that ZeroTangent works equally well.

@ToucheSir
Copy link
Member Author

Would FD.jl not generate NoTangent for those same branches? I've not tried it with nested structs.

@mcabbott
Copy link
Member

mcabbott commented Jul 4, 2022

It seems to make dense zeros:

julia> using FiniteDifferences

julia> grad(central_fdm(5, 1), x -> sum(x.a), (a=[1,2f0], b=([3,4f0], 5.0)))
((a = Float32[0.9999863, 1.0000029], b = (Float32[0.0, 0.0], 0.0)),)

I wondered if Tracker might work well, but it doesn't seem to like NamedTuples.

For comparison:

julia> Zygote.gradient(x -> sum(x.a), (a=[1,2f0], b=([3,4f0], 5.0)))
((a = Fill(1.0f0, 2), b = nothing),)

julia> Diffractor.gradient(x -> sum(x.a), (a=[1,2f0], b=([3,4f0], 5.0)))
(Tangent{NamedTuple{(:a, :b), Tuple{Vector{Float32}, Tuple{Vector{Float32}, Float64}}}}(a = InplaceableThunk(ChainRules.var"#1547#1550"{Float32, Colon}(1.0f0, Colon()), Thunk(ChainRules.var"#1548#1551"{Float32, Colon, Vector{Float32}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}(1.0f0, Colon(), Float32[1.0, 2.0], ProjectTo{AbstractArray}(element = ProjectTo{Float32}(), axes = (Base.OneTo(2),))))),),)

julia> ans[1].b
ZeroTangent()

@dfdx
Copy link

dfdx commented Jul 9, 2022

@mcabbott Since you linked the issue in Yota, I guess you want to test this example too. So let me save you a few minutes:

julia> grad(x -> sum(x.a), (a=[1,2f0], b=([3,4f0], 5.0)))[2][2:end]
(Tangent{NamedTuple{(:a, :b), Tuple{Vector{Float32}, Tuple{Vector{Float32}, Float64}}}}(a = Float32[1.0, 1.0],),)

julia> ans[1].b
ZeroTangent()

@cossio
Copy link
Contributor

cossio commented Dec 11, 2022

@mcabbott Funny that #105 explicitly says it won't close this. Maybe it was closed automatically because Github saw the "close #105" string.

Just checking if the there was a real intention to close this issue?

@mcabbott mcabbott reopened this Dec 11, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants