Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] : Fix resume issues with combined streaming dataset in dataloader #362

Draft
wants to merge 38 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
0ba50c5
chore: Add tests for CombinedStreamingDataset in test_dataloader.py
bhimrazy Sep 3, 2024
3756b9e
Merge branch 'main' into fix/combined-dataset-loading-states
bhimrazy Sep 3, 2024
f8ed272
Adds resuming for dataloading states for combined dataset case with w…
bhimrazy Sep 3, 2024
8c53791
update
bhimrazy Sep 3, 2024
e193eb9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 3, 2024
837efde
update
bhimrazy Sep 3, 2024
242a13c
Refactor dataloader to fix num_samples_yieled calculation
bhimrazy Sep 5, 2024
dff88ca
Adds more tests
bhimrazy Sep 5, 2024
bc137a3
removes the subtraction in the epoch
bhimrazy Sep 5, 2024
e38592a
update initialize part
bhimrazy Sep 5, 2024
e095407
updated epoch numbers
bhimrazy Sep 5, 2024
97510e2
format imports
bhimrazy Sep 5, 2024
5c6925d
reverted current epoch
bhimrazy Sep 6, 2024
8502bb1
removed combined data test and moved to `test_combined.py`
bhimrazy Sep 6, 2024
5cda0c6
reverted epcoh and also moved the test combined dataset
bhimrazy Sep 6, 2024
8cdafeb
Merge branch 'main' into fix/combined-dataset-loading-states
bhimrazy Sep 6, 2024
560032c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 6, 2024
f099242
revert epoch
bhimrazy Sep 6, 2024
dc484a0
updated current epoch
bhimrazy Sep 9, 2024
53b360c
fix epoch number
bhimrazy Sep 9, 2024
0c7cf3e
updated params
bhimrazy Sep 9, 2024
8010d33
Update current_epoch in test_dataloader.py
bhimrazy Sep 9, 2024
7355532
Update num_workers in test_combined.py
bhimrazy Sep 9, 2024
d40e3ca
updated the conditions
bhimrazy Sep 9, 2024
be77b58
updated tests: added case for the complete last iteration
bhimrazy Sep 9, 2024
86ccd99
Refactor test_combined.py to fix restore state issue
bhimrazy Sep 9, 2024
206b574
Merge branch 'main' into fix/combined-dataset-loading-states
bhimrazy Sep 17, 2024
5927331
fix: separated test cases for compelete and partial last epoch.
bhimrazy Sep 17, 2024
da10c3f
Merge branch 'fix/combined-dataset-loading-states' of github.com:bhim…
bhimrazy Sep 17, 2024
17de9e7
fix type errors
bhimrazy Sep 17, 2024
a05e6f1
Refactor test_combined.py: Remove print statement
bhimrazy Sep 17, 2024
3762b11
Merge branch 'main' into fix/combined-dataset-loading-states
bhimrazy Sep 19, 2024
d506972
Merge branch 'main' into fix/combined-dataset-loading-states
bhimrazy Sep 22, 2024
868aa42
Adds conftest for combined dtaatset to reuse
bhimrazy Sep 22, 2024
357c29c
Simplified testes with parameterize to test for different conditions
bhimrazy Sep 22, 2024
ec3b840
update test
bhimrazy Sep 22, 2024
0ee0617
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 22, 2024
424bc64
Merge branch 'main' into fix/combined-dataset-loading-states
bhimrazy Dec 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions src/litdata/streaming/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,14 +665,16 @@ def state_dict(self) -> Dict[str, Any]:
"latest_worker_idx": self._latest_worker_idx,
}

num_samples_yieled = [0 for _ in range(len(list(self._num_samples_yielded_combined.values())[0]))]
# Initialize a list to track the number of samples yielded for each dataset
num_samples_yieled = [0 for _ in range(len(self.dataset._datasets))]

for worker_idx in self._num_samples_yielded_combined:
for dataset_idx, samples_yieled in enumerate(self._num_samples_yielded_combined[worker_idx]):
num_samples_yieled[dataset_idx] += samples_yieled

return {
"dataset": self.dataset.state_dict(self.num_workers, self.batch_size, num_samples_yieled),
"current_epoch": self.current_epoch if self.restore else self.current_epoch - 1,
"current_epoch": self.current_epoch,
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"latest_worker_idx": self._latest_worker_idx,
"num_samples_yielded": deepcopy(self._num_samples_yielded_combined),
}
Expand Down Expand Up @@ -701,22 +703,30 @@ def load_state_dict(self, obj: Dict[str, Any]) -> None:

# Inform we are resuming and disable resetting the StreamingDataLoader state.
# This is toggle back to False when the `__iter__` method of the StreamingDataLoader completes.
# self.restore = True

if isinstance(self.dataset, CombinedStreamingDataset):
self.dataset._set_use_streaming_dataloader(True)
self.dataset.load_state_dict(obj)

# Inform that the dataloader is resuming.
# TODO: Check if the number of samples yielded is less than the length of the dataset.
# Also, len is not available for CombinedStreamingDataset in case of provided weights.
self.restore = True
total_samples_yielded = sum([sum(samples) for samples in self._num_samples_yielded_combined.values()])

# Check if we need to restore for the case without weights.
if (
self.dataset._iterate_over_all
and total_samples_yielded > 0
and total_samples_yielded < len(self.dataset) # type: ignore
):
self.restore = True

# Check if we need to restore for the case with weights.
# Note: `len` is not available for CombinedStreamingDataset in case of provided weights.
# TODO: handle the case with weights.
if not self.dataset._iterate_over_all:
self.restore = True

elif isinstance(self.dataset, StreamingDataset):
self.dataset.load_state_dict(obj["dataset"])

# Inform that the dataloader is resuming.
if self._num_samples_yielded_streaming < len(self.dataset):
if self._num_samples_yielded_streaming > 0 and self._num_samples_yielded_streaming < len(self.dataset):
self.restore = True
else:
raise RuntimeError("The provided dataset should be a `StreamingDataset` or a `CombinedStreamingDataset`.")
Expand Down
18 changes: 18 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import pytest
import torch.distributed
from litdata import CombinedStreamingDataset, StreamingDataset
from litdata.streaming.cache import Cache
from litdata.streaming.reader import PrepareChunksThread


Expand Down Expand Up @@ -108,3 +110,19 @@ def _thread_police():
thread.join(timeout=20)
else:
raise AssertionError(f"Test left zombie thread: {thread}")


@pytest.fixture()
def combined_dataset(tmpdir_factory):
tmpdir = tmpdir_factory.mktemp("data")
datasets = [str(tmpdir.join(f"dataset_{i}")) for i in range(2)]
for dataset in datasets:
cache = Cache(input_dir=dataset, chunk_bytes="64MB")
for i in range(50):
cache[i] = i
cache.done()
cache.merge()

dataset_1 = StreamingDataset(datasets[0], shuffle=True)
dataset_2 = StreamingDataset(datasets[1], shuffle=True)
return CombinedStreamingDataset(datasets=[dataset_1, dataset_2])
107 changes: 91 additions & 16 deletions tests/streaming/test_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def test_combined_dataset_with_dataloader_and_one_worker(batch_size):
"0": {"num_samples_yielded": 9, "num_workers": 1, "batch_size": batch_size},
"1": {"num_samples_yielded": 3, "num_workers": 1, "batch_size": batch_size},
},
"current_epoch": 0,
"current_epoch": 1,
"latest_worker_idx": 0,
"num_samples_yielded": {0: [9, 3]},
}
Expand Down Expand Up @@ -421,7 +421,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"region_of_interest": ANY,
},
},
"current_epoch": 0,
"current_epoch": 1,
"latest_worker_idx": 0,
"num_samples_yielded": {0: [2, 0]},
},
Expand Down Expand Up @@ -460,7 +460,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"region_of_interest": ANY,
},
},
"current_epoch": 0,
"current_epoch": 1,
"latest_worker_idx": 1,
"num_samples_yielded": {0: [2, 0], 1: [2, 0]},
},
Expand Down Expand Up @@ -499,7 +499,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"region_of_interest": ANY,
},
},
"current_epoch": 0,
"current_epoch": 1,
"latest_worker_idx": 2,
"num_samples_yielded": {0: [2, 0], 1: [2, 0], 2: [2, 0]},
},
Expand Down Expand Up @@ -538,7 +538,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"region_of_interest": ANY,
},
},
"current_epoch": 0,
"current_epoch": 1,
"latest_worker_idx": 0,
"num_samples_yielded": {0: [3, 1], 1: [2, 0], 2: [2, 0]},
},
Expand Down Expand Up @@ -577,7 +577,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"region_of_interest": ANY,
},
},
"current_epoch": 0,
"current_epoch": 1,
"latest_worker_idx": 1,
"num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [2, 0]},
},
Expand Down Expand Up @@ -616,7 +616,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"region_of_interest": ANY,
},
},
"current_epoch": 0,
"current_epoch": 1,
"latest_worker_idx": 2,
"num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [3, 1]},
},
Expand Down Expand Up @@ -655,7 +655,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"region_of_interest": ANY,
},
},
"current_epoch": 0,
"current_epoch": 1,
"latest_worker_idx": 0,
"num_samples_yielded": {0: [4, 1], 1: [3, 1], 2: [3, 1]},
},
Expand Down Expand Up @@ -697,7 +697,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"region_of_interest": ANY,
},
},
"current_epoch": 1,
"current_epoch": 2,
"latest_worker_idx": 0,
"num_samples_yielded": {0: [2, 0]},
},
Expand Down Expand Up @@ -736,7 +736,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"region_of_interest": ANY,
},
},
"current_epoch": 1,
"current_epoch": 2,
"latest_worker_idx": 1,
"num_samples_yielded": {0: [2, 0], 1: [2, 0]},
},
Expand Down Expand Up @@ -775,7 +775,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"region_of_interest": ANY,
},
},
"current_epoch": 1,
"current_epoch": 2,
"latest_worker_idx": 2,
"num_samples_yielded": {0: [2, 0], 1: [2, 0], 2: [2, 0]},
},
Expand Down Expand Up @@ -814,7 +814,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"region_of_interest": ANY,
},
},
"current_epoch": 1,
"current_epoch": 2,
"latest_worker_idx": 0,
"num_samples_yielded": {0: [3, 1], 1: [2, 0], 2: [2, 0]},
},
Expand Down Expand Up @@ -853,7 +853,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"region_of_interest": ANY,
},
},
"current_epoch": 1,
"current_epoch": 2,
"latest_worker_idx": 1,
"num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [2, 0]},
},
Expand Down Expand Up @@ -892,7 +892,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"region_of_interest": ANY,
},
},
"current_epoch": 1,
"current_epoch": 2,
"latest_worker_idx": 2,
"num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [3, 1]},
},
Expand Down Expand Up @@ -931,7 +931,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"region_of_interest": ANY,
},
},
"current_epoch": 1,
"current_epoch": 2,
"latest_worker_idx": 0,
"num_samples_yielded": {0: [4, 1], 1: [3, 1], 2: [3, 1]},
},
Expand All @@ -948,6 +948,81 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
states_23.append(dataloader.state_dict())

assert sum(not torch.equal(b1, b2) for b1, b2 in zip(batches_2[2:], batches_23)) == 0
assert states_23[0]["current_epoch"] == 1
assert states_23[0]["current_epoch"] == 2

assert not dataloader.restore


def test_combined_dataset_dataloader_states_without_any_iterations(combined_dataset):
dataloader = StreamingDataLoader(combined_dataset, batch_size=4)
assert not dataloader.restore
dataloader.load_state_dict(dataloader.state_dict())
assert not dataloader.restore


@pytest.mark.timeout(120)
@pytest.mark.parametrize("num_workers", [0, 2, 4])
def test_combined_dataset_dataloader_states_complete_iterations(combined_dataset, num_workers):
print(f"Testing with num_workers={num_workers}")
dataloader = StreamingDataLoader(combined_dataset, batch_size=4, num_workers=num_workers)
assert len(dataloader) == 25, "Dataloader length should be 25 (50+50 items / batch size 4)"

# Verify dataloader state after complete last iteration
for batch in dataloader:
assert dataloader.current_epoch == 1, "Current epoch should be 1"
pass

dataloader.load_state_dict(dataloader.state_dict())
assert not dataloader.restore

for batch in dataloader:
assert dataloader.current_epoch == 2, "Current epoch should be 2"
pass

assert not dataloader.restore

del dataloader


@pytest.mark.timeout(300)
@pytest.mark.parametrize(("num_workers", "break_at"), [(0, 10), (0, 15), (2, 10), (2, 15), (4, 10), (4, 15)])
def test_combined_dataset_dataloader_states_partial_iterations(combined_dataset, num_workers, break_at):
print(f"Testing with num_workers={num_workers}, break_at={break_at}")

# Verify dataloader state after partial last iteration
dataloader = StreamingDataLoader(combined_dataset, batch_size=4, num_workers=num_workers)

total_batches = len(dataloader)
assert total_batches == 25, "Dataloader length should be 25 (100 items / batch size 4)"

assert not dataloader.restore, "Dataloader should not be in restore state initially."

# Partial iteration up to 'break_at'
for batch_idx, batch in enumerate(dataloader):
assert dataloader.current_epoch == 1, "Current epoch should be 1 during first iteration"
if batch_idx == break_at:
break

assert (
not dataloader.restore
), "Dataloader should not be in restore state after partial iteration, before loading state."
dataloader.load_state_dict(dataloader.state_dict())
assert dataloader.restore, "Dataloader should be in restore state after loading the state from a partial iteration."

# Verify remaining batches in the first epoch
count = 0
for _ in dataloader:
assert dataloader.current_epoch == 1, "Current epoch should be 1 during restore"
count += 1
expected_batches = total_batches - break_at - 1
assert (
count >= expected_batches
), f"There should be at least{expected_batches} remaining batches in the first epoch."
assert not dataloader.restore, "Dataloader should not be in restore state after completing first epoch."

# Verify batches in the second epoch
samples_yielded = 0
for batch in dataloader:
assert dataloader.current_epoch == 2, "Current epoch should be 2 in the second iteration"
samples_yielded += len(batch)
assert samples_yielded == len(combined_dataset), "All samples should be yielded in the second epoch."
2 changes: 1 addition & 1 deletion tests/streaming/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_streaming_dataloader():

assert dataloader.state_dict() == {
"dataset": {"0": {"counter": 10}, "1": {"counter": 9}},
"current_epoch": 0,
"current_epoch": 1,
"latest_worker_idx": 0,
"num_samples_yielded": {0: [10, 9]},
}
Expand Down
Loading