diff --git a/cmd/mrmath.cpp b/cmd/mrmath.cpp index 94a1abd685..bdcf5c9aeb 100644 --- a/cmd/mrmath.cpp +++ b/cmd/mrmath.cpp @@ -20,9 +20,11 @@ #include "command.h" #include "dwi/gradient.h" #include "image.h" +#include "image_helpers.h" #include "math/math.h" #include "math/median.h" #include "memory.h" +#include "misc/voxel2vector.h" #include "phase_encoding.h" #include "progressbar.h" @@ -277,40 +279,38 @@ class ImageKernelBase { template class ImageKernel : public ImageKernelBase { protected: - class InitFunctor { - public: - template void operator()(ImageType &out) const { out.value() = Operation(); } - }; class ProcessFunctor { public: - template void operator()(ImageType1 &out, ImageType2 &in) const { - Operation op = out.value(); - op(in.value()); - out.value() = op; - } + ProcessFunctor(ImageKernel &master) : master(master) {} + template void operator()(ImageType &in) const { master.data[master.v2v(in)](in.value()); } + + protected: + ImageKernel &master; }; class ResultFunctor { public: - template void operator()(ImageType1 &out, ImageType2 &in) const { - Operation op = in.value(); - out.value() = op.result(); + ResultFunctor(ImageKernel &master) : master(master) {} + template void operator()(ImageType &out) const { + out.value() = master.data[master.v2v(out)].result(); } + + protected: + ImageKernel &master; }; public: - ImageKernel(const Header &header) : image(Header::scratch(header).get_image()) { - ThreadedLoop(image).run(InitFunctor(), image); - } + ImageKernel(const Header &header) : v2v(header), data(voxel_count(header)) {} - void write_back(Image &out) { ThreadedLoop(image).run(ResultFunctor(), out, image); } + void write_back(Image &out) { ThreadedLoop(out).run(ResultFunctor(*this), out); } void process(Header &header_in) { auto in = header_in.get_image(); - ThreadedLoop(image).run(ProcessFunctor(), image, in); + ThreadedLoop(in).run(ProcessFunctor(*this), in); } protected: - Image image; + Voxel2Vector v2v; + vector data; }; void run() { diff --git a/core/misc/voxel2vector.h b/core/misc/voxel2vector.h index 2e265a2f5d..d7798a3371 100644 --- a/core/misc/voxel2vector.h +++ b/core/misc/voxel2vector.h @@ -41,6 +41,18 @@ class Voxel2Vector { template Voxel2Vector(MaskType &mask) : Voxel2Vector(mask, Header(mask)) {} + Voxel2Vector(const Header &header) + : forward(Image::scratch(header, "Voxel to vector index conversion scratch image")) { + reverse.reserve(voxel_count(header)); + index_t counter = 0; + for (auto l = Loop(header)(forward); l; ++l) { + forward.value() = counter++; + reverse.push_back(pos()); + } + DEBUG("Voxel2vector class for image \"" + header.name() + "\" of size " + join(pos(), "x") + " initialised with " + + str(reverse.size()) + " elements"); + } + size_t size() const { return reverse.size(); } const std::vector &operator[](const size_t index) const { @@ -59,6 +71,13 @@ class Voxel2Vector { private: Image forward; std::vector> reverse; + + vector pos() const { + vector result; + for (size_t index = 0; index != forward.ndim(); ++index) + result.push_back(forward.index(index)); + return result; + } }; template @@ -75,15 +94,13 @@ Voxel2Vector::Voxel2Vector(MaskType &mask, const Header &data) for (auto l = Loop(data)(r_mask, forward); l; ++l) { if (r_mask.value()) { forward.value() = counter++; - std::vector pos; - for (size_t index = 0; index != data.ndim(); ++index) - pos.push_back(forward.index(index)); - reverse.push_back(pos); + reverse.push_back(pos()); } else { forward.value() = invalid; } } - DEBUG("Voxel2Vector class has " + str(reverse.size()) + " non-zero entries"); + DEBUG("Voxel2vector class for image \"" + data.name() + "\" of size " + join(pos(), "x") + " initialised with " + + str(reverse.size()) + " elements"); } } // namespace MR