Skip to content

Commit

Permalink
Merge pull request #964 from JuliaSymbolics/s/reg
Browse files Browse the repository at this point in the history
Array registration
  • Loading branch information
ChrisRackauckas authored Dec 11, 2023
2 parents 7c8179d + 46195a8 commit 634b485
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 119 deletions.
2 changes: 1 addition & 1 deletion src/Symbolics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ include("utils.jl")
using ConstructionBase
include("arrays.jl")

export @register, @register_symbolic
export @register, @register_symbolic, @register_array_symbolic
include("register.jl")

export @variables, Variable
Expand Down
3 changes: 1 addition & 2 deletions src/array-lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,7 @@ end
propagate_eltype(::typeof(getindex), x, idx...) = geteltype(x)

function SymbolicUtils.promote_symtype(::typeof(getindex), X, ii...)
@assert all(i -> i <: Integer, ii) "user arrterm to create arr term."

@assert all(i -> i <: Integer, ii)
eltype(X)
end

Expand Down
76 changes: 44 additions & 32 deletions src/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -352,49 +352,57 @@ end
#

"""
arrterm(f, args...; arrayop=nothing)
array_term(f, args...;
container_type = propagate_atype(f, args...),
eltype = propagate_eltype(f, args...),
size = map(length, propagate_shape(f, args...)),
ndims = propagate_ndims(f, args...))
Create a term of `Term{<: AbstractArray}` which
is the representation of `f(args...)`.
- Calls `propagate_atype(f, args...)` to determine the
container type, i.e. `Array` or `StaticArray` etc.
- Calls `propagate_eltype(f, args...)` to determine the
output element type.
- Calls `propagate_ndims(f, args...)` to determine the
output dimension.
- Calls `propagate_shape(f, args...)` to determine the
output array shape.
Default arguments:
- `container_type=propagate_atype(f, args...)` - the container type,
i.e. `Array` or `StaticArray` etc.
- `eltype=propagate_eltype(f, args...)` - the output element type.
- `size=map(length, propagate_shape(f, args...))` - the
output array size. `propagate_shape` returns a tuple of index ranges.
- `ndims=propagate_ndims(f, args...)` the output dimension.
`propagate_shape`, `propagate_atype`, `propagate_eltype` may
return `Unknown()` to say that the output cannot be determined
But `propagate_ndims` must work and return a non-negative integer.
"""
function arrterm(f, args...)
atype = propagate_atype(f, args...)
etype = propagate_eltype(f, args...)
nd = propagate_ndims(f, args...)

S = if etype === Unknown() && nd === Unknown()
atype
elseif etype === Unknown()
atype{T, nd} where T
elseif nd === Unknown()
atype{etype, N} where N
else
atype{etype, nd}
function array_term(f, args...;
container_type = propagate_atype(f, args...),
eltype = propagate_eltype(f, args...),
size = Unknown(),
ndims = size !== Unknown() ? length(size) : propagate_ndims(f, args...),
shape = size !== Unknown() ? Tuple(map(x->1:x, size)) : propagate_shape(f, args...))

if container_type == Unknown()
# There's always a fallback for this
container_type = propagate_atype(f, args...)
end

setmetadata(Term{S}(f, Any[args...]),
ArrayShapeCtx,
propagate_shape(f, args...))
if eltype == Unknown()
eltype = Base.propagate_eltype(container_type)
end

if ndims == Unknown()
ndims = if shape == Unknown()
Any
else
length(shape)
end
end
S = container_type{eltype, ndims}
setmetadata(Term{S}(f, Any[args...]), ArrayShapeCtx, shape)
end

"""
shape(s::Any)
Returns `axes(s)` or throws.
Returns `axes(s)` or Unknown().
"""
shape(s) = axes(s)

Expand All @@ -413,7 +421,7 @@ function shape(s::Symbolic{<:AbstractArray})
end

## `propagate_` interface:
# used in the `arrterm` construction.
# used in the `array_term` construction.

atype(::Type{<:Array}) = Array
atype(::Type{<:SArray}) = SArray
Expand Down Expand Up @@ -443,7 +451,11 @@ function propagate_eltype(f, args...)
end

function propagate_ndims(f, args...)
error("Could not determine the output dimension of $f$args")
if propagate_shape(f, args...) == Unknown()
error("Could not determine the output dimension of $f$args")
else
length(propagate_shape(f, args...))
end
end

function propagate_shape(f, args...)
Expand Down Expand Up @@ -644,12 +656,12 @@ function scalarize_op(f, arr, idx)
end

@wrapped function Base.:(\)(A::AbstractMatrix, b::AbstractVecOrMat)
t = arrterm(\, A, b)
t = array_term(\, A, b)
setmetadata(t, ScalarizeCache, Ref{Any}(nothing))
end

@wrapped function Base.inv(A::AbstractMatrix)
t = arrterm(inv, A)
t = array_term(inv, A)
setmetadata(t, ScalarizeCache, Ref{Any}(nothing))
end

Expand Down
2 changes: 1 addition & 1 deletion src/diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ end
(D::Differential)(x) = Term{symtype(x)}(D, [x])
(D::Differential)(x::Num) = Num(D(value(x)))
(D::Differential)(x::Complex{Num}) = wrap(ComplexTerm{Real}(D(unwrap(real(x))), D(unwrap(imag(x)))))
SymbolicUtils.promote_symtype(::Differential, x) = x
SymbolicUtils.promote_symtype(::Differential, T) = T

is_derivative(x) = istree(x) ? operation(x) isa Differential : false

Expand Down
166 changes: 88 additions & 78 deletions src/register.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,111 +21,121 @@ overwriting.
@register_symbolic goo(x, y::Int) # `y` is not overloaded to take symbolic objects
@register_symbolic hoo(x, y)::Int # `hoo` returns `Int`
```
See `@register_array_symbolic` to register functions which return arrays.
"""
macro register_symbolic(expr, define_promotion = true, Ts = :([]))
f, ftype, argnames, Ts, ret_type = destructure_registration_expr(expr, Ts)

args′ = map((a, T) -> :($a::$T), argnames, Ts)
ret_type = isnothing(ret_type) ? Real : ret_type

fexpr = :(@wrapped function $f($(args′...))
args = [$(argnames...),]
unwrapped_args = map($unwrap, args)
res = if !any(x->$issym(x) || $istree(x), unwrapped_args)
$f(unwrapped_args...) # partial-eval if all args are unwrapped
else
$Term{$ret_type}($f, unwrapped_args)
end
if typeof.(args) == typeof.(unwrapped_args)
return res
else
return $wrap(res)
end
end)

if define_promotion
fexpr = :($fexpr; (::$typeof($promote_symtype))(::$ftype, args...) = $ret_type)
end
esc(fexpr)
end

function destructure_registration_expr(expr, Ts)
if expr.head === :(::)
ret_type = expr.args[2]
expr = expr.args[1]
else
ret_type = Real
ret_type = nothing
end

@assert expr.head === :call
@assert Ts.head === :vect
Ts = Ts.args

f = expr.args[1]
args = expr.args[2:end]

if f isa Expr && f.head == :(::)
@assert length(f.args) == 2
end

types = map(args) do x
if x isa Symbol
if isempty(Ts)
Ts = [Real]
end
:(($(Ts...), $wrapper_type($Real), $Symbolic{<:$Real}))
elseif Meta.isexpr(x, :(::))
T = x.args[2]
:($has_symwrapper($T) ?
($T, $Symbolic{<:$T}, $wrapper_type($T),) :
($T, $Symbolic{<:$T}))
else
error("Invalid argument format $x")
end
end
# Default arg types to Real
Ts = map(a -> a isa Symbol ? Real : (@assert(a.head == :(::)); a.args[2]), args)
argnames = map(a -> a isa Symbol ? a : a.args[1], args)

eval_method = :(@eval function $f($(Expr(:$, :(var"##_register_macro_s"...))),)
args = [$(Expr(:$, :(var"##_register_macro_s_syms"...)))]
unwrapped_args = map($unwrap, args)
res = if !any(x->$issym(x) || $istree(x), unwrapped_args)
$f(unwrapped_args...)
else
$Term{$ret_type}($f, unwrapped_args)
end
if typeof.(args) == typeof.(unwrapped_args)
return res
else
return $wrap(res)
end
end)
verbose = false
mod, fname = f isa Expr && f.head == :(.) ? f.args : (:(@__MODULE__), QuoteNode(f))
Ts = Symbol("##__Ts")
ftype = if f isa Expr && f.head == :(::)
if length(f.args) == 1
error("please name the callable object, i.e. use (f::$(f.args[end])) instead of $f")
end
@assert length(f.args) == 2
f.args[end]
else
:($typeof($f))
end
f, ftype, argnames, Ts, ret_type
end


function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs = :())
def_assignments = MacroTools.rmlines(partial_defs).args
defs = map(def_assignments) do ex
@assert ex.head == :(=)
ex.args[1] => ex.args[2]
end |> Dict


args′ = map((a, T) -> :($a::$T), argnames, Ts)
quote
$Ts = [Tuple{x...} for x in Iterators.product($(types...),)
if any(x->x <: $Symbolic || Symbolics.is_wrapper_type(x), x)]
if $verbose
println("Candidates")
map(println, $Ts)
end
@wrapped function $f($(args′...))
args = [$(argnames...),]
unwrapped_args = map($unwrap, args)
res = if !any(x->$issym(x) || $istree(x), unwrapped_args)
$f(unwrapped_args...) # partial-eval if all args are unwrapped
elseif $ret_type == nothing || ($ret_type <: AbstractArray)
$array_term($(Expr(:parameters, [Expr(:kw, k, v) for (k, v) in defs]...)), $f, unwrapped_args...)
else
$Term{$ret_type}($f, unwrapped_args)
end

for sig in $Ts
var"##_register_macro_s" = map(((i,T,),)->Expr(:(::), Symbol("arg", i), T), enumerate(sig.parameters))
var"##_register_macro_s_syms" = map(x->x.args[1], var"##_register_macro_s")
$eval_method
end
if $define_promotion
(::$typeof($promote_symtype))(::$ftype, args...) = $ret_type
if typeof.(args) == typeof.(unwrapped_args)
return res
else
return $wrap(res)
end
end
end |> esc
end

Base.@deprecate_binding var"@register" var"@register_symbolic"
"""
@register_array_symbolic(expr)
# Ensure that Num that get @registered from outside the ModelingToolkit
# module can work without having to bring in the associated function into the
# ModelingToolkit namespace. We basically store information about functions
# registered at runtime in a ModelingToolkit variable,
# `registered_external_functions`. It's not pretty, but we are limited by the
# way GeneralizedGenerated builds a function (adding "ModelingToolkit" to every
# function call).
# ---
const registered_external_functions = Dict{Symbol,Module}()
function inject_registered_module_functions(expr)
MacroTools.postwalk(expr) do x
# Find all function calls in the expression and extract the function
# name and calling module.
MacroTools.@capture(x, f_module_.f_name_(xs__))
if isnothing(f_module)
MacroTools.@capture(x, f_name_(xs__))
end
Example:
if !isnothing(f_name)
# Set the calling module to the module that registered it.
mod = get(registered_external_functions, f_name, f_module)
if !isnothing(mod) && mod != Base
x.args[1] = :($mod.$f_name)
end
end
```julia
# Let's say vandermonde takes an n-vector and returns an n x n matrix
@register_array_symbolic vandermonde(x::AbstractVector) begin
size=(length(x), length(x))
eltype=eltype(x) # optional, will default to the promoted eltypes of x
end
```
return x
end
You can also register calls on callable structs:
```julia
@register_array_symbolic (c::Conv)(x::AbstractMatrix) begin
size=size(x) .- size(c.kernel) .+ 1
eltype=promote_type(eltype(x), eltype(c))
end
```
"""
macro register_array_symbolic(expr, block)
f, ftype, argnames, Ts, ret_type = destructure_registration_expr(expr, :([]))
return register_array_symbolic(f, ftype, argnames, Ts, ret_type, block)
end

Base.@deprecate_binding var"@register" var"@register_symbolic"
19 changes: 16 additions & 3 deletions src/wrapper-types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ function wrap_func_expr(mod, expr)
args = get(def, :args, [])
kwargs = get(def, :kwargs, [])

if fname isa Expr && fname.head == :(::) && length(fname.args) > 1
self = fname.args[1]
else
self = :nothing # LOL -- in this case the argument named nothing is passed nothing
end
impl_name = Symbol(fname,"_", hash(string(args)*string(kwargs)))

function kwargname(kwarg)
Expand All @@ -96,6 +101,14 @@ function wrap_func_expr(mod, expr)
names = vcat(argname.(args), kwargname.(kwargs))

function type_options(arg)
# for every argument find the types that
# should be allowed as argument. These are:
#
# (1) T (2) wrapper_type(T) (3) Symbolic{T}
#
# However later while emiting methods we omit the one
# method where all arguments are (1) since those are
# expected to be defined outside Symbolics
if arg isa Expr && arg.head == :(::)
T = Base.eval(mod, arg.args[2])
has_symwrapper(T) ? (T, :(SymbolicUtils.Symbolic{<:$T}), wrapper_type(T)) :
Expand All @@ -110,7 +123,7 @@ function wrap_func_expr(mod, expr)

types = map(type_options, args)

impl = :(function $impl_name($(names...))
impl = :(function $impl_name($self, $(names...))
$body
end)
# TODO: maybe don't drop first lol
Expand All @@ -120,9 +133,9 @@ function wrap_func_expr(mod, expr)
end

fbody = :(if any($iswrapped, ($(names...),))
$wrap($impl_name($([:($unwrap($arg)) for arg in names]...)))
$wrap($impl_name($self, $([:($unwrap($arg)) for arg in names]...)))
else
$impl_name($(names...))
$impl_name($self, $(names...))
end)

if isempty(kwargs)
Expand Down
Loading

0 comments on commit 634b485

Please sign in to comment.