Skip to content

Commit

Permalink
fix: ConditionalVAE on CI (#1159)
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal authored Jan 3, 2025
1 parent 6341b3d commit 8792097
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions examples/ConditionalVAE/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,27 +122,23 @@ end
@concrete struct TensorDataset
dataset
transform
total_samples::Int
end

Base.length(ds::TensorDataset) = length(ds.dataset)
Base.length(ds::TensorDataset) = ds.total_samples

function Base.getindex(ds::TensorDataset, idxs::Union{Vector{<:Integer}, AbstractRange})
img = Image.(eachslice(convert2image(ds.dataset, idxs); dims=3))
return stack(parent itemdata Base.Fix1(apply, ds.transform), img)
end

function loadmnist(batchsize, image_size::Dims{2})
## Load MNIST: Only 1500 for demonstration purposes
N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
## Load MNIST: Only 1500 for demonstration purposes on CI
train_dataset = MNIST(; split=:train)
test_dataset = MNIST(; split=:test)
if N !== nothing
train_dataset = train_dataset[1:N]
test_dataset = test_dataset[1:N]
end
N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : length(train_dataset)

train_transform = ScaleKeepAspect(image_size) |> ImageToTensor()
trainset = TensorDataset(train_dataset, train_transform)
trainset = TensorDataset(train_dataset, train_transform, N)
trainloader = DataLoader(trainset; batchsize, shuffle=true, partial=false)

return trainloader
Expand Down Expand Up @@ -247,7 +243,7 @@ function main(; batchsize=128, image_size=(64, 64), num_latent_dims=8, max_num_f

for (i, X) in enumerate(train_dataloader)
throughput_tic = time()
(_, loss, stats, train_state) = Training.single_train_step!(
(_, loss, _, train_state) = Training.single_train_step!(
AutoEnzyme(), loss_function, X, train_state)
throughput_toc = time()

Expand Down

0 comments on commit 8792097

Please sign in to comment.