diff --git a/grapesy/test-grapesy/Test/Sanity/Disconnect.hs b/grapesy/test-grapesy/Test/Sanity/Disconnect.hs index 2243d12d..32c8ff9a 100644 --- a/grapesy/test-grapesy/Test/Sanity/Disconnect.hs +++ b/grapesy/test-grapesy/Test/Sanity/Disconnect.hs @@ -17,12 +17,11 @@ module Test.Sanity.Disconnect (tests) where import Control.Concurrent import Control.Concurrent.Async +import Control.Concurrent.STM import Control.Exception import Control.Monad import Data.ByteString.Lazy qualified as Lazy (ByteString) -import Data.Either import Data.IORef -import Data.Maybe import Data.Word import Foreign.C.Types (CInt(..)) import Network.Socket @@ -42,18 +41,19 @@ import Proto.API.Trivial import Test.Util +{------------------------------------------------------------------------------- + Top-level +-------------------------------------------------------------------------------} + tests :: TestTree tests = testGroup "Test.Sanity.Disconnect" [ testCase "client" test_clientDisconnect , testCase "server" test_serverDisconnect ] --- | We want two distinct handlers running at the same time, so we have two --- trivial RPCs -type RPC1 = Trivial' "rpc1" - --- | See 'RPC1' -type RPC2 = Trivial' "rpc2" +{------------------------------------------------------------------------------- + Disconnecting clients +-------------------------------------------------------------------------------} -- | Two separate clients make many concurrent calls, one of them disconnects. test_clientDisconnect :: Assertion @@ -93,50 +93,55 @@ test_clientDisconnect = do -- Start a client in a separate process - let numCalls = 50 - void $ forkProcess $ + let numCalls = 10 + dyingChild <- forkProcess $ Client.withConnection def serverAddress $ \conn -> do - -- Make 50 concurrent calls. 49 of them sending infinite messages. One - -- of them kills this client process after 100 messages. - mapConcurrently_ - ( Client.withRPC conn def (Proxy @RPC1) - . runSteps - ) - $ replicate (numCalls - 1) stepsInfinite ++ - [ mkClientSteps Nothing [ (100, c_exit 1) ] ] - - -- Start two more clients that make 50 calls to each handler, all calls - -- counting up to 100 - let numSteps = 100 - steps = replicate numCalls $ stepsN numSteps + inLockstep conn (Proxy @RPC1) numCalls NeverTerminate $ \results _getFinal -> do + -- Wait until we are sure that all clients have started their RPC, + -- then kill the process. This avoids race conditions and guarantees + -- that the server will see @numCalls@ clients disconnecting. + _ <- waitForHistoryOfMinLen results 1 + c_exit 1 + + -- Start two more clients; these will not disconnect + let numSteps = 5 (result1, result2) <- concurrently ( Client.withConnection def serverAddress $ \conn -> do - sum <$> mapConcurrently - ( Client.withRPC conn def (Proxy @RPC1) - . runSteps - ) - steps + inLockstep conn (Proxy @RPC1) numCalls (TerminateAfter numSteps) $ \_results getFinal -> + getFinal ) ( Client.withConnection def serverAddress $ \conn -> do - sum <$> mapConcurrently - ( Client.withRPC conn def (Proxy @RPC2) - . runSteps - ) - steps + inLockstep conn (Proxy @RPC2) numCalls (TerminateAfter numSteps) $ \_results getFinal -> + getFinal ) - -- All calls by clients in /this/ process (not the ones we killed) should - -- have finished with a result of 'countTo' - assertEqual "" - (2 * sum (replicate numCalls numSteps)) - (fromIntegral $ result1 + result2) + -- Wait for the forked process to terminate + _status <- getProcessStatus True False dyingChild - -- We should also see only 50 client disconnects for the first handler and - -- none for the second + -- All calls by clients in /this/ process (not the ones we killed) should + -- have finished normally + let expectedResult = [ + replicate numCalls (StepOk i) + | i <- reverse [1 .. numSteps] + ] + assertEqual "" expectedResult result1 + assertEqual "" expectedResult result2 + + -- We should also see only @numCalls@ client disconnects for the first + -- handler and none for the second clientDisconnects1 <- readIORef disconnectCounter1 clientDisconnects2 <- readIORef disconnectCounter2 - assertEqual "" 50 clientDisconnects1 - assertEqual "" 0 clientDisconnects2 + assertEqual "" numCalls clientDisconnects1 + assertEqual "" 0 clientDisconnects2 + +-- We need to use this to properly simulate the execution environment crashing +-- in an unrecoverable way. In particular, we don't want to give the program a +-- chance to do any of its normal exception handling/cleanup behavior. +foreign import ccall unsafe "exit" c_exit :: CInt -> IO () + +{------------------------------------------------------------------------------- + Disconnecting servers +-------------------------------------------------------------------------------} -- | Client makes many concurrent calls, server disconnects test_serverDisconnect :: Assertion @@ -161,7 +166,7 @@ test_serverDisconnect = withTemporaryFile $ \ipcFile -> do server <- Server.mkGrpcServer def [ Server.someRpcHandler $ - Server.mkRpcHandler @Trivial $ echoHandler Nothing + Server.mkRpcHandler @RPC1 $ echoHandler Nothing ] let serverConfig = ServerConfig { @@ -217,105 +222,193 @@ test_serverDisconnect = withTemporaryFile $ \ipcFile -> do connParams = def { Client.connReconnectPolicy = reconnectPolicy } Client.withConnection connParams (serverAddress port1) $ \conn -> do - -- Make 50 concurrent calls. 49 of them sending infinite messages. One - -- of them kills the server after 100 messages. - let numCalls = 50 + let numCalls = 10 results <- - mapConcurrently - ( try @Client.ServerDisconnected - . Client.withRPC conn def (Proxy @Trivial) - . runSteps - ) - $ replicate (numCalls - 1) stepsInfinite ++ - [ mkClientSteps Nothing [(100, killServer)] ] - - -- All calls should have failed - assertBool "" (null (rights results)) - assertEqual "" numCalls (length (lefts results)) + inLockstep conn (Proxy @RPC1) numCalls NeverTerminate $ \results getFinal -> do + -- Once all clients have started their RPC, kill the server + _ <- waitForHistoryOfMinLen results 1 + killServer + getFinal + + -- All calls should have failed (but we don't know in which step) + assertEqual "" numCalls $ length $ filter stepFailed (concat results) -- New calls should succeed (after reconnection) killRestarted <- takeMVar signalRestart result <- - Client.withRPC conn def (Proxy @Trivial) $ - runSteps (stepsN numCalls) - assertEqual "" numCalls (fromIntegral result) + inLockstep conn (Proxy @RPC1) numCalls (TerminateAfter 1) $ \_results getFinal -> + getFinal + + let expectedResult = [replicate numCalls $ StepOk 1] + assertEqual "" expectedResult result -- Do not leave the server process hanging around killRestarted {------------------------------------------------------------------------------- - Client and handler functions + Auxiliary: echo handler -------------------------------------------------------------------------------} --- | Execute the client steps -runSteps :: forall rpc. - ( Input rpc ~ Lazy.ByteString - , Output rpc ~ Lazy.ByteString - , ResponseTrailingMetadata rpc ~ NoMetadata - ) => ClientStep -> Client.Call rpc -> IO Word64 -runSteps = - go 0 - where - go :: Word64 -> ClientStep -> Client.Call rpc -> IO Word64 - go n step call = do - case step of - KeepGoing mact next -> do - fromMaybe (return ()) mact - Binary.sendNextInput @Word64 call n - _ <- Binary.recvNextOutput @Word64 call - go (n + 1) next call - Done -> do - Binary.sendFinalInput @Word64 call n - (_, NoMetadata) <- Binary.recvFinalOutput @Word64 call - return n - -- | Echos any input echoHandler :: - ( Input rpc ~ Lazy.ByteString - , Output rpc ~ Lazy.ByteString - , ResponseTrailingMetadata rpc ~ NoMetadata - ) => Maybe (IORef Int) -> Server.Call rpc -> IO () -echoHandler disconnectCounter call = trackDisconnects disconnectCounter $ do - Binary.recvInput @Word64 call >>= \case - StreamElem n -> do - Binary.sendNextOutput @Word64 call n - echoHandler disconnectCounter call - FinalElem n _ -> do - Binary.sendFinalOutput @Word64 call (n, NoMetadata) - NoMoreElems _ -> do - Server.sendTrailers call NoMetadata + TrivialRpc rpc + => Maybe (IORef Int) + -> Server.Call rpc -> IO () +echoHandler disconnectCounter call = + trackDisconnects disconnectCounter $ loop where - trackDisconnects Nothing = - id + loop :: IO () + loop = do + inp <- Binary.recvInput @Word64 call + case inp of + StreamElem n -> Binary.sendNextOutput @Word64 call n >> loop + FinalElem n _ -> Binary.sendFinalOutput @Word64 call (n, NoMetadata) + NoMoreElems _ -> Server.sendTrailers call NoMetadata + + trackDisconnects :: Maybe (IORef Int) -> IO () -> IO () + trackDisconnects Nothing = id trackDisconnects (Just counter) = - handle ( - \(_e :: Server.ClientDisconnected) -> - atomicModifyIORef' counter $ \n -> (n + 1, ()) - ) + handle $ \(_e :: Server.ClientDisconnected) -> + atomicModifyIORef' counter $ \n -> (n + 1, ()) {------------------------------------------------------------------------------- - Auxiliary + Bunch of clients all executing in lockstep -------------------------------------------------------------------------------} --- We need to use this to properly simulate the execution environment crashing --- in an unrecoverable way. In particular, we don't want to give the program a --- chance to do any of its normal exception handling/cleanup behavior. -foreign import ccall unsafe "exit" c_exit :: CInt -> IO () - -data ClientStep = KeepGoing (Maybe (IO ())) ClientStep | Done +data NumSteps = TerminateAfter Int | NeverTerminate + +data Results = Results { + -- | Results for the current step + resultsCurr :: TVar [StepResult] + + -- | Number of the current step + , resultsStep :: Int + + -- | Previous results (in reverse order) + , resultsHist :: [[StepResult]] + } + +data StepResult = StepOk Int | StepFailed SomeException + deriving stock (Show) + +stepFailed :: StepResult -> Bool +stepFailed StepOk{} = False +stepFailed StepFailed{} = True + +instance Eq StepResult where + StepOk i == StepOk i' = i == i' + StepFailed _ == StepFailed _ = True -- the exception is merely for debugging + StepOk _ == StepFailed _ = False + StepFailed _ == StepOk _ = False + +initResults :: IO (TVar Results) +initResults = do + resultsCurr <- newTVarIO [] + newTVarIO Results{ + resultsCurr + , resultsStep = 1 + , resultsHist = [] + } + +-- | Keep collecting results (never terminates) +collectResults :: Int -> TVar Results -> IO a +collectResults numClients results = + forever $ + atomically $ do + Results{resultsCurr, resultsStep, resultsHist} <- readTVar results + current <- readTVar resultsCurr + if length current < numClients + then retry + else do + current' <- newTVar [] + writeTVar results Results{ + resultsCurr = current' + , resultsStep = succ resultsStep + , resultsHist = current : resultsHist + } -mkClientSteps :: Maybe Int -> [(Int, IO ())] -> ClientStep -mkClientSteps = go 0 +-- | Get the 'TVar' for the specified step, blocking until that step is reached +-- +-- This is executed by each client on each step. As a result, we can assume that +-- the required step can never be /before/ the current step (because all clients +-- must deliver their result for the current step before the step advances). +waitForStep :: TVar Results -> Int -> IO (TVar [StepResult]) +waitForStep results step = atomically $ do + Results{resultsCurr, resultsStep} <- readTVar results + if resultsStep < step + then retry + else return resultsCurr + +-- | Wait until a history of at least the specified length is ready +waitForHistoryOfMinLen :: TVar Results -> Int -> IO [[StepResult]] +waitForHistoryOfMinLen results numSteps = atomically $ do + Results{resultsHist} <- readTVar results + if length resultsHist < numSteps + then retry + else return resultsHist + +inLockstep :: forall rpc a. + TrivialRpc rpc + => Client.Connection -- ^ Server to connect to + -> Proxy rpc -- ^ Method to call + -> Int -- ^ Number of clients + -> NumSteps -- ^ How many steps each client should take + -> (TVar Results -> IO [[StepResult]] -> IO a) + -- ^ Monitor the results + -- + -- This is also passed a function to get the /final/ results, after all + -- clients have terminated. If some clients never terminate, this function + -- will block indefinitely. + -> IO a +inLockstep conn rpc numClients numSteps monitor = do + results <- initResults + withAsync (collectResults numClients results) $ \_ -> + withAsync (runClients results) $ \clients -> + monitor results (wait clients) where - go !i mn acts - | maybe False (i >=) mn - = Done - | otherwise - = KeepGoing (lookup i acts) $ go (i + 1) mn acts - -stepsN :: Int -> ClientStep -stepsN n = mkClientSteps (Just n) [] - -{-# INLINE stepsInfinite #-} -stepsInfinite :: ClientStep -stepsInfinite = mkClientSteps Nothing [] + runClients :: TVar Results -> IO [[StepResult]] + runClients results = do + replicateConcurrently_ numClients $ + Client.withRPC conn def rpc (client results) + resultsHist <$> readTVarIO results + + client :: TVar Results -> Client.Call rpc -> IO () + client results call = loop 1 + where + loop :: Int -> IO () + loop n = do + current <- waitForStep results n + handle (recordException current) $ + case numSteps of + TerminateAfter n' | n == n' -> do + Binary.sendFinalInput call n + (resp, NoMetadata) <- Binary.recvFinalOutput call + atomically $ modifyTVar current (StepOk resp:) + _otherwise -> do + Binary.sendNextInput call n + resp <- Binary.recvNextOutput call + atomically $ modifyTVar current (StepOk resp:) + loop (succ n) + + recordException :: TVar [StepResult] -> SomeException -> IO () + recordException current e = + atomically $ modifyTVar current (StepFailed e:) + +{------------------------------------------------------------------------------- + Auxiliary: trivial RPCs + + We want two distinct handler so we have two trivial RPCs. +-------------------------------------------------------------------------------} + +type TrivialRpc rpc = ( + SupportsClientRpc rpc + , Input rpc ~ Lazy.ByteString + , Output rpc ~ Lazy.ByteString + , RequestMetadata rpc ~ NoMetadata + , ResponseInitialMetadata rpc ~ NoMetadata + , ResponseTrailingMetadata rpc ~ NoMetadata + ) + +type RPC1 = Trivial' "rpc1" +type RPC2 = Trivial' "rpc2" +