Skip to content

Commit

Permalink
Fix actor pool in python 3.11, add better scaling down logic (#760)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh authored Oct 10, 2024
1 parent 6499656 commit 074477f
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 14 deletions.
3 changes: 3 additions & 0 deletions config/data/openwebtext_source.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@ validation_urls:
- "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz"
cache_dir: "gs://levanter-data/tokenized/openwebtext/"
tokenizer: "gpt2"
cache_options:
batch_size: 1024
num_shard_groups: 64
3 changes: 0 additions & 3 deletions src/levanter/store/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,11 +1061,8 @@ def _write_batches(writer: ShardedCacheWriter, shard_totals, batches, finished_s


def _fetch_batches(batches) -> tuple[dict[str, int], list[PreparedBatch]]:
time_in = time.time()
shards_for_batches, payloads_for_batches = zip(*batches)
payloads_for_batches = ray.get(list(payloads_for_batches))
time_out = time.time()
logger.info(f"Fetched {len(batches)} batches in {time_out - time_in} seconds")

shard_row_totals: dict[str, int] = {}
for shard, payload in zip(shards_for_batches, payloads_for_batches):
Expand Down
48 changes: 37 additions & 11 deletions src/levanter/utils/actor_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
# https://github.com/ray-project/ray/blob/1bab09bf842edee51c3778be4cfb16f8b900d764/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py


def _wrap_ray_future(ray_future):
# work around https://github.com/ray-project/ray/issues/45895#issuecomment-2165164129
return asyncio.wrap_future(ray_future.future())


class AutoScalingActorPool:
"""Utility class to operate on a dynamically scaling pool of actors."""

Expand All @@ -37,6 +42,7 @@ def __init__(
self._actor_locations: Dict[ray.actor.ActorHandle, str] = {}
self._tasks_waiting_for_actor: list[asyncio.Future] = []
self._next_task_id = 0
self._scale_down_task: Optional[asyncio.Task] = None

self._scale_up(self._min_size)

Expand All @@ -45,14 +51,17 @@ def num_pending_tasks(self):
return len(self._tasks_waiting_for_actor)

def _scale_up(self, num_actors: int):
if self._scale_down_task and not self._scale_down_task.done():
self._scale_down_task.cancel()

for _ in range(num_actors):
try:
actor = self._create_actor_fn()
ready_ref = actor.get_location.remote()
self._pending_actors[ready_ref] = actor

async def wait_for_ready(actor, ready_ref):
loc = await ready_ref
loc = await _wrap_ray_future(ready_ref)
# pending -> floating
if ready_ref not in self._pending_actors:
logger.info("Actor was cancelled before it was ready.")
Expand All @@ -67,8 +76,8 @@ async def wait_for_ready(actor, ready_ref):
except Exception as e:
logger.error("Failed to create actor.", exc_info=e)

def _scale_down(self, num_actors: int):
for _ in range(num_actors):
def _scale_down(self, target_num_actors: int):
while len(self._idle_actors) + len(self._pending_actors) > target_num_actors:
if self._pending_actors:
actor = self._pending_actors.popitem()[1]
# let it die through gc
Expand Down Expand Up @@ -102,10 +111,20 @@ def _adjust_pool_size(self):
f" {self._max_size}"
)
self._scale_up(min(self._max_size - num_busy_actors, num_pending_tasks))

# Schedule scale down if idle
elif num_pending_tasks == 0 and num_nonworking_actors > self._min_size:
return # never scal edown. too many issues
logger.info(f"Scaling down due to no pending tasks. Current pool size: {total_actors}")
self._scale_down(num_nonworking_actors - self._min_size)
if self._scale_down_task is None or self._scale_down_task.done():
self._scale_down_task = asyncio.create_task(self._schedule_scale_down())

async def _schedule_scale_down(self):
try:
await asyncio.sleep(10)
if self.num_pending_tasks == 0:
logger.info("Scaling down due to no pending tasks.")
self._scale_down(self._min_size)
except asyncio.CancelledError:
logger.info("Scale down task was cancelled due to new activity.")

def _get_object_location(self, obj_ref: ray.ObjectRef) -> Optional[str]:
"""Get the location of the given object reference."""
Expand Down Expand Up @@ -153,10 +172,11 @@ def _assign_task_to_actor(self, actor, fn, value):
# floating -> busy
ray_future = fn(actor, value)
self._busy_actors[ray_future] = actor
if self._scale_down_task and not self._scale_down_task.done():
self._scale_down_task.cancel()
self._adjust_pool_size()

# return ray_future
return asyncio.ensure_future(self._wrap_ray_future(ray_future))
return asyncio.ensure_future(self._set_up_actor_return_on_finished(ray_future))

async def _enqueue_pending_task(self, fn, obj_ref, value, actor_future):
actor = await actor_future
Expand All @@ -181,10 +201,11 @@ def _maybe_start_pending_task(self, actor):
assigned = False
return assigned

async def _wrap_ray_future(self, ray_future):
await asyncio.wait([ray_future])
async def _set_up_actor_return_on_finished(self, ray_future):
future = _wrap_ray_future(ray_future)
await asyncio.wait([future])
self._on_task_done(ray_future)
return await ray_future
return await future

def _on_task_done(self, ray_future):
actor = self._busy_actors.pop(ray_future)
Expand Down Expand Up @@ -218,6 +239,11 @@ def push(self, actor: "ray.actor.ActorHandle"):
self._actor_locations[actor] = location
self._maybe_start_pending_task(actor)

def __del__(self):
if self._scale_down_task and not self._scale_down_task.done():
self._scale_down_task.cancel()
# just let ray kill the actors naturally


class PoolWorkerBase(ABC):
def get_location(self) -> str:
Expand Down

0 comments on commit 074477f

Please sign in to comment.