Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Gaussian rendering from virtual cameras #344

Merged
merged 4 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/viser/client/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -466,12 +466,12 @@ function ViewerCanvas({ children }: { children: React.ReactNode }) {
}}
>
{inView ? null : <DisableRender />}
{children}
<BackgroundImage />
<AdaptiveDpr />
<SceneContextSetter />
{memoizedCameraControls}
<SplatRenderContext>
{children}
<SceneNodeThreeObject name="" parent={null} />
</SplatRenderContext>
<DefaultLights />
Expand Down
23 changes: 23 additions & 0 deletions src/viser/client/src/MessageHandler.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import { Progress } from "@mantine/core";
import { IconCheck } from "@tabler/icons-react";
import { computeT_threeworld_world } from "./WorldTransformUtils";
import { rootNodeTemplate } from "./SceneTreeState";
import { GaussianSplatsContext } from "./Splatting/GaussianSplats";

/** Returns a handler for all incoming messages. */
function useMessageHandler() {
Expand Down Expand Up @@ -521,6 +522,7 @@ export function FrameSynchronizedMessageHandler() {
const handleMessage = useMessageHandler();
const viewer = useContext(ViewerContext)!;
const messageQueueRef = viewer.messageQueueRef;
const splatContext = React.useContext(GaussianSplatsContext)!;

useFrame(
() => {
Expand Down Expand Up @@ -567,6 +569,21 @@ export function FrameSynchronizedMessageHandler() {
),
);

// Update splatting camera if needed.
// We'll back up the current sorted indices, and restore them after rendering.
const splatMeshProps = splatContext.meshPropsRef.current;
const sortedIndicesOrig =
splatMeshProps !== null
? splatMeshProps.sortedIndexAttribute.array.slice()
: null;
if (splatContext.updateCamera.current !== null)
splatContext.updateCamera.current!(
camera,
targetWidth,
targetHeight,
true,
);

// Note: We don't need to add the camera to the scene for rendering
// The renderer.render() function uses the camera directly
// Create a new renderer
Expand All @@ -583,6 +600,12 @@ export function FrameSynchronizedMessageHandler() {
// Render the scene.
renderer.render(viewer.sceneRef.current!, camera);

// Restore splatting indices.
if (sortedIndicesOrig !== null && splatMeshProps !== null) {
splatMeshProps.sortedIndexAttribute.array = sortedIndicesOrig;
splatMeshProps.sortedIndexAttribute.needsUpdate = true;
}

// Get the rendered image.
viewer.getRenderRequestState.current = "in_progress";
renderer.domElement.toBlob(async (blob) => {
Expand Down
265 changes: 174 additions & 91 deletions src/viser/client/src/Splatting/GaussianSplats.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
* between multiple splat objects.
*/

import MakeSorterModulePromise from "./WasmSorter/Sorter.mjs";

import React from "react";
import * as THREE from "three";
import SplatSortWorker from "./SplatSortWorker?worker";
Expand Down Expand Up @@ -64,9 +66,22 @@ function useGaussianSplatStore() {
})),
)[0];
}
const GaussianSplatsContext = React.createContext<ReturnType<
typeof useGaussianSplatStore
> | null>(null);

export const GaussianSplatsContext = React.createContext<{
useGaussianSplatStore: ReturnType<typeof useGaussianSplatStore>;
updateCamera: React.MutableRefObject<
| null
| ((
camera: THREE.PerspectiveCamera,
width: number,
height: number,
blockingSort: boolean,
) => void)
>;
meshPropsRef: React.MutableRefObject<ReturnType<
typeof useGaussianMeshProps
> | null>;
} | null>(null);

/**Provider for creating splat rendering context.*/
export function SplatRenderContext({
Expand All @@ -75,9 +90,18 @@ export function SplatRenderContext({
children: React.ReactNode;
}) {
const store = useGaussianSplatStore();
const numGroups = Object.keys(
store((state) => state.groupBufferFromId),
).length;
return (
<GaussianSplatsContext.Provider value={store}>
<SplatRenderer />
<GaussianSplatsContext.Provider
value={{
useGaussianSplatStore: store,
updateCamera: React.useRef(null),
meshPropsRef: React.useRef(null),
}}
>
{numGroups > 0 ? <SplatRenderer /> : null}
{children}
</GaussianSplatsContext.Provider>
);
Expand Down Expand Up @@ -251,9 +275,15 @@ export const SplatObject = React.forwardRef<
}
>(function SplatObject({ buffer }, ref) {
const splatContext = React.useContext(GaussianSplatsContext)!;
const setBuffer = splatContext((state) => state.setBuffer);
const removeBuffer = splatContext((state) => state.removeBuffer);
const nodeRefFromId = splatContext((state) => state.nodeRefFromId);
const setBuffer = splatContext.useGaussianSplatStore(
(state) => state.setBuffer,
);
const removeBuffer = splatContext.useGaussianSplatStore(
(state) => state.removeBuffer,
);
const nodeRefFromId = splatContext.useGaussianSplatStore(
(state) => state.nodeRefFromId,
);
const name = React.useMemo(() => uuidv4(), [buffer]);

const [obj, setRef] = React.useState<THREE.Group | null>(null);
Expand Down Expand Up @@ -281,15 +311,20 @@ export const SplatObject = React.forwardRef<
/** External interface. Component should be added to the root of canvas. */
function SplatRenderer() {
const splatContext = React.useContext(GaussianSplatsContext)!;
const groupBufferFromId = splatContext((state) => state.groupBufferFromId);
const nodeRefFromId = splatContext((state) => state.nodeRefFromId);
const groupBufferFromId = splatContext.useGaussianSplatStore(
(state) => state.groupBufferFromId,
);
const nodeRefFromId = splatContext.useGaussianSplatStore(
(state) => state.nodeRefFromId,
);

// Consolidate Gaussian groups into a single buffer.
const merged = mergeGaussianGroups(groupBufferFromId);
const meshProps = useGaussianMeshProps(
merged.gaussianBuffer,
merged.numGroups,
);
splatContext.meshPropsRef.current = meshProps;

// Create sorting worker.
const sortWorker = new SplatSortWorker();
Expand All @@ -310,7 +345,6 @@ function SplatRenderer() {
function postToWorker(message: SorterWorkerIncoming) {
sortWorker.postMessage(message);
}

postToWorker({
setBuffer: merged.gaussianBuffer,
setGroupIndices: merged.groupIndices,
Expand All @@ -337,96 +371,145 @@ function SplatRenderer() {
.slice()
.fill(0);
const prevVisibles: boolean[] = [];
useFrame((state, delta) => {
const mesh = meshRef.current;
if (mesh === null || sortWorker === null) return;

// Update camera parameter uniforms.
const dpr = state.viewport.dpr;
const fovY =
((state.camera as THREE.PerspectiveCamera).fov * Math.PI) / 180.0;
const fovX = 2 * Math.atan(Math.tan(fovY / 2) * state.viewport.aspect);
const fy = (dpr * state.size.height) / (2 * Math.tan(fovY / 2));
const fx = (dpr * state.size.width) / (2 * Math.tan(fovX / 2));
// Make local sorter. This will be used for blocking sorts, eg for rendering
// from virtual cameras.
const SorterRef = React.useRef<any>(null);
React.useEffect(() => {
(async () => {
SorterRef.current = new (await MakeSorterModulePromise()).Sorter(
merged.gaussianBuffer,
merged.groupIndices,
);
})();
}, [merged.gaussianBuffer, merged.groupIndices]);

const updateCamera = React.useCallback(
function updateCamera(
camera: THREE.PerspectiveCamera,
width: number,
height: number,
blockingSort: boolean,
) {
// Update camera parameter uniforms.
const fovY = ((camera as THREE.PerspectiveCamera).fov * Math.PI) / 180.0;

const aspect = width / height;
const fovX = 2 * Math.atan(Math.tan(fovY / 2) * aspect);
const fy = height / (2 * Math.tan(fovY / 2));
const fx = width / (2 * Math.tan(fovX / 2));

if (meshProps.material === undefined) return;

const uniforms = meshProps.material.uniforms;
uniforms.focal.value = [fx, fy];
uniforms.near.value = camera.near;
uniforms.far.value = camera.far;
uniforms.viewport.value = [width, height];

// Update group transforms.
camera.updateMatrixWorld();
const T_camera_world = camera.matrixWorldInverse;
const groupVisibles: boolean[] = [];
let visibilitiesChanged = false;
for (const [groupIndex, name] of Object.keys(
groupBufferFromId,
).entries()) {
const node = nodeRefFromId.current[name];
if (node === undefined) continue;
tmpT_camera_group.copy(T_camera_world).multiply(node.matrixWorld);
const colMajorElements = tmpT_camera_group.elements;
Tz_camera_groups.set(
[
colMajorElements[2],
colMajorElements[6],
colMajorElements[10],
colMajorElements[14],
],
groupIndex * 4,
);
const rowMajorElements = tmpT_camera_group.transpose().elements;
meshProps.rowMajorT_camera_groups.set(
rowMajorElements.slice(0, 12),
groupIndex * 12,
);

// Determine visibility. If the parent has unmountWhenInvisible=true, the
// first frame after showing a hidden parent can have visible=true with
// an incorrect matrixWorld transform. There might be a better fix, but
// `prevVisible` is an easy workaround for this.
let visibleNow = node.visible && node.parent !== null;
if (visibleNow) {
node.traverseAncestors((ancestor) => {
visibleNow = visibleNow && ancestor.visible;
});
}
groupVisibles.push(visibleNow && prevVisibles[groupIndex] === true);
if (prevVisibles[groupIndex] !== visibleNow) {
prevVisibles[groupIndex] = visibleNow;
visibilitiesChanged = true;
}
}

const groupsMovedWrtCam = !meshProps.rowMajorT_camera_groups.every(
(v, i) => v === prevRowMajorT_camera_groups[i],
);

if (groupsMovedWrtCam) {
// Gaussians need to be re-sorted.
if (blockingSort && SorterRef.current !== null) {
const sortedIndices = SorterRef.current.sort(
Tz_camera_groups,
) as Uint32Array;
meshProps.sortedIndexAttribute.set(sortedIndices);
meshProps.sortedIndexAttribute.needsUpdate = true;
} else {
postToWorker({
setTz_camera_groups: Tz_camera_groups,
});
}
}
if (groupsMovedWrtCam || visibilitiesChanged) {
// If a group is not visible, we'll throw it off the screen with some Big
// Numbers. It's important that this only impacts the coordinates used
// for the shader and not for the sorter; that way when we "show" a group
// of Gaussians the correct rendering order is immediately available.
for (const [i, visible] of groupVisibles.entries()) {
if (!visible) {
meshProps.rowMajorT_camera_groups[i * 12 + 3] = 1e10;
meshProps.rowMajorT_camera_groups[i * 12 + 7] = 1e10;
meshProps.rowMajorT_camera_groups[i * 12 + 11] = 1e10;
}
}
prevRowMajorT_camera_groups.set(meshProps.rowMajorT_camera_groups);
meshProps.textureT_camera_groups.needsUpdate = true;
}
},
[meshProps],
);
splatContext.updateCamera.current = updateCamera;

if (meshProps.material === undefined) return;
useFrame((state, delta) => {
const mesh = meshRef.current;
if (
mesh === null ||
sortWorker === null ||
meshProps.rowMajorT_camera_groups.length === 0
)
return;

const uniforms = meshProps.material.uniforms;
uniforms.transitionInState.value = Math.min(
uniforms.transitionInState.value + delta * 2.0,
1.0,
);
uniforms.focal.value = [fx, fy];
uniforms.near.value = state.camera.near;
uniforms.far.value = state.camera.far;
uniforms.viewport.value = [state.size.width * dpr, state.size.height * dpr];

// Update group transforms.
const T_camera_world = state.camera.matrixWorldInverse;
const groupVisibles: boolean[] = [];
let visibilitiesChanged = false;
for (const [groupIndex, name] of Object.keys(groupBufferFromId).entries()) {
const node = nodeRefFromId.current[name];
if (node === undefined) continue;
tmpT_camera_group.copy(T_camera_world).multiply(node.matrixWorld);
const colMajorElements = tmpT_camera_group.elements;
Tz_camera_groups.set(
[
colMajorElements[2],
colMajorElements[6],
colMajorElements[10],
colMajorElements[14],
],
groupIndex * 4,
);
const rowMajorElements = tmpT_camera_group.transpose().elements;
meshProps.rowMajorT_camera_groups.set(
rowMajorElements.slice(0, 12),
groupIndex * 12,
);

// Determine visibility. If the parent has unmountWhenInvisible=true, the
// first frame after showing a hidden parent can have visible=true with
// an incorrect matrixWorld transform. There might be a better fix, but
// `prevVisible` is an easy workaround for this.
let visibleNow = node.visible && node.parent !== null;
if (visibleNow) {
node.traverseAncestors((ancestor) => {
visibleNow = visibleNow && ancestor.visible;
});
}
groupVisibles.push(visibleNow && prevVisibles[groupIndex] === true);
if (prevVisibles[groupIndex] !== visibleNow) {
prevVisibles[groupIndex] = visibleNow;
visibilitiesChanged = true;
}
}

const groupsMovedWrtCam = !meshProps.rowMajorT_camera_groups.every(
(v, i) => v === prevRowMajorT_camera_groups[i],
updateCamera(
state.camera as THREE.PerspectiveCamera,
state.viewport.dpr * state.size.width,
state.viewport.dpr * state.size.height,
false /* blockingSort */,
);

if (groupsMovedWrtCam) {
// Gaussians need to be re-sorted.
postToWorker({
setTz_camera_groups: Tz_camera_groups,
});
}
if (groupsMovedWrtCam || visibilitiesChanged) {
// If a group is not visible, we'll throw it off the screen with some Big
// Numbers. It's important that this only impacts the coordinates used
// for the shader and not for the sorter; that way when we "show" a group
// of Gaussians the correct rendering order is immediately available.
for (const [i, visible] of groupVisibles.entries()) {
if (!visible) {
meshProps.rowMajorT_camera_groups[i * 12 + 3] = 1e10;
meshProps.rowMajorT_camera_groups[i * 12 + 7] = 1e10;
meshProps.rowMajorT_camera_groups[i * 12 + 11] = 1e10;
}
}
prevRowMajorT_camera_groups.set(meshProps.rowMajorT_camera_groups);
meshProps.textureT_camera_groups.needsUpdate = true;
}
}, -100 /* This should be called early to reduce group transform artifacts. */);

return (
Expand Down