Skip to content

Commit

Permalink
add lmc for nerf synthetic
Browse files Browse the repository at this point in the history
  • Loading branch information
Shakiba Kheradmand committed Nov 29, 2023
1 parent 6d6b273 commit ddf7ae1
Show file tree
Hide file tree
Showing 7 changed files with 339 additions and 68 deletions.
147 changes: 122 additions & 25 deletions examples/datasets/nerf_synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import torch
import torch.nn.functional as F

from .utils import Rays
from .utils import Rays, sample_from_pdf_with_indices, \
compute_sobel_edge, Image


def _load_renderings(root_fp: str, subject_id: str, split: str):
Expand Down Expand Up @@ -80,6 +81,9 @@ def __init__(
far: float = None,
batch_over_images: bool = True,
device: torch.device = torch.device("cpu"),
sampling_type: str = "uniform",
minpct: float = 0.1,
lossminpc: float = 0.1,
):
super().__init__()
assert split in self.SPLITS, "%s" % split
Expand Down Expand Up @@ -124,13 +128,43 @@ def __init__(
self.camtoworlds = self.camtoworlds.to(device)
self.K = self.K.to(device)
assert self.images.shape[1:3] == (self.HEIGHT, self.WIDTH)
self.sampling_type = sampling_type
self.simages = Image(self.images, device=device)
if self.training:
bs = np.ceil(self.num_rays / self.images.shape[0]).astype(np.int32)
self.num_rays = bs * self.images.shape[0]
self.const_img_id = torch.arange(0, self.images.shape[0], device=device).repeat_interleave(bs)

if self.sampling_type == "lmc":
self.image_edges = compute_sobel_edge(self.images.float()).reshape(self.images.shape[0], -1)
self.image_edges = self.image_edges.to(device)
probs = self.image_edges / self.image_edges.sum(dim=-1, keepdim=True)
cdf = torch.cumsum(probs, dim=-1)
cdf = torch.nn.functional.pad(cdf, pad=(1, 0), mode='constant', value=0)
self.cdf = cdf.view(cdf.shape[0], -1)

self.rand_ten = torch.empty((self.num_rays, 2), dtype=torch.float32, device=device)
self.noise = torch.empty((self.num_rays, 2), dtype=torch.float32, device=device)

self.u_num = int(minpct * self.num_rays)
self.reinit = int(lossminpc * self.num_rays)
self.prev_samples = None
x = torch.randint(0, self.WIDTH, size=(self.num_rays,), device=device)
y = torch.randint(0, self.HEIGHT, size=(self.num_rays,), device=device)
x = x.float() / (self.WIDTH - 1)
y = y.float() / (self.HEIGHT - 1)
self.prev_samples = torch.cat([y[..., None], x[..., None]], dim=1)
self.prev_samples.clamp_(min=0.0, max=1.0)
self.HW = torch.tensor([self.HEIGHT-1, self.WIDTH-1], device=device)

def __len__(self):
return len(self.images)

@torch.no_grad()
def __getitem__(self, index):
data = self.fetch_data(index)
def __getitem__(self, index, net_grad=None, loss_per_pix=None):
if self.sampling_type == "uniform":
data = self.fetch_data(index)
elif self.sampling_type == "lmc":
data = self.fetch_data_lmc(net_grad=net_grad, loss_per_pix=loss_per_pix)
data = self.preprocess(data)
return data

Expand Down Expand Up @@ -161,6 +195,30 @@ def preprocess(self, data):
def update_num_rays(self, num_rays):
self.num_rays = num_rays

def get_origin_viewdirs(self, image_id, y, x):
c2w = self.camtoworlds[image_id] # (num_rays, 3, 4)
camera_dirs = F.pad(
torch.stack(
[
(x - self.K[0, 2] + 0.5) / self.K[0, 0],
(y - self.K[1, 2] + 0.5)
/ self.K[1, 1]
* (-1.0 if self.OPENGL_CAMERA else 1.0),
],
dim=-1,
),
(0, 1),
value=(-1.0 if self.OPENGL_CAMERA else 1.0),
) # [num_rays, 3]

# [n_cams, height, width, 3]
directions = (camera_dirs[:, None, :] * c2w[:, :3, :3]).sum(dim=-1)
origins = torch.broadcast_to(c2w[:, :3, -1], directions.shape)
viewdirs = directions / torch.linalg.norm(
directions, dim=-1, keepdims=True
)
return origins, viewdirs

def fetch_data(self, index):
"""Fetch the data (it maybe cached for multiple batches)."""
num_rays = self.num_rays
Expand Down Expand Up @@ -193,27 +251,7 @@ def fetch_data(self, index):

# generate rays
rgba = self.images[image_id, y, x] / 255.0 # (num_rays, 4)
c2w = self.camtoworlds[image_id] # (num_rays, 3, 4)
camera_dirs = F.pad(
torch.stack(
[
(x - self.K[0, 2] + 0.5) / self.K[0, 0],
(y - self.K[1, 2] + 0.5)
/ self.K[1, 1]
* (-1.0 if self.OPENGL_CAMERA else 1.0),
],
dim=-1,
),
(0, 1),
value=(-1.0 if self.OPENGL_CAMERA else 1.0),
) # [num_rays, 3]

# [n_cams, height, width, 3]
directions = (camera_dirs[:, None, :] * c2w[:, :3, :3]).sum(dim=-1)
origins = torch.broadcast_to(c2w[:, :3, -1], directions.shape)
viewdirs = directions / torch.linalg.norm(
directions, dim=-1, keepdims=True
)
origins, viewdirs = self.get_origin_viewdirs(image_id, y, x)

if self.training:
origins = torch.reshape(origins, (num_rays, 3))
Expand All @@ -230,3 +268,62 @@ def fetch_data(self, index):
"rgba": rgba, # [h, w, 4] or [num_rays, 4]
"rays": rays, # [h, w, 3] or [num_rays, 3]
}

def fetch_data_lmc(self, net_grad=None, a=2e1, b=2e-2, loss_per_pix=None):
if net_grad is not None:
with torch.no_grad():
self.noise.normal_(mean=0.0, std=1.0)
self.rand_ten.uniform_()

net_grad.mul_(a).add_(self.noise, alpha=b)
self.prev_samples.add_(net_grad)

threshold, _ = torch.topk(loss_per_pix, self.reinit+1, largest=False)
mask = loss_per_pix <= threshold[-1]
mask = torch.cat([self.prev_samples < 0,
self.prev_samples > 1,
loss_per_pix.unsqueeze(1) <= threshold[-1]
], 1)
mask = mask.sum(1)
bound_idxs = torch.where(mask)[0]
self.prev_samples[-self.u_num:].copy_(self.rand_ten[-self.u_num:])

if bound_idxs.shape[0] > 0:
# sample from edges
count = torch.bincount(self.const_img_id[bound_idxs], minlength=self.images.shape[0])
batch1d = sample_from_pdf_with_indices(self.cdf, int(self.num_rays / self.images.shape[0]))
indices = torch.arange(batch1d.size(1), device=batch1d.device).unsqueeze_(0).repeat(batch1d.size(0), 1)
mask = indices < count.unsqueeze(1)
batch1d = batch1d.masked_select(mask)

self.prev_samples[bound_idxs, 0] = (batch1d // self.WIDTH) / (self.HEIGHT - 1)
self.prev_samples[bound_idxs, 1] = (batch1d % self.WIDTH) / (self.WIDTH - 1)
self.prev_samples.clamp_(min=0.0, max=1.0)
points_2d = self.prev_samples * self.HW
points_2d.round_()

# generate rays
rgba = self.simages(self.const_img_id, points_2d[:, 0], points_2d[:, 1]) / 255.0
points_2d.requires_grad = True
x, y = points_2d[:, 1], points_2d[:, 0]
origins, viewdirs = self.get_origin_viewdirs(self.const_img_id, y, x)

if self.training:
origins = torch.reshape(origins, (self.num_rays, 3))
viewdirs = torch.reshape(viewdirs, (self.num_rays, 3))
rgba = torch.reshape(rgba, (self.num_rays, 4))
else:
origins = torch.reshape(origins, (self.HEIGHT, self.WIDTH, 3))
viewdirs = torch.reshape(viewdirs, (self.HEIGHT, self.WIDTH, 3))
rgba = torch.reshape(rgba, (self.HEIGHT, self.WIDTH, 4))

rays = Rays(origins=origins, viewdirs=viewdirs)

return {
"rgba": rgba, # [h, w, 4] or [num_rays, 4]
"rays": rays, # [h, w, 3] or [num_rays, 3]
"x": x,
"y": y,
"image_id": self.const_img_id,
"points_2d": points_2d
}
68 changes: 68 additions & 0 deletions examples/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,78 @@
"""

import collections
import torch
import torch.nn.functional as F

Rays = collections.namedtuple("Rays", ("origins", "viewdirs"))


def namedtuple_map(fn, tup):
"""Apply `fn` to each element of `tup` and cast to `tup`'s namedtuple."""
return type(tup)(*(None if x is None else fn(x) for x in tup))


@torch.cuda.amp.autocast(dtype=torch.float64)
def sample_from_pdf_with_indices(cdf, num_points):
# Normalize the PDFs
u = torch.rand((cdf.shape[0], num_points,), device=cdf.device, dtype=torch.float64) #* cdf.max()
batch1d = torch.searchsorted(cdf, u, right=True) - 1
return batch1d


def compute_sobel_edge(images):
# Ensure the input is a torch tensor
if not isinstance(images, torch.Tensor):
images = torch.tensor(images)

if images.max() > 1.0:
images = images / 255.0

# Convert the images to grayscale using weighted sum of channels (shape: N x H x W x 1)
gray_images = 0.2989 * images[..., 0] + 0.5870 * images[..., 1] + 0.1140 * images[..., 2]
gray_images = gray_images.unsqueeze(-1)

# Transpose the images to the shape (N, C, H, W)
gray_images = gray_images.permute(0, 3, 1, 2)

# Define Sobel kernels
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32, device=images.device).view(1, 1, 3, 3)
sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32, device=images.device).view(1, 1, 3, 3)

# Compute Sobel edges
edge_x = F.conv2d(gray_images, sobel_x, padding=1)
edge_y = F.conv2d(gray_images, sobel_y, padding=1)
edges = torch.sqrt(edge_x ** 2 + edge_y ** 2)

maxval, _ = edges.max(dim=1)[0].max(dim=1)
edges = edges / (maxval.unsqueeze(1).unsqueeze(1) + 1e-7)
edges = torch.clip(edges, min=1e-5, max=1.0)
return edges.squeeze(1)


# source: https://github.com/NVlabs/tiny-cuda-nn/blob/master/samples/mlp_learning_an_image_pytorch.py
class Image(torch.nn.Module):
def __init__(self, images, device):
super(Image, self).__init__()
self.data = images.to(device, non_blocking=True)
self.shape = self.data[0].shape

@torch.cuda.amp.autocast(dtype=torch.float32)
def forward(self, iind, ys, xs):
shape = self.shape

xy = torch.cat([ys.unsqueeze(1), xs.unsqueeze(1)], dim=1)
indices = xy.long()
lerp_weights = xy - indices.float()

y0 = indices[:, 0].clamp(min=0, max=shape[0]-1)
x0 = indices[:, 1].clamp(min=0, max=shape[1]-1)
y1 = (y0 + 1).clamp(max=shape[0]-1)
x1 = (x0 + 1).clamp(max=shape[1]-1)

return (
self.data[iind, y0, x0] * (1.0 - lerp_weights[:,0:1]) * (1.0 - lerp_weights[:,1:2]) +
self.data[iind, y0, x1] * lerp_weights[:,0:1] * (1.0 - lerp_weights[:,1:2]) +
self.data[iind, y1, x0] * (1.0 - lerp_weights[:,0:1]) * lerp_weights[:,1:2] +
self.data[iind, y1, x1] * lerp_weights[:,0:1] * lerp_weights[:,1:2]
)
26 changes: 26 additions & 0 deletions examples/losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch
import torch.nn as nn

from nerfacc.losses import DistortionLoss
class NeRFLoss(nn.Module):
def __init__(self, lambda_opacity=0.0, lambda_distortion=0.01):
super().__init__()

self.lambda_opacity = lambda_opacity
self.lambda_distortion = lambda_distortion

def forward(self, rgb, target, opp=None, distkwargs=None):
d = {}
d['rgb'] = (rgb-target)**2

if self.lambda_opacity > 0:
o = opp+torch.finfo(torch.float16).eps
# encourage opacity to be either 0 or 1 to avoid floater
d['opacity'] = self.lambda_opacity*(-o*torch.log(o))

if self.lambda_distortion > 0 and distkwargs is not None:
d['distortion'] = self.lambda_distortion * \
DistortionLoss.apply(distkwargs['ws'], distkwargs['deltas'],
distkwargs['ts'], distkwargs['rays_a'])

return d
2 changes: 2 additions & 0 deletions examples/radiance_fields/ngp.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ def _query_rgb(self, dir, embedding, apply_act: bool = True):
if self.use_viewdirs:
dir = (dir + 1.0) / 2.0
d = self.direction_encoding(dir.reshape(-1, dir.shape[-1]))
d = (d - d.mean(-1, keepdim=True)) / (d.std(-1, keepdim=True) + torch.finfo(torch.float16).eps)
embedding = (embedding - embedding.mean(-1, keepdim=True)) / (embedding.std(-1, keepdim=True) + torch.finfo(torch.float16).eps )
h = torch.cat([d, embedding.reshape(-1, self.geo_feat_dim)], dim=-1)
else:
h = embedding.reshape(-1, self.geo_feat_dim)
Expand Down
38 changes: 38 additions & 0 deletions examples/schedulers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch.optim.lr_scheduler as lr_scheduler

def create_scheduler(optimizer_name, scheduler_type, max_steps, lr):
if scheduler_type == "step":
scheduler = lr_scheduler.StepLR(
optimizer_name, step_size=1000, gamma=0.847
)
elif scheduler_type == "cosineannealing":
scheduler = lr_scheduler.ChainedScheduler(
[
lr_scheduler.CosineAnnealingLR(
optimizer_name,
T_max=max_steps,
eta_min=lr / 10
)])
elif scheduler_type == "chain":
scheduler = lr_scheduler.ChainedScheduler(
[
lr_scheduler.LinearLR(
optimizer_name, start_factor=0.01, total_iters=100
),
lr_scheduler.MultiStepLR(
optimizer_name,
milestones=[
max_steps // 2,
max_steps * 3 // 4,
max_steps * 9 // 10,
],
gamma=0.33,
),
]
)
elif scheduler_type == "none":
scheduler = None
else:
raise ValueError(f"Invalid scheduler type: {scheduler_type}")

return scheduler
Loading

0 comments on commit ddf7ae1

Please sign in to comment.