From fec5cb4c9fc832e44e0960f2c8161a2e81ceb16e Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 18 Nov 2024 16:25:09 -0800 Subject: [PATCH] move this to another branch --- src/levanter/store/cache.py | 115 ++++++++++++++---------------------- 1 file changed, 45 insertions(+), 70 deletions(-) diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index 05c5a9d18..fda6fdb0c 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -899,15 +899,15 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): # This enables us to expose data quickly first_group = next(iter(shard_groups), None) - for group_name, group_shards in shard_groups.items(): + for group_name, shards in shard_groups.items(): if group_name == first_group: - group_cache_path = cache_dir + group_out_path = cache_dir else: - group_cache_path = os.path.join(temporary_cache_path, group_name) + group_out_path = os.path.join(temporary_cache_path, group_name) - group_cache_paths[group_name] = group_cache_path + group_cache_paths[group_name] = group_out_path - ledger = _try_load(group_cache_path) + ledger = _try_load(group_out_path) group_ledgers[group_name] = ledger if ledger is not None: @@ -928,28 +928,12 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): retry_exceptions=True, max_retries=10, ) - .remote( - group_cache_path=group_cache_path, - source=source_ref, - shards=group_shards, - processor=processor_ref, - options=options, - report_fn=report_fn_to_use, - # don't finalize the first group b/c we write it directly to the output cache - force_unfinalized=group_name == first_group, - ) + .remote(group_out_path, source_ref, shards, processor_ref, options, report_fn_to_use, parent) ) write_refs[group_name] = ref - # wait for the first group to finish - if first_group is not None: - logger.info(f"Waiting for first group {first_group} to finish") - ray.get(write_refs[first_group]) - - logger.info(f"First group {first_group} finished. Copying other groups into permanent cache.") - - ledger = _copy_temp_caches_to_final_cache( + ledger = _start_copies( parent, cache_dir, shard_groups, @@ -965,30 +949,29 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): ledger._serialize_and_commit(cache_dir) ray.get(parent._notify_updated_ledger.remote(ledger)) - _clean_up_temp_caches(temporary_cache_path) + temporary_cache_paths = set(group_cache_paths.values()) - {cache_dir} + _clean_up_temp_caches(temporary_cache_paths) -def _clean_up_temp_caches(path): - if fsspec_exists(path): - for i in range(10): - # this is crashy for some reason - try: - fsspec_remove(path, recursive=True) - break - except Exception: - logger.exception(f"Failed to remove {path} on attempt {i}") - time.sleep(1) +def _clean_up_temp_caches(paths): + for path in paths: + if fsspec_exists(path): + for i in range(10): + # this is crashy for some reason + try: + fsspec_remove(path, recursive=True) + break + except Exception: + logger.exception(f"Failed to remove {path} on attempt {i}") + time.sleep(1) def _assign_shards_to_groups(source: ShardedDataSource, num_groups: int | None) -> dict[str, Sequence[str]]: if num_groups is None or num_groups >= len(source.shard_names): return {shard_name: [shard_name] for shard_name in source.shard_names} - if num_groups <= 0: - raise ValueError("num_groups must be a positive integer.") - shard_names = source.shard_names - num_shards_per_group = len(shard_names) // num_groups + num_shards_per_group = (len(shard_names)) // num_groups num_groups_with_extra = len(shard_names) % num_groups # if we have a remainder, we want to distribute the extra shards evenly @@ -999,12 +982,13 @@ def _assign_shards_to_groups(source: ShardedDataSource, num_groups: int | None) out_groups[f"group_{i}"] = list(shard_names[start : start + num_shards]) start += num_shards - assert sum(len(shards) for shards in out_groups.values()) == len(shard_names), "Mismatch in shard assignment." + # make sure we got all the shards + assert sum(len(shards) for shards in out_groups.values()) == len(shard_names) return out_groups # type: ignore -def _merge_ledgers(dest: CacheLedger, source: CacheLedger): +def _merge_ledgers(dest, source): assert not dest.is_finished dest.total_num_rows += source.total_num_rows for shard, rows in source.shard_rows.items(): @@ -1019,7 +1003,7 @@ def _merge_ledgers(dest: CacheLedger, source: CacheLedger): return dest -def _copy_temp_caches_to_final_cache( +def _start_copies( parent, cache_dir, shard_groups, @@ -1049,7 +1033,7 @@ def _copy_temp_caches_to_final_cache( """ # This logic is a bit hairy thanks to resumes. # First, note that each TreeCache is a tree of JaggedArrayStores, and we need to copy each of these - # separately. + # separately. We also need to update the ledger as we go. # Second, note that JaggedArrayStores have two notions of length: the number of rows, and the data size. # We store the number of rows in offsets[0], and the data size in offsets[offsets[0]], which is just the final offset. # So we can keep a cache "locked" to a particular read size until we're ready by controlling the offsets. @@ -1084,10 +1068,6 @@ def _copy_temp_caches_to_final_cache( _ProgressReport(new_shards=len(overall_ledger.finished_shards), new_rows=overall_ledger.total_num_rows) ) - parent._report_copy_progress.remote( - _ProgressReport(new_shards=len(overall_ledger.finished_shards), new_rows=overall_ledger.total_num_rows) - ) - found_one_to_copy = False for group in shard_groups: @@ -1143,11 +1123,11 @@ def _copy_temp_caches_to_final_cache( # commit each group in order. We need to do two things: # "commit" here means: - # 1. update the data offsets in the permanent cache with the number of rows available from all groups so far + # 1. update the data offsets in the permanent cache with the number of rows avaialble from all groups so far # 2. update the ledger with the combined information from all groups so far num_available_rows = overall_ledger.total_num_rows for group, ref in copy_refs.items(): - ray.get(ref) # block on data copy + ray.get(ref) # block on copy group_ledger = group_ledgers[group] num_available_rows += group_ledger.total_num_rows @@ -1158,9 +1138,8 @@ def _copy_temp_caches_to_final_cache( overall_ledger._serialize_and_commit(cache_dir) ray.get(parent._notify_updated_ledger.remote(overall_ledger)) parent._report_copy_progress.remote( - _ProgressReport(new_shards=len(group_ledger.finished_shards), new_rows=group_ledger.total_num_rows) + _ProgressReport(new_shards=len(group_ledger.shard_rows), new_rows=group_ledger.total_num_rows) ) - logger.info(f"Group {group} copied. Updating ledger.") return overall_ledger @@ -1175,12 +1154,7 @@ def _expose_available_rows(permanent_cache, num_available_rows): future.result() -@ray.remote( - num_cpus=4, - memory=6 * 1024 * 1024 * 1024, - runtime_env=RuntimeEnv(env_vars={"JAX_PLATFORMS": "cpu"}), - scheduling_strategy="SPREAD", -) +@ray.remote(num_cpus=4, memory=6 * 1024 * 1024 * 1024) def _copy_cache_data(dest_path, source_path, processor, data_offset_tree, rows_so_far, parent): """ Copies the data from one cache to another, appending it to the end of the destination cache. @@ -1223,7 +1197,8 @@ async def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArra data = source_array.data[0:data_size] futures: list[ts.Future] = [] - # To prevent OOM, copy in smaller batches + # my experience is that naively doing this in one go + # can OOM and lock up the machines, so we break it up MAX_ELEMS = 1024 * 1024 * 1024 f = await _copy_in_batches(dest_array.data, data_offset, data, data_size, MAX_ELEMS) if f is not None: @@ -1232,23 +1207,24 @@ async def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArra if source_array.shapes is not None: source_shapes = source_array.shapes[0:source_num_rows] async with ts.Transaction() as txn: - dest_shapes = dest_array.shapes - assert dest_shapes is not None + dest = dest_array.shapes + assert dest is not None out_end = row_offset + source_num_rows - shape_future = dest_shapes.with_transaction(txn)[row_offset:out_end].write(source_shapes) + shape_future = dest.with_transaction(txn)[row_offset:out_end].write(source_shapes) futures.append(shape_future) source_offsets = source_array.offsets[1 : source_num_rows + 1][ts.d[:].translate_to[0]] source_offsets = _virtual_offset(source_offsets, data_offset) async with ts.Transaction() as txn: - dest_offsets = dest_array.offsets + dest = dest_array.offsets out_end = row_offset + 1 + source_num_rows - offset_future = dest_offsets.with_transaction(txn)[row_offset + 1 : out_end].write(source_offsets) + offset_future = dest.with_transaction(txn)[row_offset + 1 : out_end].write(source_offsets) futures.append(offset_future) - await asyncio.gather(*futures) + out = await asyncio.gather(*futures) + return out futures = jax.tree.map(_copy_one_array, dest.tree, source.tree, data_offset_tree) @@ -1256,8 +1232,8 @@ async def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArra logger.info(f"Finished copying data from {source_path} to {dest_path}.") return source_num_rows - except Exception as e: - logger.exception(f"Failed to copy data from {source_path} to {dest_path}: {e}") + except Exception: + logger.exception(f"Failed to copy data from {source_path} to {dest_path}.") raise @@ -1298,13 +1274,13 @@ async def do_read(domain: ts.IndexDomain, array: np.ndarray, read_params: ts.Vir @dataclass class _ProgressReport: new_rows: int = 0 - new_bytes: int = 0 + new_bytes: float = 0 new_shards: int = 0 # TODO: other counts def _tokenize_one_shard_group( - group_cache_path: str, + temporary_cache_path: str, source: ShardedDataSource, shards: list[str], processor: BatchProcessor, @@ -1323,14 +1299,13 @@ def _tokenize_one_shard_group( # we encounter significant overhead just parsing the shard names from the json source = _RestrictedShardedDataSource(source, shards) - ledger = CacheLedger.load_or_initialize(group_cache_path, source, processor) + ledger = CacheLedger.load_or_initialize(temporary_cache_path, source, processor) if ledger.is_finished: - report_fn(_ProgressReport(new_rows=ledger.total_num_rows, new_shards=len(ledger.finished_shards)), ledger) logger.info("Shard group already processed.") return ledger - writer = ShardGroupCacheWriter(group_cache_path, ledger, shards, processor.output_exemplar) + writer = ShardGroupCacheWriter(temporary_cache_path, ledger, shards, processor.output_exemplar) total_rows = ledger.total_num_rows found_shard_with_rows = False