-
-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Investigate using a different AD for tests #96
Comments
It would be nice to be quicker. But the tests aren't just about mathematical correctness, some are also checking that missing branches with |
Would FD.jl not generate |
It seems to make dense zeros: julia> using FiniteDifferences
julia> grad(central_fdm(5, 1), x -> sum(x.a), (a=[1,2f0], b=([3,4f0], 5.0)))
((a = Float32[0.9999863, 1.0000029], b = (Float32[0.0, 0.0], 0.0)),) I wondered if Tracker might work well, but it doesn't seem to like NamedTuples. For comparison: julia> Zygote.gradient(x -> sum(x.a), (a=[1,2f0], b=([3,4f0], 5.0)))
((a = Fill(1.0f0, 2), b = nothing),)
julia> Diffractor.gradient(x -> sum(x.a), (a=[1,2f0], b=([3,4f0], 5.0)))
(Tangent{NamedTuple{(:a, :b), Tuple{Vector{Float32}, Tuple{Vector{Float32}, Float64}}}}(a = InplaceableThunk(ChainRules.var"#1547#1550"{Float32, Colon}(1.0f0, Colon()), Thunk(ChainRules.var"#1548#1551"{Float32, Colon, Vector{Float32}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}(1.0f0, Colon(), Float32[1.0, 2.0], ProjectTo{AbstractArray}(element = ProjectTo{Float32}(), axes = (Base.OneTo(2),))))),),)
julia> ans[1].b
ZeroTangent() |
@mcabbott Since you linked the issue in Yota, I guess you want to test this example too. So let me save you a few minutes: julia> grad(x -> sum(x.a), (a=[1,2f0], b=([3,4f0], 5.0)))[2][2:end]
(Tangent{NamedTuple{(:a, :b), Tuple{Vector{Float32}, Tuple{Vector{Float32}, Float64}}}}(a = Float32[1.0, 1.0],),)
julia> ans[1].b
ZeroTangent() |
Zygote compile times make what ordinarily should be a pretty fast test suite rather sluggish. Perhaps we could borrow some functionality from CRTU: https://github.com/JuliaDiff/ChainRulesTestUtils.jl/blob/v1.8.1/src/testers.jl#L224.
The text was updated successfully, but these errors were encountered: