Skip to content

Commit

Permalink
add forward finite difference based frules for measurement functions
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Oct 19, 2023
1 parent 1c9dff2 commit 4be4e84
Showing 1 changed file with 75 additions and 31 deletions.
106 changes: 75 additions & 31 deletions test/signal_measurement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Its an important use case so we test it directly.
"""
module SignalMeasurement
using Test
using ChainRulesCore
using Diffractor
using Diffractor: ∂☆, ZeroBundle, TaylorBundle
using Diffractor: bundle, first_partial, TaylorTangentIndex, primal
Expand All @@ -17,48 +18,91 @@ function make_soft_square_pulse(width, hardness=100)
end
end
#soft_square = make_soft_square_pulse(0.5, 8)
#signal = soft_square.(0:0.001:1)
#signal = soft_square.(0:0.001:1)]st
#scatter(signal)

function determine_width(xs, ts)
# vs real signal processing functions this is not very robust, but for sake of demonstration it is fine.
@assert eachindex(xs) == eachindex(ts)

start_idx = nothing
end_idx = nothing
for ii in eachindex(xs)
x = xs[ii]
if isnothing(start_idx)
if x > 0.5
start_idx = ii
end
else
if x < 0.5
end_idx = ii
break
#@testset "pulse width" begin
function determine_width(xs, ts)
# vs real signal processing functions this is not very robust, but for sake of demonstration it is fine.
@assert eachindex(xs) == eachindex(ts)

start_idx = nothing
end_idx = nothing
for ii in eachindex(xs)
x = xs[ii]
if isnothing(start_idx)
if x > 0.5
start_idx = ii
end
else
if x < 0.5
end_idx = ii
break
end
end
end

(isnothing(start_idx) || isnothing(end_idx)) && throw(DomainError("no pulse found"))
return ts[end_idx] - ts[start_idx]
end

(isnothing(start_idx) || isnothing(end_idx)) && throw(DomainError("no pulse found"))
return ts[end_idx] - ts[start_idx]
end
function signal_problem(width)
func = make_soft_square_pulse(width, 8)
ts = 0.0:0.001:1.0
signal = map(func, ts)
return determine_width(signal, ts)
end

#ts = 0:0.001:1
#determine_width(make_soft_square_pulse(0.5, 5).(ts), ts)
function ChainRulesCore.frule((_, ẋs, ṫs), ::typeof(determine_width), xs, ts)
iszero(ṫs) || throw(ArgumentError("not supporting nonzero tangents to `ts` (time steps)"))
# If we needed to support this we could do something with interpolation and resampling

# Apply finite difference
y = determine_width(xs, ts)
y⃑ = determine_width(xs .+ ẋs, ts)
= y⃑ - y
return y, ẏ
end


function signal_problem(width)
func = make_soft_square_pulse(width, 8)
ts = 0.0:0.001:1.0
signal = map(func, ts)
return determine_width(signal, ts)
end
for δ in (0.001, 0.003, 0.0045, 0.1, 0.04)
🐰 = ∂☆{1}()(ZeroBundle{1}(signal_problem), TaylorBundle{1}(0.5, (δ,)))
@test primal(🐰) signal_problem(0.5)
@test convert(Float64, first_partial(🐰)) δ rtol=0.2
end
#end

@testset "risetime" begin
function determine_risetime(xs, ts)
start_ind = findfirst(>(0.2), xs)
end_ind = findfirst(>(0.8), @view(xs[Base.IdentityUnitRange(start_ind:end)]))
return ts[end_ind] - ts[start_ind]
end

function ChainRulesCore.frule((_, ẋs, ṫs), ::typeof(determine_risetime), xs, ts)
iszero(ṫs) || throw(ArgumentError("not supporting nonzero tangents to `ts` (time steps)"))
# If we needed to support this we could do something with interpolation and resampling

# Apply finite difference
y = determine_risetime(xs, ts)
y⃑ = determine_risetime(xs .+ ẋs, ts)
= y⃑ - y
return y, ẏ
end

function signal_risetime_problem(hardness)
func = make_soft_square_pulse(0.5, hardness)
ts = 0.0:0.001:1.0
signal = map(func, ts)
return determine_risetime(signal, ts)
end



🐰 = ∂☆{1}()(ZeroBundle{1}(signal_problem), TaylorBundle{1}(0.5, (1.0,)))
@test primal(🐰) signal_problem(0.5)
@test first_partial(🐰) 1.0
🐇 = ∂☆{1}()(ZeroBundle{1}(signal_risetime_problem), TaylorBundle{1}(12, (1.0,)))
@test primal(🐇) signal_risetime_problem(12)
@test convert(Float64, first_partial(🐰)) < 0 # As you increase the hardness the risetime decreases
end


end # module

0 comments on commit 4be4e84

Please sign in to comment.