Skip to content

Commit

Permalink
Dataloader now reshuffles after each epoch
Browse files Browse the repository at this point in the history
  • Loading branch information
fraserlove committed Jul 4, 2024
1 parent 4fb2aa2 commit faaed08
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
Binary file removed .DS_Store
Binary file not shown.
17 changes: 11 additions & 6 deletions train_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def __init__(self, B, T, process_rank, num_processes, split):
self.num_processes = num_processes
self.split = split
assert split in {'train', 'val'}
self.rng = np.random.default_rng(1337)

# get the shard filenames
data_root = "edu_fineweb10B"
Expand All @@ -238,15 +239,15 @@ def load_shard(self, filename):
# split tokens into documents using the <|endoftext|> token and shuffle
eot_positions = (torch.where(shard == enc.eot_token)[0] + 1).tolist()
documents = [shard[start:end] for start, end in zip([0] + eot_positions[:-1], eot_positions)]
np.random.shuffle(documents)
self.rng.shuffle(documents)
shard = torch.cat(documents) # concatenate the documents back together
return shard

def reset(self):
# state, init at shard zero
self.current_shard = 0
if self.split == "train":
np.random.shuffle(self.shards)
self.rng.shuffle(self.shards)
self.tokens = self.load_shard(self.shards[self.current_shard])
self.current_position = self.B * self.T * self.process_rank

Expand All @@ -259,9 +260,13 @@ def next_batch(self):
self.current_position += B * T * self.num_processes
# if loading the next batch would be out of bounds, advance to next shard
if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
self.current_shard = (self.current_shard + 1) % len(self.shards)
self.tokens = self.load_shard(self.shards[self.current_shard])
self.current_position = B * T * self.process_rank
self.current_shard += 1
# reshuffle after each epoch
if self.current_shard == len(self.shards):
self.reset()
else:
self.tokens = self.load_shard(self.shards[self.current_shard])
self.current_position = B * T * self.process_rank
return x, y

# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -531,4 +536,4 @@ def get_lr(it):
f.write(f"{step} train {loss_accum.item():.6f}\n")

if ddp:
destroy_process_group()
destroy_process_group()

0 comments on commit faaed08

Please sign in to comment.