Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataloading Revamp #3216

Open
wants to merge 87 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
0471543
initial debugging and testing works
AntonioMacaronio Jun 11, 2024
c6dde7d
pwais changes with RayBatchStream to alleviate training
AntonioMacaronio Jun 12, 2024
a09ea0c
Merge branch 'main' into dataloading-revamp
AntonioMacaronio Jun 12, 2024
78453cd
few bugs to iron out with multiprocessing, specifically pickled colla…
AntonioMacaronio Jun 12, 2024
f2bd96f
working version of RayBatchStream
AntonioMacaronio Jun 13, 2024
d8b7430
additional docstrings
AntonioMacaronio Jun 13, 2024
a5425d4
cleanup
AntonioMacaronio Jun 13, 2024
604f734
much more documentation
AntonioMacaronio Jun 13, 2024
0143803
successfully trained AEA-script2_seq2 closed_loop without OOM
AntonioMacaronio Jun 13, 2024
d3527e2
porting over aria dataset-size feature
AntonioMacaronio Jun 13, 2024
25f5f27
added logic to handle eviction of a worker's cached_collated_batch
AntonioMacaronio Jun 14, 2024
3a8b63b
antonio's implementation of stream batches
AntonioMacaronio Jun 15, 2024
536c6ca
training on a dataset with 4000 images works!
AntonioMacaronio Jun 15, 2024
43a0061
some configuration speedups, loops aren't actually needed!
AntonioMacaronio Jun 15, 2024
fa7cf30
quick fix adjustment to aria
AntonioMacaronio Jun 15, 2024
927cb6a
removed unnecessary looping
AntonioMacaronio Jun 16, 2024
814f2c2
much faster training when adding i variable to collate every 5 ray bu…
AntonioMacaronio Jun 25, 2024
247ac3e
cleanup unnecssary variables in Dataloader
AntonioMacaronio Jul 7, 2024
55d0803
further cleanup
AntonioMacaronio Jul 11, 2024
b6979a4
adding caching of compressed images to RAM to reduce disk bottleneck
AntonioMacaronio Jul 20, 2024
81dbf7c
added caching to RAM for masks
AntonioMacaronio Jul 22, 2024
55ca71d
found fast way to collate - many tricks applied
AntonioMacaronio Jul 26, 2024
3b4f091
quick update to aria to test on different datasets
AntonioMacaronio Jul 26, 2024
7de1922
cleaned up the accelerated pil_to_numpy function
AntonioMacaronio Jul 26, 2024
9ceaad1
cleaning up PR
AntonioMacaronio Jul 26, 2024
4147a6a
this commit was used to generate the time metrics and profiling metrics
AntonioMacaronio Jul 26, 2024
5a55b7a
REAL commit used to run tests
AntonioMacaronio Jul 26, 2024
78f02e6
testing with nerfacto-big
AntonioMacaronio Aug 15, 2024
19bc4b5
generated RayBundle collate and converting images from uint8s to floa…
AntonioMacaronio Aug 15, 2024
9245d05
updating nerfacto to support uint8 easily, will need to figure out a …
AntonioMacaronio Aug 20, 2024
3124c14
datamanager updates, both splat and nerf
AntonioMacaronio Aug 20, 2024
afb0612
must use writeable arrays because torch requires them
AntonioMacaronio Aug 20, 2024
288a740
cleaned up base_dataset, added pickle to utils, more code in full_ima…
AntonioMacaronio Aug 22, 2024
2fd0862
lots of process on a parallel FullImageDatamanger
AntonioMacaronio Aug 23, 2024
846e2f3
can train big splats with pre-assertion hack or ROI hack and 0 workers
AntonioMacaronio Aug 24, 2024
8fb0b4d
fixed all undistortion issues with ParallelImageDatamanager
AntonioMacaronio Aug 27, 2024
ce3f83f
adding some downsampling and parallel tests with splatfacto!
AntonioMacaronio Aug 31, 2024
8ab9963
deleted commented code in dataloaders.py and added bugfix to shuffling
AntonioMacaronio Aug 31, 2024
c9e16bf
testing splatfacto-big
AntonioMacaronio Sep 1, 2024
ddac38d
cleaned up base_pipeline.py
AntonioMacaronio Sep 1, 2024
443719a
cleaned up base_pipeline.py ACTUALLY THIS TIME, forgot to save last time
AntonioMacaronio Sep 1, 2024
d16e519
cleaned up a lot of code
AntonioMacaronio Sep 1, 2024
367d512
process_project_aria back to main branch and some cleanup in full_ima…
AntonioMacaronio Sep 1, 2024
d3d99b4
clarifying docstrings
AntonioMacaronio Sep 1, 2024
6f763dc
further PR cleanup
AntonioMacaronio Sep 3, 2024
a5191bd
updating models
AntonioMacaronio Sep 9, 2024
7db70dc
further cleanup
AntonioMacaronio Sep 9, 2024
5c3262b
removed caching of images into bytestrings
AntonioMacaronio Sep 9, 2024
ff2bda1
adding caching of compressed images to RAM, forgot that hardware matters
AntonioMacaronio Sep 9, 2024
f6dd7dd
removing oom methods, adding the ability to add a flag to dataloading
AntonioMacaronio Sep 15, 2024
a6602c7
removed CacheDataloader, moved RayBatchStream to dataloaders.py, new …
AntonioMacaronio Sep 15, 2024
3dc2031
fixing base_piplines, deleting a weird datamanager_configs file that …
AntonioMacaronio Sep 15, 2024
89f3d98
cleaning up next_train
AntonioMacaronio Sep 15, 2024
14e60e5
replaced parallel datamanager with new datamanager
AntonioMacaronio Sep 19, 2024
204dfb2
reverted the original base_datamanager.py, new datamanager replaced p…
AntonioMacaronio Sep 19, 2024
5864bc9
modified VanillaConfig, but VanillaDataManager is the same as before
AntonioMacaronio Sep 19, 2024
6d97de3
cleaning up, 2 datamanagers now - original and new parallel one
AntonioMacaronio Sep 19, 2024
1f34017
able to train with new nerfstudio dataloader now
AntonioMacaronio Sep 19, 2024
99cf86a
side by side datamanagers, moved tons of logic into dataloaders.py an…
AntonioMacaronio Sep 23, 2024
4ebad85
added custom ray processing API to support implementations like LERF,…
AntonioMacaronio Sep 23, 2024
87921be
adding functionality for ns-eval by adding FixedIndicesEvalDataloader…
AntonioMacaronio Sep 24, 2024
b628c7c
adding both ray API and image-view API to datamanagers for custom par…
AntonioMacaronio Sep 27, 2024
d2785d1
updating splatfacto config for 4k tests
AntonioMacaronio Sep 30, 2024
436af9d
updating docstrings to be more descriptive
AntonioMacaronio Sep 30, 2024
dd4daaa
new datamanager API breaks when setup_eval() has multiple workers, no…
AntonioMacaronio Sep 30, 2024
43c66ae
adding custom_view_processor to ImageBatchStream
AntonioMacaronio Sep 30, 2024
ba81e11
merging with main!
AntonioMacaronio Sep 30, 2024
1922566
reverting full_images_datamanager to main branch
AntonioMacaronio Oct 1, 2024
beb74be
removing nn.Module inheritance from Datamanager class
AntonioMacaronio Oct 1, 2024
087cff0
don't need to move datamanger to device anymore since Datamanager is …
AntonioMacaronio Oct 1, 2024
48e6d15
finished integration test with nerfacto
AntonioMacaronio Oct 4, 2024
3f1799b
simplified config variables, integrated the parallelism/disk-data-loa…
AntonioMacaronio Oct 25, 2024
f46aa42
updated the splatfacto config to be simpler with the dataloading and …
AntonioMacaronio Oct 25, 2024
5aa51fb
style checks and some cleanup
AntonioMacaronio Oct 25, 2024
ec3c12a
new splatfacto test, cleaning up nerfacto integration test
AntonioMacaronio Oct 25, 2024
82bc5b2
removing redundant parallel_full_images_datamaanger, as the OG full_i…
AntonioMacaronio Oct 26, 2024
377a56a
Merge branch 'main' into dataloading-revamp
AntonioMacaronio Oct 28, 2024
bbb5473
ruff linting and pyright fixing
AntonioMacaronio Oct 28, 2024
2e64120
further pyright fixing
AntonioMacaronio Oct 28, 2024
e9c2fd6
another pyright fixing
AntonioMacaronio Oct 28, 2024
e4dc9f9
fixing pyright error, camera optimization no longer part of datamanager
AntonioMacaronio Nov 1, 2024
8b0ec8e
fixing one pyright
AntonioMacaronio Nov 22, 2024
6349852
fixing dataloading error when camera is not undistorted with dataloader
AntonioMacaronio Dec 13, 2024
ad6b090
fixing comments and updating style
AntonioMacaronio Dec 21, 2024
8c678ee
undoing a style change i made
AntonioMacaronio Dec 21, 2024
64edabb
undoing another style change i made by accident
AntonioMacaronio Dec 21, 2024
cc63585
Merge branch 'main' into dataloading-revamp
AntonioMacaronio Dec 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nerfstudio/configs/method_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
max_num_iterations=30000,
mixed_precision=True,
pipeline=VanillaPipelineConfig(
datamanager=ParallelDataManagerConfig(
datamanager=VanillaDataManagerConfig(
dataparser=NerfstudioDataParserConfig(),
train_num_rays_per_batch=4096,
eval_num_rays_per_batch=4096,
Expand Down
290 changes: 267 additions & 23 deletions nerfstudio/data/datamanagers/base_datamanager.py

Large diffs are not rendered by default.

10 changes: 9 additions & 1 deletion nerfstudio/data/datamanagers/full_images_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@
from nerfstudio.utils.misc import get_orig_class
from nerfstudio.utils.rich_utils import CONSOLE

class ImageBatchStream(torch.utils.data.IterableDataset):
def __init__(
self,

):
return

# def

@dataclass
class FullImageDatamanagerConfig(DataManagerConfig):
Expand Down Expand Up @@ -79,7 +87,7 @@ class FullImageDatamanagerConfig(DataManagerConfig):
fps_reset_every: int = 100
"""The number of iterations before one resets fps sampler repeatly, which is essentially drawing fps_reset_every
samples from the pool of all training cameras without replacement before a new round of sampling starts."""


class FullImageDatamanager(DataManager, Generic[TDataset]):
"""
Expand Down
40 changes: 32 additions & 8 deletions nerfstudio/data/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

from copy import deepcopy
import io
from pathlib import Path
from typing import Dict, List, Literal

Expand All @@ -31,8 +32,8 @@

from nerfstudio.cameras.cameras import Cameras
from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs
from nerfstudio.data.utils.data_utils import get_image_mask_tensor_from_path

from nerfstudio.data.utils.data_utils import get_image_mask_tensor_from_path, pil_to_numpy
from torch.profiler import record_function

class InputDataset(Dataset):
"""Dataset that returns images.
Expand All @@ -45,7 +46,7 @@ class InputDataset(Dataset):
exclude_batch_keys_from_device: List[str] = ["image", "mask"]
cameras: Cameras

def __init__(self, dataparser_outputs: DataparserOutputs, scale_factor: float = 1.0):
def __init__(self, dataparser_outputs: DataparserOutputs, scale_factor: float = 1.0, cache_images: bool = True):
super().__init__()
self._dataparser_outputs = dataparser_outputs
self.scale_factor = scale_factor
Expand All @@ -54,6 +55,18 @@ def __init__(self, dataparser_outputs: DataparserOutputs, scale_factor: float =
self.cameras = deepcopy(dataparser_outputs.cameras)
self.cameras.rescale_output_resolution(scaling_factor=scale_factor)
self.mask_color = dataparser_outputs.metadata.get("mask_color", None)
self.cache_images = cache_images
"""If cache_images == True, cache all the image files into RAM in their compressed form (not as tensors yet)"""
if cache_images:
self.binary_images = []
self.binary_masks = []
for image_filename in self._dataparser_outputs.image_filenames:
with open(image_filename, 'rb') as f:
self.binary_images.append(io.BytesIO(f.read()))
if self._dataparser_outputs.mask_filenames is not None:
for mask_filename in self._dataparser_outputs.mask_filenames:
with open(mask_filename, 'rb') as f:
self.binary_masks.append(io.BytesIO(f.read()))

def __len__(self):
return len(self._dataparser_outputs.image_filenames)
Expand All @@ -65,12 +78,15 @@ def get_numpy_image(self, image_idx: int) -> npt.NDArray[np.uint8]:
image_idx: The image index in the dataset.
"""
image_filename = self._dataparser_outputs.image_filenames[image_idx]
pil_image = Image.open(image_filename)
if self.cache_images:
pil_image = Image.open(self.binary_images[image_idx])
else:
pil_image = Image.open(image_filename)
if self.scale_factor != 1.0:
width, height = pil_image.size
newsize = (int(width * self.scale_factor), int(height * self.scale_factor))
pil_image = pil_image.resize(newsize, resample=Image.Resampling.BILINEAR)
image = np.array(pil_image, dtype="uint8") # shape is (h, w) or (h, w, 3 or 4)
image = pil_to_numpy(pil_image) # # shape is (h, w) or (h, w, 3 or 4) and dtype == "uint8"
if len(image.shape) == 2:
image = image[:, :, None].repeat(3, axis=2)
assert len(image.shape) == 3
Expand All @@ -84,7 +100,12 @@ def get_image_float32(self, image_idx: int) -> Float[Tensor, "image_height image
Args:
image_idx: The image index in the dataset.
"""
image = torch.from_numpy(self.get_numpy_image(image_idx).astype("float32") / 255.0)
with record_function("pil_to_numpy()"):
image = self.get_numpy_image(image_idx)
with record_function("divide by 255.0 + convert to float32"):
image = image / np.float32(255)
with record_function("torch.from_numpy()"):
image = torch.from_numpy(image)
if self._dataparser_outputs.alpha_color is not None and image.shape[-1] == 4:
assert (self._dataparser_outputs.alpha_color >= 0).all() and (
self._dataparser_outputs.alpha_color <= 1
Expand All @@ -98,7 +119,7 @@ def get_image_uint8(self, image_idx: int) -> UInt8[Tensor, "image_height image_w
Args:
image_idx: The image index in the dataset.
"""
image = torch.from_numpy(self.get_numpy_image(image_idx))
image = torch.from_numpy(self.get_numpy_image(image_idx).astype(np.uint8))
if self._dataparser_outputs.alpha_color is not None and image.shape[-1] == 4:
assert (self._dataparser_outputs.alpha_color >= 0).all() and (
self._dataparser_outputs.alpha_color <= 1
Expand All @@ -125,7 +146,10 @@ def get_data(self, image_idx: int, image_type: Literal["uint8", "float32"] = "fl

data = {"image_idx": image_idx, "image": image}
if self._dataparser_outputs.mask_filenames is not None:
mask_filepath = self._dataparser_outputs.mask_filenames[image_idx]
if self.cache_images:
mask_filepath = self.binary_masks[image_idx]
else:
mask_filepath = self._dataparser_outputs.mask_filenames[image_idx]
data["mask"] = get_image_mask_tensor_from_path(filepath=mask_filepath, scale_factor=self.scale_factor)
assert (
data["mask"].shape[:2] == data["image"].shape[:2]
Expand Down
38 changes: 35 additions & 3 deletions nerfstudio/data/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,47 @@

"""Utility functions to allow easy re-use of common operations across dataloaders"""
from pathlib import Path
from typing import List, Tuple, Union
from typing import List, Tuple, Union, IO

import cv2
import numpy as np
import torch
from PIL import Image


def get_image_mask_tensor_from_path(filepath: Path, scale_factor: float = 1.0) -> torch.Tensor:
def pil_to_numpy(im: Image) -> np.ndarray:
"""Converts a PIL Image object to a NumPy array.

Args:
im (PIL.Image.Image): The input PIL Image object.

Returns:
numpy.ndarray representing the image data.
"""
# Load in image completely (PIL defaults to lazy loading)
im.load()

# Unpack data
e = Image._getencoder(im.mode, "raw", im.mode)
e.setimage(im.im)

# NumPy buffer for the result
shape, typestr = Image._conv_type_shape(im)
data = np.empty(shape, dtype=np.dtype(typestr))
mem = data.data.cast("B", (data.data.nbytes,))

bufsize, s, offset = 65536, 0, 0
while not s:
l, s, d = e.encode(bufsize)
mem[offset:offset + len(d)] = d
offset += len(d)
if s < 0:
raise RuntimeError("encoder error %d in tobytes" % s)

return data


def get_image_mask_tensor_from_path(filepath: Union[Path, IO[bytes]], scale_factor: float = 1.0) -> torch.Tensor:
"""
Utility function to read a mask image from the given path and return a boolean tensor
"""
Expand All @@ -31,7 +63,7 @@ def get_image_mask_tensor_from_path(filepath: Path, scale_factor: float = 1.0) -
width, height = pil_mask.size
newsize = (int(width * scale_factor), int(height * scale_factor))
pil_mask = pil_mask.resize(newsize, resample=Image.Resampling.NEAREST)
mask_tensor = torch.from_numpy(np.array(pil_mask)).unsqueeze(-1).bool()
mask_tensor = torch.from_numpy(pil_to_numpy(pil_mask)).unsqueeze(-1).bool()
AntonioMacaronio marked this conversation as resolved.
Show resolved Hide resolved
if len(mask_tensor.shape) != 3:
raise ValueError("The mask image should have 1 channel")
return mask_tensor
Expand Down
Loading
Loading