Skip to content

aurora: Model Benchmarking

Clouke edited this page May 6, 2023 · 1 revision

Benchmarking

In this example, we'll be benchmarking a Neural Network model.

Every model implementing Trainable can be benchmarked using Aurora's core framework, (NeuralNetwork, LinearRegression, etc.)

The model we're using in this example:

NeuralNetworkTrainer trainer = new NeuralNetworkBuilder() // your model, this could also be a LinearRegression or other trainable models
  .learningRate(0.01)
  .epochs(100_000)
  .layers(mapper -> mapper
    .inputLayers(3)
    .hiddenLayers(2)
    .outputLayers(1))
  .disableStatsPrint()
  .build();

Build the Benchmark

Benchmark benchmark = new TrainableBenchmarkBuilder()
  .trainable(trainer)
  .inputs(inputs) // apply the inputs to the model
  .targets(outputs) // apply the targets to the model
  .warmupCycles(20) // warmup the model for 20 cycles - this is to prevent the first few cycles from being slower than the rest (optional)
  .benchmarkCycles(1000) // benchmark the model for 100 cycles (optional)
  .build(); // build the benchmark

Or build directly:

Benchmark benchmark = new TrainableBenchmarkBuilder()
  .trainable(new NeuralNetworkBuilder() // your model, this could also be a LinearRegression or other trainable models
    .learningRate(0.01)
    .epochs(100_000)
    .layers(mapper -> mapper
      .inputLayers(3)
      .hiddenLayers(2)
      .outputLayers(1))
     .disableStatsPrint()
     .build())
  .inputs(inputs) // apply the inputs to the model
  .targets(outputs) // apply the targets to the model
  .warmupCycles(20) // warmup the model for 20 cycles - this is to prevent the first few cycles from being slower than the rest (optional)
  .benchmarkCycles(1000) // benchmark the model for 100 cycles (optional)
  .build(); // build the benchmark

Execute the benchmarking process

BenchmarkResult res = benchmark.compose(); // benchmark the model and get the result

Print result

System.out.println("Average time per epoch: " + TimeUnit.NANOSECONDS.toMillis((long) res.averageTimePerEpoch()) + "ms");
System.out.println("Average accuracy: " + res.averageAccuracy() + "%");
System.out.println("Time per sample: " + TimeUnit.NANOSECONDS.toMicros((long) res.timePerSample()) + "µs");
System.out.println("Throughput: " + res.throughput() + " samples/s");