diff --git a/CMakeLists.txt b/CMakeLists.txt index da0690a..e4be369 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,10 +17,6 @@ find_package(JPEG REQUIRED) find_library(LIBRAW_LIBRARY NAMES raw raw_r) set(src_files - src/align.cpp - src/finish.cpp - src/merge.cpp - src/util.cpp src/InputSource.cpp src/Burst.cpp src/LibRaw2DngConverter.cpp) @@ -31,10 +27,36 @@ set(header_files src/dngwriter.h src/LibRaw2DngConverter.h) -include_directories(${HALIDE_DISTRIB_DIR}/include ${HALIDE_DISTRIB_DIR}/tools ${RAW2DNG_INCLUDE_DIRS}) - -add_executable(hdrplus src/HDRPlus.cpp ${src_files}) -target_link_libraries(hdrplus Halide png ${LIBRAW_LIBRARY} ${TIFF_LIBRARIES} ${TIFFXX_LIBRARY}) - -add_executable(stack_frames src/stack_frames.cpp ${src_files}) -target_link_libraries(stack_frames Halide ${LIBRAW_LIBRARY} ${PNG_LIBRARIES} ${JPEG_LIBRARIES} ${TIFF_LIBRARIES} ${TIFFXX_LIBRARY}) +include_directories(${CMAKE_CURRENT_SOURCE_DIR} ${HALIDE_DISTRIB_DIR}/include ${HALIDE_DISTRIB_DIR}/tools ${RAW2DNG_INCLUDE_DIRS}) +include_directories(${CMAKE_BINARY_DIR}/genfiles) + +set(HALIDE_TARGET "") +set(HALIDE_TARGET_FEATURES "") +halide_library(hdrplus_pipeline + SRCS src/hdrplus_pipeline_generator.cpp src/align.cpp src/merge.cpp src/finish.cpp src/util.cpp + # GENERATOR_DEPS # We don't have any yet + GENERATOR_NAME hdrplus_pipeline + # GENERATOR_ARGS # We don't have any yet + FUNCTION_NAME hdrplus_pipeline + # HALIDE_TARGET ${HALIDE_TARGET} # TODO: add option with custom HALIDE_TARGET + # HALIDE_TARGET_FEATURES ${HALIDE_TARGET_FEATURES} # TODO: add option with custom HALIDE_TARGET + # EXTRA_OUTPUTS "stmt;html;schedule") # uncomment for extra output +) + +halide_library(align_and_merge + SRCS src/align_and_merge_generator.cpp src/align.cpp src/merge.cpp src/util.cpp + GENERATOR_NAME align_and_merge + FUNCTION_NAME align_and_merge + # HALIDE_TARGET ${HALIDE_TARGET} # TODO: add option with custom HALIDE_TARGET + # HALIDE_TARGET_FEATURES ${HALIDE_TARGET_FEATURES} # TODO: add option with custom HALIDE_TARGET + # EXTRA_OUTPUTS "stmt;html;schedule") # uncomment for extra output +) + + +add_executable(hdrplus bin/HDRPlus.cpp ${src_files}) +add_dependencies(hdrplus hdrplus_pipeline) +target_link_libraries(hdrplus hdrplus_pipeline Halide png ${LIBRAW_LIBRARY} ${TIFF_LIBRARIES} ${TIFFXX_LIBRARY}) + +add_executable(stack_frames bin/stack_frames.cpp ${src_files}) +add_dependencies(stack_frames align_and_merge) +target_link_libraries(stack_frames Halide align_and_merge ${LIBRAW_LIBRARY} ${PNG_LIBRARIES} ${JPEG_LIBRARIES} ${TIFF_LIBRARIES} ${TIFFXX_LIBRARY}) diff --git a/src/HDRPlus.cpp b/bin/HDRPlus.cpp similarity index 56% rename from src/HDRPlus.cpp rename to bin/HDRPlus.cpp index 0bea860..3b476f5 100644 --- a/src/HDRPlus.cpp +++ b/bin/HDRPlus.cpp @@ -1,32 +1,22 @@ -#include "Halide.h" -#include "halide_load_raw.h" +#include +#include +#include -#include "Burst.h" +#include #define STB_IMAGE_WRITE_IMPLEMENTATION -#include "../include/stb_image_write.h" -#include -#include "align.h" -#include "merge.h" -#include "finish.h" -#include -#include +#include + +#include +#include -using namespace Halide; -using namespace std; /* * HDRPlus Class -- Houses file I/O, defines pipeline attributes and calls * processes main stages of the pipeline. */ class HDRPlus { - -private: - - const Halide::Runtime::Buffer imgs; - + Halide::Runtime::Buffer imgs; public: - - const int width; const int height; @@ -35,9 +25,6 @@ class HDRPlus { const WhiteBalance wb; const Compression c; const Gain g; - char *readfile; - std::string readfilestring; - void imagesize(); HDRPlus(Halide::Runtime::Buffer imgs, BlackPoint bp, WhitePoint wp, WhiteBalance wb, Compression c, Gain g) : imgs(imgs) @@ -49,53 +36,31 @@ class HDRPlus { , c(c) , g(g) { - assert(imgs.dimensions() == 3); // width * height * img_idx - assert(imgs.extent(2) >= 2); // must have at least one alternate image + if (imgs.dimensions() != 3 || imgs.extent(2) < 2) { + throw std::invalid_argument("The input of HDRPlus must be a 3-dimensional buffer with at least two channels."); + } } - /* - * process -- Calls all of the main stages (align, merge, finish) of the pipeline. - */ - Buffer process() { - Halide::Buffer imgsBuffer(*imgs.raw_buffer()); - - Func alignment = align(imgsBuffer); - Func merged = merge(imgsBuffer, alignment); - Func finished = finish(merged, width, height, bp, wp, wb, c, g); - - /////////////////////////////////////////////////////////////////////////// - // realize image - /////////////////////////////////////////////////////////////////////////// - - Buffer output_img(3, width, height); - - finished.realize(output_img); + Halide::Runtime::Buffer process() { + Halide::Runtime::Buffer output_img(3, width, height); + hdrplus_pipeline(imgs, bp, wp, wb.r, wb.g0, wb.g1, wb.b, c, g, output_img); // transpose to account for interleaved layout - output_img.transpose(0, 1); output_img.transpose(1, 2); return output_img; } - /* - * save_png -- Writes an interleaved Halide image to an output file. - */ - static bool save_png(std::string dir_path, std::string img_name, Buffer &img) { - + static bool save_png(const std::string& dir_path, const std::string& img_name, const Halide::Runtime::Buffer &img) { std::string img_path = dir_path + "/" + img_name; - std::remove(img_path.c_str()); - int stride_in_bytes = img.width() * img.channels(); if(!stbi_write_png(img_path.c_str(), img.width(), img.height(), img.channels(), img.data(), stride_in_bytes)) { - std::cerr << "Unable to write output image '" << img_name << "'" << std::endl; return false; } - return true; } }; @@ -114,20 +79,15 @@ int main(int argc, char* argv[]) { int i = 1; while(argv[i][0] == '-') { - if(argv[i][1] == 'c') { - c = atof(argv[++i]); i++; continue; - } - else if(argv[i][1] == 'g') { - + } else if(argv[i][1] == 'g') { g = atof(argv[++i]); i++; continue; - } - else { + } else { std::cerr << "Invalid flag '" << argv[i][1] << "'" << std::endl; return 1; } @@ -142,23 +102,15 @@ int main(int argc, char* argv[]) { std::string out_name = argv[i++]; std::vector in_names; - while (i < argc) in_names.push_back(argv[i++]); + while (i < argc) { + in_names.emplace_back(argv[i++]); + } Burst burst(dir_path, in_names); - Halide::Runtime::Buffer imgs = burst.ToBuffer(); - if (imgs.channels() < 2) { - return EXIT_FAILURE; - } - HDRPlus hdr_plus( - imgs, - burst.GetBlackLevel(), - burst.GetWhiteLevel(), - burst.GetWhiteBalance(), - c, - g); + HDRPlus hdr_plus(burst.ToBuffer(), burst.GetBlackLevel(), burst.GetWhiteLevel(), burst.GetWhiteBalance(), c, g); - Buffer output = hdr_plus.process(); + Halide::Runtime::Buffer output = hdr_plus.process(); if (!HDRPlus::save_png(dir_path, out_name, output)) { return EXIT_FAILURE; diff --git a/src/stack_frames.cpp b/bin/stack_frames.cpp similarity index 62% rename from src/stack_frames.cpp rename to bin/stack_frames.cpp index 1049bd6..666c56f 100644 --- a/src/stack_frames.cpp +++ b/bin/stack_frames.cpp @@ -5,27 +5,16 @@ #include #include -#include "align.h" -#include "Burst.h" -#include "finish.h" -#include "merge.h" +#include -using namespace Halide; -using namespace std; +#include -Halide::Buffer align_and_merge(const Halide::Runtime::Buffer& burst) { +Halide::Runtime::Buffer align_and_merge(Halide::Runtime::Buffer burst) { if (burst.channels() < 2) { return {}; } - - Halide::Buffer imgsBuffer(*burst.raw_buffer()); - - Func alignment = align(imgsBuffer); - Func merged = merge(imgsBuffer, alignment); - - Halide::Buffer merged_buffer(burst.width(), burst.height()); - merged.realize(merged_buffer); - + Halide::Runtime::Buffer merged_buffer(burst.width(), burst.height()); + align_and_merge(burst, merged_buffer); return merged_buffer; } @@ -49,12 +38,12 @@ int main(int argc, char* argv[]) { Burst burst(dir_path, in_names); - Halide::Buffer merged = align_and_merge(burst.ToBuffer()); + const auto merged = align_and_merge(burst.ToBuffer()); std::cerr << "merged size: " << merged.width() << " " << merged.height() << std::endl; const RawImage& raw = burst.GetRaw(0); const std::string merged_filename = dir_path + "/" + out_name; - raw.WriteDng(merged_filename, *merged.get()); + raw.WriteDng(merged_filename, merged); return EXIT_SUCCESS; } diff --git a/src/align.cpp b/src/align.cpp index d491cbf..b858983 100644 --- a/src/align.cpp +++ b/src/align.cpp @@ -65,7 +65,7 @@ Func align_layer(Func layer, Func prev_alignment, Point prev_min, Point prev_max * by T_SIZE_2 in each dimension. align(imgs)(tile_x, tile_y, n) is a point representing the x and y offset * for a tile in layer n that most closely matches that tile in the reference (relative to the reference tile's location) */ -Func align(const Buffer imgs) { +Func align(const Halide::Func imgs, Halide::Expr width, Halide::Expr height) { Func alignment_3("layer_3_alignment"); Func alignment("alignment"); @@ -74,7 +74,7 @@ Func align(const Buffer imgs) { // mirror input with overlapping edges - Func imgs_mirror = BoundaryConditions::mirror_interior(imgs, 0, imgs.width(), 0, imgs.height()); + Func imgs_mirror = BoundaryConditions::mirror_interior(imgs, 0, width, 0, height); // downsampled layers for alignment @@ -107,8 +107,8 @@ Func align(const Buffer imgs) { // number of tiles in the x and y dimensions - int num_tx = imgs.width() / T_SIZE_2 - 1; - int num_ty = imgs.height() / T_SIZE_2 - 1; + Expr num_tx = width / T_SIZE_2 - 1; + Expr num_ty = height / T_SIZE_2 - 1; // final alignment offsets for the original mosaic image; tiles outside of the bounds use the nearest alignment offset @@ -118,3 +118,8 @@ Func align(const Buffer imgs) { return alignment_repeat; } + +Halide::Func align(Halide::Buffer imgs) { + Halide::Func imgs_function(imgs); + return align(imgs_function, imgs.width(), imgs.height()); +} \ No newline at end of file diff --git a/src/align.h b/src/align.h index 707bcd8..398e383 100644 --- a/src/align.h +++ b/src/align.h @@ -55,6 +55,7 @@ inline Halide::Expr idx_layer(Halide::Expr t, Halide::Expr i) { return t * T_SIZ * by T_SIZE_2 in each dimension. align(imgs)(tile_x, tile_y, n) is a point representing the x and y offset * for a tile in layer n that most closely matches that tile in the reference (relative to the reference tile's location) */ -Halide::Func align(const Halide::Buffer imgs); +Halide::Func align(Halide::Buffer imgs); +Halide::Func align(const Halide::Func imgs, Halide::Expr width, Halide::Expr height); #endif diff --git a/src/align_and_merge_generator.cpp b/src/align_and_merge_generator.cpp new file mode 100644 index 0000000..048185f --- /dev/null +++ b/src/align_and_merge_generator.cpp @@ -0,0 +1,25 @@ +#include + +#include "align.h" +#include "merge.h" +#include "finish.h" + +namespace { + + class StackFrames : public Halide::Generator { + public: + // 'inputs' is really a series of raw 2d frames; extent[2] specifies the count + Input> inputs{"inputs", 3}; + // Merged buffer + Output> output{"output", 2}; + + void generate() { + Func alignment = align(inputs, inputs.width(), inputs.height()); + Func merged = merge(inputs, inputs.width(), inputs.height(), inputs.dim(2).extent(), alignment); + output = merged; + } + }; + +} // namespace + +HALIDE_REGISTER_GENERATOR(StackFrames, align_and_merge) diff --git a/src/finish.cpp b/src/finish.cpp index 0f08f51..e23a367 100644 --- a/src/finish.cpp +++ b/src/finish.cpp @@ -11,13 +11,13 @@ using namespace Halide::ConciseCasts; * levels to take advantage of the full 16-bit integer depth. This is a * necessary step for camera white balance levels to be valid. */ -Func black_white_level(Func input, const BlackPoint bp, const BlackPoint wp) { +Func black_white_level(Func input, const Expr bp, const Expr wp) { Func output("black_white_level_output"); Var x, y; - float white_factor = 65535.f / (wp - bp); + Expr white_factor = 65535.f / (wp - bp); output(x, y) = u16_sat((i32(input(x, y)) - bp) * white_factor); @@ -29,7 +29,7 @@ Func black_white_level(Func input, const BlackPoint bp, const BlackPoint wp) { * color multipliers. Note that the two green channels in the bayer pattern * are white-balanced separately. */ -Func white_balance(Func input, int width, int height, const WhiteBalance &wb) { +Func white_balance(Func input, Expr width, Expr height, const CompiletimeWhiteBalance &wb) { Func output("white_balance_output"); @@ -62,12 +62,17 @@ Func white_balance(Func input, int width, int height, const WhiteBalance &wb) { * work of Malvar et al. Assumes that data is laid out in an RG/GB pattern. * https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/Demosaicing_ICASSP04.pdf */ -Func demosaic(Func input, int width, int height) { +Func demosaic(Func input, Expr width, Expr height) { - Func f0("demosaic_f0"); // G at R locations; G at B locations - Func f1("demosaic_f1"); // R at green in R row, B column; B at green in B row, R column - Func f2("demosaic_f2"); // R at green in B row, R column; B at green in R row, B column - Func f3("demosaic_f3"); // R at blue in B row, B column; B at red in R row, R column + Buffer f0(5, 5, "demosaic_f0"); // G at R locations; G at B locations + Buffer f1(5, 5, "demosaic_f1"); // R at green in R row, B column; B at green in B row, R column + Buffer f2(5, 5, "demosaic_f2"); // R at green in B row, R column; B at green in R row, B column + Buffer f3(5, 5, "demosaic_f3"); // R at blue in B row, B column; B at red in R row, R column + + f0.translate({-2, -2}); + f1.translate({-2, -2}); + f2.translate({-2, -2}); + f3.translate({-2, -2}); Func d0("demosaic_0"); Func d1("demosaic_1"); @@ -86,10 +91,10 @@ Func demosaic(Func input, int width, int height) { // demosaic filters - f0(x,y) = 0; - f1(x,y) = 0; - f2(x,y) = 0; - f3(x,y) = 0; + f0.fill(0); + f1.fill(0); + f2.fill(0); + f3.fill(0); int f0_sum = 8; int f1_sum = 16; @@ -129,47 +134,36 @@ Func demosaic(Func input, int width, int height) { // resulting demosaicked function - output(x, y, c) = input(x, y); // initialize each channel to input mosaicked image - - // red - output(r1.x * 2 + 1, r1.y * 2, 0) = d1(r1.x * 2 + 1, r1.y * 2); // R at green in R row, B column - output(r1.x * 2, r1.y * 2 + 1, 0) = d2(r1.x * 2, r1.y * 2 + 1); // R at green in B row, R column - output(r1.x * 2 + 1, r1.y * 2 + 1, 0) = d3(r1.x * 2 + 1, r1.y * 2 + 1); // R at blue in B row, B column - - // green - output(r1.x * 2, r1.y * 2, 1) = d0(r1.x * 2, r1.y * 2); // G at R locations - output(r1.x * 2 + 1, r1.y * 2 + 1, 1) = d0(r1.x * 2 + 1, r1.y * 2 + 1); // G at B locations - - // blue - output(r1.x * 2, r1.y * 2 + 1, 2) = d1(r1.x * 2, r1.y * 2 + 1); // B at green in B row, R column - output(r1.x * 2 + 1, r1.y * 2, 2) = d2(r1.x * 2 + 1, r1.y * 2); // B at green in R row, B column - output(r1.x * 2, r1.y * 2, 2) = d3(r1.x * 2, r1.y * 2); // B at red in R row, R column + Expr R_row = y % 2 == 0; + Expr B_row = !R_row; + Expr R_col = x % 2 == 0; + Expr B_col = !R_col; + Expr at_R = c == 0; + Expr at_G = c == 1; + Expr at_B = c == 2; + + output(x, y, c) = select(at_R && R_row && B_col, d1(x, y), + at_R && B_row && R_col, d2(x, y), + at_R && B_row && B_col, d3(x, y), + at_G && R_row && R_col, d0(x, y), + at_G && B_row && B_col, d0(x, y), + at_B && B_row && R_col, d1(x, y), + at_B && R_row && B_col, d2(x, y), + at_B && R_row && R_col, d3(x, y), + input(x, y)); /////////////////////////////////////////////////////////////////////////// // schedule /////////////////////////////////////////////////////////////////////////// - - f0.compute_root().parallel(y).parallel(x); - f1.compute_root().parallel(y).parallel(x); - f2.compute_root().parallel(y).parallel(x); - f3.compute_root().parallel(y).parallel(x); - d0.compute_root().parallel(y).vectorize(x, 16); d1.compute_root().parallel(y).vectorize(x, 16); d2.compute_root().parallel(y).vectorize(x, 16); d3.compute_root().parallel(y).vectorize(x, 16); - output.compute_root().parallel(y).vectorize(x, 16); - - output.update(0).parallel(r1.y); - output.update(1).parallel(r1.y); - output.update(2).parallel(r1.y); - output.update(3).parallel(r1.y); - output.update(4).parallel(r1.y); - output.update(5).parallel(r1.y); - output.update(6).parallel(r1.y); - output.update(7).parallel(r1.y); - + output.compute_root().parallel(y) + .align_bounds(x, 2).unroll(x, 2) + .align_bounds(y, 2).unroll(y, 2) + .vectorize(x, 16); return output; } @@ -179,9 +173,11 @@ Func demosaic(Func input, int width, int height) { * weighted as 0 to decrease amplification of saturation artifacts, which can * occur around bright highlights. */ -Func bilateral_filter(Func input, int width, int height) { +Func bilateral_filter(Func input, Expr width, Expr height) { + + Buffer k(7, 7, "gauss_kernel"); + k.translate({-3, -3}); - Func k("gauss_kernel"); Func weights("bilateral_weights"); Func total_weights("bilateral_total_weights"); Func bilateral("bilateral"); @@ -192,8 +188,7 @@ Func bilateral_filter(Func input, int width, int height) { // gaussian kernel - k(dx, dy) = f32(0.f); - + k.fill(0.f); k(-3, -3) = 0.000690f; k(-2, -3) = 0.002646f; k(-1, -3) = 0.005923f; k(0, -3) = 0.007748f; k(1, -3) = 0.005923f; k(2, -3) = 0.002646f; k(3, -3) = 0.000690f; k(-3, -2) = 0.002646f; k(-2, -2) = 0.010149f; k(-1, -2) = 0.022718f; k(0, -2) = 0.029715f; k(1, -2) = 0.022718f; k(2, -2) = 0.010149f; k(3, -2) = 0.002646f; k(-3, -1) = 0.005923f; k(-2, -1) = 0.022718f; k(-1, -1) = 0.050855f; k(0, -1) = 0.066517f; k(1, -1) = 0.050855f; k(2, -1) = 0.022718f; k(3, -1) = 0.005923f; @@ -233,7 +228,7 @@ Func bilateral_filter(Func input, int width, int height) { // schedule /////////////////////////////////////////////////////////////////////////// - k.parallel(dy).parallel(dx).compute_root(); + //k.parallel(dy).parallel(dx).compute_root(); weights.compute_at(output, y).vectorize(x, 16); @@ -250,7 +245,7 @@ Func bilateral_filter(Func input, int width, int height) { * input in and using the result only if it falls within constraints on by what * factor and absolute threshold the chroma magnitudes fall. */ -Func desaturate_noise(Func input, int width, int height) { +Func desaturate_noise(Func input, Expr width, Expr height) { Func output("desaturate_noise_output"); @@ -314,7 +309,7 @@ Func increase_saturation(Func input, float strength) { * will be applied iteratively in order of increasing aggressiveness, with the * total number of passes determined by input. */ -Func chroma_denoise(Func input, int width, int height, int num_passes) { +Func chroma_denoise(Func input, Expr width, Expr height, int num_passes) { Func output = rgb_to_yuv(input); @@ -341,7 +336,7 @@ Func chroma_denoise(Func input, int width, int height, int num_passes) { * by Mertens et al. * http://ntp-0.cs.ucl.ac.uk/staff/j.kautz/publications/exposure_fusion.pdf */ -Func combine(Func im1, Func im2, int width, int height, Func dist) { +Func combine(Func im1, Func im2, Expr width, Expr height, Func dist) { Func init_mask1("mask1_layer_0"); Func init_mask2("mask2_layer_0"); @@ -423,11 +418,8 @@ Func combine(Func im1, Func im2, int width, int height, Func dist) { /////////////////////////////////////////////////////////////////////////// init_mask1.compute_root().parallel(y).vectorize(x, 16); - accumulator.compute_root().parallel(y).vectorize(x, 16); - for (int layer = 0; layer < num_layers; layer++) { - accumulator.update(layer).parallel(y).vectorize(x, 16); } @@ -437,7 +429,7 @@ Func combine(Func im1, Func im2, int width, int height, Func dist) { /* * brighten -- Applies a specified gain to an input. */ -Func brighten(Func input, float gain) { +Func brighten(Func input, Expr gain) { Func output("brighten_output"); @@ -454,7 +446,7 @@ Func brighten(Func input, float gain) { * with an increasing strength in each iteration to ensure a natural looking * dynamic range compression. */ -Func tone_map(Func input, int width, int height, float comp, float gain) { +Func tone_map(Func input, Expr width, Expr height, Expr comp, Expr gain) { Func normal_dist("luma_weight_distribution"); Func grayscale("grayscale"); @@ -481,18 +473,18 @@ Func tone_map(Func input, int width, int height, float comp, float gain) { // constants used to determine compression and gain values at each iteration - float comp_const = 1.f + comp / num_passes; - float gain_const = 1.f + gain / num_passes; + Expr comp_const = 1.f + comp / num_passes; + Expr gain_const = 1.f + gain / num_passes; - float comp_slope = (comp - comp_const) / (num_passes - 1); - float gain_slope = (gain - gain_const) / (num_passes - 1); + Expr comp_slope = (comp - comp_const) / (num_passes - 1); + Expr gain_slope = (gain - gain_const) / (num_passes - 1); for (int pass = 0; pass < num_passes; pass++) { // compute compression and gain at given iteration - float norm_comp = pass * comp_slope + comp_const; - float norm_gain = pass * gain_slope + gain_const; + Expr norm_comp = pass * comp_slope + comp_const; + Expr norm_gain = pass * gain_slope + gain_const; bright = brighten(dark, norm_comp); @@ -530,16 +522,13 @@ Func tone_map(Func input, int width, int height, float comp, float gain) { */ Func srgb(Func input) { - Func srgb_matrix("srgb_matrix"); + Buffer srgb_matrix(3, 3, "srgb_matrix"); Func output("srgb_output"); Var x, y, c; RDom r(0, 3); // srgb conversion matrix; - - srgb_matrix(x, y) = 0.f; - srgb_matrix(0, 0) = 1.964399f; srgb_matrix(1, 0) = -1.119710f; srgb_matrix(2, 0) = 0.155311f; srgb_matrix(0, 1) = -0.241156f; srgb_matrix(1, 1) = 1.673722f; srgb_matrix(2, 1) = -0.432566f; srgb_matrix(0, 2) = 0.013887f; srgb_matrix(1, 2) = -0.549820f; srgb_matrix(2, 2) = 1.535933f; @@ -548,12 +537,6 @@ Func srgb(Func input) { output(x, y, c) = u16_sat(sum(srgb_matrix(r, c) * input(x, y, r))); - /////////////////////////////////////////////////////////////////////////// - // schedule - /////////////////////////////////////////////////////////////////////////// - - srgb_matrix.compute_root().parallel(y).parallel(x); - return output; } @@ -647,7 +630,7 @@ Func sharpen(Func input, float strength) { */ Func u8bit_interleaved(Func input) { - Func output("8bit_interleaved_output"); + Func output("_8bit_interleaved_output"); Var c, x, y; @@ -672,8 +655,7 @@ Func u8bit_interleaved(Func input) { * and gain amounts. This produces natural-looking brightened shadows, without * blowing out highlights. The output values are 8-bit. */ -Func finish(Func input, int width, int height, const BlackPoint bp, const WhitePoint wp, const WhiteBalance &wb, const Compression c, const Gain g) { - +Halide::Func finish(Halide::Func input, Expr width, Expr height, Expr bp, Expr wp, const CompiletimeWhiteBalance &wb, const Expr c, const Expr g) { int denoise_passes = 1; float contrast_strength = 5.f; int black_level = 2000; @@ -717,3 +699,7 @@ Func finish(Func input, int width, int height, const BlackPoint bp, const WhiteP return u8bit_interleaved(contrast_output); } + +Func finish(Func input, int width, int height, const BlackPoint bp, const WhitePoint wp, const WhiteBalance &wb, const Compression c, const Gain g) { + return finish(input, width, height, bp, wp, wb, c, g); +} diff --git a/src/finish.h b/src/finish.h index 2784c71..f11d531 100644 --- a/src/finish.h +++ b/src/finish.h @@ -2,13 +2,31 @@ #define HDRPLUS_FINISH_H_ #include "Halide.h" + +template +struct TypedWhiteBalance { + template + TypedWhiteBalance(const TypedWhiteBalance& other) + : r(other.r) + , g0(other.g0) + , g1(other.g1) + , b(other.b) + {} -struct WhiteBalance { - float r; - float g0; - float g1; - float b; + TypedWhiteBalance(T r, T g0, T g1, T b) + : r(r) + , g0(g0) + , g1(g1) + , b(b) + {} + + T r; + T g0; + T g1; + T b; }; +using WhiteBalance = TypedWhiteBalance; +using CompiletimeWhiteBalance = TypedWhiteBalance; typedef uint16_t BlackPoint; typedef uint16_t WhitePoint; @@ -25,5 +43,6 @@ typedef float Gain; * blowing out highlights. The output values are 8-bit. */ Halide::Func finish(Halide::Func input, int width, int height, const BlackPoint bp, const WhitePoint wp, const WhiteBalance &wb, const Compression c, const Gain g); +Halide::Func finish(Halide::Func input, Halide::Expr width, Halide::Expr height, const Halide::Expr bp, const Halide::Expr wp, const CompiletimeWhiteBalance &wb, const Halide::Expr c, const Halide::Expr g); #endif \ No newline at end of file diff --git a/src/hdrplus_pipeline_generator.cpp b/src/hdrplus_pipeline_generator.cpp new file mode 100644 index 0000000..b14307c --- /dev/null +++ b/src/hdrplus_pipeline_generator.cpp @@ -0,0 +1,38 @@ +#include + +#include "align.h" +#include "merge.h" +#include "finish.h" + +namespace { + + class HdrPlusPipeline : public Halide::Generator { + public: + // 'inputs' is really a series of raw 2d frames; extent[2] specifies the count + Input> inputs{"inputs", 3}; + Input black_point{"black_point"}; + Input white_point{"white_point"}; + Input white_balance_r{"white_balance_r"}; + Input white_balance_g0{"white_balance_g0"}; + Input white_balance_g1{"white_balance_g1"}; + Input white_balance_b{"white_balance_b"}; + Input compression{"compression"}; + Input gain{"gain"}; + + // RGB output + Output> output{"output", 3}; + + void generate() { + // Algorithm + Func alignment = align(inputs, inputs.width(), inputs.height()); + Func merged = merge(inputs, inputs.width(), inputs.height(), inputs.dim(2).extent(), alignment); + CompiletimeWhiteBalance wb{ white_balance_r, white_balance_g0, white_balance_g1, white_balance_b }; + Func finished = finish(merged, inputs.width(), inputs.height(), black_point, white_point, wb, compression, gain); + output = finished; + // Schedule handled inside included functions + } + }; + +} // namespace + +HALIDE_REGISTER_GENERATOR(HdrPlusPipeline, hdrplus_pipeline) diff --git a/src/merge.cpp b/src/merge.cpp index 1430128..1b2e56c 100644 --- a/src/merge.cpp +++ b/src/merge.cpp @@ -14,7 +14,7 @@ using namespace Halide::ConciseCasts; * tile. Thresholds L1 scores so that tiles above a certain distance are completely * discounted, and tiles below a certain distance are assumed to be perfectly aligned. */ -Func merge_temporal(Buffer imgs, Func alignment) { +Func merge_temporal(Halide::Func imgs, Expr width, Expr height, Expr frames, Func alignment) { Func weight("merge_temporal_weights"); Func total_weight("merge_temporal_total_weights"); @@ -22,11 +22,11 @@ Func merge_temporal(Buffer imgs, Func alignment) { Var ix, iy, tx, ty, n; RDom r0(0, 16, 0, 16); // reduction over pixels in downsampled tile - RDom r1(1, imgs.extent(2) - 1); // reduction over alternate images + RDom r1(1, frames - 1); // reduction over alternate images // mirror input with overlapping edges - Func imgs_mirror = BoundaryConditions::mirror_interior(imgs, 0, imgs.width(), 0, imgs.height()); + Func imgs_mirror = BoundaryConditions::mirror_interior(imgs, 0, width, 0, height); // downsampled layer for computing L1 distances @@ -146,9 +146,11 @@ Func merge_spatial(Func input) { * merge -- fully merges aligned frames in the temporal and spatial * dimension to produce one denoised bayer frame. */ -Func merge(Buffer imgs, Func alignment) { - - Func merge_temporal_output = merge_temporal(imgs, alignment); - +Func merge(Halide::Func imgs, Halide::Expr width, Halide::Expr height, Halide::Expr frames, Halide::Func alignment) { + Func merge_temporal_output = merge_temporal(imgs, width, height, frames, alignment); return merge_spatial(merge_temporal_output); } + +Halide::Func merge(Halide::Buffer imgs, Halide::Func alignment) { + return merge(Halide::Func(imgs), imgs.width(), imgs.height(), imgs.extent(2), alignment); +} diff --git a/src/merge.h b/src/merge.h index 98be677..1084715 100644 --- a/src/merge.h +++ b/src/merge.h @@ -7,6 +7,6 @@ * merge -- fully merges aligned frames in the temporal and spatial * dimension to produce one denoised bayer frame. */ +Halide::Func merge(Halide::Func imgs, Halide::Expr width, Halide::Expr height, Halide::Expr frames, Halide::Func alignment); Halide::Func merge(Halide::Buffer imgs, Halide::Func alignment); - #endif diff --git a/src/util.cpp b/src/util.cpp index fee4abf..9c984f0 100644 --- a/src/util.cpp +++ b/src/util.cpp @@ -37,15 +37,15 @@ Func box_down2(Func input, std::string name) { Func gauss_down4(Func input, std::string name) { Func output(name); - Func k(name + "_filter"); + Buffer k(5, 5, "gauss_down4_kernel"); + k.translate({-2, -2}); Var x, y, n; RDom r(-2, 5, -2, 5); // gaussian kernel - k(x, y) = 0; - + k.fill(0); k(-2,-2) = 2; k(-1,-2) = 4; k(0,-2) = 5; k(1,-2) = 4; k(2,-2) = 2; k(-2,-1) = 4; k(-1,-1) = 9; k(0,-1) = 12; k(1,-1) = 9; k(2,-1) = 4; k(-2, 0) = 5; k(-1, 0) = 12; k(0, 0) = 15; k(1, 0) = 12; k(2, 0) = 5; @@ -60,8 +60,6 @@ Func gauss_down4(Func input, std::string name) { // schedule /////////////////////////////////////////////////////////////////////////// - k.compute_root().parallel(y).parallel(x); - output.compute_root().parallel(y).vectorize(x, 16); return output; @@ -70,7 +68,7 @@ Func gauss_down4(Func input, std::string name) { /* * gauss_7x7 -- Applies a 7x7 gauss kernel with a std deviation of 4/3. Requires its input to handle boundaries. */ -Func gauss(Func input, Func k, RDom r, std::string name) { +Func gauss(Func input, Buffer k, RDom r, std::string name) { Func blur_x(name + "_x"); Func output(name); @@ -117,46 +115,33 @@ Func gauss_7x7(Func input, std::string name) { // gaussian kernel - Func k("gauss_7x7_kernel"); + Buffer k(7, "gauss_7x7_kernel"); + k.translate({-3}); Var x; RDom r(-3, 7); - k(x) = f32(0.f); - + k.fill(0.f); k(-3) = 0.026267f; k(-2) = 0.100742f; k(-1) = 0.225511f; k(0) = 0.29496f; k( 3) = 0.026267f; k( 2) = 0.100742f; k( 1) = 0.225511f; - /////////////////////////////////////////////////////////////////////////// - // schedule - /////////////////////////////////////////////////////////////////////////// - - k.compute_root().parallel(x); - return gauss(input, k, r, name); - } Func gauss_15x15(Func input, std::string name) { // gaussian kernel - Func k("gauss_7x7_kernel"); + Buffer k(15, "gauss_15x15"); + k.translate({-7}); Var x; RDom r(-7, 15); - k(x) = f32(0.f); - + k.fill(0.f); k(-7) = 0.004961f; k(-6) = 0.012246f; k(-5) = 0.026304f; k(-4) = 0.049165f; k(-3) = 0.079968f; k(-2) = 0.113193f; k(-1) = 0.139431f; k(0) = 0.149464f; k( 7) = 0.004961f; k( 6) = 0.012246f; k( 5) = 0.026304f; k( 4) = 0.049165f; k( 3) = 0.079968f; k( 2) = 0.113193f; k( 1) = 0.139431f; - /////////////////////////////////////////////////////////////////////////// - // schedule - /////////////////////////////////////////////////////////////////////////// - - k.compute_root().parallel(x); - return gauss(input, k, r, name); }