Skip to content

Commit

Permalink
Merge pull request #2991 from boutproject/lazy-grid-loading
Browse files Browse the repository at this point in the history
Lazy grid loading
  • Loading branch information
bendudson authored Nov 6, 2024
2 parents 36e3da5 + df57fcc commit daffcf9
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 53 deletions.
2 changes: 0 additions & 2 deletions include/bout/invert/laplacexy2_hypre.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@
#if not BOUT_HAS_HYPRE
// If no Hypre

#warning LaplaceXY requires Hypre. No LaplaceXY available

#include "bout/globalindexer.hxx"
#include <bout/boutexception.hxx>
#include <bout/mesh.hxx>
Expand Down
25 changes: 25 additions & 0 deletions include/bout/options.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class Options;
#include <fmt/core.h>

#include <cmath>
#include <functional>
#include <map>
#include <ostream>
#include <set>
Expand Down Expand Up @@ -833,6 +834,25 @@ public:

static std::string getDefaultSource();

/// API for delayed loading of data from the grid file
/// Currently only for 3D data
using lazyLoadFunction = std::unique_ptr<std::function<Tensor<BoutReal>(
int xstart, int xend, int ystart, int yend, int zstart, int zend)>>;
void setLazyLoad(lazyLoadFunction func) { lazyLoad = std::move(func); }
/// Load and get a chunk of the data
Tensor<BoutReal> doLazyLoad(int xstart, int xend, int ystart, int yend, int zstart,
int zend) const {
ASSERT1(lazyLoad != nullptr);
return (*lazyLoad)(xstart, xend, ystart, yend, zstart, zend);
}
/// Some backends support to only read the data when needed. This
/// allows to check whether the data is loaded, or whether it needs
/// to be loaded by doLazyLoad.
bool is_loaded() const { return lazyLoad == nullptr; }
/// Get the shape of the value
std::vector<int> getShape() const;
void setLazyShape(std::vector<int> shape) { lazy_shape = std::move(shape); }

private:
/// The source label given to default values
static const std::string DEFAULT_SOURCE;
Expand All @@ -845,6 +865,11 @@ private:
std::map<std::string, Options> children; ///< If a section then has children
mutable bool value_used = false; ///< Record whether this value is used

// Function to load data
lazyLoadFunction lazyLoad{nullptr};
// Shape of underlying data
std::vector<int> lazy_shape;

template <typename T>
void _set_no_check(T val, std::string source) {
if (not children.empty()) {
Expand Down
2 changes: 1 addition & 1 deletion include/bout/options_io.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public:
OptionsIO& operator=(OptionsIO&&) noexcept = default;

/// Read options from file
virtual Options read() = 0;
virtual Options read(bool lazy = true) = 0;

/// Write options to file
void write(const Options& options) { write(options, "t"); }
Expand Down
61 changes: 26 additions & 35 deletions src/mesh/data/gridfromfile.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -131,28 +131,6 @@ bool GridFile::get(Mesh* m, Field3D& var, const std::string& name, BoutReal def,
return getField(m, var, name, def, location);
}

namespace {
/// Visitor that returns the shape of its argument
struct GetDimensions {
std::vector<int> operator()([[maybe_unused]] bool value) { return {1}; }
std::vector<int> operator()([[maybe_unused]] int value) { return {1}; }
std::vector<int> operator()([[maybe_unused]] BoutReal value) { return {1}; }
std::vector<int> operator()([[maybe_unused]] const std::string& value) { return {1}; }
std::vector<int> operator()(const Array<BoutReal>& array) { return {array.size()}; }
std::vector<int> operator()(const Matrix<BoutReal>& array) {
const auto shape = array.shape();
return {std::get<0>(shape), std::get<1>(shape)};
}
std::vector<int> operator()(const Tensor<BoutReal>& array) {
const auto shape = array.shape();
return {std::get<0>(shape), std::get<1>(shape), std::get<2>(shape)};
}
std::vector<int> operator()(const Field& array) {
return {array.getNx(), array.getNy(), array.getNz()};
}
};
} // namespace

template <typename T>
bool GridFile::getField(Mesh* m, T& var, const std::string& name, BoutReal def,
CELL_LOC location) {
Expand All @@ -175,7 +153,7 @@ bool GridFile::getField(Mesh* m, T& var, const std::string& name, BoutReal def,
Options& option = data[name];

// Global (x, y, z) dimensions of field
const std::vector<int> size = bout::utils::visit(GetDimensions{}, option.value);
const std::vector<int> size = option.getShape();

switch (size.size()) {
case 1: {
Expand Down Expand Up @@ -515,7 +493,7 @@ bool GridFile::get(Mesh* UNUSED(m), std::vector<BoutReal>& var, const std::strin
bool GridFile::hasXBoundaryGuards(Mesh* m) {
// Global (x,y) dimensions of some field
// a grid file should always contain "dx"
const std::vector<int> size = bout::utils::visit(GetDimensions{}, data["dx"].value);
const std::vector<int> size = data["dx"].getShape();

if (size.empty()) {
// handle case where "dx" is not present - non-standard grid file
Expand Down Expand Up @@ -550,7 +528,7 @@ bool GridFile::readgrid_3dvar_fft(Mesh* m, const std::string& name, int yread, i
}

/// Check the size of the data
const std::vector<int> size = bout::utils::visit(GetDimensions{}, data[name].value);
const std::vector<int> size = data[name].getShape();

if (size.size() != 3) {
output_warn.write("\tWARNING: Number of dimensions of {:s} incorrect\n", name);
Expand Down Expand Up @@ -639,21 +617,34 @@ bool GridFile::readgrid_3dvar_real(const std::string& name, int yread, int ydest
Options& option = data[name];

/// Check the size of the data
const std::vector<int> size = bout::utils::visit(GetDimensions{}, option.value);
const std::vector<int> size = option.getShape();

if (size.size() != 3) {
output_warn.write("\tWARNING: Number of dimensions of {:s} incorrect\n", name);
return false;
}

const auto full_var = option.as<Tensor<BoutReal>>();
if (not option.is_loaded()) {
const auto& chunk = option.doLazyLoad(xread, xread + xsize - 1, yread,
yread + ysize - 1, 0, size[2] - 1);
for (int jx = 0; jx < xsize; jx++) {
for (int jy = 0; jy < ysize; jy++) {
for (int jz = 0; jz < size[2]; ++jz) {
var(jx + xdest, jy + ydest, jz) = chunk(jx, jy, jz);
}
}
}

for (int jx = xread; jx < xread + xsize; jx++) {
// jx is global x-index to start from
for (int jy = yread; jy < yread + ysize; jy++) {
// jy is global y-index to start from
for (int jz = 0; jz < size[2]; ++jz) {
var(jx - xread + xdest, jy - yread + ydest, jz) = full_var(jx, jy, jz);
} else {
const auto full_var = option.as<Tensor<BoutReal>>();

for (int jx = xread; jx < xread + xsize; jx++) {
// jx is global x-index to start from
for (int jy = yread; jy < yread + ysize; jy++) {
// jy is global y-index to start from
for (int jz = 0; jz < size[2]; ++jz) {
var(jx - xread + xdest, jy - yread + ydest, jz) = full_var(jx, jy, jz);
}
}
}
}
Expand All @@ -680,7 +671,7 @@ bool GridFile::readgrid_perpvar_fft(Mesh* m, const std::string& name, int xread,

/// Check the size of the data
Options& option = data[name];
const std::vector<int> size = bout::utils::visit(GetDimensions{}, option.value);
const std::vector<int> size = option.getShape();

if (size.size() != 2) {
output_warn.write("\tWARNING: Number of dimensions of {:s} incorrect\n", name);
Expand Down Expand Up @@ -763,7 +754,7 @@ bool GridFile::readgrid_perpvar_real(const std::string& name, int xread, int xde

/// Check the size of the data
Options& option = data[name];
const std::vector<int> size = bout::utils::visit(GetDimensions{}, option.value);
const std::vector<int> size = option.getShape();

if (size.size() != 2) {
output_warn.write("\tWARNING: Number of dimensions of {:s} incorrect\n", name);
Expand Down
2 changes: 1 addition & 1 deletion src/physics/physicsmodel.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ void PhysicsModel::initialise(Solver* s) {
const bool restarting = Options::root()["restart"].withDefault(false);

if (restarting) {
restart_options = restart_file->read();
restart_options = restart_file->read(false);
}

// Call user init code to specify evolving variables
Expand Down
34 changes: 33 additions & 1 deletion src/sys/options.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,10 @@ Field3D Options::as<Field3D>(const Field3D& similar_to) const {
}

// Get a reference, to try and avoid copying
const auto& tensor = bout::utils::get<Tensor<BoutReal>>(value);
const auto& tensor =
is_loaded() ? bout::utils::get<Tensor<BoutReal>>(value)
: doLazyLoad(0, localmesh->LocalNx - 1, 0, localmesh->LocalNy - 1, 0,
localmesh->LocalNz - 1);

// Check if the dimension sizes are the same as a Field3D
if (tensor.shape()
Expand Down Expand Up @@ -936,6 +939,35 @@ std::vector<std::string> Options::getFlattenedKeys() const {
return flattened_names;
}

namespace {
/// Visitor that returns the shape of its argument
struct GetDimensions {
std::vector<int> operator()([[maybe_unused]] bool value) { return {1}; }
std::vector<int> operator()([[maybe_unused]] int value) { return {1}; }
std::vector<int> operator()([[maybe_unused]] BoutReal value) { return {1}; }
std::vector<int> operator()([[maybe_unused]] const std::string& value) { return {1}; }
std::vector<int> operator()(const Array<BoutReal>& array) { return {array.size()}; }
std::vector<int> operator()(const Matrix<BoutReal>& array) {
const auto shape = array.shape();
return {std::get<0>(shape), std::get<1>(shape)};
}
std::vector<int> operator()(const Tensor<BoutReal>& array) {
const auto shape = array.shape();
return {std::get<0>(shape), std::get<1>(shape), std::get<2>(shape)};
}
std::vector<int> operator()(const Field& array) {
return {array.getNx(), array.getNy(), array.getNz()};
}
};
} // namespace

std::vector<int> Options::getShape() const {
if (is_loaded()) {
return bout::utils::visit(GetDimensions{}, value);
}
return lazy_shape;
}

fmt::format_parse_context::iterator
bout::details::OptionsFormatterBase::parse(fmt::format_parse_context& ctx) {

Expand Down
57 changes: 46 additions & 11 deletions src/sys/options/options_netcdf.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "bout/mesh.hxx"
#include "bout/sys/timer.hxx"

#include <climits>
#include <exception>
#include <iostream>
#include <netcdf>
Expand Down Expand Up @@ -56,7 +57,8 @@ T readAttribute(const NcAtt& attribute) {
return value;
}

void readGroup(const std::string& filename, const NcGroup& group, Options& result) {
void readGroup(const std::string& filename, const NcGroup& group, Options& result,
const std::shared_ptr<netCDF::NcFile>& file) {

// Iterate over all variables
for (const auto& varpair : group.getVars()) {
Expand Down Expand Up @@ -107,11 +109,44 @@ void readGroup(const std::string& filename, const NcGroup& group, Options& resul
}
case 3: {
if (var_type == ncDouble or var_type == ncFloat) {
Tensor<double> value(static_cast<int>(dims[0].getSize()),
static_cast<int>(dims[1].getSize()),
static_cast<int>(dims[2].getSize()));
var.getVar(value.begin());
result[var_name] = value;
if (file) {
result[var_name] = Tensor<double>(0, 0, 0);
const auto s2i = [](size_t s) {
if (s > INT_MAX) {
throw BoutException("BadCast {} > {}", s, INT_MAX);
}
return static_cast<int>(s);
};
result[var_name].setLazyShape(
{s2i(dims[0].getSize()), s2i(dims[1].getSize()), s2i(dims[2].getSize())});
// We need to explicitly copy file, so that there is a pointer to the file, and
// the file does not get closed, which would prevent us from reading.
result[var_name].setLazyLoad(std::make_unique<std::function<Tensor<double>(
int, int, int, int, int, int)>>(
[file, var](int xstart, int xend, int ystart, int yend, int zstart,
int zend) {
const auto i2s = [](int i) {
if (i < 0) {
throw BoutException("BadCast {} < 0", i);
}
return static_cast<size_t>(i);
};
Tensor<double> value(xend - xstart + 1, yend - ystart + 1,
zend - zstart + 1);
const std::vector<size_t> index{i2s(xstart), i2s(ystart), i2s(zstart)};
const std::vector<size_t> count{i2s(xend - xstart + 1),
i2s(yend - ystart + 1),
i2s(zend - zstart + 1)};
var.getVar(index, count, value.begin());
return value;
}));
} else {
Tensor<double> value(static_cast<int>(dims[0].getSize()),
static_cast<int>(dims[1].getSize()),
static_cast<int>(dims[2].getSize()));
var.getVar(value.begin());
result[var_name] = value;
}
}
}
}
Expand Down Expand Up @@ -144,25 +179,25 @@ void readGroup(const std::string& filename, const NcGroup& group, Options& resul
const auto& name = grouppair.first;
const auto& subgroup = grouppair.second;

readGroup(filename, subgroup, result[name]);
readGroup(filename, subgroup, result[name], file);
}
}
} // namespace

namespace bout {

Options OptionsNetCDF::read() {
Options OptionsNetCDF::read(bool lazy) {
Timer timer("io");

// Open file
const NcFile read_file(filename, NcFile::read);
auto read_file = std::make_shared<netCDF::NcFile>(filename, NcFile::read);

if (read_file.isNull()) {
if (read_file->isNull()) {
throw BoutException("Could not open NetCDF file '{:s}' for reading", filename);
}

Options result;
readGroup(filename, read_file, result);
readGroup(filename, *read_file, result, lazy ? read_file : nullptr);

return result;
}
Expand Down
2 changes: 1 addition & 1 deletion src/sys/options/options_netcdf.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public:
OptionsNetCDF& operator=(OptionsNetCDF&&) noexcept = default;

/// Read options from file
Options read();
Options read(bool lazy = true) override;

/// Write options to file
void write(const Options& options) { write(options, "t"); }
Expand Down
2 changes: 1 addition & 1 deletion tests/integrated/test-options-netcdf/runtest
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# requires: not legacy_netcdf

from boututils.datafile import DataFile
from boututils.run_wrapper import build_and_log, shell, launch
from boututils.run_wrapper import build_and_log, shell, launch_safe as launch
from boutdata.data import BoutOptionsFile

import math
Expand Down

0 comments on commit daffcf9

Please sign in to comment.