Skip to content

Commit

Permalink
cache fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Nov 19, 2024
1 parent 98e9296 commit 688ab49
Showing 1 changed file with 72 additions and 45 deletions.
117 changes: 72 additions & 45 deletions src/levanter/store/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,15 +899,15 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger):
# This enables us to expose data quickly
first_group = next(iter(shard_groups), None)

for group_name, shards in shard_groups.items():
for group_name, group_shards in shard_groups.items():
if group_name == first_group:
group_out_path = cache_dir
group_cache_path = cache_dir
else:
group_out_path = os.path.join(temporary_cache_path, group_name)
group_cache_path = os.path.join(temporary_cache_path, group_name)

group_cache_paths[group_name] = group_out_path
group_cache_paths[group_name] = group_cache_path

ledger = _try_load(group_out_path)
ledger = _try_load(group_cache_path)
group_ledgers[group_name] = ledger

if ledger is not None:
Expand All @@ -928,12 +928,28 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger):
retry_exceptions=True,
max_retries=10,
)
.remote(group_out_path, source_ref, shards, processor_ref, options, report_fn_to_use, parent)
.remote(
group_cache_path=group_cache_path,
source=source_ref,
shards=group_shards,
processor=processor_ref,
options=options,
report_fn=report_fn_to_use,
# don't finalize the first group b/c we write it directly to the output cache
force_unfinalized=group_name == first_group,
)
)

write_refs[group_name] = ref

ledger = _start_copies(
# wait for the first group to finish
if first_group is not None:
logger.info(f"Waiting for first group {first_group} to finish")
ray.get(write_refs[first_group])

logger.info(f"First group {first_group} finished. Copying other groups into permanent cache.")

ledger = _copy_temp_caches_to_final_cache(
parent,
cache_dir,
shard_groups,
Expand All @@ -949,29 +965,30 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger):
ledger._serialize_and_commit(cache_dir)
ray.get(parent._notify_updated_ledger.remote(ledger))

temporary_cache_paths = set(group_cache_paths.values()) - {cache_dir}
_clean_up_temp_caches(temporary_cache_paths)
_clean_up_temp_caches(temporary_cache_path)


def _clean_up_temp_caches(paths):
for path in paths:
if fsspec_exists(path):
for i in range(10):
# this is crashy for some reason
try:
fsspec_remove(path, recursive=True)
break
except Exception:
logger.exception(f"Failed to remove {path} on attempt {i}")
time.sleep(1)
def _clean_up_temp_caches(path):
if fsspec_exists(path):
for i in range(10):
# this is crashy for some reason
try:
fsspec_remove(path, recursive=True)
break
except Exception:
logger.exception(f"Failed to remove {path} on attempt {i}")
time.sleep(1)


def _assign_shards_to_groups(source: ShardedDataSource, num_groups: int | None) -> dict[str, Sequence[str]]:
if num_groups is None or num_groups >= len(source.shard_names):
return {shard_name: [shard_name] for shard_name in source.shard_names}

if num_groups <= 0:
raise ValueError("num_groups must be a positive integer.")

shard_names = source.shard_names
num_shards_per_group = (len(shard_names)) // num_groups
num_shards_per_group = len(shard_names) // num_groups
num_groups_with_extra = len(shard_names) % num_groups

# if we have a remainder, we want to distribute the extra shards evenly
Expand All @@ -982,13 +999,12 @@ def _assign_shards_to_groups(source: ShardedDataSource, num_groups: int | None)
out_groups[f"group_{i}"] = list(shard_names[start : start + num_shards])
start += num_shards

# make sure we got all the shards
assert sum(len(shards) for shards in out_groups.values()) == len(shard_names)
assert sum(len(shards) for shards in out_groups.values()) == len(shard_names), "Mismatch in shard assignment."

return out_groups # type: ignore


def _merge_ledgers(dest, source):
def _merge_ledgers(dest: CacheLedger, source: CacheLedger):
assert not dest.is_finished
dest.total_num_rows += source.total_num_rows
for shard, rows in source.shard_rows.items():
Expand All @@ -1003,7 +1019,7 @@ def _merge_ledgers(dest, source):
return dest


def _start_copies(
def _copy_temp_caches_to_final_cache(
parent,
cache_dir,
shard_groups,
Expand Down Expand Up @@ -1033,7 +1049,7 @@ def _start_copies(
"""
# This logic is a bit hairy thanks to resumes.
# First, note that each TreeCache is a tree of JaggedArrayStores, and we need to copy each of these
# separately. We also need to update the ledger as we go.
# separately.
# Second, note that JaggedArrayStores have two notions of length: the number of rows, and the data size.
# We store the number of rows in offsets[0], and the data size in offsets[offsets[0]], which is just the final offset.
# So we can keep a cache "locked" to a particular read size until we're ready by controlling the offsets.
Expand Down Expand Up @@ -1068,6 +1084,12 @@ def _start_copies(
_ProgressReport(new_shards=len(overall_ledger.finished_shards), new_rows=overall_ledger.total_num_rows)
)

parent._report_copy_progress.remote(
_ProgressReport(new_shards=len(overall_ledger.finished_shards), new_rows=overall_ledger.total_num_rows)
)

found_one_to_copy = False

for group in shard_groups:
# first make sure it's either done this run or already done
if write_refs.get(group) is not None:
Expand Down Expand Up @@ -1121,11 +1143,11 @@ def _start_copies(

# commit each group in order. We need to do two things:
# "commit" here means:
# 1. update the data offsets in the permanent cache with the number of rows avaialble from all groups so far
# 1. update the data offsets in the permanent cache with the number of rows available from all groups so far
# 2. update the ledger with the combined information from all groups so far
num_available_rows = overall_ledger.total_num_rows
for group, ref in copy_refs.items():
ray.get(ref) # block on copy
ray.get(ref) # block on data copy

group_ledger = group_ledgers[group]
num_available_rows += group_ledger.total_num_rows
Expand All @@ -1136,8 +1158,9 @@ def _start_copies(
overall_ledger._serialize_and_commit(cache_dir)
ray.get(parent._notify_updated_ledger.remote(overall_ledger))
parent._report_copy_progress.remote(
_ProgressReport(new_shards=len(group_ledger.shard_rows), new_rows=group_ledger.total_num_rows)
_ProgressReport(new_shards=len(group_ledger.finished_shards), new_rows=group_ledger.total_num_rows)
)
logger.info(f"Group {group} copied. Updating ledger.")

return overall_ledger

Expand All @@ -1152,7 +1175,12 @@ def _expose_available_rows(permanent_cache, num_available_rows):
future.result()


@ray.remote(num_cpus=4, memory=6 * 1024 * 1024 * 1024)
@ray.remote(
num_cpus=4,
memory=6 * 1024 * 1024 * 1024,
runtime_env=RuntimeEnv(env_vars={"JAX_PLATFORMS": "cpu"}),
scheduling_strategy="SPREAD",
)
def _copy_cache_data(dest_path, source_path, processor, data_offset_tree, rows_so_far, parent):
"""
Copies the data from one cache to another, appending it to the end of the destination cache.
Expand Down Expand Up @@ -1195,8 +1223,7 @@ async def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArra
data = source_array.data[0:data_size]
futures: list[ts.Future] = []

# my experience is that naively doing this in one go
# can OOM and lock up the machines, so we break it up
# To prevent OOM, copy in smaller batches
MAX_ELEMS = 1024 * 1024 * 1024
f = await _copy_in_batches(dest_array.data, data_offset, data, data_size, MAX_ELEMS)
if f is not None:
Expand All @@ -1205,33 +1232,32 @@ async def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArra
if source_array.shapes is not None:
source_shapes = source_array.shapes[0:source_num_rows]
async with ts.Transaction() as txn:
dest = dest_array.shapes
assert dest is not None
dest_shapes = dest_array.shapes
assert dest_shapes is not None
out_end = row_offset + source_num_rows
shape_future = dest.with_transaction(txn)[row_offset:out_end].write(source_shapes)
shape_future = dest_shapes.with_transaction(txn)[row_offset:out_end].write(source_shapes)
futures.append(shape_future)

source_offsets = source_array.offsets[1 : source_num_rows + 1][ts.d[:].translate_to[0]]
source_offsets = _virtual_offset(source_offsets, data_offset)

async with ts.Transaction() as txn:
dest = dest_array.offsets
dest_offsets = dest_array.offsets
out_end = row_offset + 1 + source_num_rows
offset_future = dest.with_transaction(txn)[row_offset + 1 : out_end].write(source_offsets)
offset_future = dest_offsets.with_transaction(txn)[row_offset + 1 : out_end].write(source_offsets)

futures.append(offset_future)

out = await asyncio.gather(*futures)
return out
await asyncio.gather(*futures)

futures = jax.tree.map(_copy_one_array, dest.tree, source.tree, data_offset_tree)

await asyncio.gather(*jax.tree.leaves(futures))
logger.info(f"Finished copying data from {source_path} to {dest_path}.")

return source_num_rows
except Exception:
logger.exception(f"Failed to copy data from {source_path} to {dest_path}.")
except Exception as e:
logger.exception(f"Failed to copy data from {source_path} to {dest_path}: {e}")
raise


Expand Down Expand Up @@ -1272,13 +1298,13 @@ async def do_read(domain: ts.IndexDomain, array: np.ndarray, read_params: ts.Vir
@dataclass
class _ProgressReport:
new_rows: int = 0
new_bytes: float = 0
new_bytes: int = 0
new_shards: int = 0
# TODO: other counts


def _tokenize_one_shard_group(
temporary_cache_path: str,
group_cache_path: str,
source: ShardedDataSource,
shards: list[str],
processor: BatchProcessor,
Expand All @@ -1297,13 +1323,14 @@ def _tokenize_one_shard_group(
# we encounter significant overhead just parsing the shard names from the json
source = _RestrictedShardedDataSource(source, shards)

ledger = CacheLedger.load_or_initialize(temporary_cache_path, source, processor)
ledger = CacheLedger.load_or_initialize(group_cache_path, source, processor)

if ledger.is_finished:
report_fn(_ProgressReport(new_rows=ledger.total_num_rows, new_shards=len(ledger.finished_shards)), ledger)
logger.info("Shard group already processed.")
return ledger

writer = ShardGroupCacheWriter(temporary_cache_path, ledger, shards, processor.output_exemplar)
writer = ShardGroupCacheWriter(group_cache_path, ledger, shards, processor.output_exemplar)

total_rows = ledger.total_num_rows
found_shard_with_rows = False
Expand Down

0 comments on commit 688ab49

Please sign in to comment.