-
Notifications
You must be signed in to change notification settings - Fork 0
aurora: Model Benchmarking
Clouke edited this page May 6, 2023
·
1 revision
In this example, we'll be benchmarking a Neural Network
model.
Every model implementing Trainable
can be benchmarked using Aurora's core framework, (NeuralNetwork
, LinearRegression
, etc.)
The model we're using in this example:
NeuralNetworkTrainer trainer = new NeuralNetworkBuilder() // your model, this could also be a LinearRegression or other trainable models
.learningRate(0.01)
.epochs(100_000)
.layers(mapper -> mapper
.inputLayers(3)
.hiddenLayers(2)
.outputLayers(1))
.disableStatsPrint()
.build();
Benchmark benchmark = new TrainableBenchmarkBuilder()
.trainable(trainer)
.inputs(inputs) // apply the inputs to the model
.targets(outputs) // apply the targets to the model
.warmupCycles(20) // warmup the model for 20 cycles - this is to prevent the first few cycles from being slower than the rest (optional)
.benchmarkCycles(1000) // benchmark the model for 100 cycles (optional)
.build(); // build the benchmark
Or build directly:
Benchmark benchmark = new TrainableBenchmarkBuilder()
.trainable(new NeuralNetworkBuilder() // your model, this could also be a LinearRegression or other trainable models
.learningRate(0.01)
.epochs(100_000)
.layers(mapper -> mapper
.inputLayers(3)
.hiddenLayers(2)
.outputLayers(1))
.disableStatsPrint()
.build())
.inputs(inputs) // apply the inputs to the model
.targets(outputs) // apply the targets to the model
.warmupCycles(20) // warmup the model for 20 cycles - this is to prevent the first few cycles from being slower than the rest (optional)
.benchmarkCycles(1000) // benchmark the model for 100 cycles (optional)
.build(); // build the benchmark
BenchmarkResult res = benchmark.compose(); // benchmark the model and get the result
System.out.println("Average time per epoch: " + TimeUnit.NANOSECONDS.toMillis((long) res.averageTimePerEpoch()) + "ms");
System.out.println("Average accuracy: " + res.averageAccuracy() + "%");
System.out.println("Time per sample: " + TimeUnit.NANOSECONDS.toMicros((long) res.timePerSample()) + "µs");
System.out.println("Throughput: " + res.throughput() + " samples/s");