diff --git a/lib/OptimizationOptimisers/Project.toml b/lib/OptimizationOptimisers/Project.toml index dbc5aecd2..371467455 100644 --- a/lib/OptimizationOptimisers/Project.toml +++ b/lib/OptimizationOptimisers/Project.toml @@ -1,17 +1,25 @@ name = "OptimizationOptimisers" uuid = "42dfb2eb-d2b4-4451-abcd-913932933ac1" authors = ["Vaibhav Dixit and contributors"] -version = "0.3.0" +version = "0.3.1" [deps] -MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +[weakdeps] +MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" + +[extensions] +OptimizationOptimisersMLDataDevicesExt = "MLDataDevices" +OptimizationOptimisersMLUtilsExt = "MLUtils" + [compat] +MLDataDevices = "1.1" MLUtils = "0.4.4" Optimisers = "0.2, 0.3" Optimization = "4" @@ -20,9 +28,14 @@ Reexport = "1.2" julia = "1" [extras] +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["ForwardDiff", "Test", "Zygote"] +test = ["ComponentArrays", "ForwardDiff", "Lux", "MLDataDevices", "MLUtils", "Random", "Test", "Zygote"] diff --git a/lib/OptimizationOptimisers/ext/OptimizationOptimisersMLDataDevicesExt.jl b/lib/OptimizationOptimisers/ext/OptimizationOptimisersMLDataDevicesExt.jl new file mode 100644 index 000000000..ed5020daa --- /dev/null +++ b/lib/OptimizationOptimisers/ext/OptimizationOptimisersMLDataDevicesExt.jl @@ -0,0 +1,8 @@ +module OptimizationOptimisersMLDataDevicesExt + +using MLDataDevices +using OptimizationOptimisers + +OptimizationOptimisers.isa_dataiterator(::DeviceIterator) = (@show "dkjht"; true) + +end diff --git a/lib/OptimizationOptimisers/ext/OptimizationOptimisersMLUtilsExt.jl b/lib/OptimizationOptimisers/ext/OptimizationOptimisersMLUtilsExt.jl new file mode 100644 index 000000000..1790d7aea --- /dev/null +++ b/lib/OptimizationOptimisers/ext/OptimizationOptimisersMLUtilsExt.jl @@ -0,0 +1,8 @@ +module OptimizationOptimisersMLUtilsExt + +using MLUtils +using OptimizationOptimisers + +OptimizationOptimisers.isa_dataiterator(::MLUtils.DataLoader) = true + +end diff --git a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl index b3811bbd7..ea2ef9202 100644 --- a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl +++ b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl @@ -2,7 +2,7 @@ module OptimizationOptimisers using Reexport, Printf, ProgressLogging @reexport using Optimisers, Optimization -using Optimization.SciMLBase, MLUtils +using Optimization.SciMLBase SciMLBase.supports_opt_cache_interface(opt::AbstractRule) = true SciMLBase.requiresgradient(opt::AbstractRule) = true @@ -16,6 +16,8 @@ function SciMLBase.__init( kwargs...) end +isa_dataiterator(data) = false + function SciMLBase.__solve(cache::OptimizationCache{ F, RC, @@ -57,13 +59,14 @@ function SciMLBase.__solve(cache::OptimizationCache{ throw(ArgumentError("The number of epochs must be specified as the epochs or maxiters kwarg.")) end - if cache.p isa MLUtils.DataLoader + if isa_dataiterator(cache.p) data = cache.p dataiterate = true else data = [cache.p] dataiterate = false end + opt = cache.opt θ = copy(cache.u0) G = copy(θ) @@ -114,7 +117,7 @@ function SciMLBase.__solve(cache::OptimizationCache{ opt = min_opt x = min_err θ = min_θ - cache.f.grad(G, θ, d...) + cache.f.grad(G, θ, d) opt_state = Optimization.OptimizationState(iter = i, u = θ, objective = x[1], diff --git a/lib/OptimizationOptimisers/test/runtests.jl b/lib/OptimizationOptimisers/test/runtests.jl index ddee2ea4c..02b764df2 100644 --- a/lib/OptimizationOptimisers/test/runtests.jl +++ b/lib/OptimizationOptimisers/test/runtests.jl @@ -68,3 +68,43 @@ using Zygote @test_throws ArgumentError sol=solve(prob, Optimisers.Adam()) end + +@testset "Minibatching" begin + using Optimization, OptimizationOptimisers, Lux, Zygote, MLUtils, Random, ComponentArrays + + x = rand(10000) + y = sin.(x) + data = MLUtils.DataLoader((x, y), batchsize = 100) + + # Define the neural network + model = Chain(Dense(1, 32, tanh), Dense(32, 1)) + ps, st = Lux.setup(Random.default_rng(), model) + ps_ca = ComponentArray(ps) + smodel = StatefulLuxLayer{true}(model, nothing, st) + + function callback(state, l) + state.iter % 25 == 1 && @show "Iteration: %5d, Loss: %.6e\n" state.iter l + return l < 1e-4 + end + + function loss(ps, data) + ypred = [smodel([data[1][i]], ps)[1] for i in eachindex(data[1])] + return sum(abs2, ypred .- data[2]) + end + + optf = OptimizationFunction(loss, AutoZygote()) + prob = OptimizationProblem(optf, ps_ca, data) + + res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 10000) + + @test res.objective < 1e-4 + + using MLDataDevices + data = CPUDevice()(data) + optf = OptimizationFunction(loss, AutoZygote()) + prob = OptimizationProblem(optf, ps_ca, data) + + res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 10000) + + @test res.objective < 1e-4 +end