Skip to content

Commit

Permalink
Implement sort in C++/WebAssembly, 2~3x faster
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Oct 9, 2023
1 parent 81faa31 commit ffb5973
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 98 deletions.
2 changes: 1 addition & 1 deletion src/viser/client/src/WebsocketInterface.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import { isGuiConfig, useViserMantineTheme } from "./ControlPanel/GuiState";
import { useFrame } from "@react-three/fiber";
import GeneratedGuiContainer from "./ControlPanel/Generated";
import { MantineProvider, Paper } from "@mantine/core";
import GaussianSplats from "./splatting/GaussianSplats";
import GaussianSplats from "./Splatting/GaussianSplats";

/** Convert raw RGB color buffers to linear color buffers. **/
function linearColorArrayFromSrgbColorArray(colors: ArrayBuffer) {
Expand Down
128 changes: 31 additions & 97 deletions src/viser/client/src/splatting/SplatSortWorker.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
/** Worker for sorting splats.
*
* Adapted from Kevin Kwok:
* https://github.com/antimatter15/splat/blob/main/main.js
*/

import MakeSorterModulePromise from "./WasmSorter/Sorter.mjs";

export type GaussianBuffersSplitCov = {
// (N, 3)
centers: Float32Array;
Expand All @@ -16,92 +15,25 @@ export type GaussianBuffersSplitCov = {
// (N, 3)
covB: Float32Array;
};
{
// Worker state.
let buffers: GaussianBuffersSplitCov | null = null;
let sortedBuffers: GaussianBuffersSplitCov | null = null;

{
let sorter: any = null;
let viewProj: number[] | null = null;
let depthList = new Int32Array();
let sortedIndices: number[] = [];

// Counting sort buffers.
const counts0 = new Uint32Array(256 * 256);
const starts0 = new Uint32Array(256 * 256);

const runSort = (viewProj: number[] | null) => {
if (buffers === null || sortedBuffers === null || viewProj === null) return;

console.time("Start");
const numGaussians = buffers.centers.length / 3;

// Create new buffers.
if (sortedIndices.length !== numGaussians) {
depthList = new Int32Array(numGaussians);
sortedIndices = [...Array(numGaussians).keys()];
}

// Compute depth for each Gaussian.
let maxDepth = -Infinity;
let minDepth = Infinity;
for (let i = 0; i < depthList.length; i++) {
const depth =
((viewProj[2] * buffers.centers[i * 3 + 0] +
viewProj[6] * buffers.centers[i * 3 + 1] +
viewProj[10] * buffers.centers[i * 3 + 2]) *
4096) |
0;
depthList[i] = depth;
if (depth > maxDepth) maxDepth = depth;
if (depth < minDepth) minDepth = depth;
}

// This is a 16 bit single-pass counting sort.
const depthInv = (256 * 256) / (maxDepth - minDepth);
counts0.fill(0);
for (let i = 0; i < numGaussians; i++) {
depthList[i] = ((depthList[i] - minDepth) * depthInv) | 0;
counts0[depthList[i]]++;
}
for (let i = 1; i < 256 * 256; i++)
starts0[i] = starts0[i - 1] + counts0[i - 1];
console.timeEnd("Start");
console.time("Fill");
for (let i = 0; i < numGaussians; i++)
sortedIndices[starts0[depthList[i]]++] = i;

// Sort and post underlying buffers.
for (let i = 0; i < sortedIndices.length; i++) {
const j = sortedIndices[sortedIndices.length - i - 1];
sortedBuffers.centers[i * 3 + 0] = buffers.centers[j * 3 + 0];
sortedBuffers.centers[i * 3 + 1] = buffers.centers[j * 3 + 1];
sortedBuffers.centers[i * 3 + 2] = buffers.centers[j * 3 + 2];

sortedBuffers.rgbs[i * 3 + 0] = buffers.rgbs[j * 3 + 0];
sortedBuffers.rgbs[i * 3 + 1] = buffers.rgbs[j * 3 + 1];
sortedBuffers.rgbs[i * 3 + 2] = buffers.rgbs[j * 3 + 2];

sortedBuffers.opacities[i] = buffers.opacities[j];

sortedBuffers.covA[i * 3 + 0] = buffers.covA[j * 3 + 0];
sortedBuffers.covA[i * 3 + 1] = buffers.covA[j * 3 + 1];
sortedBuffers.covA[i * 3 + 2] = buffers.covA[j * 3 + 2];

sortedBuffers.covB[i * 3 + 0] = buffers.covB[j * 3 + 0];
sortedBuffers.covB[i * 3 + 1] = buffers.covB[j * 3 + 1];
sortedBuffers.covB[i * 3 + 2] = buffers.covB[j * 3 + 2];
}
console.timeEnd("Fill");
self.postMessage(sortedBuffers);
};

let sortRunning = false;
const throttledSort = () => {
if (sortRunning) return;
if (sorter === null || viewProj === null || sortRunning) return;

sortRunning = true;
const lastView = viewProj;
runSort(lastView);
sorter.sort(viewProj[2], viewProj[6], viewProj[10]);
self.postMessage({
centers: sorter.getSortedCenters(),
rgbs: sorter.getSortedRgbs(),
opacities: sorter.getSortedOpacities(),
covA: sorter.getSortedCovA(),
covB: sorter.getSortedCovB(),
});

setTimeout(() => {
sortRunning = false;
if (lastView !== viewProj) {
Expand All @@ -110,7 +42,9 @@ export type GaussianBuffersSplitCov = {
}, 0);
};

self.onmessage = (e) => {
const SorterModulePromise = MakeSorterModulePromise();

self.onmessage = async (e) => {
const data = e.data as
| {
setBuffers: GaussianBuffersSplitCov;
Expand All @@ -121,22 +55,22 @@ export type GaussianBuffersSplitCov = {
| { close: true };

if ("setBuffers" in data) {
buffers = data.setBuffers;
sortedBuffers = {
centers: new Float32Array(buffers.centers.length),
rgbs: new Float32Array(buffers.rgbs.length),
opacities: new Float32Array(buffers.opacities.length),
covA: new Float32Array(buffers.covA.length),
covB: new Float32Array(buffers.covB.length),
};
}

if ("setViewProj" in data) {
// Instantiate sorter with buffers populated.
const buffers = data.setBuffers as GaussianBuffersSplitCov;
sorter = new (await SorterModulePromise).Sorter(
buffers.centers,
buffers.rgbs,
buffers.opacities,
buffers.covA,
buffers.covB,
);
throttledSort();
} else if ("setViewProj" in data) {
// Update view projection matrix.
viewProj = data.setViewProj;
throttledSort();
}

if ("close" in data) {
} else if ("close" in data) {
// Done!
self.close();
}
};
Expand Down
16 changes: 16 additions & 0 deletions src/viser/client/src/splatting/WasmSorter/Sorter.mjs

Large diffs are not rendered by default.

Binary file not shown.
3 changes: 3 additions & 0 deletions src/viser/client/src/splatting/WasmSorter/build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/usr/bin/env bash

emcc --bind -Oz sorter.cpp -o Sorter.mjs -s WASM=1 -s NO_EXIT_RUNTIME=1 -s "EXPORTED_RUNTIME_METHODS=['addOnPostRun']" -s ALLOW_MEMORY_GROWTH=1 -s MAXIMUM_MEMORY=1GB;
156 changes: 156 additions & 0 deletions src/viser/client/src/splatting/WasmSorter/sorter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
#include <cstdint>
#include <iostream>
#include <string>
#include <vector>

#include <emscripten/bind.h>
#include <emscripten/val.h>

struct Gaussian {
float center[3];
float rgb[3];
float opacity;
float cov_a[3];
float cov_b[3];
};

struct Buffers {
std::vector<float> centers;
std::vector<float> rgbs;
std::vector<float> opacities;
std::vector<float> cov_a;
std::vector<float> cov_b;
};

// Build a Float32Array view of a C++ vector.
emscripten::val getFloat32Array(const std::vector<float> &v) {
return emscripten::val(emscripten::typed_memory_view(v.size(), &(v[0])));
}

class Sorter {
// Properties for each unsorted Gaussian. A vector of structs (as opposed to a
// struct of vectors) produces less fragmented reads. This can result in an
// ~30% runtime improvement.
std::vector<Gaussian> unsorted_gaussians;

// Sorted buffers. The memory layout here is intended to match WebGL.
Buffers sorted_buffers;

public:
Sorter(const emscripten::val &centers, const emscripten::val &rgbs,
const emscripten::val &opacities, const emscripten::val &cov_a,
const emscripten::val &cov_b) {
Buffers unsorted_buffers{
emscripten::convertJSArrayToNumberVector<float>(centers),
emscripten::convertJSArrayToNumberVector<float>(rgbs),
emscripten::convertJSArrayToNumberVector<float>(opacities),
emscripten::convertJSArrayToNumberVector<float>(cov_a),
emscripten::convertJSArrayToNumberVector<float>(cov_b)};

int num_gaussians = unsorted_buffers.centers.size() / 3;
for (int i = 0; i < num_gaussians; i++) {
unsorted_gaussians.push_back({
{unsorted_buffers.centers[i * 3 + 0],
unsorted_buffers.centers[i * 3 + 1],
unsorted_buffers.centers[i * 3 + 2]},
{unsorted_buffers.rgbs[i * 3 + 0], unsorted_buffers.rgbs[i * 3 + 1],
unsorted_buffers.rgbs[i * 3 + 2]},
unsorted_buffers.opacities[i],
{unsorted_buffers.cov_a[i * 3 + 0], unsorted_buffers.cov_a[i * 3 + 1],
unsorted_buffers.cov_a[i * 3 + 2]},
{unsorted_buffers.cov_b[i * 3 + 0], unsorted_buffers.cov_b[i * 3 + 1],
unsorted_buffers.cov_b[i * 3 + 2]},
});
}

sorted_buffers = unsorted_buffers;
};

// Run sorting using the newest view projection matrix. Mutates internal
// buffers.
void sort(float view_proj_2, float view_proj_6, float view_proj_10) {
const int num_gaussians = unsorted_gaussians.size();

// We do a 16-bit counting sort. This is mostly translated from Kevin Kwok's
// Javascript implementation:
// https://github.com/antimatter15/splat/blob/main/main.js
std::vector<int> depths(num_gaussians);
std::vector<int> counts0(256 * 256, 0);
std::vector<int> starts0(256 * 256, 0);

int min_depth;
int max_depth;
for (int i = 0; i < num_gaussians; i++) {
const int depth = (((view_proj_2 * unsorted_gaussians[i].center[0] +
view_proj_6 * unsorted_gaussians[i].center[1] +
view_proj_10 * unsorted_gaussians[i].center[2]) *
4096.0));
depths[i] = depth;

if (i == 0 || depth < min_depth)
min_depth = depth;
if (i == 0 || depth > max_depth)
max_depth = depth;
}
const float depth_inv = (256 * 256 - 1) / (max_depth - min_depth + 1e-5);
for (int i = 0; i < num_gaussians; i++) {
const int depth_bin = ((depths[i] - min_depth) * depth_inv);
depths[i] = depth_bin;
counts0[depth_bin]++;
}
for (int i = 1; i < 256 * 256; i++) {
starts0[i] = starts0[i - 1] + counts0[i - 1];
}

std::vector<int> sorted_indices(num_gaussians);
for (int i = 0; i < num_gaussians; i++)
sorted_indices[starts0[depths[i]]++] = i;

// Rearrange values in underlying buffers. This is the slowest part of the
// sort.
for (int i = 0; i < num_gaussians; i++) {
const int j = sorted_indices[num_gaussians - i - 1];

const Gaussian &gaussian = unsorted_gaussians[j];
memcpy(&(sorted_buffers.centers[i * 3]), &gaussian.center, 4 * 3);
memcpy(&(sorted_buffers.rgbs[i * 3]), &gaussian.rgb, 4 * 3);
sorted_buffers.opacities[i] = gaussian.opacity;
memcpy(&(sorted_buffers.cov_a[i * 3]), &gaussian.cov_a, 4 * 3);
memcpy(&(sorted_buffers.cov_b[i * 3]), &gaussian.cov_b, 4 * 3);
}
}

// Access outputs.
emscripten::val getSortedCenters() {
return getFloat32Array(sorted_buffers.centers);
}
emscripten::val getSortedRgbs() {
return getFloat32Array(sorted_buffers.rgbs);
}
emscripten::val getSortedOpacities() {
return getFloat32Array(sorted_buffers.opacities);
}
emscripten::val getSortedCovA() {
return getFloat32Array(sorted_buffers.cov_a);
}
emscripten::val getSortedCovB() {
return getFloat32Array(sorted_buffers.cov_b);
}
};

EMSCRIPTEN_BINDINGS(c) {
emscripten::class_<Sorter>("Sorter")
.constructor<emscripten::val, emscripten::val, emscripten::val,
emscripten::val, emscripten::val>()
.function("sort", &Sorter::sort)
.function("getSortedCenters", &Sorter::getSortedCenters,
emscripten::allow_raw_pointers())
.function("getSortedRgbs", &Sorter::getSortedRgbs,
emscripten::allow_raw_pointers())
.function("getSortedOpacities", &Sorter::getSortedOpacities,
emscripten::allow_raw_pointers())
.function("getSortedCovA", &Sorter::getSortedCovA,
emscripten::allow_raw_pointers())
.function("getSortedCovB", &Sorter::getSortedCovB,
emscripten::allow_raw_pointers());
};

0 comments on commit ffb5973

Please sign in to comment.