Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Jan 25, 2024
1 parent 8514b3e commit ddd1647
Showing 1 changed file with 23 additions and 21 deletions.
44 changes: 23 additions & 21 deletions src/tangent_types/abstract_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,36 +126,38 @@ end

@generated function zero_tangent(primal)
fieldcount(primal) == 0 && return NoTangent() # no tangent space at all, no need for structural zero.
zfield_exprs = map(fieldnames(primal)) do fname
fval = :(
if isdefined(primal, $(QuoteNode(fname)))
zero_tangent(getfield(primal, $(QuoteNode(fname))))
else
# This is going to be potentially bad, but that's what they get for not giving us a primal
# This will never me mutated inplace, rather it will alway be replaced with an actual value first
ZeroTangent()
end
)
Expr(:kw, fname, fval)
end


# easy case exit early, can't hold references, can't be a reference.
if isbitstype(primal)
zfield_exprs = map(fieldnames(primal)) do fname
fval = :(zero_tangent(getfield(primal, $(QuoteNode(fname)))))
Expr(:kw, fname, fval)
end
return :($Tangent{$primal}($(Expr(:parameters, zfield_exprs...))))
end

# hard case need to be prepared for cycic references to this, or that are contained within this
# hard case need to be prepared for references to this, or that are contained within this
quote
counts = $count_references!(primal)
counts = $count_references(primal)
any_mask = $(Expr(:tuple, Expr(:parameters, map(fieldnames(primal), fieldtypes(primal)) do fname, ftype
# If it is is unassigned, or if it doesn't have a concrete type, or we have multiple reference to it
# then let it take any value for its tangent
fdef = :(
!isdefined(primal, $(QuoteNode(fname))) ||
!isconcretetype($ftype) ||
get(counts, $(QuoteNode(fname)), 0) > 1
)
Expr(:kw, fname, fdef)
end...)))

# Construct tangents

# Go back and fill in tangents that were not ready
end

## TODO rewrite below
has_mutable_tangent(primal)
any_mask = map(fieldnames(primal), fieldtypes(primal)) do fname, ftype
# If it is is unassigned, or if it doesn't have a concrete type, let it take any value for its tangent
fdef = :(!isdefined(primal, $(QuoteNode(fname))) || !isconcretetype($ftype))
Expr(:kw, fname, fdef)
end
any_mask =
:($MutableTangent{$primal}(
$(Expr(:tuple, Expr(:parameters, any_mask...))),
$(Expr(:tuple, Expr(:parameters, zfield_exprs...))),
Expand Down Expand Up @@ -184,7 +186,7 @@ function zero_tangent(x::Array{P,N}) where {P,N}
end

###############################################
count_references!(x) = count_references(IdDict{Any, Int}(), x)
count_references(x) = count_references(IdDict{Any, Int}(), x)
function count_references!(counts::IdDict{Any, Int}, x)
isbits(x) && return counts # can't be a refernece and can't hold a reference
counts[x] = get(counts, x, 0) + 1 # Increment *before* recursing
Expand Down

0 comments on commit ddd1647

Please sign in to comment.