From 304104c54cd9ee732c3329f86c6066b12465faed Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Sun, 22 Sep 2024 00:03:57 -0400 Subject: [PATCH] Add minibatching tests --- lib/OptimizationOptimisers/Project.toml | 8 ++-- .../OptimizationOptimisersMLDataDevicesExt.jl | 2 +- .../src/OptimizationOptimisers.jl | 2 +- lib/OptimizationOptimisers/test/runtests.jl | 40 +++++++++++++++++++ 4 files changed, 46 insertions(+), 6 deletions(-) diff --git a/lib/OptimizationOptimisers/Project.toml b/lib/OptimizationOptimisers/Project.toml index b7356ce20..2c4c8cc5e 100644 --- a/lib/OptimizationOptimisers/Project.toml +++ b/lib/OptimizationOptimisers/Project.toml @@ -10,14 +10,14 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -[extensions] -OptimizationOptimisersMLDataDevicesExt = "MLDataDevices" -OptimizationOptimisersMLUtilsExt = "MLUtils" - [weakdeps] MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +[extensions] +OptimizationOptimisersMLDataDevicesExt = "MLDataDevices" +OptimizationOptimisersMLUtilsExt = "MLUtils" + [compat] MLDataDevices = "1.1" MLUtils = "0.4.4" diff --git a/lib/OptimizationOptimisers/ext/OptimizationOptimisersMLDataDevicesExt.jl b/lib/OptimizationOptimisers/ext/OptimizationOptimisersMLDataDevicesExt.jl index 545f73c6c..ed5020daa 100644 --- a/lib/OptimizationOptimisers/ext/OptimizationOptimisersMLDataDevicesExt.jl +++ b/lib/OptimizationOptimisers/ext/OptimizationOptimisersMLDataDevicesExt.jl @@ -3,6 +3,6 @@ module OptimizationOptimisersMLDataDevicesExt using MLDataDevices using OptimizationOptimisers -OptimizationOptimisers.isa_dataiterator(::DeviceIterator) = true +OptimizationOptimisers.isa_dataiterator(::DeviceIterator) = (@show "dkjht"; true) end diff --git a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl index 12b021da3..ea2ef9202 100644 --- a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl +++ b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl @@ -117,7 +117,7 @@ function SciMLBase.__solve(cache::OptimizationCache{ opt = min_opt x = min_err θ = min_θ - cache.f.grad(G, θ, d...) + cache.f.grad(G, θ, d) opt_state = Optimization.OptimizationState(iter = i, u = θ, objective = x[1], diff --git a/lib/OptimizationOptimisers/test/runtests.jl b/lib/OptimizationOptimisers/test/runtests.jl index ddee2ea4c..3ff3c9a3d 100644 --- a/lib/OptimizationOptimisers/test/runtests.jl +++ b/lib/OptimizationOptimisers/test/runtests.jl @@ -68,3 +68,43 @@ using Zygote @test_throws ArgumentError sol=solve(prob, Optimisers.Adam()) end + +@testset "Minibatching" begin + using Optimization, OptimizationOptimisers, Lux, Zygote, MLUtils, Statistics, Plots, Random, ComponentArrays + + x = rand(10000) + y = sin.(x) + data = MLUtils.DataLoader((x, y), batchsize = 100) + + # Define the neural network + model = Chain(Dense(1, 32, tanh), Dense(32, 1)) + ps, st = Lux.setup(Random.default_rng(), model) + ps_ca = ComponentArray(ps) + smodel = StatefulLuxLayer{true}(model, nothing, st) + + function callback(state, l) + state.iter % 25 == 1 && @show "Iteration: %5d, Loss: %.6e\n" state.iter l + return l < 1e-4 + end + + function loss(ps, data) + ypred = [smodel([data[1][i]], ps)[1] for i in eachindex(data[1])] + return sum(abs2, ypred .- data[2]) + end + + optf = OptimizationFunction(loss, AutoZygote()) + prob = OptimizationProblem(optf, ps_ca, data) + + res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 100) + + @test res.objective < 1e-4 + + using MLDataDevices + data = CPUDevice()(data) + optf = OptimizationFunction(loss, AutoZygote()) + prob = OptimizationProblem(optf, ps_ca, data) + + res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 100) + + @test res.objective < 1e-4 +end \ No newline at end of file