diff --git a/src/viser/client/src/App.tsx b/src/viser/client/src/App.tsx
index 717de5cb..6e9806b0 100644
--- a/src/viser/client/src/App.tsx
+++ b/src/viser/client/src/App.tsx
@@ -466,12 +466,12 @@ function ViewerCanvas({ children }: { children: React.ReactNode }) {
}}
>
{inView ? null : }
- {children}
{memoizedCameraControls}
+ {children}
diff --git a/src/viser/client/src/MessageHandler.tsx b/src/viser/client/src/MessageHandler.tsx
index 8a26ca87..d4929677 100644
--- a/src/viser/client/src/MessageHandler.tsx
+++ b/src/viser/client/src/MessageHandler.tsx
@@ -19,6 +19,7 @@ import { Progress } from "@mantine/core";
import { IconCheck } from "@tabler/icons-react";
import { computeT_threeworld_world } from "./WorldTransformUtils";
import { rootNodeTemplate } from "./SceneTreeState";
+import { GaussianSplatsContext } from "./Splatting/GaussianSplats";
/** Returns a handler for all incoming messages. */
function useMessageHandler() {
@@ -521,6 +522,7 @@ export function FrameSynchronizedMessageHandler() {
const handleMessage = useMessageHandler();
const viewer = useContext(ViewerContext)!;
const messageQueueRef = viewer.messageQueueRef;
+ const splatContext = React.useContext(GaussianSplatsContext)!;
useFrame(
() => {
@@ -567,6 +569,21 @@ export function FrameSynchronizedMessageHandler() {
),
);
+ // Update splatting camera if needed.
+ // We'll back up the current sorted indices, and restore them after rendering.
+ const splatMeshProps = splatContext.meshPropsRef.current;
+ const sortedIndicesOrig =
+ splatMeshProps !== null
+ ? splatMeshProps.sortedIndexAttribute.array.slice()
+ : null;
+ if (splatContext.updateCamera.current !== null)
+ splatContext.updateCamera.current!(
+ camera,
+ targetWidth,
+ targetHeight,
+ true,
+ );
+
// Note: We don't need to add the camera to the scene for rendering
// The renderer.render() function uses the camera directly
// Create a new renderer
@@ -583,6 +600,12 @@ export function FrameSynchronizedMessageHandler() {
// Render the scene.
renderer.render(viewer.sceneRef.current!, camera);
+ // Restore splatting indices.
+ if (sortedIndicesOrig !== null && splatMeshProps !== null) {
+ splatMeshProps.sortedIndexAttribute.array = sortedIndicesOrig;
+ splatMeshProps.sortedIndexAttribute.needsUpdate = true;
+ }
+
// Get the rendered image.
viewer.getRenderRequestState.current = "in_progress";
renderer.domElement.toBlob(async (blob) => {
diff --git a/src/viser/client/src/Splatting/GaussianSplats.tsx b/src/viser/client/src/Splatting/GaussianSplats.tsx
index 2e77374e..ceabf35d 100644
--- a/src/viser/client/src/Splatting/GaussianSplats.tsx
+++ b/src/viser/client/src/Splatting/GaussianSplats.tsx
@@ -22,6 +22,8 @@
* between multiple splat objects.
*/
+import MakeSorterModulePromise from "./WasmSorter/Sorter.mjs";
+
import React from "react";
import * as THREE from "three";
import SplatSortWorker from "./SplatSortWorker?worker";
@@ -64,9 +66,22 @@ function useGaussianSplatStore() {
})),
)[0];
}
-const GaussianSplatsContext = React.createContext | null>(null);
+
+export const GaussianSplatsContext = React.createContext<{
+ useGaussianSplatStore: ReturnType;
+ updateCamera: React.MutableRefObject<
+ | null
+ | ((
+ camera: THREE.PerspectiveCamera,
+ width: number,
+ height: number,
+ blockingSort: boolean,
+ ) => void)
+ >;
+ meshPropsRef: React.MutableRefObject | null>;
+} | null>(null);
/**Provider for creating splat rendering context.*/
export function SplatRenderContext({
@@ -75,9 +90,18 @@ export function SplatRenderContext({
children: React.ReactNode;
}) {
const store = useGaussianSplatStore();
+ const numGroups = Object.keys(
+ store((state) => state.groupBufferFromId),
+ ).length;
return (
-
-
+
+ {numGroups > 0 ? : null}
{children}
);
@@ -251,9 +275,15 @@ export const SplatObject = React.forwardRef<
}
>(function SplatObject({ buffer }, ref) {
const splatContext = React.useContext(GaussianSplatsContext)!;
- const setBuffer = splatContext((state) => state.setBuffer);
- const removeBuffer = splatContext((state) => state.removeBuffer);
- const nodeRefFromId = splatContext((state) => state.nodeRefFromId);
+ const setBuffer = splatContext.useGaussianSplatStore(
+ (state) => state.setBuffer,
+ );
+ const removeBuffer = splatContext.useGaussianSplatStore(
+ (state) => state.removeBuffer,
+ );
+ const nodeRefFromId = splatContext.useGaussianSplatStore(
+ (state) => state.nodeRefFromId,
+ );
const name = React.useMemo(() => uuidv4(), [buffer]);
const [obj, setRef] = React.useState(null);
@@ -281,8 +311,12 @@ export const SplatObject = React.forwardRef<
/** External interface. Component should be added to the root of canvas. */
function SplatRenderer() {
const splatContext = React.useContext(GaussianSplatsContext)!;
- const groupBufferFromId = splatContext((state) => state.groupBufferFromId);
- const nodeRefFromId = splatContext((state) => state.nodeRefFromId);
+ const groupBufferFromId = splatContext.useGaussianSplatStore(
+ (state) => state.groupBufferFromId,
+ );
+ const nodeRefFromId = splatContext.useGaussianSplatStore(
+ (state) => state.nodeRefFromId,
+ );
// Consolidate Gaussian groups into a single buffer.
const merged = mergeGaussianGroups(groupBufferFromId);
@@ -290,6 +324,7 @@ function SplatRenderer() {
merged.gaussianBuffer,
merged.numGroups,
);
+ splatContext.meshPropsRef.current = meshProps;
// Create sorting worker.
const sortWorker = new SplatSortWorker();
@@ -310,7 +345,6 @@ function SplatRenderer() {
function postToWorker(message: SorterWorkerIncoming) {
sortWorker.postMessage(message);
}
-
postToWorker({
setBuffer: merged.gaussianBuffer,
setGroupIndices: merged.groupIndices,
@@ -337,96 +371,145 @@ function SplatRenderer() {
.slice()
.fill(0);
const prevVisibles: boolean[] = [];
- useFrame((state, delta) => {
- const mesh = meshRef.current;
- if (mesh === null || sortWorker === null) return;
- // Update camera parameter uniforms.
- const dpr = state.viewport.dpr;
- const fovY =
- ((state.camera as THREE.PerspectiveCamera).fov * Math.PI) / 180.0;
- const fovX = 2 * Math.atan(Math.tan(fovY / 2) * state.viewport.aspect);
- const fy = (dpr * state.size.height) / (2 * Math.tan(fovY / 2));
- const fx = (dpr * state.size.width) / (2 * Math.tan(fovX / 2));
+ // Make local sorter. This will be used for blocking sorts, eg for rendering
+ // from virtual cameras.
+ const SorterRef = React.useRef(null);
+ React.useEffect(() => {
+ (async () => {
+ SorterRef.current = new (await MakeSorterModulePromise()).Sorter(
+ merged.gaussianBuffer,
+ merged.groupIndices,
+ );
+ })();
+ }, [merged.gaussianBuffer, merged.groupIndices]);
+
+ const updateCamera = React.useCallback(
+ function updateCamera(
+ camera: THREE.PerspectiveCamera,
+ width: number,
+ height: number,
+ blockingSort: boolean,
+ ) {
+ // Update camera parameter uniforms.
+ const fovY = ((camera as THREE.PerspectiveCamera).fov * Math.PI) / 180.0;
+
+ const aspect = width / height;
+ const fovX = 2 * Math.atan(Math.tan(fovY / 2) * aspect);
+ const fy = height / (2 * Math.tan(fovY / 2));
+ const fx = width / (2 * Math.tan(fovX / 2));
+
+ if (meshProps.material === undefined) return;
+
+ const uniforms = meshProps.material.uniforms;
+ uniforms.focal.value = [fx, fy];
+ uniforms.near.value = camera.near;
+ uniforms.far.value = camera.far;
+ uniforms.viewport.value = [width, height];
+
+ // Update group transforms.
+ camera.updateMatrixWorld();
+ const T_camera_world = camera.matrixWorldInverse;
+ const groupVisibles: boolean[] = [];
+ let visibilitiesChanged = false;
+ for (const [groupIndex, name] of Object.keys(
+ groupBufferFromId,
+ ).entries()) {
+ const node = nodeRefFromId.current[name];
+ if (node === undefined) continue;
+ tmpT_camera_group.copy(T_camera_world).multiply(node.matrixWorld);
+ const colMajorElements = tmpT_camera_group.elements;
+ Tz_camera_groups.set(
+ [
+ colMajorElements[2],
+ colMajorElements[6],
+ colMajorElements[10],
+ colMajorElements[14],
+ ],
+ groupIndex * 4,
+ );
+ const rowMajorElements = tmpT_camera_group.transpose().elements;
+ meshProps.rowMajorT_camera_groups.set(
+ rowMajorElements.slice(0, 12),
+ groupIndex * 12,
+ );
+
+ // Determine visibility. If the parent has unmountWhenInvisible=true, the
+ // first frame after showing a hidden parent can have visible=true with
+ // an incorrect matrixWorld transform. There might be a better fix, but
+ // `prevVisible` is an easy workaround for this.
+ let visibleNow = node.visible && node.parent !== null;
+ if (visibleNow) {
+ node.traverseAncestors((ancestor) => {
+ visibleNow = visibleNow && ancestor.visible;
+ });
+ }
+ groupVisibles.push(visibleNow && prevVisibles[groupIndex] === true);
+ if (prevVisibles[groupIndex] !== visibleNow) {
+ prevVisibles[groupIndex] = visibleNow;
+ visibilitiesChanged = true;
+ }
+ }
+
+ const groupsMovedWrtCam = !meshProps.rowMajorT_camera_groups.every(
+ (v, i) => v === prevRowMajorT_camera_groups[i],
+ );
+
+ if (groupsMovedWrtCam) {
+ // Gaussians need to be re-sorted.
+ if (blockingSort && SorterRef.current !== null) {
+ const sortedIndices = SorterRef.current.sort(
+ Tz_camera_groups,
+ ) as Uint32Array;
+ meshProps.sortedIndexAttribute.set(sortedIndices);
+ meshProps.sortedIndexAttribute.needsUpdate = true;
+ } else {
+ postToWorker({
+ setTz_camera_groups: Tz_camera_groups,
+ });
+ }
+ }
+ if (groupsMovedWrtCam || visibilitiesChanged) {
+ // If a group is not visible, we'll throw it off the screen with some Big
+ // Numbers. It's important that this only impacts the coordinates used
+ // for the shader and not for the sorter; that way when we "show" a group
+ // of Gaussians the correct rendering order is immediately available.
+ for (const [i, visible] of groupVisibles.entries()) {
+ if (!visible) {
+ meshProps.rowMajorT_camera_groups[i * 12 + 3] = 1e10;
+ meshProps.rowMajorT_camera_groups[i * 12 + 7] = 1e10;
+ meshProps.rowMajorT_camera_groups[i * 12 + 11] = 1e10;
+ }
+ }
+ prevRowMajorT_camera_groups.set(meshProps.rowMajorT_camera_groups);
+ meshProps.textureT_camera_groups.needsUpdate = true;
+ }
+ },
+ [meshProps],
+ );
+ splatContext.updateCamera.current = updateCamera;
- if (meshProps.material === undefined) return;
+ useFrame((state, delta) => {
+ const mesh = meshRef.current;
+ if (
+ mesh === null ||
+ sortWorker === null ||
+ meshProps.rowMajorT_camera_groups.length === 0
+ )
+ return;
const uniforms = meshProps.material.uniforms;
uniforms.transitionInState.value = Math.min(
uniforms.transitionInState.value + delta * 2.0,
1.0,
);
- uniforms.focal.value = [fx, fy];
- uniforms.near.value = state.camera.near;
- uniforms.far.value = state.camera.far;
- uniforms.viewport.value = [state.size.width * dpr, state.size.height * dpr];
-
- // Update group transforms.
- const T_camera_world = state.camera.matrixWorldInverse;
- const groupVisibles: boolean[] = [];
- let visibilitiesChanged = false;
- for (const [groupIndex, name] of Object.keys(groupBufferFromId).entries()) {
- const node = nodeRefFromId.current[name];
- if (node === undefined) continue;
- tmpT_camera_group.copy(T_camera_world).multiply(node.matrixWorld);
- const colMajorElements = tmpT_camera_group.elements;
- Tz_camera_groups.set(
- [
- colMajorElements[2],
- colMajorElements[6],
- colMajorElements[10],
- colMajorElements[14],
- ],
- groupIndex * 4,
- );
- const rowMajorElements = tmpT_camera_group.transpose().elements;
- meshProps.rowMajorT_camera_groups.set(
- rowMajorElements.slice(0, 12),
- groupIndex * 12,
- );
- // Determine visibility. If the parent has unmountWhenInvisible=true, the
- // first frame after showing a hidden parent can have visible=true with
- // an incorrect matrixWorld transform. There might be a better fix, but
- // `prevVisible` is an easy workaround for this.
- let visibleNow = node.visible && node.parent !== null;
- if (visibleNow) {
- node.traverseAncestors((ancestor) => {
- visibleNow = visibleNow && ancestor.visible;
- });
- }
- groupVisibles.push(visibleNow && prevVisibles[groupIndex] === true);
- if (prevVisibles[groupIndex] !== visibleNow) {
- prevVisibles[groupIndex] = visibleNow;
- visibilitiesChanged = true;
- }
- }
-
- const groupsMovedWrtCam = !meshProps.rowMajorT_camera_groups.every(
- (v, i) => v === prevRowMajorT_camera_groups[i],
+ updateCamera(
+ state.camera as THREE.PerspectiveCamera,
+ state.viewport.dpr * state.size.width,
+ state.viewport.dpr * state.size.height,
+ false /* blockingSort */,
);
-
- if (groupsMovedWrtCam) {
- // Gaussians need to be re-sorted.
- postToWorker({
- setTz_camera_groups: Tz_camera_groups,
- });
- }
- if (groupsMovedWrtCam || visibilitiesChanged) {
- // If a group is not visible, we'll throw it off the screen with some Big
- // Numbers. It's important that this only impacts the coordinates used
- // for the shader and not for the sorter; that way when we "show" a group
- // of Gaussians the correct rendering order is immediately available.
- for (const [i, visible] of groupVisibles.entries()) {
- if (!visible) {
- meshProps.rowMajorT_camera_groups[i * 12 + 3] = 1e10;
- meshProps.rowMajorT_camera_groups[i * 12 + 7] = 1e10;
- meshProps.rowMajorT_camera_groups[i * 12 + 11] = 1e10;
- }
- }
- prevRowMajorT_camera_groups.set(meshProps.rowMajorT_camera_groups);
- meshProps.textureT_camera_groups.needsUpdate = true;
- }
}, -100 /* This should be called early to reduce group transform artifacts. */);
return (