Skip to content

Commit

Permalink
Sketch WSS implementation (mostly stolen from warp-tls)
Browse files Browse the repository at this point in the history
  • Loading branch information
deepfire committed May 31, 2023
1 parent 2a5670e commit 7d2423e
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 9 deletions.
102 changes: 102 additions & 0 deletions src/Network/WebSockets/Connection/Options.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ module Network.WebSockets.Connection.Options

, SizeLimit (..)
, atMostSizeLimit

, TLSSettings (..)
, defaultTlsSettings

, CertSettings (..)
, defaultCertSettings
) where


Expand All @@ -18,6 +24,14 @@ import Data.Int (Int64)
import Data.Monoid (Monoid (..))
import Prelude

import qualified Crypto.PubKey.DH as DH
import qualified Data.ByteString as B
import Data.Default.Class (def)
import qualified Data.IORef as IO
import qualified Network.TLS as TLS
import qualified Network.TLS.Extra as TLSExtra
import qualified Network.TLS.SessionManager as SM


--------------------------------------------------------------------------------
-- | Set options for a 'Connection'. Please do not use this constructor
Expand Down Expand Up @@ -50,6 +64,7 @@ data ConnectionOptions = ConnectionOptions
-- compressed messages, as well as the size of the uncompressed messages
-- as we are deflating them to ensure we don't use too much memory in any
-- case.
, connectionTlsSettings :: !(Maybe TLSSettings)
}


Expand All @@ -66,6 +81,7 @@ defaultConnectionOptions = ConnectionOptions
, connectionStrictUnicode = False
, connectionFramePayloadSizeLimit = mempty
, connectionMessageDataSizeLimit = mempty
, connectionTlsSettings = Nothing
}


Expand Down Expand Up @@ -126,3 +142,89 @@ atMostSizeLimit :: Int64 -> SizeLimit -> Bool
atMostSizeLimit _ NoSizeLimit = True
atMostSizeLimit s (SizeLimit l) = s <= l
{-# INLINE atMostSizeLimit #-}

--------------------------------------------------------------------------------
-- | Determines where to load the certificate, chain
-- certificates, and key from.
data CertSettings
= CertFromFile !FilePath ![FilePath] !FilePath
| CertFromMemory !B.ByteString ![B.ByteString] !B.ByteString
| CertFromRef !(IO.IORef B.ByteString) ![IO.IORef B.ByteString] !(IO.IORef B.ByteString)

-- | The default 'CertSettings'.
defaultCertSettings :: CertSettings
defaultCertSettings = CertFromFile "certificate.pem" [] "key.pem"

--------------------------------------------------------------------------------
data TLSSettings = TLSSettings {
certSettings :: CertSettings
-- ^ Where are the certificate, chain certificates, and key
-- loaded from?
--
-- >>> certSettings defaultTlsSettings
-- tlsSettings "certificate.pem" "key.pem"
, tlsLogging :: TLS.Logging
-- ^ The level of logging to turn on.
--
-- Default: 'TLS.defaultLogging'.
, tlsAllowedVersions :: [TLS.Version]
-- ^ The TLS versions this server accepts.
--
-- >>> tlsAllowedVersions defaultTlsSettings
-- [TLS13,TLS12,TLS11,TLS10]
, tlsCiphers :: [TLS.Cipher]
-- ^ The TLS ciphers this server accepts.
--
-- >>> tlsCiphers defaultTlsSettings
-- [ECDHE-ECDSA-AES256GCM-SHA384,ECDHE-ECDSA-AES128GCM-SHA256,ECDHE-RSA-AES256GCM-SHA384,ECDHE-RSA-AES128GCM-SHA256,DHE-RSA-AES256GCM-SHA384,DHE-RSA-AES128GCM-SHA256,ECDHE-ECDSA-AES256CBC-SHA384,ECDHE-RSA-AES256CBC-SHA384,DHE-RSA-AES256-SHA256,ECDHE-ECDSA-AES256CBC-SHA,ECDHE-RSA-AES256CBC-SHA,DHE-RSA-AES256-SHA1,RSA-AES256GCM-SHA384,RSA-AES256-SHA256,RSA-AES256-SHA1,AES128GCM-SHA256,AES256GCM-SHA384]
, tlsWantClientCert :: Bool
-- ^ Whether or not to demand a certificate from the client. If this
-- is set to True, you must handle received certificates in a server hook
-- or all connections will fail.
--
-- >>> tlsWantClientCert defaultTlsSettings
-- False
, tlsServerHooks :: TLS.ServerHooks
-- ^ The server-side hooks called by the tls package, including actions
-- to take when a client certificate is received. See the "Network.TLS"
-- module for details.
--
-- Default: def
, tlsServerDHEParams :: Maybe DH.Params
-- ^ Configuration for ServerDHEParams
-- more function lives in `cryptonite` package
--
-- Default: Nothing
, tlsSessionManagerConfig :: Maybe SM.Config
-- ^ Configuration for in-memory TLS session manager.
-- If Nothing, 'TLS.noSessionManager' is used.
-- Otherwise, an in-memory TLS session manager is created
-- according to 'Config'.
--
-- Default: Nothing
, tlsCredentials :: Maybe TLS.Credentials
-- ^ Specifying 'TLS.Credentials' directly. If this value is
-- specified, other fields such as 'certFile' are ignored.
, tlsSessionManager :: Maybe TLS.SessionManager
-- ^ Specifying 'TLS.SessionManager' directly. If this value is
-- specified, 'tlsSessionManagerConfig' is ignored.
}

defaultTlsSettings :: TLSSettings
defaultTlsSettings =
TLSSettings
{ certSettings = defaultCertSettings
, tlsLogging = def
, tlsAllowedVersions = [TLS.TLS13,TLS.TLS12]
, tlsCiphers = ciphers
, tlsWantClientCert = False
, tlsServerHooks = def
, tlsServerDHEParams = Nothing
, tlsSessionManagerConfig = Nothing
, tlsCredentials = Nothing
, tlsSessionManager = Nothing
}
where
-- taken from stunnel example in tls-extra
ciphers :: [TLS.Cipher]
ciphers = TLSExtra.ciphersuite_strong
4 changes: 3 additions & 1 deletion src/Network/WebSockets/Server.hs
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,9 @@ runApp socket opts app =
makePendingConnection
:: Socket -> ConnectionOptions -> IO PendingConnection
makePendingConnection socket opts = do
stream <- Stream.makeSocketStream socket
stream <- case connectionTlsSettings opts of
Nothing -> Stream.makeSocketStream socket
Just tls -> Stream.makeTlsSocketStream tls socket
makePendingConnectionFromStream stream opts


Expand Down
104 changes: 97 additions & 7 deletions src/Network/WebSockets/Stream.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
--------------------------------------------------------------------------------
-- | Lightweight abstraction over an input/output stream.
{-# LANGUAGE CPP #-}
{-# LANGUAGE RecordWildCards #-}
module Network.WebSockets.Stream
( Stream
, makeStream
Expand All @@ -10,16 +11,23 @@ module Network.WebSockets.Stream
, parseBin
, write
, close
-- * TLS
, makeTlsSocketStream
, streamTlsContext
) where

import Control.Applicative ((<|>))
import Control.Concurrent.MVar (MVar, newEmptyMVar, newMVar,
putMVar, takeMVar, withMVar)
import Control.Exception (SomeException, SomeAsyncException, throwIO, catch, fromException)
import Control.Exception (SomeException, SomeAsyncException, catch, handle, throwIO, fromException)
import Control.Monad (forM_)
import qualified Data.Attoparsec.ByteString as Atto
import qualified Data.Binary.Get as BIN
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import Data.Default.Class (def)
import Data.Functor ((<&>))
import qualified Data.IORef as IO
import Data.IORef (IORef, atomicModifyIORef',
newIORef, readIORef,
writeIORef)
Expand All @@ -32,7 +40,11 @@ import qualified Network.Socket.ByteString.Lazy as SBL (sendAll)
import qualified Network.Socket.ByteString as SB (sendAll)
#endif

import qualified Network.TLS as TLS
import qualified Network.TLS.SessionManager as SM
import Network.WebSockets.Types
import Network.WebSockets.Connection.Options
import System.IO.Error (isEOFError)


--------------------------------------------------------------------------------
Expand All @@ -45,12 +57,12 @@ data StreamState
--------------------------------------------------------------------------------
-- | Lightweight abstraction over an input/output stream.
data Stream = Stream
{ streamIn :: IO (Maybe B.ByteString)
, streamOut :: (Maybe BL.ByteString -> IO ())
, streamState :: !(IORef StreamState)
{ streamIn :: IO (Maybe B.ByteString)
, streamOut :: (Maybe BL.ByteString -> IO ())
, streamState :: !(IORef StreamState)
, streamTlsContext :: Maybe TLS.Context
}


--------------------------------------------------------------------------------
-- | Create a stream from a "receive" and "send" action. The following
-- properties apply:
Expand All @@ -72,7 +84,7 @@ makeStream receive send = do
ref <- newIORef (Open B.empty)
receiveLock <- newMVar ()
sendLock <- newMVar ()
return $ Stream (receive' ref receiveLock) (send' ref sendLock) ref
return $ Stream (receive' ref receiveLock) (send' ref sendLock) ref Nothing
where
closeRef :: IORef StreamState -> IO ()
closeRef ref = atomicModifyIORef' ref $ \state -> case state of
Expand Down Expand Up @@ -110,7 +122,6 @@ makeStream receive send = do
Nothing -> what *> pure ()
throwIO e


--------------------------------------------------------------------------------
makeSocketStream :: S.Socket -> IO Stream
makeSocketStream socket = makeStream receive send
Expand All @@ -127,6 +138,85 @@ makeSocketStream socket = makeStream receive send
forM_ (BL.toChunks bs) (SB.sendAll socket)
#endif

loadCredentials :: TLSSettings -> IO TLS.Credentials
loadCredentials TLSSettings{ tlsCredentials = Just creds } = return creds
loadCredentials TLSSettings{..} = case certSettings of
CertFromFile cert chainFiles key -> do
cred <- either error id <$> TLS.credentialLoadX509Chain cert chainFiles key
return $ TLS.Credentials [cred]
CertFromRef certRef chainCertsRef keyRef -> do
cert <- IO.readIORef certRef
chainCerts <- mapM IO.readIORef chainCertsRef
key <- IO.readIORef keyRef
cred <- either error return $ TLS.credentialLoadX509ChainFromMemory cert chainCerts key
return $ TLS.Credentials [cred]
CertFromMemory certMemory chainCertsMemory keyMemory -> do
cred <- either error return $ TLS.credentialLoadX509ChainFromMemory certMemory chainCertsMemory keyMemory
return $ TLS.Credentials [cred]

makeTlsSocketStream :: TLSSettings -> S.Socket -> IO Stream
makeTlsSocketStream stts socket = do
creds <- loadCredentials stts
mgr <- getSessionManager stts
ctx <- TLS.contextNew socket (params mgr creds)
TLS.contextHookSetLogging ctx (tlsLogging stts)
TLS.handshake ctx
makeStream (receive ctx) (send ctx) <&>
\s -> s { streamTlsContext = Just ctx }
where
receive ctx = handle onEOF go
where
onEOF e
| Just TLS.Error_EOF <- fromException e = pure Nothing
| Just ioe <- fromException e, isEOFError ioe = pure Nothing
| otherwise = throwIO e
go = do
x <- TLS.recvData ctx
if B.null x then
go
else
pure $ Just x

send _ Nothing = return ()
send ctx (Just bs) =
TLS.sendData ctx bs

params mgr creds = def { -- TLS.ServerParams
TLS.serverWantClientCert = tlsWantClientCert stts
, TLS.serverCACertificates = []
, TLS.serverDHEParams = tlsServerDHEParams stts
, TLS.serverHooks = hooks
, TLS.serverShared = shared mgr creds
, TLS.serverSupported = supported
, TLS.serverEarlyDataSize = 2018
}
-- Adding alpn to user's tlsServerHooks.
hooks = (tlsServerHooks stts)
{ TLS.onALPNClientSuggest = TLS.onALPNClientSuggest (tlsServerHooks stts)
-- <|> (if settingsHTTP2Enabled set then Just alpn else Nothing)
}

shared mgr creds = def {
TLS.sharedCredentials = creds
, TLS.sharedSessionManager = mgr
}
supported = def { -- TLS.Supported
TLS.supportedVersions = tlsAllowedVersions stts
, TLS.supportedCiphers = tlsCiphers stts
, TLS.supportedCompressions = [TLS.nullCompression]
, TLS.supportedSecureRenegotiation = True
, TLS.supportedClientInitiatedRenegotiation = False
, TLS.supportedSession = True
, TLS.supportedFallbackScsv = True
, TLS.supportedGroups = [TLS.X25519,TLS.P256,TLS.P384]
}

getSessionManager :: TLSSettings -> IO TLS.SessionManager
getSessionManager TLSSettings{ tlsSessionManager = Just mgr } = return mgr
getSessionManager stts' = case tlsSessionManagerConfig stts' of
Nothing -> return TLS.noSessionManager
Just config -> SM.newSessionManager config


--------------------------------------------------------------------------------
makeEchoStream :: IO Stream
Expand Down
7 changes: 6 additions & 1 deletion websockets.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ Library
Network.WebSockets
Network.WebSockets.Client
Network.WebSockets.Connection
Network.WebSockets.Connection.Options
Network.WebSockets.Extensions
Network.WebSockets.Stream
-- Network.WebSockets.Util.PubSub TODO

Other-modules:
Network.WebSockets.Connection.Options
Network.WebSockets.Extensions.Description
Network.WebSockets.Extensions.PermessageDeflate
Network.WebSockets.Extensions.StrictUnicode
Expand All @@ -94,12 +94,17 @@ Library
bytestring-builder < 0.11,
case-insensitive >= 0.3 && < 1.3,
clock >= 0.8 && < 0.9,
connection,
containers >= 0.3 && < 0.7,
cryptonite,
data-default-class,
network >= 2.3 && < 3.2,
random >= 1.0.1 && < 1.3,
SHA >= 1.5 && < 1.7,
streaming-commons >= 0.1 && < 0.3,
text >= 0.10 && < 2.1,
tls,
tls-session-manager,
entropy >= 0.2.1 && < 0.5

Test-suite websockets-tests
Expand Down

0 comments on commit 7d2423e

Please sign in to comment.