Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: make MLUtils into a weakdep & suppport MLDataDevices #827

Merged
merged 5 commits into from
Sep 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions lib/OptimizationOptimisers/Project.toml
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
name = "OptimizationOptimisers"
uuid = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
version = "0.3.0"
version = "0.3.1"

[deps]
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"

[weakdeps]
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"

[extensions]
OptimizationOptimisersMLDataDevicesExt = "MLDataDevices"
OptimizationOptimisersMLUtilsExt = "MLUtils"

[compat]
MLDataDevices = "1.1"
MLUtils = "0.4.4"
Optimisers = "0.2, 0.3"
Optimization = "4"
Expand All @@ -20,9 +28,14 @@ Reexport = "1.2"
julia = "1"

[extras]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["ForwardDiff", "Test", "Zygote"]
test = ["ComponentArrays", "ForwardDiff", "Lux", "MLDataDevices", "MLUtils", "Random", "Test", "Zygote"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module OptimizationOptimisersMLDataDevicesExt

using MLDataDevices
using OptimizationOptimisers

OptimizationOptimisers.isa_dataiterator(::DeviceIterator) = (@show "dkjht"; true)

end
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module OptimizationOptimisersMLUtilsExt

using MLUtils
using OptimizationOptimisers

OptimizationOptimisers.isa_dataiterator(::MLUtils.DataLoader) = true

end
9 changes: 6 additions & 3 deletions lib/OptimizationOptimisers/src/OptimizationOptimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module OptimizationOptimisers

using Reexport, Printf, ProgressLogging
@reexport using Optimisers, Optimization
using Optimization.SciMLBase, MLUtils
using Optimization.SciMLBase

SciMLBase.supports_opt_cache_interface(opt::AbstractRule) = true
SciMLBase.requiresgradient(opt::AbstractRule) = true
Expand All @@ -16,6 +16,8 @@ function SciMLBase.__init(
kwargs...)
end

isa_dataiterator(data) = false

function SciMLBase.__solve(cache::OptimizationCache{
F,
RC,
Expand Down Expand Up @@ -57,13 +59,14 @@ function SciMLBase.__solve(cache::OptimizationCache{
throw(ArgumentError("The number of epochs must be specified as the epochs or maxiters kwarg."))
end

if cache.p isa MLUtils.DataLoader
if isa_dataiterator(cache.p)
data = cache.p
dataiterate = true
else
data = [cache.p]
dataiterate = false
end

opt = cache.opt
θ = copy(cache.u0)
G = copy(θ)
Expand Down Expand Up @@ -114,7 +117,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
opt = min_opt
x = min_err
θ = min_θ
cache.f.grad(G, θ, d...)
cache.f.grad(G, θ, d)
opt_state = Optimization.OptimizationState(iter = i,
u = θ,
objective = x[1],
Expand Down
40 changes: 40 additions & 0 deletions lib/OptimizationOptimisers/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,43 @@ using Zygote

@test_throws ArgumentError sol=solve(prob, Optimisers.Adam())
end

@testset "Minibatching" begin
using Optimization, OptimizationOptimisers, Lux, Zygote, MLUtils, Random, ComponentArrays

x = rand(10000)
y = sin.(x)
data = MLUtils.DataLoader((x, y), batchsize = 100)

# Define the neural network
model = Chain(Dense(1, 32, tanh), Dense(32, 1))
ps, st = Lux.setup(Random.default_rng(), model)
ps_ca = ComponentArray(ps)
smodel = StatefulLuxLayer{true}(model, nothing, st)

function callback(state, l)
state.iter % 25 == 1 && @show "Iteration: %5d, Loss: %.6e\n" state.iter l
return l < 1e-4
end

function loss(ps, data)
ypred = [smodel([data[1][i]], ps)[1] for i in eachindex(data[1])]
return sum(abs2, ypred .- data[2])
end

optf = OptimizationFunction(loss, AutoZygote())
prob = OptimizationProblem(optf, ps_ca, data)

res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 10000)

@test res.objective < 1e-4

using MLDataDevices
data = CPUDevice()(data)
optf = OptimizationFunction(loss, AutoZygote())
prob = OptimizationProblem(optf, ps_ca, data)

res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 10000)

@test res.objective < 1e-4
end
Loading