Skip to content

Commit

Permalink
manager: expand API to include errors, participant information and nu…
Browse files Browse the repository at this point in the history
…meric test (#19)

* manager: added participant information

* manager: error reporting APIs and numerics test
  • Loading branch information
d4l3k authored Nov 28, 2024
1 parent 4ad6f1b commit 63ee40c
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 23 deletions.
108 changes: 88 additions & 20 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import uuid
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import Dict, List, Optional
from typing import Callable, Dict, List, Optional, TYPE_CHECKING

import torch
from torch.distributed import PrefixStore, ReduceOp, TCPStore, Work
Expand All @@ -42,6 +42,9 @@
# pyre-fixme[21]: can't find rust module
from torchft.torchft import Manager as _Manager, ManagerClient

if TYPE_CHECKING:
from torchft.process_group import ProcessGroup

logger: logging.Logger = logging.getLogger(__name__)

MANAGER_ADDR_KEY: str = "manager_addr"
Expand All @@ -58,9 +61,9 @@ class Manager:

def __init__(
self,
pg,
load_state_dict,
state_dict,
pg: "ProcessGroup",
load_state_dict: Callable[[object], None],
state_dict: Callable[[], object],
min_replica_size: int,
port: int = MANAGER_DEFAULT_PORT,
use_async_quorum: bool = True,
Expand Down Expand Up @@ -175,15 +178,14 @@ def allreduce_grad(self, grad: torch.Tensor) -> torch.futures.Future[torch.Tenso
Returns:
a Future that will be completed with the allreduced gradient
"""
if self._errored:
if self.errored():
fut = torch.futures.Future()
fut.set_result(grad)
return fut

self._quorum_future.result()

if self._healing:
assert self._use_async_quorum
if not self.is_participating():
grad.zero_()

# TODO: increase timeout when waiting when healing
Expand All @@ -193,38 +195,81 @@ def allreduce_grad(self, grad: torch.Tensor) -> torch.futures.Future[torch.Tenso
work = self._pg.allreduce([grad], ReduceOp.SUM)
fut = work.get_future()

# schedule error handling and grad normalization as a continuation
# schedule grad normalization as a continuation
# on the Future
def callback(
fut: torch.futures.Future[List[torch.Tensor]],
) -> torch.futures.Future[torch.Tensor]:
nonlocal grad

try:
val = fut.value()
except Exception:
logger.exception(
"got exception in all reduce future -- skipping remaining"
)
self._errored = True
return grad
fut.value()

grad /= self._participating_replicas
grad /= self.num_participants()

return grad

fut = fut.then(callback)
self._pending_work.append(fut)
fut = self.wrap_future(fut, grad)
return fut

except Exception as e:
logger.exception("got exception in all reduce -- skipping remaining")
self._errored = True
logger.exception(f"got exception in all reduce -- skipping remaining: {e}")
self.report_error()

fut = torch.futures.Future()
fut.set_result(grad)
return fut

def report_error(self) -> None:
"""
Report an error to the manager.
This will cause the manager to skip the current step and will be
reconfigured on the next step.
This should be called when an error occurs that leads to a corrupted
gradient that needs to be discarded.
"""
self._errored = True

def errored(self) -> bool:
"""
Get whether an error has occurred.
Returns:
whether an error has occurred
"""
return self._errored

def wrap_future(self, fut: torch.futures.Future[object], default: object) -> None:
"""
Wrap a Future and swallow any errors that occur and report them to the manager.
If an error occurs, the Future will be completed with the default value.
Args:
fut: the Future to wrap
default: the default value to complete the Future with if an error occurs
"""

# schedule error handling and grad normalization as a continuation
# on the Future
def callback(
fut: torch.futures.Future[List[torch.Tensor]],
) -> torch.futures.Future[torch.Tensor]:
nonlocal default

try:
return fut.value()
except Exception as e:
logger.exception(f"got exception in future -- skipping remaining: {e}")
self.report_error()
return default

fut = fut.then(callback)
self._pending_work.append(fut)
return fut

def step(self) -> None:
"""
.. note::
Expand Down Expand Up @@ -411,3 +456,26 @@ def batches_committed(self) -> int:
the total number of batches committed
"""
return self._batches_committed

def num_participants(self) -> int:
"""
Get the number of participants in the current quorum.
This is the number of replicas participating in the current step.
Returns:
the number of participants in the current quorum
"""
return self._participating_replicas

def is_participating(self) -> bool:
"""
Get whether this replica is participating in the current quorum.
Returns:
whether this replica is participating in the current quorum
"""
if self._healing:
assert self._use_async_quorum
return False
return True
47 changes: 44 additions & 3 deletions torchft/manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
# LICENSE file in the root directory of this source tree.

from unittest import TestCase
from unittest.mock import patch, create_autospec, MagicMock
from unittest.mock import create_autospec, MagicMock, patch

import torch
from torch.distributed import TCPStore
from torchft.manager import Manager, MANAGER_ADDR_KEY
from torchft.process_group import _DummyWork, ProcessGroup

from torchft.torchft import ManagerClient
from torchft.manager import Manager, MANAGER_ADDR_KEY
from torchft.process_group import ProcessGroup


class TestManager(TestCase):
Expand Down Expand Up @@ -129,6 +129,8 @@ def test_quorum_heal_sync(self, client_mock) -> None:
manager.step()
manager.allreduce_grad(torch.tensor([1.0])).wait()
self.assertFalse(manager._healing)
self.assertTrue(manager.is_participating())
self.assertEqual(manager.num_participants(), 2)
self.assertTrue(manager.should_commit())

self.assertEqual(manager._quorum_id, 123)
Expand Down Expand Up @@ -164,6 +166,8 @@ def test_quorum_heal_async_not_enough_participants(self, client_mock) -> None:
manager.step()
manager._quorum_future.result()
self.assertTrue(manager._healing)
self.assertFalse(manager.is_participating())
self.assertEqual(manager.num_participants(), 1)

grad = torch.tensor([1.0])
manager.allreduce_grad(grad).wait()
Expand Down Expand Up @@ -307,3 +311,40 @@ def test_allreduce_error(self, client_mock) -> None:
manager.step()
manager.allreduce_grad(torch.tensor([1.0])).wait()
self.assertTrue(manager.should_commit())

@patch("torchft.manager.ManagerClient", autospec=True)
def test_manager_report_error(self, client_mock) -> None:
manager = self._create_manager()

self.assertFalse(manager.errored())
manager.report_error()
self.assertTrue(manager.errored())

@patch("torchft.manager.ManagerClient", autospec=True)
def test_manager_wrap_future(self, client_mock) -> None:
manager = self._create_manager()

self.assertFalse(manager.errored())

fut = torch.futures.Future()
wrapped_fut = manager.wrap_future(fut, 2)

fut.set_exception(RuntimeError("injected failure"))

self.assertEqual(wrapped_fut.value(), 2)
self.assertTrue(manager.errored())
self.assertEqual(manager._pending_work, [wrapped_fut])

@patch("torchft.manager.ManagerClient", autospec=True)
def test_manager_numerics(self, client_mock) -> None:
manager = self._create_manager()

manager._quorum_future = MagicMock()
manager._participating_replicas = 5
self.assertEqual(manager.num_participants(), 5)
manager._pg.allreduce.return_value = _DummyWork(None)

fut = torch.futures.Future()
fut = manager.allreduce_grad(torch.tensor([1.0]))
result = fut.value()
torch.testing.assert_close(result, torch.tensor([1.0 / 5]))

0 comments on commit 63ee40c

Please sign in to comment.