diff --git a/cardano-client/CHANGELOG.md b/cardano-client/CHANGELOG.md index e3287b5c201..4c4ee5322f8 100644 --- a/cardano-client/CHANGELOG.md +++ b/cardano-client/CHANGELOG.md @@ -2,6 +2,16 @@ ## next version +### Breaking changes + +* Reimplemntation of `subscribe` without relaying on non-p2p stack. Its + arguments have changed. Note that the `NodeToClientProtocols` and + `OuroborosApplicationWithMinimalCtx` specify `Void` as return type of the + responder side. +* The default reconnect delay was increased from `0.025s` to `5s`. + +### Non-breaking changes + ## 0.3.1.5 -- 2024-08-27 ### Breaking changes diff --git a/cardano-client/cardano-client.cabal b/cardano-client/cardano-client.cabal index 3ffad65c293..7a052e360fd 100644 --- a/cardano-client/cardano-client.cabal +++ b/cardano-client/cardano-client.cabal @@ -23,11 +23,14 @@ library build-depends: base >=4.14 && <4.21, bytestring >=0.10 && <0.13, + cborg, containers, + contra-tracer, network-mux ^>=0.4.5, ouroboros-network >=0.9 && <0.18, ouroboros-network-api >=0.5.2 && <0.10, ouroboros-network-framework >=0.8 && <0.14, + si-timers, ghc-options: -Wall diff --git a/cardano-client/src/Cardano/Client/Subscription.hs b/cardano-client/src/Cardano/Client/Subscription.hs index 1e49a1a26b6..528dc2b16cd 100644 --- a/cardano-client/src/Cardano/Client/Subscription.hs +++ b/cardano-client/src/Cardano/Client/Subscription.hs @@ -1,24 +1,40 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} module Cardano.Client.Subscription - ( subscribe + ( -- * Subscription API + subscribe + , SubscriptionParams (..) + , SubscriptionTracers (..) + , SubscriptionTrace (..) + -- * Re-exports + -- ** Mux , MuxMode (..) - , ConnectionId - , LocalAddress + , MuxTrace + , WithMuxBearer + -- ** Connections + , ConnectionId (..) + , LocalAddress (..) + -- ** Protocol API , NodeToClientProtocols (..) , MiniProtocolCb (..) - , MuxTrace , RunMiniProtocol (..) - , WithMuxBearer , ControlMessage (..) ) where +import Codec.CBOR.Term qualified as CBOR +import Control.Exception +import Control.Monad (join) +import Control.Monad.Class.MonadTime.SI +import Control.Monad.Class.MonadTimer.SI +import Control.Tracer (Tracer, traceWith) import Data.ByteString.Lazy qualified as BSL import Data.Map.Strict (Map) import Data.Map.Strict qualified as Map +import Data.Maybe (fromMaybe) import Data.Void (Void) import Network.Mux.Trace (MuxTrace, WithMuxBearer) @@ -27,15 +43,43 @@ import Ouroboros.Network.ControlMessage (ControlMessage (..)) import Ouroboros.Network.Magic (NetworkMagic) import Ouroboros.Network.Mux (MiniProtocolCb (..), MuxMode (..), OuroborosApplicationWithMinimalCtx, RunMiniProtocol (..)) -import Ouroboros.Network.NodeToClient (ClientSubscriptionParams (..), - ConnectionId, LocalAddress, NetworkClientSubcriptionTracers, - NodeToClientProtocols (..), NodeToClientVersion, - NodeToClientVersionData (NodeToClientVersionData), - ncSubscriptionWorker, newNetworkMutableState, - versionedNodeToClientProtocols) -import Ouroboros.Network.Protocol.Handshake.Version (Versions, foldMapVersions) + +import Ouroboros.Network.ConnectionId (ConnectionId (..)) +import Ouroboros.Network.NodeToClient (Handshake, LocalAddress (..), + NetworkConnectTracers (..), NodeToClientProtocols, + NodeToClientVersion, NodeToClientVersionData (..), TraceSendRecv, + Versions) +import Ouroboros.Network.NodeToClient qualified as NtC import Ouroboros.Network.Snocket qualified as Snocket +data SubscriptionParams a = SubscriptionParams + { spAddress :: !LocalAddress + -- ^ unix socket or named pipe address + , spReconnectionDelay :: !(Maybe DiffTime) + -- ^ delay between connection attempts. The default value is `5s`. + , spCompleteCb :: Either SomeException a -> Decision + } + +data Decision = + Abort + -- ^ abort subscription loop + | Reconnect + -- ^ reconnect + +data SubscriptionTracers = SubscriptionTracers { + stMuxTracer :: Tracer IO (WithMuxBearer (ConnectionId LocalAddress) MuxTrace), + -- ^ low level mux-network tracer, which logs mux sdu (send and received) + -- and other low level multiplexing events. + stHandshakeTracer :: Tracer IO (WithMuxBearer (ConnectionId LocalAddress) + (TraceSendRecv (Handshake NodeToClientVersion CBOR.Term))), + -- ^ handshake protocol tracer; it is important for analysing version + -- negotation mismatches. + stSubscriptionTracer :: Tracer IO SubscriptionTrace + } + +data SubscriptionTrace = SubscriptionError SomeException + deriving Show + -- | Subscribe using `node-to-client` mini-protocol. -- -- 'blockVersion' ought to be instantiated with `BlockNodeToClientVersion blk`. @@ -44,34 +88,63 @@ import Ouroboros.Network.Snocket qualified as Snocket -- `Ouroboros.Consensus.Network.NodeToClient.clientCodecs`. -- subscribe - :: forall blockVersion x y. + :: forall blockVersion a. Snocket.LocalSnocket -> NetworkMagic -> Map NodeToClientVersion blockVersion -- ^ Use `supportedNodeToClientVersions` from `ouroboros-consensus`. - -> NetworkClientSubcriptionTracers - -> ClientSubscriptionParams () + -> SubscriptionTracers + -> SubscriptionParams a -> ( NodeToClientVersion -> blockVersion - -> NodeToClientProtocols 'InitiatorMode LocalAddress BSL.ByteString IO x y) - -> IO Void -subscribe snocket networkMagic supportedVersions tracers subscriptionParams protocols = do - networkState <- newNetworkMutableState - ncSubscriptionWorker - snocket - tracers - networkState - subscriptionParams - (versionedProtocols networkMagic supportedVersions protocols) + -> NodeToClientProtocols 'InitiatorMode LocalAddress BSL.ByteString IO a Void) + -> IO () +subscribe snocket networkMagic supportedVersions + SubscriptionTracers { + stMuxTracer = muxTracer, + stHandshakeTracer = handshakeTracer, + stSubscriptionTracer = tracer + } + SubscriptionParams { + spAddress = addr, + spReconnectionDelay = reConnDelay, + spCompleteCb = completeCb + } + protocols = + mask $ \unmask -> + loop unmask $ + NtC.connectTo + snocket + NetworkConnectTracers { + nctMuxTracer = muxTracer, + nctHandshakeTracer = handshakeTracer + } + (versionedProtocols networkMagic supportedVersions protocols) + (getFilePath addr) + where + loop :: (forall x. IO x -> IO x) -> IO (Either SomeException a) -> IO () + loop unmask act = do + r <- fn <$> try (unmask act) + case r of + Right _ -> pure () + Left e -> traceWith tracer (SubscriptionError e) + case completeCb r of + Abort -> pure () + Reconnect -> do + threadDelay (fromMaybe 5 reConnDelay) + loop unmask act + + fn :: forall x y. Either x (Either x y) -> Either x y + fn = join versionedProtocols :: - forall m appType bytes blockVersion a b. + forall m appType bytes blockVersion a. NetworkMagic -> Map NodeToClientVersion blockVersion -- ^ Use `supportedNodeToClientVersions` from `ouroboros-consensus`. -> ( NodeToClientVersion -> blockVersion - -> NodeToClientProtocols appType LocalAddress bytes m a b) + -> NodeToClientProtocols appType LocalAddress bytes m a Void) -- ^ callback which receives codecs, connection id and STM action which -- can be checked if the networking runtime system requests the protocols -- to stop. @@ -82,18 +155,21 @@ versionedProtocols :: -> Versions NodeToClientVersion NodeToClientVersionData - (OuroborosApplicationWithMinimalCtx appType LocalAddress bytes m a b) + (OuroborosApplicationWithMinimalCtx appType LocalAddress bytes m a Void) versionedProtocols networkMagic supportedVersions callback = - foldMapVersions applyVersion $ Map.toList supportedVersions + NtC.foldMapVersions applyVersion (Map.toList supportedVersions) where applyVersion :: (NodeToClientVersion, blockVersion) -> Versions NodeToClientVersion NodeToClientVersionData - (OuroborosApplicationWithMinimalCtx appType LocalAddress bytes m a b) + (OuroborosApplicationWithMinimalCtx appType LocalAddress bytes m a Void) applyVersion (version, blockVersion) = - versionedNodeToClientProtocols + NtC.versionedNodeToClientProtocols version - (NodeToClientVersionData networkMagic False) + NodeToClientVersionData { + networkMagic, + query = False + } (callback version blockVersion) diff --git a/network-mux/CHANGELOG.md b/network-mux/CHANGELOG.md index 5b63ba61832..926c96ad31f 100644 --- a/network-mux/CHANGELOG.md +++ b/network-mux/CHANGELOG.md @@ -4,6 +4,16 @@ ### Breaking changes +* Removed `Netowrk.Mux.Compat` module with legacy API. +* `Ouroboros.Network.Mux.toApplication` was removed. +* `Ouroboros.Network.Mux.mkMiniProtocolBundle` was renamed to + `mkMiniProtocolInfos`, its type changed. +* Removed `MiniProtocolBundle` newtype wrapper. +* Generalised `Channel` type and provide `ByteChannel` type alias. +* Provide additional APIs in the `Network.Mux.Channel` for creating channels + and byte channels. +* `MuxBearer` has a `name` field. + ### Non-breaking changes * Fix compilation with `tracetcpinfo` flag. diff --git a/network-mux/demo/mux-demo.hs b/network-mux/demo/mux-demo.hs index ab4dce124fc..4b85b6707fc 100644 --- a/network-mux/demo/mux-demo.hs +++ b/network-mux/demo/mux-demo.hs @@ -135,9 +135,8 @@ serverWorker bearer = do runMux nullTracer mux bearer where - ptcls :: MiniProtocolBundle ResponderMode - ptcls = MiniProtocolBundle - [ MiniProtocolInfo { + ptcls :: [MiniProtocolInfo ResponderMode] + ptcls = [ MiniProtocolInfo { miniProtocolNum = MiniProtocolNum 2, miniProtocolDir = ResponderDirectionOnly, miniProtocolLimits = defaultProtocolLimits @@ -195,9 +194,8 @@ clientWorker bearer n msg = do runMux nullTracer mux bearer where - ptcls :: MiniProtocolBundle InitiatorMode - ptcls = MiniProtocolBundle - [ MiniProtocolInfo { + ptcls :: [MiniProtocolInfo InitiatorMode] + ptcls = [ MiniProtocolInfo { miniProtocolNum = MiniProtocolNum 2, miniProtocolDir = InitiatorDirectionOnly, miniProtocolLimits = defaultProtocolLimits diff --git a/network-mux/network-mux.cabal b/network-mux/network-mux.cabal index e8cd2031aa8..13084eb782e 100644 --- a/network-mux/network-mux.cabal +++ b/network-mux/network-mux.cabal @@ -83,7 +83,6 @@ library Network.Mux.Bearer.Socket Network.Mux.Channel Network.Mux.Codec - Network.Mux.Compat Network.Mux.DeltaQ.TraceStats Network.Mux.DeltaQ.TraceStatsSupport Network.Mux.DeltaQ.TraceTransformer diff --git a/network-mux/src/Network/Mux.hs b/network-mux/src/Network/Mux.hs index 9dc29691712..eed58b15b11 100644 --- a/network-mux/src/Network/Mux.hs +++ b/network-mux/src/Network/Mux.hs @@ -3,7 +3,6 @@ {-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTSyntax #-} -{-# LANGUAGE KindSignatures #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -16,7 +15,6 @@ module Network.Mux , MuxMode (..) , HasInitiator , HasResponder - , MiniProtocolBundle (..) , MiniProtocolInfo (..) , MiniProtocolNum (..) , MiniProtocolDirection (..) @@ -25,6 +23,8 @@ module Network.Mux , runMux , runMiniProtocol , StartOnDemandOrEagerly (..) + , ByteChannel + , Channel (..) , stopMux -- * Bearer , MuxBearer @@ -113,8 +113,15 @@ data MuxStatus | MuxStopped -newMux :: MonadSTM m => MiniProtocolBundle mode -> m (Mux mode m) -newMux (MiniProtocolBundle ptcls) = do +-- | Create a mux handle. +-- +newMux :: forall (mode :: MuxMode) m. + MonadLabelledSTM m + => [MiniProtocolInfo mode] + -- ^ description of protocols run by the mux layer. Only these protocols + -- one will be able to execute. + -> m (Mux mode m) +newMux ptcls = do muxMiniProtocols <- mkMiniProtocolStateMap ptcls muxControlCmdQueue <- atomically newTQueue muxStatus <- newTVarIO MuxReady @@ -210,9 +217,13 @@ runMux :: forall m mode. -> Mux mode m -> MuxBearer m -> m () -runMux tracer Mux {muxMiniProtocols, muxControlCmdQueue, muxStatus} bearer = do +runMux tracer Mux {muxMiniProtocols, muxControlCmdQueue, muxStatus} bearer@MuxBearer {name} = do egressQueue <- atomically $ newTBQueue 100 - labelTBQueueIO egressQueue "mux-eq" + + -- label shared variables + labelTBQueueIO egressQueue (name ++ "-mux-egress") + labelTVarIO muxStatus (name ++ "-mux-status") + labelTQueueIO muxControlCmdQueue (name ++ "-mux-ctrl") JobPool.withJobPool (\jobpool -> do @@ -241,13 +252,13 @@ runMux tracer Mux {muxMiniProtocols, muxControlCmdQueue, muxStatus} bearer = do JobPool.Job (muxer egressQueue bearer) (return . MuxerException) MuxJob - "muxer" + (name ++ "-muxer") demuxerJob = JobPool.Job (demuxer (Map.elems muxMiniProtocols) bearer) (return . DemuxerException) MuxJob - "demuxer" + (name ++ "-demuxer") miniProtocolJob :: forall mode m. @@ -324,7 +335,7 @@ data StartOnDemandOrEagerly = StartOnDemand | StartEagerly deriving Eq data MiniProtocolAction m where - MiniProtocolAction :: (Channel m -> m (a, Maybe BL.ByteString)) -- ^ Action + MiniProtocolAction :: (ByteChannel m -> m (a, Maybe BL.ByteString)) -- ^ Action -> StrictTMVar m (Either SomeException a) -- ^ Completion var -> MiniProtocolAction m @@ -333,8 +344,8 @@ type MiniProtocolKey = (MiniProtocolNum, MiniProtocolDir) newtype MonitorCtx m mode = MonitorCtx { -- | Mini-Protocols started on demand and waiting to be scheduled. -- - mcOnDemandProtocols :: (Map MiniProtocolKey - (MiniProtocolState mode m, MiniProtocolAction m)) + mcOnDemandProtocols :: Map MiniProtocolKey + (MiniProtocolState mode m, MiniProtocolAction m) } @@ -364,10 +375,10 @@ monitor tracer timeout jobpool egressQueue cmdQueue muxStatus = go !monitorCtx@MonitorCtx { mcOnDemandProtocols } = do result <- atomically $ runFirstToFinish $ -- wait for a mini-protocol thread to terminate - (FirstToFinish $ EventJobResult <$> JobPool.waitForJob jobpool) + FirstToFinish (EventJobResult <$> JobPool.waitForJob jobpool) -- wait for a new control command - <> (FirstToFinish $ EventControlCmd <$> readTQueue cmdQueue) + <> FirstToFinish (EventControlCmd <$> readTQueue cmdQueue) -- or wait for data to arrive on the channels that do not yet have -- responder threads running @@ -546,7 +557,7 @@ muxChannel -> MiniProtocolNum -> MiniProtocolDir -> IngressQueue m - -> Channel m + -> ByteChannel m muxChannel tracer egressQueue want@(Wanton w) mc md q = Channel { send, recv} where @@ -633,7 +644,7 @@ runMiniProtocol :: forall mode m a. -> MiniProtocolNum -> MiniProtocolDirection mode -> StartOnDemandOrEagerly - -> (Channel m -> m (a, Maybe BL.ByteString)) + -> (ByteChannel m -> m (a, Maybe BL.ByteString)) -> m (STM m (Either SomeException a)) runMiniProtocol Mux { muxMiniProtocols, muxControlCmdQueue , muxStatus} ptclNum ptclDir startMode protocolAction diff --git a/network-mux/src/Network/Mux/Bearer/AttenuatedChannel.hs b/network-mux/src/Network/Mux/Bearer/AttenuatedChannel.hs index b060c28aad9..87d63782cd6 100644 --- a/network-mux/src/Network/Mux/Bearer/AttenuatedChannel.hs +++ b/network-mux/src/Network/Mux/Bearer/AttenuatedChannel.hs @@ -262,7 +262,8 @@ attenuationChannelAsMuxBearer sduSize sduTimeout muxTracer chan = MuxBearer { read = readMux, write = writeMux, - sduSize + sduSize, + name = "attenuation-channel" } where readMux :: TimeoutFn m -> m (MuxSDU, Time) diff --git a/network-mux/src/Network/Mux/Bearer/NamedPipe.hs b/network-mux/src/Network/Mux/Bearer/NamedPipe.hs index f4fba4ebd8b..0bbe6775481 100644 --- a/network-mux/src/Network/Mux/Bearer/NamedPipe.hs +++ b/network-mux/src/Network/Mux/Bearer/NamedPipe.hs @@ -37,7 +37,8 @@ namedPipeAsBearer sduSize tracer h = Mx.MuxBearer { Mx.read = readNamedPipe, Mx.write = writeNamedPipe, - Mx.sduSize = sduSize + Mx.sduSize = sduSize, + Mx.name = "named-pipe" } where readNamedPipe :: Mx.TimeoutFn IO -> IO (Mx.MuxSDU, Time) diff --git a/network-mux/src/Network/Mux/Bearer/Pipe.hs b/network-mux/src/Network/Mux/Bearer/Pipe.hs index 9fe601257e6..1c759dd43d6 100644 --- a/network-mux/src/Network/Mux/Bearer/Pipe.hs +++ b/network-mux/src/Network/Mux/Bearer/Pipe.hs @@ -77,7 +77,8 @@ pipeAsMuxBearer sduSize tracer channel = Mx.MuxBearer { Mx.read = readPipe, Mx.write = writePipe, - Mx.sduSize = sduSize + Mx.sduSize = sduSize, + Mx.name = "pipe" } where readPipe :: Mx.TimeoutFn IO -> IO (Mx.MuxSDU, Time) diff --git a/network-mux/src/Network/Mux/Bearer/Queues.hs b/network-mux/src/Network/Mux/Bearer/Queues.hs index 81156268103..ce7ec3edc9b 100644 --- a/network-mux/src/Network/Mux/Bearer/Queues.hs +++ b/network-mux/src/Network/Mux/Bearer/Queues.hs @@ -42,7 +42,8 @@ queueChannelAsMuxBearer sduSize tracer QueueChannel { writeQueue, readQueue } = Mx.MuxBearer { Mx.read = readMux, Mx.write = writeMux, - Mx.sduSize = sduSize + Mx.sduSize = sduSize, + Mx.name = "queue-channel" } where readMux :: Mx.TimeoutFn m -> m (Mx.MuxSDU, Time) diff --git a/network-mux/src/Network/Mux/Bearer/Socket.hs b/network-mux/src/Network/Mux/Bearer/Socket.hs index 6529b0a0b7e..1d0c0185bda 100644 --- a/network-mux/src/Network/Mux/Bearer/Socket.hs +++ b/network-mux/src/Network/Mux/Bearer/Socket.hs @@ -53,7 +53,8 @@ socketAsMuxBearer sduSize sduTimeout tracer sd = Mx.MuxBearer { Mx.read = readSocket, Mx.write = writeSocket, - Mx.sduSize = sduSize + Mx.sduSize = sduSize, + Mx.name = "socket-bearer" } where hdrLenght = 8 diff --git a/network-mux/src/Network/Mux/Channel.hs b/network-mux/src/Network/Mux/Channel.hs index e95b5aaec47..099ddbba15c 100644 --- a/network-mux/src/Network/Mux/Channel.hs +++ b/network-mux/src/Network/Mux/Channel.hs @@ -1,45 +1,67 @@ {-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} -- | An extension of 'Network.TypedProtocol.Channel', with additional 'Channel' -- implementations. -- module Network.Mux.Channel - ( Channel (..) + ( -- * Channel + Channel (..) + -- ** Channel API + , isoKleisliChannel + , hoistChannel + , channelEffect + , delayChannel + , loggingChannel + -- ** create a `Channel` + , mvarsAsChannel + -- ** connected `Channel`s + , createConnectedChannels + -- * `ByteChannel` + , ByteChannel + -- ** create a `ByteChannel` + , handlesAsChannel + , withFifosAsChannel + , socketAsChannel + -- ** connected `ByteChannel`s , createBufferConnectedChannels , createPipeConnectedChannels #if !defined(mingw32_HOST_OS) , createSocketConnectedChannels #endif - , withFifosAsChannel - , socketAsChannel - , channelEffect - , delayChannel - , loggingChannel ) where -import qualified Data.ByteString as BS -import qualified Data.ByteString.Lazy as LBS -import qualified Data.ByteString.Lazy.Internal as LBS (smallChunkSize) -import qualified Network.Socket as Socket -import qualified Network.Socket.ByteString as Socket -import qualified System.IO as IO (Handle, IOMode (..), hFlush, hIsEOF, withFile) -import qualified System.Process as IO (createPipe) +import Data.ByteString qualified as BS +import Data.ByteString.Lazy qualified as LBS +import Data.ByteString.Lazy.Internal qualified as LBS (smallChunkSize) +import Network.Socket qualified as Socket +import Network.Socket.ByteString qualified as Socket +import System.IO qualified as IO (Handle, IOMode (..), hFlush, hIsEOF, withFile) +import System.Process qualified as IO (createPipe) -import Control.Concurrent.Class.MonadSTM -import Control.Monad.Class.MonadSay -import Control.Monad.Class.MonadTimer.SI +import Control.Concurrent.Class.MonadSTM +import Control.Concurrent.Class.MonadSTM.Strict qualified as StrictSTM +import Control.Monad ((>=>)) +import Control.Monad.Class.MonadSay +import Control.Monad.Class.MonadTimer.SI -data Channel m = Channel { +-- | A channel which can send and receive values. +-- +-- It is more general than what `network-mux` requires, see `ByteChannel` +-- instead. However this is useful for testing purposes when one is either +-- using `mux` or connecting two ends directly. +-- +data Channel m a = Channel { -- | Write bytes to the channel. -- -- It maybe raise exceptions. -- - send :: LBS.ByteString -> m (), + send :: a -> m (), -- | Read some input from the channel, or @Nothing@ to indicate EOF. -- @@ -49,9 +71,126 @@ data Channel m = Channel { -- It may raise exceptions (as appropriate for the monad and kind of -- channel). -- - recv :: m (Maybe LBS.ByteString) + recv :: m (Maybe a) } +-- | Given an isomorphism between @a@ and @b@ (in Kleisli category), transform +-- a @'Channel' m a@ into @'Channel' m b@. +-- +isoKleisliChannel + :: forall a b m. Monad m + => (a -> m b) + -> (b -> m a) + -> Channel m a + -> Channel m b +isoKleisliChannel f finv Channel{send, recv} = Channel { + send = finv >=> send, + recv = recv >>= traverse f + } + + +hoistChannel + :: (forall x . m x -> n x) + -> Channel m a + -> Channel n a +hoistChannel nat channel = Channel + { send = nat . send channel + , recv = nat (recv channel) + } + +channelEffect :: forall m a. + Monad m + => (a -> m ()) -- ^ Action before 'send' + -> (Maybe a -> m ()) -- ^ Action after 'recv' + -> Channel m a + -> Channel m a +channelEffect beforeSend afterRecv Channel{send, recv} = + Channel{ + send = \x -> do + beforeSend x + send x + + , recv = do + mx <- recv + afterRecv mx + return mx + } + +-- | Delay a channel on the receiver end. +-- +-- This is intended for testing, as a crude approximation of network delays. +-- More accurate models along these lines are of course possible. +-- +delayChannel :: MonadDelay m + => DiffTime + -> Channel m a + -> Channel m a +delayChannel delay = channelEffect (\_ -> return ()) + (\_ -> threadDelay delay) + +-- | Channel which logs sent and received messages. +-- +loggingChannel :: ( MonadSay m + , Show id + , Show a + ) + => id + -> Channel m a + -> Channel m a +loggingChannel ident Channel{send,recv} = + Channel { + send = loggingSend, + recv = loggingRecv + } + where + loggingSend a = do + say (show ident ++ ":send:" ++ show a) + send a + + loggingRecv = do + msg <- recv + case msg of + Nothing -> return () + Just a -> say (show ident ++ ":recv:" ++ show a) + return msg + + +-- | Make a 'Channel' from a pair of 'TMVar's, one for reading and one for +-- writing. +-- +mvarsAsChannel :: MonadSTM m + => StrictSTM.StrictTMVar m a + -> StrictSTM.StrictTMVar m a + -> Channel m a +mvarsAsChannel bufferRead bufferWrite = + Channel{send, recv} + where + send x = atomically (StrictSTM.putTMVar bufferWrite x) + recv = atomically (Just <$> StrictSTM.takeTMVar bufferRead) + + +-- | Create a pair of channels that are connected via one-place buffers. +-- +-- This is primarily useful for testing protocols. +-- +createConnectedChannels :: MonadSTM m => m (Channel m a, Channel m a) +createConnectedChannels = do + -- Create two TMVars to act as the channel buffer (one for each direction) + -- and use them to make both ends of a bidirectional channel + bufferA <- StrictSTM.newEmptyTMVarIO + bufferB <- StrictSTM.newEmptyTMVarIO + + return (mvarsAsChannel bufferB bufferA, + mvarsAsChannel bufferA bufferB) + +-- +-- ByteChannel +-- + +-- | Channel using `LBS.ByteString`. +-- +type ByteChannel m = Channel m LBS.ByteString + -- | Make a 'Channel' from a pair of IO 'Handle's, one for reading and one -- for writing. @@ -64,7 +203,7 @@ data Channel m = Channel { -- handlesAsChannel :: IO.Handle -- ^ Read handle -> IO.Handle -- ^ Write handle - -> Channel IO + -> Channel IO LBS.ByteString handlesAsChannel hndRead hndWrite = Channel{send, recv} where @@ -90,8 +229,8 @@ handlesAsChannel hndRead hndWrite = -- takes place on the /writer side and not the reader side/. -- createBufferConnectedChannels :: forall m. MonadSTM m - => m (Channel m, - Channel m) + => m (ByteChannel m, + ByteChannel m) createBufferConnectedChannels = do bufferA <- newEmptyTMVarIO bufferB <- newEmptyTMVarIO @@ -117,8 +256,8 @@ createBufferConnectedChannels = do -- -- This is primarily for testing purposes since it does not allow actual IPC. -- -createPipeConnectedChannels :: IO (Channel IO, - Channel IO) +createPipeConnectedChannels :: IO (ByteChannel IO, + ByteChannel IO) createPipeConnectedChannels = do -- Create two pipes (each one is unidirectional) to make both ends of -- a bidirectional channel @@ -138,7 +277,7 @@ createPipeConnectedChannels = do -- withFifosAsChannel :: FilePath -- ^ FIFO for reading -> FilePath -- ^ FIFO for writing - -> (Channel IO -> IO a) -> IO a + -> (ByteChannel IO -> IO a) -> IO a withFifosAsChannel fifoPathRead fifoPathWrite action = IO.withFile fifoPathRead IO.ReadMode $ \hndRead -> IO.withFile fifoPathWrite IO.WriteMode $ \hndWrite -> @@ -149,7 +288,7 @@ withFifosAsChannel fifoPathRead fifoPathWrite action = -- | Make a 'Channel' from a 'Socket'. The socket must be a stream socket --- type and status connected. --- -socketAsChannel :: Socket.Socket -> Channel IO +socketAsChannel :: Socket.Socket -> ByteChannel IO socketAsChannel socket = Channel{send, recv} where @@ -175,8 +314,8 @@ socketAsChannel socket = --- This is primarily for testing purposes since it does not allow actual IPC. --- createSocketConnectedChannels :: Socket.Family -- ^ Usually AF_UNIX or AF_INET - -> IO (Channel IO, - Channel IO) + -> IO (ByteChannel IO, + ByteChannel IO) createSocketConnectedChannels family = do -- Create a socket pair to make both ends of a bidirectional channel (socketA, socketB) <- Socket.socketPair family Socket.Stream @@ -185,58 +324,3 @@ createSocketConnectedChannels family = do return (socketAsChannel socketA, socketAsChannel socketB) #endif - -channelEffect :: forall m. - Monad m - => (LBS.ByteString -> m ()) -- ^ Action before 'send' - -> (Maybe LBS.ByteString -> m ()) -- ^ Action after 'recv' - -> Channel m - -> Channel m -channelEffect beforeSend afterRecv Channel{send, recv} = - Channel{ - send = \x -> do - beforeSend x - send x - - , recv = do - mx <- recv - afterRecv mx - return mx - } - --- | Delay a channel on the receiver end. --- --- This is intended for testing, as a crude approximation of network delays. --- More accurate models along these lines are of course possible. --- -delayChannel :: MonadDelay m - => DiffTime - -> Channel m - -> Channel m -delayChannel delay = channelEffect (\_ -> return ()) - (\_ -> threadDelay delay) - --- | Channel which logs sent and received messages. --- -loggingChannel :: ( MonadSay m - , Show id - ) - => id - -> Channel m - -> Channel m -loggingChannel ident Channel{send,recv} = - Channel { - send = loggingSend, - recv = loggingRecv - } - where - loggingSend a = do - say (show ident ++ ":send:" ++ show a) - send a - - loggingRecv = do - msg <- recv - case msg of - Nothing -> return () - Just a -> say (show ident ++ ":recv:" ++ show a) - return msg diff --git a/network-mux/src/Network/Mux/Compat.hs b/network-mux/src/Network/Mux/Compat.hs deleted file mode 100644 index 09c8739a738..00000000000 --- a/network-mux/src/Network/Mux/Compat.hs +++ /dev/null @@ -1,156 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE ExistentialQuantification #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE GADTSyntax #-} -{-# LANGUAGE NamedFieldPuns #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeFamilies #-} - -module Network.Mux.Compat - ( muxStart - -- * Mux bearers - , MuxBearer - , MakeBearer (..) - -- * Defining 'MuxApplication's - , MuxMode (..) - , HasInitiator - , HasResponder - , MuxApplication (..) - , MuxMiniProtocol (..) - , RunMiniProtocol (..) - , MiniProtocolNum (..) - , MiniProtocolLimits (..) - , MiniProtocolDir (..) - -- * Errors - , MuxError (..) - , MuxErrorType (..) - -- * Tracing - , traceMuxBearerState - , MuxBearerState (..) - , MuxTrace (..) - , WithMuxBearer (..) - ) where - -import Data.ByteString.Lazy qualified as BL -import Data.Void (Void) - -import Control.Applicative (Alternative (..), (<|>)) -import Control.Concurrent.Class.MonadSTM.Strict -import Control.Monad -import Control.Monad.Class.MonadAsync -import Control.Monad.Class.MonadThrow -import Control.Monad.Class.MonadTimer.SI -import Control.Tracer - -import Network.Mux (StartOnDemandOrEagerly (..), newMux, runMiniProtocol, - runMux, stopMux, traceMuxBearerState) -import Network.Mux.Bearer -import Network.Mux.Channel -import Network.Mux.Trace -import Network.Mux.Types hiding (MiniProtocolInfo (..)) -import Network.Mux.Types qualified as Types - - -newtype MuxApplication (mode :: MuxMode) m a b = - MuxApplication [MuxMiniProtocol mode m a b] - -data MuxMiniProtocol (mode :: MuxMode) m a b = - MuxMiniProtocol { - miniProtocolNum :: !MiniProtocolNum, - miniProtocolLimits :: !MiniProtocolLimits, - miniProtocolRun :: !(RunMiniProtocol mode m a b) - } - -data RunMiniProtocol (mode :: MuxMode) m a b where - InitiatorProtocolOnly - -- Initiator application; most simple application will be @'runPeer'@ or - -- @'runPipelinedPeer'@ supplied with a codec and a @'Peer'@ for each - -- @ptcl@. But it allows to handle resources if just application of - -- @'runPeer'@ is not enough. It will be run as @'InitiatorDir'@. - :: (Channel m -> m (a, Maybe BL.ByteString)) - -> RunMiniProtocol InitiatorMode m a Void - - ResponderProtocolOnly - -- Responder application; similarly to the @'MuxInitiatorApplication'@ but it - -- will be run using @'ResponderDir'@. - :: (Channel m -> m (b, Maybe BL.ByteString)) - -> RunMiniProtocol ResponderMode m Void b - - InitiatorAndResponderProtocol - -- Initiator and server applications. - :: (Channel m -> m (a, Maybe BL.ByteString)) - -> (Channel m -> m (b, Maybe BL.ByteString)) - -> RunMiniProtocol InitiatorResponderMode m a b - - -muxStart - :: forall m mode a b. - ( MonadAsync m - , MonadFork m - , MonadLabelledSTM m - , Alternative (STM m) - , MonadThrow (STM m) - , MonadTimer m - , MonadMask m - ) - => Tracer m MuxTrace - -> MuxApplication mode m a b - -> MuxBearer m - -> m () -muxStart tracer muxapp bearer = do - mux <- newMux (toMiniProtocolBundle muxapp) - - resOps <- sequence - [ runMiniProtocol - mux - miniProtocolNum - ptclDir - StartEagerly - (\a -> do - r <- action a - return (r, Nothing) -- Compat interface doesn't do restarts - ) - | let MuxApplication ptcls = muxapp - , MuxMiniProtocol{miniProtocolNum, miniProtocolRun} <- ptcls - , (ptclDir, action) <- selectRunner miniProtocolRun - ] - - -- Wait for the first MuxApplication to finish, then stop the mux. - withAsync (runMux tracer mux bearer) $ \aid -> do - waitOnAny resOps - stopMux mux - wait aid - - where - waitOnAny :: [STM m (Either SomeException ())] -> m () - waitOnAny resOps = atomically $ void $ foldr (<|>) retry resOps - - toMiniProtocolBundle :: MuxApplication mode m a b -> MiniProtocolBundle mode - toMiniProtocolBundle (MuxApplication ptcls) = - MiniProtocolBundle - [ Types.MiniProtocolInfo { - Types.miniProtocolNum, - Types.miniProtocolDir, - Types.miniProtocolLimits - } - | MuxMiniProtocol { - miniProtocolNum, - miniProtocolLimits, - miniProtocolRun - } <- ptcls - , miniProtocolDir <- case miniProtocolRun of - InitiatorProtocolOnly{} -> [InitiatorDirectionOnly] - ResponderProtocolOnly{} -> [ResponderDirectionOnly] - InitiatorAndResponderProtocol{} -> [InitiatorDirection, ResponderDirection] - ] - - selectRunner :: RunMiniProtocol mode m a b - -> [(MiniProtocolDirection mode, Channel m -> m ())] - selectRunner (InitiatorProtocolOnly initiator) = - [(InitiatorDirectionOnly, void . initiator)] - selectRunner (ResponderProtocolOnly responder) = - [(ResponderDirectionOnly, void . responder)] - selectRunner (InitiatorAndResponderProtocol initiator responder) = - [(InitiatorDirection, void . initiator) - ,(ResponderDirection, void . responder)] diff --git a/network-mux/src/Network/Mux/Types.hs b/network-mux/src/Network/Mux/Types.hs index 4a2b93cf300..41caf9d0dd9 100644 --- a/network-mux/src/Network/Mux/Types.hs +++ b/network-mux/src/Network/Mux/Types.hs @@ -10,8 +10,7 @@ {-# LANGUAGE TypeFamilies #-} module Network.Mux.Types - ( MiniProtocolBundle (..) - , MiniProtocolInfo (..) + ( MiniProtocolInfo (..) , MiniProtocolNum (..) , MiniProtocolDirection (..) , MiniProtocolLimits (..) @@ -53,7 +52,7 @@ import GHC.Generics (Generic) import Control.Concurrent.Class.MonadSTM.Strict (StrictTVar) import Control.Monad.Class.MonadTime.SI -import Network.Mux.Channel (Channel (..)) +import Network.Mux.Channel (ByteChannel, Channel (..)) import Network.Mux.Timeout (TimeoutFn) @@ -113,27 +112,16 @@ type family HasResponder (mode :: MuxMode) :: Bool where HasResponder ResponderMode = True HasResponder InitiatorResponderMode = True --- | Application run by mux layer. +-- | A static description of a mini-protocol. -- --- * enumeration of client application, e.g. a wallet application communicating --- with a node using ChainSync and TxSubmission protocols; this only requires --- to run client side of each protocol. --- --- * enumeration of server applications: this application type is mostly useful --- tests. --- --- * enumeration of both client and server applications, e.g. a full node --- serving downstream peers using server side of each protocol and getting --- updates from upstream peers using client side of each of the protocols. --- -newtype MiniProtocolBundle (mode :: MuxMode) = - MiniProtocolBundle [MiniProtocolInfo mode] - data MiniProtocolInfo (mode :: MuxMode) = MiniProtocolInfo { miniProtocolNum :: !MiniProtocolNum, + -- ^ Unique mini-protocol number. miniProtocolDir :: !(MiniProtocolDirection mode), + -- ^ Mini-protocol direction. miniProtocolLimits :: !MiniProtocolLimits + -- ^ ingress queue limits for the protocol } data MiniProtocolDirection (mode :: MuxMode) where @@ -217,6 +205,8 @@ data MuxBearer m = MuxBearer { , read :: TimeoutFn m -> m (MuxSDU, Time) -- | Return a suitable MuxSDU payload size. , sduSize :: SDUSize + -- | Name of the bearer + , name :: String } newtype SDUSize = SDUSize { getSDUSize :: Word16 } @@ -231,7 +221,7 @@ muxBearerAsChannel => MuxBearer m -> MiniProtocolNum -> MiniProtocolDir - -> Channel m + -> ByteChannel m muxBearerAsChannel bearer ptclNum ptclDir = Channel { send = \blob -> void $ write bearer noTimeout (wrap blob), diff --git a/network-mux/test/Test/Mux.hs b/network-mux/test/Test/Mux.hs index 80aad090d7b..1cde5c9bcd1 100644 --- a/network-mux/test/Test/Mux.hs +++ b/network-mux/test/Test/Mux.hs @@ -66,11 +66,8 @@ import Network.Mux.Bearer import Network.Mux.Bearer.AttenuatedChannel as AttenuatedChannel import Network.Mux.Bearer.Pipe import Network.Mux.Bearer.Queues -import Network.Mux.Channel import Network.Mux.Codec -import Network.Mux.Compat qualified as Compat -import Network.Mux.Types (MiniProtocolDir (..), MuxSDU (..), MuxSDUHeader (..), - RemoteClockModel (..), muxBearerAsChannel) +import Network.Mux.Types import Network.Socket qualified as Socket import Text.Show.Functions () -- import qualified Debug.Trace as Debug @@ -236,8 +233,8 @@ instance Show InvalidSDU where (isRealLength a) (isPattern a) -data ArbitrarySDU = ArbitraryInvalidSDU InvalidSDU Compat.MuxErrorType - | ArbitraryValidSDU DummyPayload (Maybe Compat.MuxErrorType) +data ArbitrarySDU = ArbitraryInvalidSDU InvalidSDU MuxErrorType + | ArbitraryValidSDU DummyPayload (Maybe MuxErrorType) deriving Show instance Arbitrary ArbitrarySDU where @@ -259,7 +256,7 @@ instance Arbitrary ArbitrarySDU where -- This SDU is still considered valid, since the header itself will -- not cause a trouble, the error will be triggered by the fact that -- it is sent as a single message. - return $ ArbitraryValidSDU (DummyPayload pl) (Just Compat.MuxIngressQueueOverRun) + return $ ArbitraryValidSDU (DummyPayload pl) (Just MuxIngressQueueOverRun) unknownMiniProtocol = do ts <- arbitrary @@ -270,7 +267,7 @@ instance Arbitrary ArbitrarySDU where return $ ArbitraryInvalidSDU (InvalidSDU (RemoteClockModel ts) (mid .|. mode) len (8 + fromIntegral len) p) - Compat.MuxUnknownMiniProtocol + MuxUnknownMiniProtocol invalidLenght = do ts <- arbitrary mid <- arbitrary @@ -279,12 +276,10 @@ instance Arbitrary ArbitrarySDU where p <- arbitrary return $ ArbitraryInvalidSDU (InvalidSDU (RemoteClockModel ts) mid len realLen p) - Compat.MuxDecodeError + MuxDecodeError -instance Arbitrary Compat.MuxBearerState where - arbitrary = elements [ Compat.Mature - , Compat.Dead - ] +instance Arbitrary MuxBearerState where + arbitrary = elements [Mature, Dead] @@ -341,8 +336,8 @@ prop_mux_snd_recv (DummyRun messages) = ioProperty $ do serverTracer QueueChannel { writeQueue = server_w, readQueue = server_r } - clientTracer = contramap (Compat.WithMuxBearer "client") activeTracer - serverTracer = contramap (Compat.WithMuxBearer "server") activeTracer + clientTracer = contramap (WithMuxBearer "client") activeTracer + serverTracer = contramap (WithMuxBearer "server") activeTracer clientApp = MiniProtocolInfo { miniProtocolNum = MiniProtocolNum 2, @@ -356,9 +351,9 @@ prop_mux_snd_recv (DummyRun messages) = ioProperty $ do miniProtocolLimits = defaultMiniProtocolLimits } - clientMux <- newMux $ MiniProtocolBundle [clientApp] + clientMux <- newMux [clientApp] - serverMux <- newMux $ MiniProtocolBundle [serverApp] + serverMux <- newMux [serverApp] withAsync (runMux clientTracer clientMux clientBearer) $ \clientAsync -> withAsync (runMux serverTracer serverMux serverBearer) $ \serverAsync -> do @@ -398,8 +393,8 @@ prop_mux_snd_recv_bi (DummyRun messages) = ioProperty $ do let server_w = client_r server_r = client_w - clientTracer = contramap (Compat.WithMuxBearer "client") activeTracer - serverTracer = contramap (Compat.WithMuxBearer "server") activeTracer + clientTracer = contramap (WithMuxBearer "client") activeTracer + serverTracer = contramap (WithMuxBearer "server") activeTracer clientBearer <- getBearer makeQueueChannelBearer (-1) @@ -435,10 +430,10 @@ prop_mux_snd_recv_bi (DummyRun messages) = ioProperty $ do ] - clientMux <- newMux $ MiniProtocolBundle clientApps + clientMux <- newMux clientApps clientAsync <- async $ runMux clientTracer clientMux clientBearer - serverMux <- newMux $ MiniProtocolBundle serverApps + serverMux <- newMux serverApps serverAsync <- async $ runMux serverTracer serverMux serverBearer r <- step clientMux clientApps serverMux serverApps messages @@ -494,7 +489,7 @@ prop_mux_snd_recv_bi (DummyRun messages) = ioProperty $ do -- | Like prop_mux_snd_recv but using the Compat interface. prop_mux_snd_recv_compat :: DummyTrace - -> Property + -> Property prop_mux_snd_recv_compat messages = ioProperty $ do client_w <- atomically $ newTBQueue 10 client_r <- atomically $ newTBQueue 10 @@ -503,8 +498,8 @@ prop_mux_snd_recv_compat messages = ioProperty $ do let server_w = client_r server_r = client_w - clientTracer = contramap (Compat.WithMuxBearer "client") activeTracer - serverTracer = contramap (Compat.WithMuxBearer "server") activeTracer + clientTracer = contramap (WithMuxBearer "client") activeTracer + serverTracer = contramap (WithMuxBearer "server") activeTracer clientBearer <- getBearer makeQueueChannelBearer @@ -518,24 +513,53 @@ prop_mux_snd_recv_compat messages = ioProperty $ do (verify, client_mp, server_mp) <- setupMiniReqRspCompat (return ()) endMpsVar messages - let clientApp = Compat.MuxApplication - [ Compat.MuxMiniProtocol { - Compat.miniProtocolNum = Compat.MiniProtocolNum 2, - Compat.miniProtocolLimits = defaultMiniProtocolLimits, - Compat.miniProtocolRun = Compat.InitiatorProtocolOnly client_mp - } - ] + let clientBundle = [ MiniProtocolInfo { + miniProtocolNum = MiniProtocolNum 2, + miniProtocolLimits = defaultMiniProtocolLimits, + miniProtocolDir = InitiatorDirectionOnly } + ] - serverApp = Compat.MuxApplication - [ Compat.MuxMiniProtocol { - Compat.miniProtocolNum = Compat.MiniProtocolNum 2, - Compat.miniProtocolLimits = defaultMiniProtocolLimits, - Compat.miniProtocolRun = Compat.ResponderProtocolOnly server_mp - } - ] + serverBundle = [ MiniProtocolInfo { + miniProtocolNum = MiniProtocolNum 2, + miniProtocolLimits = defaultMiniProtocolLimits, + miniProtocolDir = ResponderDirectionOnly } + ] - clientAsync <- async $ Compat.muxStart clientTracer clientApp clientBearer - serverAsync <- async $ Compat.muxStart serverTracer serverApp serverBearer + clientAsync <- async $ do + clientMux <- newMux clientBundle + res <- runMiniProtocol + clientMux + (MiniProtocolNum 2) + InitiatorDirectionOnly + StartEagerly + (\chann -> do + r <- client_mp chann + return (r, Nothing) + ) + + -- Wait for the first MuxApplication to finish, then stop the mux. + withAsync (runMux clientTracer clientMux clientBearer) $ \aid -> do + _ <- atomically res + stopMux clientMux + wait aid + + serverAsync <- async $ do + serverMux <- newMux serverBundle + res <- runMiniProtocol + serverMux + (MiniProtocolNum 2) + ResponderDirectionOnly + StartEagerly + (\chann -> do + r <- server_mp chann + return (r, Nothing) + ) + + -- Wait for the first MuxApplication to finish, then stop the mux. + withAsync (runMux serverTracer serverMux serverBearer) $ \aid -> do + _ <- atomically res + stopMux serverMux + wait aid _ <- waitBoth clientAsync serverAsync property <$> verify @@ -551,8 +575,8 @@ setupMiniReqRspCompat :: IO () -> DummyTrace -- ^ Trace of messages -> IO ( IO Bool - , Channel IO -> IO ((), Maybe BL.ByteString) - , Channel IO -> IO ((), Maybe BL.ByteString) + , ByteChannel IO -> IO ((), Maybe BL.ByteString) + , ByteChannel IO -> IO ((), Maybe BL.ByteString) ) setupMiniReqRspCompat serverAction mpsEndVar (DummyTrace msgs) = do serverResultVar <- newEmptyTMVarIO @@ -591,7 +615,7 @@ setupMiniReqRspCompat serverAction mpsEndVar (DummyTrace msgs) = do go resps (req:reqs) = SendMsgReq req $ \resp -> return (go (resp:resps) reqs) clientApp :: StrictTMVar IO Bool - -> Channel IO + -> ByteChannel IO -> IO ((), Maybe BL.ByteString) clientApp clientResultVar clientChan = do (result, trailing) <- runClient nullTracer clientChan (reqRespClient requests) @@ -599,7 +623,7 @@ setupMiniReqRspCompat serverAction mpsEndVar (DummyTrace msgs) = do (,trailing) <$> end serverApp :: StrictTMVar IO Bool - -> Channel IO + -> ByteChannel IO -> IO ((), Maybe BL.ByteString) serverApp serverResultVar serverChan = do (result, trailing) <- runServer nullTracer serverChan (reqRespServer responses) @@ -626,8 +650,8 @@ setupMiniReqRsp :: IO () -- ^ Action performed by responder before processing the response -> DummyTrace -- ^ Trace of messages - -> IO ( Channel IO -> IO (Bool, Maybe BL.ByteString) - , Channel IO -> IO (Bool, Maybe BL.ByteString) + -> IO ( ByteChannel IO -> IO (Bool, Maybe BL.ByteString) + , ByteChannel IO -> IO (Bool, Maybe BL.ByteString) ) setupMiniReqRsp serverAction (DummyTrace msgs) = do @@ -658,11 +682,11 @@ setupMiniReqRsp serverAction (DummyTrace msgs) = do go resps [] = SendMsgDone (pure $ reverse resps == responses) go resps (req:reqs) = SendMsgReq req $ \resp -> return (go (resp:resps) reqs) - clientApp :: Channel IO + clientApp :: ByteChannel IO -> IO (Bool, Maybe BL.ByteString) clientApp clientChan = runClient nullTracer clientChan (reqRespClient requests) - serverApp :: Channel IO + serverApp :: ByteChannel IO -> IO (Bool, Maybe BL.ByteString) serverApp serverChan = runServer nullTracer serverChan (reqRespServer responses) @@ -672,24 +696,24 @@ setupMiniReqRsp serverAction (DummyTrace msgs) = do -- Run applications continuation type RunMuxApplications - = [Channel IO -> IO (Bool, Maybe BL.ByteString)] - -> [Channel IO -> IO (Bool, Maybe BL.ByteString)] + = [ByteChannel IO -> IO (Bool, Maybe BL.ByteString)] + -> [ByteChannel IO -> IO (Bool, Maybe BL.ByteString)] -> IO Bool -runMuxApplication :: [Channel IO -> IO (Bool, Maybe BL.ByteString)] +runMuxApplication :: [ByteChannel IO -> IO (Bool, Maybe BL.ByteString)] -> MuxBearer IO - -> [Channel IO -> IO (Bool, Maybe BL.ByteString)] + -> [ByteChannel IO -> IO (Bool, Maybe BL.ByteString)] -> MuxBearer IO -> IO Bool runMuxApplication initApps initBearer respApps respBearer = do - let clientTracer = contramap (Compat.WithMuxBearer "client") activeTracer - serverTracer = contramap (Compat.WithMuxBearer "server") activeTracer + let clientTracer = contramap (WithMuxBearer "client") activeTracer + serverTracer = contramap (WithMuxBearer "server") activeTracer protNum = [1..] respApps' = zip protNum respApps initApps' = zip protNum initApps - respMux <- newMux $ MiniProtocolBundle $ map (\(pn,_) -> + respMux <- newMux $ map (\(pn,_) -> MiniProtocolInfo (MiniProtocolNum pn) ResponderDirectionOnly defaultMiniProtocolLimits) respApps' respAsync <- async $ runMux serverTracer respMux respBearer @@ -702,7 +726,7 @@ runMuxApplication initApps initBearer respApps respBearer = do | (pn, app) <- respApps' ] - initMux <- newMux $ MiniProtocolBundle $ map (\(pn,_) -> + initMux <- newMux $ map (\(pn,_) -> MiniProtocolInfo (MiniProtocolNum pn) InitiatorDirectionOnly defaultMiniProtocolLimits) initApps' initAsync <- async $ runMux clientTracer initMux initBearer @@ -738,8 +762,8 @@ runWithQueues initApps respApps = do let server_w = client_r server_r = client_w - clientTracer = contramap (Compat.WithMuxBearer "client") activeTracer - serverTracer = contramap (Compat.WithMuxBearer "server") activeTracer + clientTracer = contramap (WithMuxBearer "client") activeTracer + serverTracer = contramap (WithMuxBearer "server") activeTracer clientBearer <- getBearer makeQueueChannelBearer (-1) @@ -807,8 +831,8 @@ runWithPipe initApps respApps = #endif where - clientTracer = contramap (Compat.WithMuxBearer "client") activeTracer - serverTracer = contramap (Compat.WithMuxBearer "server") activeTracer + clientTracer = contramap (WithMuxBearer "client") activeTracer + serverTracer = contramap (WithMuxBearer "server") activeTracer -- | Verify that it is possible to run two miniprotocols over the same bearer. @@ -875,15 +899,15 @@ prop_mux_starvation (Uneven response0 response1) = traceHeaderVar <- newTVarIO [] let headerTracer = Tracer $ \e -> case e of - Compat.MuxTraceRecvHeaderEnd header + MuxTraceRecvHeaderEnd header -> atomically (modifyTVar traceHeaderVar (header:)) _ -> return () let server_w = client_r server_r = client_w - clientTracer = contramap (Compat.WithMuxBearer "client") activeTracer - serverTracer = contramap (Compat.WithMuxBearer "server") activeTracer + clientTracer = contramap (WithMuxBearer "client") activeTracer + serverTracer = contramap (WithMuxBearer "server") activeTracer clientBearer <- getBearer makeQueueChannelBearer @@ -924,14 +948,14 @@ prop_mux_starvation (Uneven response0 response1) = miniProtocolLimits = defaultMiniProtocolLimits } - serverMux <- newMux $ MiniProtocolBundle [serverApp2, serverApp3] + serverMux <- newMux [serverApp2, serverApp3] serverMux_aid <- async $ runMux serverTracer serverMux serverBearer serverRes2 <- runMiniProtocol serverMux (miniProtocolNum serverApp2) (miniProtocolDir serverApp2) StartOnDemand server_short serverRes3 <- runMiniProtocol serverMux (miniProtocolNum serverApp3) (miniProtocolDir serverApp3) StartOnDemand server_long - clientMux <- newMux $ MiniProtocolBundle [clientApp2, clientApp3] + clientMux <- newMux [clientApp2, clientApp3] clientMux_aid <- async $ runMux (clientTracer <> headerTracer) clientMux clientBearer clientRes2 <- runMiniProtocol clientMux (miniProtocolNum clientApp2) (miniProtocolDir clientApp2) StartEagerly client_short @@ -1025,10 +1049,10 @@ prop_demux_sdu a = do return $ tabulate "SDU type" [stateLabel a] $ tabulate "SDU Violation " [violationLabel a] r where - run (ArbitraryValidSDU sdu (Just Compat.MuxIngressQueueOverRun)) = do + run (ArbitraryValidSDU sdu (Just MuxIngressQueueOverRun)) = do stopVar <- newEmptyTMVarIO - -- To trigger Compat.MuxIngressQueueOverRun we use a special test protocol + -- To trigger MuxIngressQueueOverRun we use a special test protocol -- with an ingress queue which is less than 0xffff so that it can be -- triggered by a single segment. let server_mps = MiniProtocolInfo { @@ -1049,7 +1073,7 @@ prop_demux_sdu a = do case res of Left e -> case fromException e of - Just me -> return $ Compat.errorType me === Compat.MuxIngressQueueOverRun + Just me -> return $ errorType me === MuxIngressQueueOverRun Nothing -> return $ property False Right _ -> return $ property False @@ -1074,7 +1098,7 @@ prop_demux_sdu a = do Left e -> case fromException e of Just me -> case err_m of - Just err -> return $ Compat.errorType me === err + Just err -> return $ errorType me === err Nothing -> return $ property False Nothing -> return $ property False Right _ -> return $ err_m === Nothing @@ -1104,7 +1128,7 @@ prop_demux_sdu a = do case res of Left e -> case fromException e of - Just me -> return $ Compat.errorType me === err + Just me -> return $ errorType me === err Nothing -> return $ counterexample ("unexpected: " ++ show e) False Right _ -> return $ counterexample "expected an exception" False @@ -1112,7 +1136,7 @@ prop_demux_sdu a = do server_w <- atomically $ newTBQueue 10 server_r <- atomically $ newTBQueue 10 - let serverTracer = contramap (Compat.WithMuxBearer "server") activeTracer + let serverTracer = contramap (WithMuxBearer "server") activeTracer serverBearer <- getBearer makeQueueChannelBearer (-1) @@ -1121,7 +1145,7 @@ prop_demux_sdu a = do readQueue = server_r } - serverMux <- newMux $ MiniProtocolBundle [serverApp] + serverMux <- newMux [serverApp] serverRes <- runMiniProtocol serverMux (miniProtocolNum serverApp) (miniProtocolDir serverApp) StartEagerly server_mp @@ -1149,8 +1173,8 @@ prop_demux_sdu a = do sdu' = MuxSDU (MuxSDUHeader (RemoteClockModel 0) - (Compat.MiniProtocolNum 2) - Compat.InitiatorDir + (MiniProtocolNum 2) + InitiatorDir (fromIntegral $ BL.length frag)) frag !pkt = encodeMuxSDU (sdu' :: MuxSDU) @@ -1164,11 +1188,11 @@ prop_demux_sdu a = do violationLabel (ArbitraryValidSDU _ err_m) = sduViolation err_m violationLabel (ArbitraryInvalidSDU _ err) = sduViolation $ Just err - sduViolation (Just Compat.MuxUnknownMiniProtocol) = "unknown miniprotocol" - sduViolation (Just Compat.MuxDecodeError ) = "decode error" - sduViolation (Just Compat.MuxIngressQueueOverRun) = "ingress queue overrun" - sduViolation (Just _ ) = "unknown violation" - sduViolation Nothing = "none" + sduViolation (Just MuxUnknownMiniProtocol) = "unknown miniprotocol" + sduViolation (Just MuxDecodeError ) = "decode error" + sduViolation (Just MuxIngressQueueOverRun) = "ingress queue overrun" + sduViolation (Just _ ) = "unknown violation" + sduViolation Nothing = "none" prop_demux_sdu_sim :: ArbitrarySDU -> Property @@ -1260,7 +1284,7 @@ dummyAppToChannel :: forall m. , MonadCatch m ) => DummyApp - -> (Channel m -> m ((), Maybe BL.ByteString)) + -> (ByteChannel m -> m ((), Maybe BL.ByteString)) dummyAppToChannel DummyApp {daAction, daRunTime} = \_ -> do threadDelay daRunTime case daAction of @@ -1295,7 +1319,7 @@ dummyRestartingAppToChannel :: forall a m. , MonadDelay m ) => (DummyApp, a) - -> (Channel m -> m ((DummyApp, a), Maybe BL.ByteString)) + -> (ByteChannel m -> m ((DummyApp, a), Maybe BL.ByteString)) dummyRestartingAppToChannel (app, r) = \_ -> do threadDelay $ daRunTime app case daAction app of @@ -1397,9 +1421,9 @@ prop_mux_restart_m (DummyRestartingInitiatorApps apps) = do (-1) nullTracer QueueChannel { writeQueue = mux_w, readQueue = mux_r } - let MiniProtocolBundle minis = MiniProtocolBundle $ map (appToInfo InitiatorDirectionOnly . fst) apps + let minis = map (appToInfo InitiatorDirectionOnly . fst) apps - mux <- newMux $ MiniProtocolBundle minis + mux <- newMux minis mux_aid <- async $ runMux nullTracer mux bearer getRes <- sequence [ runMiniProtocol mux @@ -1444,9 +1468,9 @@ prop_mux_restart_m (DummyRestartingResponderApps rapps) = do nullTracer QueueChannel { writeQueue = mux_r, readQueue = mux_w } let apps = map fst rapps - MiniProtocolBundle minis = MiniProtocolBundle $ map (appToInfo ResponderDirectionOnly) apps + minis = map (appToInfo ResponderDirectionOnly) apps - mux <- newMux $ MiniProtocolBundle minis + mux <- newMux minis mux_aid <- async $ runMux nullTracer mux bearer getRes <- sequence [ runMiniProtocol mux @@ -1495,7 +1519,7 @@ prop_mux_restart_m (DummyRestartingInitiatorResponderApps rapps) = do initMinis = map (appToInfo InitiatorDirection) apps respMinis = map (appToInfo ResponderDirection) apps - mux <- newMux $ MiniProtocolBundle $ initMinis ++ respMinis + mux <- newMux $ initMinis ++ respMinis mux_aid <- async $ runMux nullTracer mux bearer getInitRes <- sequence [ runMiniProtocol mux @@ -1567,10 +1591,10 @@ prop_mux_start_m :: forall m. -> DiffTime -> m Property prop_mux_start_m bearer _ checkRes (DummyInitiatorApps apps) runTime = do - let MiniProtocolBundle minis = MiniProtocolBundle $ map (appToInfo InitiatorDirectionOnly) apps + let minis = map (appToInfo InitiatorDirectionOnly) apps minRunTime = minimum $ runTime : (map daRunTime $ filter (\app -> daAction app == DummyAppFail) apps) - mux <- newMux $ MiniProtocolBundle minis + mux <- newMux minis mux_aid <- async $ runMux nullTracer mux bearer killer <- async $ (threadDelay runTime) >> stopMux mux getRes <- sequence [ runMiniProtocol @@ -1588,10 +1612,10 @@ prop_mux_start_m bearer _ checkRes (DummyInitiatorApps apps) runTime = do return (conjoin $ map fst rc) prop_mux_start_m bearer trigger checkRes (DummyResponderApps apps) runTime = do - let MiniProtocolBundle minis = MiniProtocolBundle $ map (appToInfo ResponderDirectionOnly) apps + let minis = map (appToInfo ResponderDirectionOnly) apps minRunTime = minimum $ runTime : (map (\a -> daRunTime a + daStartAfter a) $ filter (\app -> daAction app == DummyAppFail) apps) - mux <- newMux $ MiniProtocolBundle minis + mux <- newMux minis mux_aid <- async $ runMux verboseTracer mux bearer getRes <- sequence [ runMiniProtocol mux @@ -1615,9 +1639,9 @@ prop_mux_start_m bearer _trigger _checkRes (DummyResponderAppsKillMux apps) runT -- Start a mini-protocol on demand, but kill mux before the application is -- triggered. This test assures that mini-protocol completion action does -- not deadlocks. - let MiniProtocolBundle minis = MiniProtocolBundle $ map (appToInfo ResponderDirectionOnly) apps + let minis = map (appToInfo ResponderDirectionOnly) apps - mux <- newMux $ MiniProtocolBundle minis + mux <- newMux minis mux_aid <- async $ runMux verboseTracer mux bearer getRes <- sequence [ runMiniProtocol mux @@ -1640,7 +1664,7 @@ prop_mux_start_m bearer trigger checkRes (DummyInitiatorResponderApps apps) runT respMinis = map (appToInfo ResponderDirection) apps minRunTime = minimum $ runTime : (map (\a -> daRunTime a) $ filter (\app -> daAction app == DummyAppFail) apps) - mux <- newMux $ MiniProtocolBundle $ initMinis ++ respMinis + mux <- newMux $ initMinis ++ respMinis mux_aid <- async $ runMux verboseTracer mux bearer getInitRes <- sequence [ runMiniProtocol mux @@ -1795,8 +1819,7 @@ close_experiment fault tracer muxTracer clientCtx serverCtx reqs0 fn acc0 = do withAsync -- run client thread - (bracket (newMux $ MiniProtocolBundle - [ MiniProtocolInfo { + (bracket (newMux [ MiniProtocolInfo { miniProtocolNum, miniProtocolDir = InitiatorDirectionOnly, miniProtocolLimits = MiniProtocolLimits maxBound @@ -1814,8 +1837,7 @@ close_experiment $ \clientAsync -> withAsync -- run server thread - (bracket ( newMux $ MiniProtocolBundle - [ MiniProtocolInfo { + (bracket ( newMux [ MiniProtocolInfo { miniProtocolNum, miniProtocolDir = ResponderDirectionOnly, miniProtocolLimits = MiniProtocolLimits maxBound diff --git a/network-mux/test/Test/Mux/ReqResp.hs b/network-mux/test/Test/Mux/ReqResp.hs index 352ee95156f..f36b20e778e 100644 --- a/network-mux/test/Test/Mux/ReqResp.hs +++ b/network-mux/test/Test/Mux/ReqResp.hs @@ -79,7 +79,7 @@ data TraceSendRecv msg runDecoderWithChannel :: forall m a. MonadST m - => Channel m + => ByteChannel m -> Maybe LBS.ByteString -> Decoder (PrimState m) a -> m (Either CBOR.DeserialiseFailure (a, Maybe LBS.ByteString)) @@ -104,7 +104,7 @@ runDecoderWithChannel Channel{recv} trailing decoder = --- | Run a client using a byte 'Channel'. +-- | Run a client using a byte 'ByteChannel'. -- runClient :: forall req resp m a. ( MonadST m @@ -114,7 +114,7 @@ runClient :: forall req resp m a. , Show resp ) => Tracer m (TraceSendRecv (MsgReqResp req resp)) - -> Channel m + -> ByteChannel m -> ReqRespClient req resp m a -> m (a, Maybe LBS.ByteString) @@ -177,7 +177,7 @@ runServer :: forall req resp m a. , Show resp ) => Tracer m (TraceSendRecv (MsgReqResp req resp)) - -> Channel m + -> ByteChannel m -> ReqRespServer req resp m a -> m (a, Maybe LBS.ByteString) diff --git a/ouroboros-network-framework/CHANGELOG.md b/ouroboros-network-framework/CHANGELOG.md index 2e871ad3257..23f77a77bc8 100644 --- a/ouroboros-network-framework/CHANGELOG.md +++ b/ouroboros-network-framework/CHANGELOG.md @@ -6,10 +6,21 @@ * Added `createConnectedBufferedChannelsUnbounded`. * Use `typed-protocols-0.2.0.0`. +* Removed `Ouroboros.Network.Mux.toApplication` +* Renamed `Ouroboros.Network.Mux.mkMiniProtocolBundle` as `mkMiniProtocolInfos` + (its type has changed). +* Added `Ouroboros.Network.Mux.toMiniProtocolInfos`. +* Added `ConnectToArgs` for `Ouroboros.Network.Socket.connectToNode` & friends. +* `Ouroboros.Network.Socket.connectToNode` & friends return result (or an + error) of the first terminated mini-protocol. +* Added `Ouroboros.Network.Socket.connectToNodeWithMux` and + `connectToNodeWithMux'`. They give control over running mux, e.g. one can + start some of the mini-protocols, or implement a re-start policy. ### Non-breaking changes * Added tracing on CM connVars for testing purposes. +* Improved haddocks of `Hanshake` protocol codec. ## 0.13.2.4 -- 2024-08-27 diff --git a/ouroboros-network-framework/demo/ping-pong.hs b/ouroboros-network-framework/demo/ping-pong.hs index b4c9939f418..67ee72d521e 100644 --- a/ouroboros-network-framework/demo/ping-pong.hs +++ b/ouroboros-network-framework/demo/ping-pong.hs @@ -114,15 +114,18 @@ demoProtocol0 pingPong = clientPingPong :: Bool -> IO () clientPingPong pipelined = withIOManager $ \iomgr -> + void $ connectToNode (Snocket.localSnocket iomgr) makeLocalBearer + ConnectToArgs { + ctaHandshakeCodec = unversionedHandshakeCodec, + ctaHandshakeTimeLimits = noTimeLimitsHandshake, + ctaVersionDataCodec = unversionedProtocolDataCodec, + ctaConnectTracers = nullNetworkConnectTracers, + ctaHandshakeCallbacks = HandshakeCallbacks acceptableVersion queryVersion + } mempty - unversionedHandshakeCodec - noTimeLimitsHandshake - unversionedProtocolDataCodec - nullNetworkConnectTracers - (HandshakeCallbacks acceptableVersion queryVersion) (unversionedProtocol app) Nothing defaultLocalSocketAddr @@ -206,16 +209,18 @@ demoProtocol1 pingPong pingPong' = clientPingPong2 :: IO () clientPingPong2 = - withIOManager $ \iomgr -> do + withIOManager $ \iomgr -> void $ do connectToNode (Snocket.localSnocket iomgr) makeLocalBearer + ConnectToArgs { + ctaHandshakeCodec = unversionedHandshakeCodec, + ctaHandshakeTimeLimits = noTimeLimitsHandshake, + ctaVersionDataCodec = unversionedProtocolDataCodec, + ctaConnectTracers = nullNetworkConnectTracers, + ctaHandshakeCallbacks = HandshakeCallbacks acceptableVersion queryVersion + } mempty - unversionedHandshakeCodec - noTimeLimitsHandshake - unversionedProtocolDataCodec - nullNetworkConnectTracers - (HandshakeCallbacks acceptableVersion queryVersion) (unversionedProtocol app) Nothing defaultLocalSocketAddr diff --git a/ouroboros-network-framework/io-tests/Test/Ouroboros/Network/Socket.hs b/ouroboros-network-framework/io-tests/Test/Ouroboros/Network/Socket.hs index 072ba2c15ad..45d1e0fe8eb 100644 --- a/ouroboros-network-framework/io-tests/Test/Ouroboros/Network/Socket.hs +++ b/ouroboros-network-framework/io-tests/Test/Ouroboros/Network/Socket.hs @@ -3,17 +3,18 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TupleSections #-} -{-# LANGUAGE TypeApplications #-} {-# OPTIONS_GHC -Wno-orphans #-} module Test.Ouroboros.Network.Socket (tests) where import Data.Bifoldable (bitraverse_) import Data.ByteString.Lazy qualified as BL +import Data.Either (fromRight) import Data.List (mapAccumL) +import Data.Monoid.Synchronisation (FirstToFinish (..)) import Data.Time.Clock (UTCTime, getCurrentTime) import Data.Void (Void) #ifndef mingw32_HOST_OS @@ -55,12 +56,10 @@ import Ouroboros.Network.Socket -- TODO: remove Mx prefixes import Ouroboros.Network.Mux -import Network.Mux qualified as Mx (MuxError (..), MuxErrorType (..)) +import Network.Mux qualified as Mx import Network.Mux.Bearer qualified as Mx -import Network.Mux.Compat qualified as Mx (muxStart) -import Network.Mux.Timeout -import Network.Mux.Types (MiniProtocolDir (..), MuxSDU (..), MuxSDUHeader (..), - RemoteClockModel (..), write) +import Network.Mux.Timeout qualified as Mx +import Network.Mux.Types qualified as Mx import Ouroboros.Network.Protocol.Handshake.Codec import Ouroboros.Network.Protocol.Handshake.Unversioned @@ -174,7 +173,7 @@ prop_socket_send_recv_unix request response = ioProperty $ do mempty request response cleanUp serverName cleanUp clientName - return $ r + return r where cleanUp name = do catchJust (\e -> if isDoesNotExistErrorType (ioeGetErrorType e) then Just () else Nothing) @@ -254,15 +253,17 @@ prop_socket_send_recv initiatorAddr responderAddr configureSock f xs = (unversionedProtocol (SomeResponderApplication responderApp)) nullErrorPolicies $ \_ _ -> do - connectToNode + void $ connectToNode snocket Mx.makeSocketBearer - (flip configureSock Nothing) - unversionedHandshakeCodec - noTimeLimitsHandshake - unversionedProtocolDataCodec - (NetworkConnectTracers activeMuxTracer nullTracer) - (HandshakeCallbacks acceptableVersion queryVersion) + ConnectToArgs { + ctaHandshakeCodec = unversionedHandshakeCodec, + ctaHandshakeTimeLimits = noTimeLimitsHandshake, + ctaVersionDataCodec = unversionedProtocolDataCodec, + ctaConnectTracers = NetworkConnectTracers activeMuxTracer nullTracer, + ctaHandshakeCallbacks = HandshakeCallbacks acceptableVersion queryVersion + } + (`configureSock` Nothing) (unversionedProtocol initiatorApp) (Just initiatorAddr) responderAddr @@ -352,10 +353,30 @@ prop_socket_recv_error f rerr = _ <- async $ do threadDelay 0.1 atomically $ putTMVar lock () - Mx.muxStart nullTracer (toApplication MinimalInitiatorContext { micConnectionId = connectionId } - ResponderContext { rcConnectionId = connectionId } - app) - bearer + mux <- Mx.newMux (toMiniProtocolInfos app) + let respCtx = ResponderContext connectionId + resOps <- sequence + [ Mx.runMiniProtocol + mux + miniProtocolNum + miniProtocolDir + Mx.StartEagerly + (\a -> do + r <- action a + return (r, Nothing) + ) + | MiniProtocol{miniProtocolNum, miniProtocolRun} + <- getOuroborosApplication app + , (miniProtocolDir, action) <- + case miniProtocolRun of + ResponderProtocolOnly initiator -> + [(Mx.ResponderDirectionOnly, void . runMiniProtocolCb initiator respCtx)] + ] + + withAsync (Mx.runMux nullTracer mux bearer) $ \aid -> do + _ <- atomically $ runFirstToFinish $ foldMap FirstToFinish resOps + Mx.stopMux mux + wait aid ) $ \muxAsync -> do @@ -429,10 +450,10 @@ prop_socket_send_error rerr = else (-1) -- No timeout blob = BL.pack $ replicate 0xffff 0xa5 bearer <- Mx.getBearer Mx.makeSocketBearer sduTimeout nullTracer sd' - withTimeoutSerial $ \timeout -> + Mx.withTimeoutSerial $ \timeout -> -- send maximum mux sdus until we've filled the window. replicateM 100 $ do - ((), Nothing) <$ write bearer timeout (wrap blob ResponderDir (MiniProtocolNum 0)) + ((), Nothing) <$ Mx.write bearer timeout (wrap blob Mx.ResponderDir (MiniProtocolNum 0)) ) $ \muxAsync -> do @@ -459,16 +480,16 @@ prop_socket_send_error rerr = return result where -- wrap a 'ByteString' as 'MuxSDU' - wrap :: BL.ByteString -> MiniProtocolDir -> MiniProtocolNum -> MuxSDU - wrap blob ptclDir ptclNum = MuxSDU { + wrap :: BL.ByteString -> Mx.MiniProtocolDir -> MiniProtocolNum -> Mx.MuxSDU + wrap blob ptclDir ptclNum = Mx.MuxSDU { -- it will be filled when the 'MuxSDU' is send by the 'bearer' - msHeader = MuxSDUHeader { - mhTimestamp = RemoteClockModel 0, - mhNum = ptclNum, - mhDir = ptclDir, - mhLength = fromIntegral $ BL.length blob + Mx.msHeader = Mx.MuxSDUHeader { + Mx.mhTimestamp = Mx.RemoteClockModel 0, + Mx.mhNum = ptclNum, + Mx.mhDir = ptclDir, + Mx.mhLength = fromIntegral $ BL.length blob }, - msBlob = blob + Mx.msBlob = blob } prop_socket_client_connect_error :: (Int -> Int -> (Int, Int)) @@ -501,18 +522,20 @@ prop_socket_client_connect_error _ xs = <- try $ False <$ connectToNode (socketSnocket iomgr) Mx.makeSocketBearer - (flip configureSocket Nothing) - unversionedHandshakeCodec - noTimeLimitsHandshake - unversionedProtocolDataCodec - nullNetworkConnectTracers - (HandshakeCallbacks acceptableVersion queryVersion) + ConnectToArgs { + ctaHandshakeCodec = unversionedHandshakeCodec, + ctaHandshakeTimeLimits = noTimeLimitsHandshake, + ctaVersionDataCodec = unversionedProtocolDataCodec, + ctaConnectTracers = nullNetworkConnectTracers, + ctaHandshakeCallbacks = HandshakeCallbacks acceptableVersion queryVersion + } + (`configureSocket` Nothing) (unversionedProtocol app) (Just $ Socket.addrAddress clientAddr) (Socket.addrAddress serverAddr) -- XXX Disregarding the exact exception type - pure $ either (const True) id res + pure $ fromRight True res diff --git a/ouroboros-network-framework/io-tests/Test/Ouroboros/Network/Subscription.hs b/ouroboros-network-framework/io-tests/Test/Ouroboros/Network/Subscription.hs index 5d0f7c5320d..1b2dd8c05f8 100644 --- a/ouroboros-network-framework/io-tests/Test/Ouroboros/Network/Subscription.hs +++ b/ouroboros-network-framework/io-tests/Test/Ouroboros/Network/Subscription.hs @@ -632,13 +632,15 @@ prop_send_recv f xs _first = ioProperty $ withIOManager $ \iocp -> do } (\_ -> waitSiblingSTM siblingVar) (connectToNodeSocket - iocp - unversionedHandshakeCodec - noTimeLimitsHandshake - unversionedProtocolDataCodec - nullNetworkConnectTracers - (HandshakeCallbacks acceptableVersion queryVersion) - (unversionedProtocol initiatorApp)) + iocp + ConnectToArgs { + ctaHandshakeCodec = unversionedHandshakeCodec, + ctaHandshakeTimeLimits = noTimeLimitsHandshake, + ctaVersionDataCodec = unversionedProtocolDataCodec, + ctaConnectTracers = nullNetworkConnectTracers, + ctaHandshakeCallbacks = HandshakeCallbacks acceptableVersion queryVersion + } + (unversionedProtocol initiatorApp)) res <- atomically $ (,) <$> takeTMVar sv <*> takeTMVar cv return (res == L.mapAccumL f 0 xs) @@ -806,13 +808,15 @@ prop_send_recv_init_and_rsp f xs = ioProperty $ withIOManager $ \iocp -> do nullErrorPolicies (\_ -> waitSiblingSTM (rrcSiblingVar rrcfg)) (connectToNodeSocket - iocp - unversionedHandshakeCodec - noTimeLimitsHandshake - unversionedProtocolDataCodec - nullNetworkConnectTracers - (HandshakeCallbacks acceptableVersion queryVersion) - (unversionedProtocol (appX rrcfg))) + iocp + ConnectToArgs { + ctaHandshakeCodec = unversionedHandshakeCodec, + ctaHandshakeTimeLimits = noTimeLimitsHandshake, + ctaVersionDataCodec = unversionedProtocolDataCodec, + ctaConnectTracers = nullNetworkConnectTracers, + ctaHandshakeCallbacks = HandshakeCallbacks acceptableVersion queryVersion + } + (unversionedProtocol (appX rrcfg))) atomically $ (,) <$> takeTMVar (rrcServerVar rrcfg) <*> takeTMVar (rrcClientVar rrcfg) @@ -879,13 +883,15 @@ _demo = ioProperty $ withIOManager $ \iocp -> do } (connectToNodeSocket - iocp - unversionedHandshakeCodec - noTimeLimitsHandshake - unversionedProtocolDataCodec - nullNetworkConnectTracers - (HandshakeCallbacks acceptableVersion queryVersion) - (unversionedProtocol appReq)) + iocp + ConnectToArgs { + ctaHandshakeCodec = unversionedHandshakeCodec, + ctaHandshakeTimeLimits = noTimeLimitsHandshake, + ctaVersionDataCodec = unversionedProtocolDataCodec, + ctaConnectTracers = nullNetworkConnectTracers, + ctaHandshakeCallbacks = HandshakeCallbacks acceptableVersion queryVersion + } + (unversionedProtocol appReq)) threadDelay 130 -- bring the servers back again diff --git a/ouroboros-network-framework/ouroboros-network-framework.cabal b/ouroboros-network-framework/ouroboros-network-framework.cabal index b092fd38c4e..a8ea399c408 100644 --- a/ouroboros-network-framework/ouroboros-network-framework.cabal +++ b/ouroboros-network-framework/ouroboros-network-framework.cabal @@ -253,6 +253,7 @@ test-suite io-tests io-classes, io-sim, iproute, + monoidal-synchronisation, network, network-mux, ouroboros-network-api, diff --git a/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/ConnectionManager.hs b/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/ConnectionManager.hs index baa673a57f9..a682c26e318 100644 --- a/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/ConnectionManager.hs +++ b/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/ConnectionManager.hs @@ -345,7 +345,8 @@ makeFDBearer = MakeBearer $ \_ _ _ -> return MuxBearer { write = \_ _ -> getMonotonicTime, read = \_ -> forever (threadDelay 3600), - sduSize = SDUSize 1500 + sduSize = SDUSize 1500, + name = "FD" } -- | We only keep exceptions here which should not be handled by the test diff --git a/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Socket.hs b/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Socket.hs index eab489bee10..0eb9001f9d8 100644 --- a/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Socket.hs +++ b/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Socket.hs @@ -3,6 +3,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} @@ -14,6 +15,7 @@ module Test.Ouroboros.Network.Socket (tests) where import Data.Bifoldable (bitraverse_) import Data.ByteString.Lazy qualified as BL import Data.List (mapAccumL) +import Data.Monoid.Synchronisation (FirstToFinish (..)) import Data.Time.Clock (UTCTime, getCurrentTime) import Data.Void (Void) #ifndef mingw32_HOST_OS @@ -55,9 +57,8 @@ import Ouroboros.Network.Socket -- TODO: remove Mx prefixes import Ouroboros.Network.Mux -import Network.Mux qualified as Mx (MuxError (..), MuxErrorType (..)) +import Network.Mux qualified as Mx import Network.Mux.Bearer qualified as Mx -import Network.Mux.Compat qualified as Mx (muxStart) import Network.Mux.Timeout import Network.Mux.Types (MiniProtocolDir (..), MuxSDU (..), MuxSDUHeader (..), RemoteClockModel (..), write) @@ -254,15 +255,17 @@ prop_socket_send_recv initiatorAddr responderAddr configureSock f xs = (unversionedProtocol (SomeResponderApplication responderApp)) nullErrorPolicies $ \_ _ -> do - connectToNode + void $ connectToNode snocket Mx.makeSocketBearer - (flip configureSock Nothing) - unversionedHandshakeCodec - noTimeLimitsHandshake - unversionedProtocolDataCodec - (NetworkConnectTracers activeMuxTracer nullTracer) - (HandshakeCallbacks acceptableVersion queryVersion) + ConnectToArgs { + ctaHandshakeCodec = unversionedHandshakeCodec, + ctaHandshakeTimeLimits = noTimeLimitsHandshake, + ctaVersionDataCodec = unversionedProtocolDataCodec, + ctaConnectTracers = NetworkConnectTracers activeMuxTracer nullTracer, + ctaHandshakeCallbacks = HandshakeCallbacks acceptableVersion queryVersion + } + (`configureSock` Nothing) (unversionedProtocol initiatorApp) (Just initiatorAddr) responderAddr @@ -352,10 +355,30 @@ prop_socket_recv_error f rerr = _ <- async $ do threadDelay 0.1 atomically $ putTMVar lock () - Mx.muxStart nullTracer (toApplication MinimalInitiatorContext { micConnectionId = connectionId } - ResponderContext { rcConnectionId = connectionId } - app) - bearer + mux <- Mx.newMux (toMiniProtocolInfos app) + let respCtx = ResponderContext connectionId + resOps <- sequence + [ Mx.runMiniProtocol + mux + miniProtocolNum + miniProtocolDir + Mx.StartEagerly + (\a -> do + r <- action a + return (r, Nothing) + ) + | MiniProtocol{miniProtocolNum, miniProtocolRun} + <- getOuroborosApplication app + , (miniProtocolDir, action) <- + case miniProtocolRun of + ResponderProtocolOnly initiator -> + [(Mx.ResponderDirectionOnly, void . runMiniProtocolCb initiator respCtx)] + ] + + withAsync (Mx.runMux nullTracer mux bearer) $ \aid -> do + _ <- atomically $ runFirstToFinish $ foldMap FirstToFinish resOps + Mx.stopMux mux + wait aid ) $ \muxAsync -> do @@ -502,12 +525,14 @@ prop_socket_client_connect_error _ xs = <- try $ False <$ connectToNode (socketSnocket iomgr) Mx.makeSocketBearer + ConnectToArgs { + ctaHandshakeCodec = unversionedHandshakeCodec, + ctaHandshakeTimeLimits = noTimeLimitsHandshake, + ctaVersionDataCodec = unversionedProtocolDataCodec, + ctaConnectTracers = NetworkConnectTracers activeMuxTracer nullTracer, + ctaHandshakeCallbacks = HandshakeCallbacks acceptableVersion queryVersion + } (flip configureSocket Nothing) - unversionedHandshakeCodec - noTimeLimitsHandshake - unversionedProtocolDataCodec - nullNetworkConnectTracers - (HandshakeCallbacks acceptableVersion queryVersion) (unversionedProtocol app) (Just $ Socket.addrAddress clientAddr) (Socket.addrAddress serverAddr) diff --git a/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Subscription.hs b/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Subscription.hs index 14c4ca3b977..78249479458 100644 --- a/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Subscription.hs +++ b/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Subscription.hs @@ -626,13 +626,15 @@ prop_send_recv f xs _first = ioProperty $ withIOManager $ \iocp -> do } (\_ -> waitSiblingSTM siblingVar) (connectToNodeSocket - iocp - unversionedHandshakeCodec - noTimeLimitsHandshake - unversionedProtocolDataCodec - nullNetworkConnectTracers - (HandshakeCallbacks acceptableVersion queryVersion) - (unversionedProtocol initiatorApp)) + iocp + ConnectToArgs { + ctaHandshakeCodec = unversionedHandshakeCodec, + ctaHandshakeTimeLimits = noTimeLimitsHandshake, + ctaVersionDataCodec = unversionedProtocolDataCodec, + ctaConnectTracers = nullNetworkConnectTracers, + ctaHandshakeCallbacks = HandshakeCallbacks acceptableVersion queryVersion + } + (unversionedProtocol initiatorApp)) res <- atomically $ (,) <$> takeTMVar sv <*> takeTMVar cv return (res == L.mapAccumL f 0 xs) @@ -800,13 +802,15 @@ prop_send_recv_init_and_rsp f xs = ioProperty $ withIOManager $ \iocp -> do nullErrorPolicies (\_ -> waitSiblingSTM (rrcSiblingVar rrcfg)) (connectToNodeSocket - iocp - unversionedHandshakeCodec - noTimeLimitsHandshake - unversionedProtocolDataCodec - nullNetworkConnectTracers - (HandshakeCallbacks acceptableVersion queryVersion) - (unversionedProtocol (appX rrcfg))) + iocp + ConnectToArgs { + ctaHandshakeCodec = unversionedHandshakeCodec, + ctaHandshakeTimeLimits = noTimeLimitsHandshake, + ctaVersionDataCodec = unversionedProtocolDataCodec, + ctaConnectTracers = nullNetworkConnectTracers, + ctaHandshakeCallbacks = HandshakeCallbacks acceptableVersion queryVersion + } + (unversionedProtocol (appX rrcfg))) atomically $ (,) <$> takeTMVar (rrcServerVar rrcfg) <*> takeTMVar (rrcClientVar rrcfg) @@ -873,13 +877,15 @@ _demo = ioProperty $ withIOManager $ \iocp -> do } (connectToNodeSocket - iocp - unversionedHandshakeCodec - noTimeLimitsHandshake - unversionedProtocolDataCodec - nullNetworkConnectTracers - (HandshakeCallbacks acceptableVersion queryVersion) - (unversionedProtocol appReq)) + iocp + ConnectToArgs { + ctaHandshakeCodec = unversionedHandshakeCodec, + ctaHandshakeTimeLimits = noTimeLimitsHandshake, + ctaVersionDataCodec = unversionedProtocolDataCodec, + ctaConnectTracers = nullNetworkConnectTracers, + ctaHandshakeCallbacks = HandshakeCallbacks acceptableVersion queryVersion + } + (unversionedProtocol appReq)) threadDelay 130 -- bring the servers back again diff --git a/ouroboros-network-framework/sim-tests/Test/Simulation/Network/Snocket.hs b/ouroboros-network-framework/sim-tests/Test/Simulation/Network/Snocket.hs index 5c9f8f339f3..195f020f6dd 100644 --- a/ouroboros-network-framework/sim-tests/Test/Simulation/Network/Snocket.hs +++ b/ouroboros-network-framework/sim-tests/Test/Simulation/Network/Snocket.hs @@ -47,7 +47,6 @@ import Text.Printf import Foreign.C.Error import GHC.IO.Exception -import Ouroboros.Network.Channel import Ouroboros.Network.ConnectionId import Ouroboros.Network.Driver.Simple import Ouroboros.Network.Snocket @@ -218,7 +217,7 @@ clientServerSimulation payloads = listen snocket fd accept snocket fd >>= acceptLoop threadsVar) `finally` do - threads <- atomically (readTVar threadsVar) + threads <- readTVarIO threadsVar traverse_ cancel threads where acceptLoop :: StrictTVar m (Set (Async m ())) @@ -242,13 +241,12 @@ clientServerSimulation payloads = handleConnection bearer remoteAddr = do labelThisThread "server-handler" bracket - (newMux (MiniProtocolBundle - [ MiniProtocolInfo { - miniProtocolNum = reqRespProtocolNum, - miniProtocolDir = ResponderDirectionOnly, - miniProtocolLimits = MiniProtocolLimits maxBound - } - ])) + (newMux [ MiniProtocolInfo { + miniProtocolNum = reqRespProtocolNum, + miniProtocolDir = ResponderDirectionOnly, + miniProtocolLimits = MiniProtocolLimits maxBound + } + ]) stopMux $ \mux -> do let connId = ConnectionId { @@ -265,7 +263,7 @@ clientServerSimulation payloads = ResponderDirectionOnly StartOnDemand (\channel -> runPeer tr codecReqResp - (fromChannel channel) + channel serverPeer) withAsync (do labelThisThread "server-mux" @@ -286,13 +284,12 @@ clientServerSimulation payloads = (close snocket) $ \fd -> do connect snocket fd serverAddr - mux <- newMux (MiniProtocolBundle - [ MiniProtocolInfo { - miniProtocolNum = reqRespProtocolNum, - miniProtocolDir = InitiatorDirectionOnly, - miniProtocolLimits = MiniProtocolLimits maxBound - } - ]) + mux <- newMux [ MiniProtocolInfo { + miniProtocolNum = reqRespProtocolNum, + miniProtocolDir = InitiatorDirectionOnly, + miniProtocolLimits = MiniProtocolLimits maxBound + } + ] localAddr <- getLocalAddr snocket fd let connId = ConnectionId { localAddress = localAddr, @@ -307,7 +304,7 @@ clientServerSimulation payloads = InitiatorDirectionOnly StartEagerly (\channel -> runPeer tr codecReqResp - (fromChannel channel) + channel clientPeer) bearer <- Mx.getBearer makeFDBearer 10 nullTracer fd diff --git a/ouroboros-network-framework/src/Ouroboros/Network/Channel.hs b/ouroboros-network-framework/src/Ouroboros/Network/Channel.hs index 2259e2ebb00..361c10f6d38 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/Channel.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/Channel.hs @@ -6,110 +6,31 @@ module Ouroboros.Network.Channel ( Channel (..) , toChannel , fromChannel - , createPipeConnectedChannels - , hoistChannel - , isoKleisliChannel + , module Mx , fixedInputChannel - , mvarsAsChannel - , handlesAsChannel - , createConnectedChannels , createConnectedBufferedChannelsUnbounded , createConnectedBufferedChannels , createConnectedBufferedChannelsSTM , createPipelineTestChannels - , channelEffect - , delayChannel - , loggingChannel ) where -import Control.Monad ((>=>)) -import Control.Monad.Class.MonadSay -import Control.Monad.Class.MonadTimer.SI -import Data.ByteString qualified as BS import Data.ByteString.Lazy qualified as LBS -import Data.ByteString.Lazy.Internal (smallChunkSize) import Numeric.Natural -import System.IO qualified as IO (Handle, hFlush, hIsEOF) - import Control.Concurrent.Class.MonadSTM.Strict -import Network.Mux.Channel qualified as Mx - - --- | One end of a duplex channel. It is a reliable, ordered channel of some --- medium. The medium does not imply message boundaries, it can be just bytes. --- -data Channel m a = Channel { - - -- | Write output to the channel. - -- - -- It may raise exceptions (as appropriate for the monad and kind of - -- channel). - -- - send :: a -> m (), - - -- | Read some input from the channel, or @Nothing@ to indicate EOF. - -- - -- Note that having received EOF it is still possible to send. - -- The EOF condition is however monotonic. - -- - -- It may raise exceptions (as appropriate for the monad and kind of - -- channel). - -- - recv :: m (Maybe a) - } +import Network.Mux.Channel as Mx --- TODO: eliminate the second Channel type and these conversion functions. - -fromChannel :: Mx.Channel m +fromChannel :: Mx.ByteChannel m -> Channel m LBS.ByteString -fromChannel Mx.Channel { Mx.send, Mx.recv } = Channel { - send = send, - recv = recv - } +fromChannel = id +{-# DEPRECATED fromChannel "Not needed, use `id` instead." #-} toChannel :: Channel m LBS.ByteString - -> Mx.Channel m -toChannel Channel { send, recv } = Mx.Channel { - Mx.send = send, - Mx.recv = recv - } + -> Mx.ByteChannel m +toChannel = id +{-# DEPRECATED toChannel "Not needed, use `id` instead." #-} --- | Create a local pipe, with both ends in this process, and expose that as --- a pair of 'Channel's, one for each end. --- --- This is primarily for testing purposes since it does not allow actual IPC. --- -createPipeConnectedChannels :: IO (Channel IO LBS.ByteString, - Channel IO LBS.ByteString) -createPipeConnectedChannels = - (\(a, b) -> (fromChannel a, fromChannel b)) - <$> Mx.createPipeConnectedChannels - --- | Given an isomorphism between @a@ and @b@ (in Kleisli category), transform --- a @'Channel' m a@ into @'Channel' m b@. --- -isoKleisliChannel - :: forall a b m. Monad m - => (a -> m b) - -> (b -> m a) - -> Channel m a - -> Channel m b -isoKleisliChannel f finv Channel{send, recv} = Channel { - send = finv >=> send, - recv = recv >>= traverse f - } - - -hoistChannel - :: (forall x . m x -> n x) - -> Channel m a - -> Channel n a -hoistChannel nat channel = Channel - { send = nat . send channel - , recv = nat (recv channel) - } -- | A 'Channel' with a fixed input, and where all output is discarded. -- @@ -134,34 +55,6 @@ fixedInputChannel xs0 = do send _ = return () --- | Make a 'Channel' from a pair of 'TMVar's, one for reading and one for --- writing. --- -mvarsAsChannel :: MonadSTM m - => StrictTMVar m a - -> StrictTMVar m a - -> Channel m a -mvarsAsChannel bufferRead bufferWrite = - Channel{send, recv} - where - send x = atomically (putTMVar bufferWrite x) - recv = atomically (Just <$> takeTMVar bufferRead) - - --- | Create a pair of channels that are connected via one-place buffers. --- --- This is primarily useful for testing protocols. --- -createConnectedChannels :: MonadSTM m => m (Channel m a, Channel m a) -createConnectedChannels = do - -- Create two TMVars to act as the channel buffer (one for each direction) - -- and use them to make both ends of a bidirectional channel - bufferA <- newEmptyTMVarIO - bufferB <- newEmptyTMVarIO - - return (mvarsAsChannel bufferB bufferA, - mvarsAsChannel bufferA bufferB) - -- | Create a pair of channels that are connected via two unbounded buffers. -- @@ -172,8 +65,8 @@ createConnectedBufferedChannelsUnbounded :: forall m a. MonadSTM m createConnectedBufferedChannelsUnbounded = do -- Create two TQueues to act as the channel buffers (one for each -- direction) and use them to make both ends of a bidirectional channel - bufferA <- atomically $ newTQueue - bufferB <- atomically $ newTQueue + bufferA <- atomically newTQueue + bufferB <- atomically newTQueue return (queuesAsChannel bufferB bufferA, queuesAsChannel bufferA bufferB) @@ -259,97 +152,3 @@ createPipelineTestChannels sz = do failureMsg = "createPipelineTestChannels: " ++ "maximum pipeline depth exceeded: " ++ show sz - - --- | Make a 'Channel' from a pair of IO 'Handle's, one for reading and one --- for writing. --- --- The Handles should be open in the appropriate read or write mode, and in --- binary mode. Writes are flushed after each write, so it is safe to use --- a buffering mode. --- --- For bidirectional handles it is safe to pass the same handle for both. --- -handlesAsChannel :: IO.Handle -- ^ Read handle - -> IO.Handle -- ^ Write handle - -> Channel IO LBS.ByteString -handlesAsChannel hndRead hndWrite = - Channel{send, recv} - where - send :: LBS.ByteString -> IO () - send chunk = do - LBS.hPut hndWrite chunk - IO.hFlush hndWrite - - recv :: IO (Maybe LBS.ByteString) - recv = do - eof <- IO.hIsEOF hndRead - if eof - then return Nothing - else Just . LBS.fromStrict <$> BS.hGetSome hndRead smallChunkSize - - --- | Transform a channel to add an extra action before /every/ send and after --- /every/ receive. --- -channelEffect :: forall m a. - Monad m - => (a -> m ()) -- ^ Action before 'send' - -> (Maybe a -> m ()) -- ^ Action after 'recv' - -> Channel m a - -> Channel m a -channelEffect beforeSend afterRecv Channel{send, recv} = - Channel{ - send = \x -> do - beforeSend x - send x - - , recv = do - mx <- recv - afterRecv mx - return mx - } - --- | Delay a channel on the receiver end. --- --- This is intended for testing, as a crude approximation of network delays. --- More accurate models along these lines are of course possible. --- -delayChannel :: MonadDelay m - => DiffTime - -> Channel m a - -> Channel m a -delayChannel delay Channel{send, recv} = - Channel { send - , recv = threadDelay (delay / 2) - >> recv - <* threadDelay (delay / 2) - } - - - --- | Channel which logs sent and received messages. --- -loggingChannel :: ( MonadSay m - , Show id - , Show a - ) - => id - -> Channel m a - -> Channel m a -loggingChannel ident Channel{send,recv} = - Channel { - send = loggingSend, - recv = loggingRecv - } - where - loggingSend a = do - say (show ident ++ ":send:" ++ show a) - send a - - loggingRecv = do - msg <- recv - case msg of - Nothing -> return () - Just a -> say (show ident ++ ":recv:" ++ show a) - return msg diff --git a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionHandler.hs b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionHandler.hs index eeac4eeb930..58ac5293a2a 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionHandler.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionHandler.hs @@ -318,7 +318,7 @@ makeConnectionHandler muxTracer singMuxMode <$> newTVarIO Continue <*> newTVarIO Continue <*> newTVarIO Continue - mux <- newMux (mkMiniProtocolBundle app) + mux <- newMux (mkMiniProtocolInfos app) let !handle = Handle { hMux = mux, hMuxBundle = app, @@ -385,7 +385,8 @@ makeConnectionHandler muxTracer singMuxMode <$> newTVarIO Continue <*> newTVarIO Continue <*> newTVarIO Continue - mux <- newMux (mkMiniProtocolBundle app) + mux <- newMux (mkMiniProtocolInfos app) + let !handle = Handle { hMux = mux, hMuxBundle = app, diff --git a/ouroboros-network-framework/src/Ouroboros/Network/Driver/Limits.hs b/ouroboros-network-framework/src/Ouroboros/Network/Driver/Limits.hs index 7b227154fe1..3e9df4f19bf 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/Driver/Limits.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/Driver/Limits.hs @@ -22,14 +22,19 @@ module Ouroboros.Network.Driver.Limits -- * Normal peers , runPeerWithLimits , runPipelinedPeerWithLimits + , runPeerWithLimitsRnd + , runPipelinedPeerWithLimitsRnd , TraceSendRecv (..) -- * Driver utilities , driverWithLimits , runConnectedPeersWithLimits , runConnectedPipelinedPeersWithLimits + , runConnectedPeersWithLimitsRnd + , runConnectedPipelinedPeersWithLimitsRnd ) where import Data.Maybe (fromMaybe) +import System.Random import Control.Monad.Class.MonadAsync import Control.Monad.Class.MonadFork @@ -105,6 +110,65 @@ driverWithLimits tracer timeoutFn Nothing -> throwIO (ExceededTimeLimit tok) +driverWithLimitsRnd :: forall ps (pr :: PeerRole) failure bytes m. + ( MonadThrow m + , ShowProxy ps + , forall (st' :: ps) tok. tok ~ StateToken st' => Show tok + , Show failure + ) + => Tracer m (TraceSendRecv ps) + -> TimeoutFn m + -> StdGen + -> Codec ps failure m bytes + -> ProtocolSizeLimits ps bytes + -> (StdGen -> ProtocolTimeLimits ps) + -> Channel m bytes + -> Driver ps pr (Maybe bytes, StdGen) m +driverWithLimitsRnd tracer timeoutFn rnd0 + Codec{encode, decode} + ProtocolSizeLimits{sizeLimitForState, dataSize} + genProtocolTimeLimits + channel@Channel{send} = + Driver { sendMessage, recvMessage, initialDState = (Nothing, rnd0) } + where + sendMessage :: forall (st :: ps) (st' :: ps). + StateTokenI st + => ActiveState st + => WeHaveAgencyProof pr st + -> Message ps st st' + -> m () + sendMessage !_ msg = do + send (encode msg) + traceWith tracer (TraceSendMsg (AnyMessage msg)) + + + recvMessage :: forall (st :: ps). + StateTokenI st + => ActiveState st + => TheyHaveAgencyProof pr st + -> (Maybe bytes, StdGen) + -> m (SomeMessage st, (Maybe bytes, StdGen)) + recvMessage !_ (trailing, !rnd) = do + let tok = stateToken + decoder <- decode tok + let sizeLimit = sizeLimitForState @st stateToken + + let (rnd', rnd'') = split rnd + ProtocolTimeLimits{timeLimitForState} = genProtocolTimeLimits rnd'' + timeLimit = fromMaybe (-1) $ timeLimitForState @st stateToken + result <- timeoutFn timeLimit $ + runDecoderWithLimit sizeLimit dataSize + channel trailing decoder + + case result of + Just (Right (x@(SomeMessage msg), trailing')) -> do + traceWith tracer (TraceRecvMsg (AnyMessage msg)) + return (x, (trailing', rnd')) + Just (Left (Just failure)) -> throwIO (DecoderFailure tok failure) + Just (Left Nothing) -> throwIO (ExceededSizeLimit tok) + Nothing -> throwIO (ExceededTimeLimit tok) + + runDecoderWithLimit :: forall m bytes failure a. Monad m => Word @@ -152,7 +216,8 @@ runDecoderWithLimit limit size Channel{recv} = Just bs -> do let sz' = sz + size bs go sz' Nothing =<< k (Just bs) - +-- | Run a peer with limits. +-- runPeerWithLimits :: forall ps (st :: ps) pr failure bytes m a . ( MonadAsync m @@ -175,6 +240,37 @@ runPeerWithLimits tracer codec slimits tlimits channel peer = withTimeoutSerial $ \timeoutFn -> let driver = driverWithLimits tracer timeoutFn codec slimits tlimits channel in runPeerWithDriver driver peer + + +-- | Run a peer with limits. 'ProtocolTimeLimits' have access to +-- a pseudorandom generator. +-- +runPeerWithLimitsRnd + :: forall ps (st :: ps) pr failure bytes m a . + ( MonadAsync m + , MonadFork m + , MonadMask m + , MonadThrow (STM m) + , MonadTimer m + , ShowProxy ps + , forall (st' :: ps) stok. stok ~ StateToken st' => Show stok + , Show failure + ) + => Tracer m (TraceSendRecv ps) + -> StdGen + -> Codec ps failure m bytes + -> ProtocolSizeLimits ps bytes + -> (StdGen -> ProtocolTimeLimits ps) + -> Channel m bytes + -> Peer ps pr NonPipelined st m a + -> m (a, Maybe bytes) +runPeerWithLimitsRnd tracer rnd codec slimits tlimits channel peer = + withTimeoutSerial $ \timeoutFn -> + let driver = driverWithLimitsRnd tracer timeoutFn rnd codec slimits tlimits channel + in (\(a, (trailing, _)) -> (a, trailing)) + <$> runPeerWithDriver driver peer + + -- | Run a pipelined peer with the given channel via the given codec. -- -- This runs the peer to completion (if the protocol allows for termination). @@ -206,6 +302,35 @@ runPipelinedPeerWithLimits tracer codec slimits tlimits channel peer = in runPipelinedPeerWithDriver driver peer +-- | Like 'runPipelinedPeerWithLimits' but time limits have access to +-- a pseudorandom generator. +-- +runPipelinedPeerWithLimitsRnd + :: forall ps (st :: ps) pr failure bytes m a. + ( MonadAsync m + , MonadFork m + , MonadMask m + , MonadTimer m + , MonadThrow (STM m) + , ShowProxy ps + , forall (st' :: ps) stok. stok ~ StateToken st' => Show stok + , Show failure + ) + => Tracer m (TraceSendRecv ps) + -> StdGen + -> Codec ps failure m bytes + -> ProtocolSizeLimits ps bytes + -> (StdGen -> ProtocolTimeLimits ps) + -> Channel m bytes + -> PeerPipelined ps pr st m a + -> m (a, Maybe bytes) +runPipelinedPeerWithLimitsRnd tracer rnd codec slimits tlimits channel peer = + withTimeoutSerial $ \timeoutFn -> + let driver = driverWithLimitsRnd tracer timeoutFn rnd codec slimits tlimits channel + in (\(a, (trailing, _)) -> (a, trailing)) + <$> runPipelinedPeerWithDriver driver peer + + -- | Run two 'Peer's via a pair of connected 'Channel's and a common 'Codec'. -- The client side is using 'driverWithLimits'. -- @@ -248,6 +373,41 @@ runConnectedPeersWithLimits createChannels tracer codec slimits tlimits client s tracerServer = contramap ((,) Server) tracer +runConnectedPeersWithLimitsRnd + :: forall ps pr st failure bytes m a b. + ( MonadAsync m + , MonadFork m + , MonadMask m + , MonadTimer m + , MonadThrow (STM m) + , Exception failure + , ShowProxy ps + , forall (st' :: ps) sing. sing ~ StateToken st' => Show sing + ) + => m (Channel m bytes, Channel m bytes) + -> Tracer m (Role, TraceSendRecv ps) + -> StdGen + -> Codec ps failure m bytes + -> ProtocolSizeLimits ps bytes + -> (StdGen -> ProtocolTimeLimits ps) + -> Peer ps pr NonPipelined st m a + -> Peer ps (FlipAgency pr) NonPipelined st m b + -> m (a, b) +runConnectedPeersWithLimitsRnd createChannels tracer rnd codec slimits tlimits client server = + createChannels >>= \(clientChannel, serverChannel) -> + + (do labelThisThread "client" + fst <$> runPeerWithLimitsRnd + tracerClient rnd codec slimits tlimits + clientChannel client) + `concurrently` + (do labelThisThread "server" + fst <$> runPeer tracerServer codec serverChannel server) + where + tracerClient = contramap ((,) Client) tracer + tracerServer = contramap ((,) Server) tracer + + -- | Run two 'Peer's via a pair of connected 'Channel's and a common 'Codec'. -- The client side is using 'driverWithLimits'. -- @@ -286,3 +446,36 @@ runConnectedPipelinedPeersWithLimits createChannels tracer codec slimits tlimits where tracerClient = contramap ((,) Client) tracer tracerServer = contramap ((,) Server) tracer + + +runConnectedPipelinedPeersWithLimitsRnd + :: forall ps pr st failure bytes m a b. + ( MonadAsync m + , MonadFork m + , MonadMask m + , MonadTimer m + , MonadThrow (STM m) + , Exception failure + , ShowProxy ps + , forall (st' :: ps) sing. sing ~ StateToken st' => Show sing + ) + => m (Channel m bytes, Channel m bytes) + -> Tracer m (Role, TraceSendRecv ps) + -> StdGen + -> Codec ps failure m bytes + -> ProtocolSizeLimits ps bytes + -> (StdGen -> ProtocolTimeLimits ps) + -> PeerPipelined ps pr st m a + -> Peer ps (FlipAgency pr) NonPipelined st m b + -> m (a, b) +runConnectedPipelinedPeersWithLimitsRnd createChannels tracer rnd codec slimits tlimits client server = + createChannels >>= \(clientChannel, serverChannel) -> + + (fst <$> runPipelinedPeerWithLimitsRnd + tracerClient rnd codec slimits tlimits + clientChannel client) + `concurrently` + (fst <$> runPeer tracerServer codec serverChannel server) + where + tracerClient = contramap ((,) Client) tracer + tracerServer = contramap ((,) Server) tracer diff --git a/ouroboros-network-framework/src/Ouroboros/Network/Mux.hs b/ouroboros-network-framework/src/Ouroboros/Network/Mux.hs index d3e48eef10f..bf4d754ab9b 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/Mux.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/Mux.hs @@ -1,11 +1,9 @@ {-# LANGUAGE BangPatterns #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveTraversable #-} -{-# LANGUAGE ExplicitNamespaces #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} @@ -45,12 +43,13 @@ module Ouroboros.Network.Mux -- * MiniProtocol bundle , OuroborosBundle , OuroborosBundleWithExpandedCtx + , OuroborosBundleWithMinimalCtx -- * Non-P2P API , OuroborosApplication (..) , OuroborosApplicationWithMinimalCtx - , toApplication - , mkMiniProtocolBundle + , mkMiniProtocolInfos , fromOuroborosBundle + , toMiniProtocolInfos , contramapInitiatorCtx -- * Re-exports -- | from "Network.Mux" @@ -78,11 +77,10 @@ import Network.TypedProtocol.Peer import Network.TypedProtocol.Stateful.Codec qualified as Stateful import Network.TypedProtocol.Stateful.Peer qualified as Stateful -import Network.Mux (HasInitiator, HasResponder, MiniProtocolBundle (..), - MiniProtocolInfo, MiniProtocolLimits (..), MiniProtocolNum, - MuxError (..), MuxErrorType (..), MuxMode (..)) +import Network.Mux (HasInitiator, HasResponder, MiniProtocolInfo, + MiniProtocolLimits (..), MiniProtocolNum, MuxError (..), + MuxErrorType (..), MuxMode (..)) import Network.Mux.Channel qualified as Mux -import Network.Mux.Compat qualified as Mux.Compat import Network.Mux.Types qualified as Mux import Ouroboros.Network.Channel @@ -231,6 +229,12 @@ type OuroborosBundleWithExpandedCtx (mode :: MuxMode) peerAddr bytes m a b = (ResponderContext peerAddr) bytes m a b +type OuroborosBundleWithMinimalCtx (mode :: MuxMode) peerAddr bytes m a b = + OuroborosBundle mode + (MinimalInitiatorContext peerAddr) + (ResponderContext peerAddr) + bytes m a b + -- | Each mini-protocol is represented by its -- @@ -245,6 +249,25 @@ data MiniProtocol (mode :: MuxMode) initiatorCtx responderCtx bytes m a b = miniProtocolRun :: !(RunMiniProtocol mode initiatorCtx responderCtx bytes m a b) } +mkMiniProtocolInfo :: MiniProtocol mode initiatorCtx responderCtx bytes m a b -> [MiniProtocolInfo mode] +mkMiniProtocolInfo MiniProtocol { + miniProtocolNum, + miniProtocolLimits, + miniProtocolRun + } + = + [ Mux.MiniProtocolInfo { + Mux.miniProtocolNum, + Mux.miniProtocolDir = dir, + Mux.miniProtocolLimits + } + | dir <- case miniProtocolRun of + InitiatorProtocolOnly{} -> [ Mux.InitiatorDirectionOnly ] + ResponderProtocolOnly{} -> [ Mux.ResponderDirectionOnly ] + InitiatorAndResponderProtocol{} -> [ Mux.InitiatorDirection + , Mux.ResponderDirection ] + ] + -- | 'MiniProtocol' type used in P2P. -- @@ -431,9 +454,9 @@ runMiniProtocolCb :: ( MonadAsync m ) => MiniProtocolCb ctx LBS.ByteString m a -> ctx - -> Mux.Channel m + -> Mux.ByteChannel m -> m (a, Maybe LBS.ByteString) -runMiniProtocolCb (MiniProtocolCb run) !ctx = run ctx . fromChannel +runMiniProtocolCb (MiniProtocolCb run) !ctx = run ctx runMiniProtocolCb (MuxPeer fn) !ctx = runMiniProtocolCb (mkMiniProtocolCbFromPeer fn) ctx runMiniProtocolCb (MuxPeerPipelined fn) !ctx = runMiniProtocolCb (mkMiniProtocolCbFromPeerPipelined fn) ctx @@ -450,7 +473,10 @@ contramapMiniProtocolCbCtx f (MuxPeerPipelined cb) = MuxPeerPipelined (cb . f) -- -- Note: Only used in some non-P2P contexts. newtype OuroborosApplication (mode :: MuxMode) initiatorCtx responderCtx bytes m a b = - OuroborosApplication [MiniProtocol mode initiatorCtx responderCtx bytes m a b] + OuroborosApplication { + getOuroborosApplication + :: [MiniProtocol mode initiatorCtx responderCtx bytes m a b] + } -- | 'OuroborosApplication' used in NonP2P mode. -- @@ -465,6 +491,12 @@ fromOuroborosBundle :: OuroborosBundle mode initiatorCtx responderCtx bytes fromOuroborosBundle = OuroborosApplication . fold +toMiniProtocolInfos :: OuroborosApplication mode initiatorCtx responderCtx bytes m a b + -> [MiniProtocolInfo mode] +toMiniProtocolInfos = + foldMap mkMiniProtocolInfo . getOuroborosApplication + + contramapInitiatorCtx :: (initiatorCtx' -> initiatorCtx) -> OuroborosApplication mode initiatorCtx responderCtx bytes m a b -> OuroborosApplication mode initiatorCtx' responderCtx bytes m a b @@ -482,61 +514,9 @@ contramapInitiatorCtx f (OuroborosApplication ptcls) = OuroborosApplication ] --- | Create non p2p mux application. --- --- Note that callbacks will always receive `IsNotBigLedgerPeer`. -toApplication :: forall mode initiatorCtx responderCtx m a b. - ( MonadAsync m - , MonadThrow m - ) - => initiatorCtx - -> responderCtx - -> OuroborosApplication mode initiatorCtx responderCtx LBS.ByteString m a b - -> Mux.Compat.MuxApplication mode m a b -toApplication initiatorContext responderContext (OuroborosApplication ptcls) = - Mux.Compat.MuxApplication - [ Mux.Compat.MuxMiniProtocol { - Mux.Compat.miniProtocolNum = miniProtocolNum ptcl, - Mux.Compat.miniProtocolLimits = miniProtocolLimits ptcl, - Mux.Compat.miniProtocolRun = toMuxRunMiniProtocol (miniProtocolRun ptcl) - } - | ptcl <- ptcls ] - where - toMuxRunMiniProtocol :: RunMiniProtocol mode initiatorCtx responderCtx LBS.ByteString m a b - -> Mux.Compat.RunMiniProtocol mode m a b - toMuxRunMiniProtocol (InitiatorProtocolOnly i) = - Mux.Compat.InitiatorProtocolOnly - (runMiniProtocolCb i initiatorContext) - toMuxRunMiniProtocol (ResponderProtocolOnly r) = - Mux.Compat.ResponderProtocolOnly - (runMiniProtocolCb r responderContext) - toMuxRunMiniProtocol (InitiatorAndResponderProtocol i r) = - Mux.Compat.InitiatorAndResponderProtocol - (runMiniProtocolCb i initiatorContext) - (runMiniProtocolCb r responderContext) - - -- | Make 'MiniProtocolBundle', which is used to create a mux interface with -- 'newMux'. -- -mkMiniProtocolBundle :: OuroborosBundle mode initiatorCtx responderCtx bytes m a b - -> MiniProtocolBundle mode -mkMiniProtocolBundle = MiniProtocolBundle . foldMap fn - where - fn :: [MiniProtocol mode initiatorCtx responderCtx bytes m a b] -> [MiniProtocolInfo mode] - fn ptcls = [ Mux.MiniProtocolInfo - { Mux.miniProtocolNum - , Mux.miniProtocolDir = dir - , Mux.miniProtocolLimits - } - | MiniProtocol { miniProtocolNum - , miniProtocolLimits - , miniProtocolRun - } - <- ptcls - , dir <- case miniProtocolRun of - InitiatorProtocolOnly{} -> [ Mux.InitiatorDirectionOnly ] - ResponderProtocolOnly{} -> [ Mux.ResponderDirectionOnly ] - InitiatorAndResponderProtocol{} -> [ Mux.InitiatorDirection - , Mux.ResponderDirection ] - ] +mkMiniProtocolInfos :: OuroborosBundle mode initiatorCtx responderCtx bytes m a b + -> [MiniProtocolInfo mode] +mkMiniProtocolInfos = foldMap (foldMap mkMiniProtocolInfo) diff --git a/ouroboros-network-framework/src/Ouroboros/Network/Protocol/Handshake.hs b/ouroboros-network-framework/src/Ouroboros/Network/Protocol/Handshake.hs index 2160cfe90af..e48ff082713 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/Protocol/Handshake.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/Protocol/Handshake.hs @@ -32,7 +32,6 @@ import Network.Mux.Trace import Network.Mux.Types import Network.TypedProtocol.Codec -import Ouroboros.Network.Channel import Ouroboros.Network.Driver.Limits import Ouroboros.Network.Protocol.Handshake.Client @@ -143,7 +142,7 @@ runHandshakeClient bearer haHandshakeCodec byteLimitsHandshake haTimeLimits - (fromChannel (muxBearerAsChannel bearer handshakeProtocolNum InitiatorDir)) + (muxBearerAsChannel bearer handshakeProtocolNum InitiatorDir) (handshakeClientPeer haVersionDataCodec haAcceptVersion versions)) @@ -181,7 +180,7 @@ runHandshakeServer bearer haHandshakeCodec byteLimitsHandshake haTimeLimits - (fromChannel (muxBearerAsChannel bearer handshakeProtocolNum ResponderDir)) + (muxBearerAsChannel bearer handshakeProtocolNum ResponderDir) (handshakeServerPeer haVersionDataCodec haAcceptVersion haQueryVersion versions)) -- | A 20s delay after query result was send back, before we close the diff --git a/ouroboros-network-framework/src/Ouroboros/Network/Protocol/Handshake/Codec.hs b/ouroboros-network-framework/src/Ouroboros/Network/Protocol/Handshake/Codec.hs index c42adb44c36..7512e037da9 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/Protocol/Handshake/Codec.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/Protocol/Handshake/Codec.hs @@ -99,6 +99,14 @@ byteLimitsHandshake = ProtocolSizeLimits stateToLimit (fromIntegral . BL.length) -- | Time limits. -- +-- +--------------------+-------------+ +-- | 'Handshake' state | timeout (s) | +-- +====================+=============+ +-- | `StPropose` | `shortWait` | +-- +--------------------+-------------+ +-- | `StConfirm` | `shortWait` | +-- +--------------------+-------------+ +-- timeLimitsHandshake :: forall vNumber. ProtocolTimeLimits (Handshake vNumber CBOR.Term) timeLimitsHandshake = ProtocolTimeLimits stateToLimit where @@ -123,10 +131,9 @@ noTimeLimitsHandshake = ProtocolTimeLimits stateToLimit -- | -- @'Handshake'@ codec. The @'MsgProposeVersions'@ encodes proposed map in --- ascending order and it expects to receive them in this order. This allows --- to construct the map in linear time. There is also another limiting factor --- to the number of versions on can present: the whole message must fit into --- a single TCP segment. +-- ascending order and it expects to receive them in this order. The whole +-- `MsgProposeVersions` message must fit into a single TCP segment which limits +-- number of versions that can be proposed. -- codecHandshake :: forall vNumber m failure. @@ -135,6 +142,7 @@ codecHandshake , Show failure ) => CodecCBORTerm (failure, Maybe Int) vNumber + -- ^ `CBOR.Term` codec for `vNumber` -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure m ByteString codecHandshake versionNumberCodec = mkCodecCborLazyBS encodeMsg decodeMsg where diff --git a/ouroboros-network-framework/src/Ouroboros/Network/Socket.hs b/ouroboros-network-framework/src/Ouroboros/Network/Socket.hs index e90ad9b3314..ea356c9c97f 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/Socket.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/Socket.hs @@ -33,9 +33,12 @@ module Ouroboros.Network.Socket , ConnectionId (..) , withServerNode , withServerNode' + , ConnectToArgs (..) , connectToNode + , connectToNodeWithMux , connectToNodeSocket , connectToNode' + , connectToNodeWithMux' -- * Socket configuration , configureSocket , configureSystemdSocket @@ -80,8 +83,12 @@ import Control.Monad.Class.MonadThrow import Control.Monad.Class.MonadTime.SI import Control.Monad.Class.MonadTimer.SI import Control.Monad.STM qualified as STM +import Data.Bifunctor (first) import Data.ByteString.Lazy qualified as BL +import Data.Foldable (traverse_) +import Data.Functor (void) import Data.Hashable +import Data.Monoid.Synchronisation (FirstToFinish (..)) import Data.Typeable (Typeable) import Data.Void import Data.Word (Word16) @@ -95,8 +102,8 @@ import Network.Socket qualified as Socket import Control.Tracer +import Network.Mux qualified as Mx import Network.Mux.Bearer qualified as Mx -import Network.Mux.Compat qualified as Mx import Network.Mux.DeltaQ.TraceTransformer import Network.TypedProtocol.Codec hiding (decode, encode) @@ -245,53 +252,91 @@ sduTimeout = 30 sduHandshakeTimeout :: DiffTime sduHandshakeTimeout = 10 +-- | Common arguments of various variants of `connectToNode`. +-- +data ConnectToArgs fd addr vNumber vData = ConnectToArgs { + ctaHandshakeCodec :: Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString, + ctaHandshakeTimeLimits :: ProtocolTimeLimits (Handshake vNumber CBOR.Term), + ctaVersionDataCodec :: VersionDataCodec CBOR.Term vNumber vData, + ctaConnectTracers :: NetworkConnectTracers addr vNumber, + ctaHandshakeCallbacks :: HandshakeCallbacks vData + } --- | --- Connect to a remote node. It is using bracket to enclose the underlying + +-- | Connect to a remote node. It is using bracket to enclose the underlying -- socket acquisition. This implies that when the continuation exits the -- underlying bearer will get closed. -- -- The connection will start with handshake protocol sending @Versions@ to the -- remote peer. It must fit into @'maxTransmissionUnit'@ (~5k bytes). -- --- Exceptions thrown by @'MuxApplication'@ are rethrown by @'connectTo'@. +-- Exceptions thrown by 'MuxApplication' are rethrown by 'connectToNode'. connectToNode - :: forall appType vNumber vData fd addr a b. + :: forall muxMode vNumber vData fd addr a b. ( Ord vNumber , Typeable vNumber , Show vNumber - , Mx.HasInitiator appType ~ True + , Mx.HasInitiator muxMode ~ True ) => Snocket IO fd addr -> Mx.MakeBearer IO fd - -> (fd -> IO ()) -- ^ configure a socket - -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString - -> ProtocolTimeLimits (Handshake vNumber CBOR.Term) - -> VersionDataCodec CBOR.Term vNumber vData - -> NetworkConnectTracers addr vNumber - -> HandshakeCallbacks vData - -> Versions vNumber vData (OuroborosApplicationWithMinimalCtx appType addr BL.ByteString IO a b) - -- ^ application to run over the connection + -> ConnectToArgs fd addr vNumber vData + -> (fd -> IO ()) -- ^ configure socket + -> Versions vNumber vData (OuroborosApplicationWithMinimalCtx muxMode addr BL.ByteString IO a b) -> Maybe addr -- ^ local address; the created socket will bind to it -> addr -- ^ remote address - -> IO () -connectToNode sn makeBearer configureSock handshakeCodec handshakeTimeLimits versionDataCodec tracers handshakeCallbacks versions localAddr remoteAddr = - bracket - (Snocket.openToConnect sn remoteAddr) - (Snocket.close sn) - (\sd -> do - configureSock sd - case localAddr of - Just addr -> Snocket.bind sn sd addr - Nothing -> return () - Snocket.connect sn sd remoteAddr - connectToNode' sn makeBearer handshakeCodec handshakeTimeLimits versionDataCodec tracers handshakeCallbacks versions sd - ) + -> IO (Either SomeException (Either a b)) +connectToNode sn mkBearer args configureSock versions localAddr remoteAddr = + connectToNodeWithMux sn mkBearer args configureSock versions localAddr remoteAddr simpleMuxCallback --- | --- Connect to a remote node using an existing socket. It is up to to caller to + +-- | A version `connectToNode` which allows one to control which mini-protocols +-- to execute on a given connection. +connectToNodeWithMux + :: forall muxMode vNumber vData fd addr a b x. + ( Ord vNumber + , Typeable vNumber + , Show vNumber + , Mx.HasInitiator muxMode ~ True + ) + => Snocket IO fd addr + -> Mx.MakeBearer IO fd + -> ConnectToArgs fd addr vNumber vData + -> (fd -> IO ()) -- ^ configure socket + -> Versions vNumber vData (OuroborosApplicationWithMinimalCtx muxMode addr BL.ByteString IO a b) + -- ^ application to run over the connection + -- ^ remote address + -> Maybe addr + -> addr + -> ( ConnectionId addr + -> vNumber + -> vData + -> OuroborosApplicationWithMinimalCtx muxMode addr BL.ByteString IO a b + -> Mx.Mux muxMode IO + -> Async () + -> IO x) + -- ^ callback which has access to ConnectionId, negotiated protocols, mux + -- handle created for that connection and an `Async` handle to the thread + -- which runs `Mx.runMux`. The `Mux` handle allows schedule mini-protocols. + -- + -- NOTE: when the callback returns or errors, the mux thread will be killed. + -> IO x +connectToNodeWithMux sn mkBearer args configureSock versions localAddr remoteAddr k + = + bracket + (Snocket.openToConnect sn remoteAddr) + (Snocket.close sn) + (\sd -> do + configureSock sd + traverse_ (Snocket.bind sn sd) localAddr + Snocket.connect sn sd remoteAddr + connectToNodeWithMux' sn mkBearer args versions sd k + ) + + +-- | Connect to a remote node using an existing socket. It is up to to caller to -- ensure that the socket is closed in case of an exception. -- -- The connection will start with handshake protocol sending @Versions@ to the @@ -299,25 +344,65 @@ connectToNode sn makeBearer configureSock handshakeCodec handshakeTimeLimits ver -- -- Exceptions thrown by @'MuxApplication'@ are rethrown by @'connectTo'@. connectToNode' - :: forall appType vNumber vData fd addr a b. + :: forall muxMode vNumber vData fd addr a b. ( Ord vNumber , Typeable vNumber , Show vNumber - , Mx.HasInitiator appType ~ True + , Mx.HasInitiator muxMode ~ True ) => Snocket IO fd addr -> Mx.MakeBearer IO fd - -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString - -> ProtocolTimeLimits (Handshake vNumber CBOR.Term) - -> VersionDataCodec CBOR.Term vNumber vData - -> NetworkConnectTracers addr vNumber - -> HandshakeCallbacks vData - -> Versions vNumber vData (OuroborosApplicationWithMinimalCtx appType addr BL.ByteString IO a b) + -> ConnectToArgs fd addr vNumber vData + -- ^ a configured socket to use to connect to a remote service provider + -> Versions vNumber vData (OuroborosApplicationWithMinimalCtx muxMode addr BL.ByteString IO a b) -- ^ application to run over the connection -> fd + -> IO (Either SomeException (Either a b)) +connectToNode' sn mkBearer args versions as = + connectToNodeWithMux' sn mkBearer args versions as simpleMuxCallback + + +connectToNodeWithMux' + :: forall muxMode vNumber vData fd addr a b x. + ( Ord vNumber + , Typeable vNumber + , Show vNumber + , Mx.HasInitiator muxMode ~ True + ) + => Snocket IO fd addr + -> Mx.MakeBearer IO fd + -> ConnectToArgs fd addr vNumber vData + -> Versions vNumber vData (OuroborosApplicationWithMinimalCtx muxMode addr BL.ByteString IO a b) + -- ^ application to run over the connection -- ^ a configured socket to use to connect to a remote service provider - -> IO () -connectToNode' sn makeBearer handshakeCodec handshakeTimeLimits versionDataCodec NetworkConnectTracers {nctMuxTracer, nctHandshakeTracer } handshakeCallbacks versions sd = do + -> fd + -> ( ConnectionId addr + -> vNumber + -> vData + -> OuroborosApplicationWithMinimalCtx muxMode addr BL.ByteString IO a b + -> Mx.Mux muxMode IO + -> Async () + -> IO x) + -- ^ callback which has access to ConnectionId, negotiated protocols, mux + -- handle created for that connection and an `Async` handle to the thread + -- which runs `Mx.runMux`. The `Mux` handle allows schedule mini-protocols. + -- + -- NOTE: when the callback returns or errors, the mux thread will be killed. + -> IO x +connectToNodeWithMux' + sn makeBearer + ConnectToArgs { + ctaHandshakeCodec = handshakeCodec, + ctaHandshakeTimeLimits = handshakeTimeLimits, + ctaVersionDataCodec = versionDataCodec, + ctaConnectTracers = + NetworkConnectTracers { + nctMuxTracer, + nctHandshakeTracer + }, + ctaHandshakeCallbacks = handshakeCallbacks + } + versions sd k = do connectionId <- (\localAddress remoteAddress -> ConnectionId { localAddress, remoteAddress }) <$> Snocket.getLocalAddr sn sd <*> Snocket.getRemoteAddr sn sd muxTracer <- initDeltaQTracer' $ Mx.WithMuxBearer connectionId `contramap` nctMuxTracer @@ -340,56 +425,91 @@ connectToNode' sn makeBearer handshakeCodec handshakeTimeLimits versionDataCodec versions ts_end <- getMonotonicTime case app_e of - Left (HandshakeProtocolLimit err) -> do - traceWith muxTracer $ Mx.MuxTraceHandshakeClientError err (diffTime ts_end ts_start) - throwIO err + Left (HandshakeProtocolLimit err) -> do + traceWith muxTracer $ Mx.MuxTraceHandshakeClientError err (diffTime ts_end ts_start) + throwIO err + + Left (HandshakeProtocolError err) -> do + traceWith muxTracer $ Mx.MuxTraceHandshakeClientError err (diffTime ts_end ts_start) + throwIO err - Left (HandshakeProtocolError err) -> do - traceWith muxTracer $ Mx.MuxTraceHandshakeClientError err (diffTime ts_end ts_start) - throwIO err + Right (HandshakeNegotiationResult app versionNumber agreedOptions) -> do + traceWith muxTracer $ Mx.MuxTraceHandshakeClientEnd (diffTime ts_end ts_start) + bearer <- Mx.getBearer makeBearer sduTimeout muxTracer sd + mux <- Mx.newMux (toMiniProtocolInfos app) + withAsync (Mx.runMux muxTracer mux bearer) $ \aid -> + k connectionId versionNumber agreedOptions app mux aid - Right (HandshakeNegotiationResult app _versionNumber _agreedOptions) -> do - traceWith muxTracer $ Mx.MuxTraceHandshakeClientEnd (diffTime ts_end ts_start) - bearer <- Mx.getBearer makeBearer sduTimeout muxTracer sd - Mx.muxStart - muxTracer - (toApplication MinimalInitiatorContext { micConnectionId = connectionId } - ResponderContext { rcConnectionId = connectionId } - app) - bearer + Right (HandshakeQueryResult _vMap) -> do + traceWith muxTracer $ Mx.MuxTraceHandshakeClientEnd (diffTime ts_end ts_start) + throwIO (QueryNotSupported @vNumber) - Right (HandshakeQueryResult _vMap) -> do - traceWith muxTracer $ Mx.MuxTraceHandshakeClientEnd (diffTime ts_end ts_start) - throwIO (QueryNotSupported @vNumber) + +-- | An internal mux callback which starts all mini-protocols and blocks +-- until the first one terminates. It returns the result (or error) of the +-- first terminated mini-protocol. +-- +simpleMuxCallback + :: ConnectionId addr + -> vNumber + -> vData + -> OuroborosApplicationWithMinimalCtx muxMode addr BL.ByteString IO a b + -> Mx.Mux muxMode IO + -> Async () + -> IO (Either SomeException (Either a b)) +simpleMuxCallback connectionId _ _ app mux aid = do + let initCtx = MinimalInitiatorContext connectionId + respCtx = ResponderContext connectionId + + resOps <- sequence + [ Mx.runMiniProtocol + mux + miniProtocolNum + miniProtocolDir + Mx.StartEagerly + action + | MiniProtocol{miniProtocolNum, miniProtocolRun} + <- getOuroborosApplication app + , (miniProtocolDir, action) <- + case miniProtocolRun of + InitiatorProtocolOnly initiator -> + [(Mx.InitiatorDirectionOnly, fmap (first Left) . runMiniProtocolCb initiator initCtx)] + ResponderProtocolOnly responder -> + [(Mx.ResponderDirectionOnly, fmap (first Right) . runMiniProtocolCb responder respCtx)] + InitiatorAndResponderProtocol initiator responder -> + [(Mx.InitiatorDirection, fmap (first Left) . runMiniProtocolCb initiator initCtx) + ,(Mx.ResponderDirection, fmap (first Right) . runMiniProtocolCb responder respCtx)] + ] + + -- Wait for the first MuxApplication to finish, then stop the mux. + r <- waitOnAny resOps + Mx.stopMux mux + wait aid + return r + where + waitOnAny :: [STM IO (Either SomeException x)] -> IO (Either SomeException x) + waitOnAny = atomically . runFirstToFinish . foldMap FirstToFinish -- Wraps a Socket inside a Snocket and calls connectToNode' connectToNodeSocket - :: forall appType vNumber vData a b. + :: forall muxMode vNumber vData a b. ( Ord vNumber , Typeable vNumber , Show vNumber - , Mx.HasInitiator appType ~ True + , Mx.HasInitiator muxMode ~ True ) => IOManager - -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString - -> ProtocolTimeLimits (Handshake vNumber CBOR.Term) - -> VersionDataCodec CBOR.Term vNumber vData - -> NetworkConnectTracers Socket.SockAddr vNumber - -> HandshakeCallbacks vData - -> Versions vNumber vData (OuroborosApplicationWithMinimalCtx appType Socket.SockAddr BL.ByteString IO a b) + -> ConnectToArgs Socket.Socket Socket.SockAddr vNumber vData + -> Versions vNumber vData (OuroborosApplicationWithMinimalCtx muxMode Socket.SockAddr BL.ByteString IO a b) -- ^ application to run over the connection -> Socket.Socket - -> IO () -connectToNodeSocket iocp handshakeCodec handshakeTimeLimits versionDataCodec tracers handshakeCallbacks versions sd = + -> IO (Either SomeException (Either a b)) +connectToNodeSocket iocp args versions sd = connectToNode' (Snocket.socketSnocket iocp) Mx.makeSocketBearer - handshakeCodec - handshakeTimeLimits - versionDataCodec - tracers - handshakeCallbacks + args versions sd @@ -398,9 +518,9 @@ connectToNodeSocket iocp handshakeCodec handshakeTimeLimits versionDataCodec tra -- data SomeResponderApplication addr bytes m b where SomeResponderApplication - :: forall appType addr bytes m a b. - Mx.HasResponder appType ~ True - => (OuroborosApplicationWithMinimalCtx appType addr bytes m a b) + :: forall muxMode addr bytes m a b. + Mx.HasResponder muxMode ~ True + => (OuroborosApplicationWithMinimalCtx muxMode addr bytes m a b) -> SomeResponderApplication addr bytes m b -- | @@ -429,9 +549,8 @@ data AcceptConnection st vNumber vData peerid m bytes where -> AcceptConnection st vNumber vData peerid m bytes --- | --- Accept or reject incoming connection based on the current state and address --- of the incoming connection. +-- | Accept or reject incoming connection based on the current state and +-- address of the incoming connection. -- beginConnection :: forall vNumber vData addr st fd. @@ -481,15 +600,12 @@ beginConnection makeBearer muxTracer handshakeTracer handshakeCodec handshakeTim traceWith muxTracer' $ Mx.MuxTraceHandshakeServerError err throwIO err - Right (HandshakeNegotiationResult (SomeResponderApplication app) _versionNumber _agreedOptions) -> do - traceWith muxTracer' $ Mx.MuxTraceHandshakeServerEnd + Right (HandshakeNegotiationResult (SomeResponderApplication app) versionNumber agreedOptions) -> do + traceWith muxTracer' Mx.MuxTraceHandshakeServerEnd bearer <- Mx.getBearer makeBearer sduTimeout muxTracer' sd - Mx.muxStart - muxTracer' - (toApplication MinimalInitiatorContext { micConnectionId = connectionId } - ResponderContext { rcConnectionId = connectionId } - app) - bearer + mux <- Mx.newMux (toMiniProtocolInfos app) + withAsync (Mx.runMux muxTracer' mux bearer) $ \aid -> + void $ simpleMuxCallback connectionId versionNumber agreedOptions app mux aid Right (HandshakeQueryResult _vMap) -> do traceWith muxTracer' Mx.MuxTraceHandshakeServerEnd @@ -498,6 +614,7 @@ beginConnection makeBearer muxTracer handshakeTracer handshakeCodec handshakeTim RejectConnection st' _peerid -> pure $ Server.Reject st' + mkListeningSocket :: Snocket IO fd addr -> (fd -> addr -> IO ()) diff --git a/ouroboros-network-protocols/CHANGELOG.md b/ouroboros-network-protocols/CHANGELOG.md index 115e2044261..cc7f2a20444 100644 --- a/ouroboros-network-protocols/CHANGELOG.md +++ b/ouroboros-network-protocols/CHANGELOG.md @@ -6,6 +6,11 @@ * Adapt the `versionNumber` cddl definition to account for `NodeToClientVersionV18`. * Use `typed-protocols-0.3.0.0`. +* `Ouroboros.Network.Protocols.TxSubmission2.Codec.{encode,decode}TxSubmission2` are no longer exported. + +### Non-breaking changes + +* Improved haddocks of `node-to-node` mini-protocol codecs. ## 0.10.0.2 -- 2024-08-27 diff --git a/ouroboros-network-protocols/ouroboros-network-protocols.cabal b/ouroboros-network-protocols/ouroboros-network-protocols.cabal index 34da18afc30..62abadb50e1 100644 --- a/ouroboros-network-protocols/ouroboros-network-protocols.cabal +++ b/ouroboros-network-protocols/ouroboros-network-protocols.cabal @@ -105,6 +105,7 @@ library nothunks, ouroboros-network-api ^>=0.9.0, quiet, + random, serialise, si-timers, singletons, diff --git a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/BlockFetch/Codec.hs b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/BlockFetch/Codec.hs index fbca6badfb8..17304387754 100644 --- a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/BlockFetch/Codec.hs +++ b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/BlockFetch/Codec.hs @@ -1,7 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -18,20 +17,21 @@ import Control.Monad.Class.MonadST import Control.Monad.Class.MonadTime.SI import Data.ByteString.Lazy qualified as LBS +import Data.Kind (Type) import Codec.CBOR.Decoding qualified as CBOR import Codec.CBOR.Encoding qualified as CBOR import Codec.CBOR.Read qualified as CBOR import Text.Printf -import Network.TypedProtocol.Codec.CBOR +import Network.TypedProtocol.Codec.CBOR hiding (decode, encode) import Ouroboros.Network.Protocol.BlockFetch.Type import Ouroboros.Network.Protocol.Limits -- | Byte Limit. -byteLimitsBlockFetch :: forall bytes block point. - (bytes -> Word) +byteLimitsBlockFetch :: forall bytes (block :: Type) (point :: Type). + (bytes -> Word) -- ^ compute size of bytes -> ProtocolSizeLimits (BlockFetch block point) bytes byteLimitsBlockFetch = ProtocolSizeLimits stateToLimit where @@ -44,10 +44,17 @@ byteLimitsBlockFetch = ProtocolSizeLimits stateToLimit -- | Time Limits -- --- `TokIdle' No timeout --- `TokBusy` `longWait` timeout --- `TokStreaming` `longWait` timeout -timeLimitsBlockFetch :: forall block point. +-- +------------------+---------------+ +-- | BlockFetch state | timeout (s) | +-- +==================+===============+ +-- | `BFIdle` | `waitForever` | +-- +------------------+---------------+ +-- | `BFBusy` | `longWait` | +-- +------------------+---------------+ +-- | `BFStreaming` | `longWait` | +-- +------------------+---------------+ +-- +timeLimitsBlockFetch :: forall (block :: Type) (point :: Type). ProtocolTimeLimits (BlockFetch block point) timeLimitsBlockFetch = ProtocolTimeLimits stateToLimit where @@ -58,17 +65,21 @@ timeLimitsBlockFetch = ProtocolTimeLimits stateToLimit stateToLimit SingBFStreaming = longWait stateToLimit a@SingBFDone = notActiveState a --- | Codec for chain sync that encodes/decodes blocks +-- | Codec for chain sync that encodes/decodes blocks and points. -- --- NOTE: See 'wrapCBORinCBOR' and 'unwrapCBORinCBOR' if you want to use this +-- /NOTE:/ See 'wrapCBORinCBOR' and 'unwrapCBORinCBOR' if you want to use this -- with a block type that has annotations. codecBlockFetch :: forall block point m. MonadST m => (block -> CBOR.Encoding) + -- ^ encode block -> (forall s. CBOR.Decoder s block) + -- ^ decode block -> (point -> CBOR.Encoding) + -- ^ encode point -> (forall s. CBOR.Decoder s point) + -- ^ decode point -> Codec (BlockFetch block point) CBOR.DeserialiseFailure m LBS.ByteString codecBlockFetch encodeBlock decodeBlock encodePoint decodePoint = @@ -128,7 +139,7 @@ codecBlockFetch encodeBlock decodeBlock codecBlockFetchId - :: forall block point m. + :: forall (block :: Type) (point :: Type) m. Monad m => Codec (BlockFetch block point) CodecFailure m (AnyMessage (BlockFetch block point)) diff --git a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/ChainSync/Codec.hs b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/ChainSync/Codec.hs index 1391ceba433..3d9d807143a 100644 --- a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/ChainSync/Codec.hs +++ b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/ChainSync/Codec.hs @@ -1,7 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -11,19 +10,22 @@ module Ouroboros.Network.Protocol.ChainSync.Codec , codecChainSyncId , byteLimitsChainSync , timeLimitsChainSync - , ChainSyncTimeout (..) + , maxChainSyncTimeout + , minChainSyncTimeout ) where import Control.Monad.Class.MonadST import Control.Monad.Class.MonadTime.SI -import Network.TypedProtocol.Codec.CBOR +import Network.TypedProtocol.Codec.CBOR hiding (decode, encode) import Ouroboros.Network.Protocol.ChainSync.Type import Ouroboros.Network.Protocol.Limits import Data.ByteString.Lazy qualified as LBS +import Data.Kind (Type) import Data.Singletons (withSingI) +import System.Random (StdGen, randomR) import Codec.CBOR.Decoding (decodeListLen, decodeWord) import Codec.CBOR.Decoding qualified as CBOR @@ -34,8 +36,8 @@ import Text.Printf -- | Byte Limits -byteLimitsChainSync :: forall bytes header point tip . - (bytes -> Word) +byteLimitsChainSync :: forall bytes (header :: Type) (point :: Type) (tip :: Type) . + (bytes -> Word) -- ^ compute size of `bytes` -> ProtocolSizeLimits (ChainSync header point tip) bytes byteLimitsChainSync = ProtocolSizeLimits stateToLimit where @@ -47,60 +49,86 @@ byteLimitsChainSync = ProtocolSizeLimits stateToLimit stateToLimit SingIntersect = smallByteLimit stateToLimit a@SingDone = notActiveState a --- | Configurable timeouts --- --- These are configurable for at least the following reasons. + +-- | Chain sync `mustReplayTimeout` lower bound. -- --- o So that deployment and testing can use different values. +minChainSyncTimeout :: DiffTime +minChainSyncTimeout = 135 + + +-- | Chain sync `mustReplayTimeout` upper bound. -- --- o So that a net running Praos can better cope with streaks of empty slots. --- (See @intersectmbo/ouroboros-network#2245@.) -data ChainSyncTimeout = ChainSyncTimeout - { canAwaitTimeout :: Maybe DiffTime - , intersectTimeout :: Maybe DiffTime - , mustReplyTimeout :: Maybe DiffTime - , idleTimeout :: Maybe DiffTime - } +maxChainSyncTimeout :: DiffTime +maxChainSyncTimeout = 269 + -- | Time Limits -- --- > 'TokIdle' 'waitForever' (ie never times out) --- > 'TokNext TokCanAwait' the given 'canAwaitTimeout' --- > 'TokNext TokMustReply' the given 'mustReplyTimeout' --- > 'TokIntersect' the given 'intersectTimeout' -timeLimitsChainSync :: forall header point tip. - ChainSyncTimeout +-- +----------------------------+-------------------------------------------------------------+ +-- | ChainSync State | timeout (s) | +-- +============================+=============================================================+ +-- | @'StIdle'@ | 'waitForever' (i.e. never times out) | +-- +----------------------------+-------------------------------------------------------------+ +-- | @'StNext' 'StCanAwait'@ | 'shortWait' | +-- +----------------------------+-------------------------------------------------------------+ +-- | @'StNext' 'StMustReply'@ | randomly picked using uniform distribution from | +-- | | the range @('minChainSyncTimeout', 'maxChainSyncTimeout')@, | +-- | | which corresponds to a chance of an empty streak of slots | +-- | | between `0.0001%` and `1%` probability. | +-- +----------------------------+-------------------------------------------------------------+ +-- | @'StIntersect'@ | 'shortWait' | +-- +----------------------------+-------------------------------------------------------------+ +-- +timeLimitsChainSync :: forall (header :: Type) (point :: Type) (tip :: Type). + StdGen -> ProtocolTimeLimits (ChainSync header point tip) -timeLimitsChainSync csTimeouts = ProtocolTimeLimits stateToLimit +timeLimitsChainSync rnd = ProtocolTimeLimits stateToLimit where - ChainSyncTimeout - { canAwaitTimeout - , intersectTimeout - , mustReplyTimeout - , idleTimeout - } = csTimeouts - stateToLimit :: forall (st :: ChainSync header point tip). ActiveState st => StateToken st -> Maybe DiffTime - stateToLimit SingIdle = idleTimeout - stateToLimit (SingNext SingCanAwait) = canAwaitTimeout - stateToLimit (SingNext SingMustReply) = mustReplyTimeout - stateToLimit SingIntersect = intersectTimeout - stateToLimit a@SingDone = notActiveState a - --- | Codec for chain sync that encodes/decodes headers + stateToLimit SingIdle = Just 3673 + stateToLimit SingIntersect = shortWait + stateToLimit (SingNext SingCanAwait) = shortWait + stateToLimit (SingNext SingMustReply) = + -- We draw from a range for which streaks of empty slots ranges + -- from 0.0001% up to 1% probability. + -- t = T_s [log (1-Y) / log (1-f)] + -- Y = [0.99, 0.999...] + -- T_s = slot length of 1s. + -- f = 0.05 + -- The timeout is randomly picked per state to avoid all peers go down at + -- the same time in case of a long streak of empty slots, and thus to + -- avoid global synchronisation. The timeout is picked uniformly from + -- the interval 135 - 269, which corresponds to 99.9% to + -- 99.9999% thresholds. + let timeout :: DiffTime + timeout = realToFrac . fst + . randomR ( realToFrac minChainSyncTimeout :: Double + , realToFrac maxChainSyncTimeout :: Double + ) + $ rnd + in Just timeout + stateToLimit a@SingDone = notActiveState a + +-- | Codec for chain sync that encodes/decodes headers, points & tips. -- --- NOTE: See 'wrapCBORinCBOR' and 'unwrapCBORinCBOR' if you want to use this +-- /NOTE:/ See 'wrapCBORinCBOR' and 'unwrapCBORinCBOR' if you want to use this -- with a header type that has annotations. codecChainSync :: forall header point tip m. (MonadST m) => (header -> CBOR.Encoding) + -- ^ encode header -> (forall s . CBOR.Decoder s header) + -- ^ decode header -> (point -> CBOR.Encoding) + -- ^ encode point -> (forall s . CBOR.Decoder s point) + -- ^ decode point -> (tip -> CBOR.Encoding) + -- ^ encode tip -> (forall s. CBOR.Decoder s tip) + -- ^ decode tip -> Codec (ChainSync header point tip) CBOR.DeserialiseFailure m LBS.ByteString codecChainSync encodeHeader decodeHeader @@ -222,7 +250,7 @@ decodeList dec = do -- | An identity 'Codec' for the 'ChainSync' protocol. It does not do any -- serialisation. It keeps the typed messages, wrapped in 'AnyMessage'. -- -codecChainSyncId :: forall header point tip m. Monad m +codecChainSyncId :: forall (header :: Type) (point :: Type) (tip :: Type) m. Monad m => Codec (ChainSync header point tip) CodecFailure m (AnyMessage (ChainSync header point tip)) codecChainSyncId = Codec encode decode diff --git a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/KeepAlive/Codec.hs b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/KeepAlive/Codec.hs index 687f9843c97..21c7436bf4a 100644 --- a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/KeepAlive/Codec.hs +++ b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/KeepAlive/Codec.hs @@ -86,6 +86,16 @@ byteLimitsKeepAlive = ProtocolSizeLimits sizeLimitForState sizeLimitForState a@SingDone = notActiveState a +-- | 'KeepAlive' time limits. +-- +-- +--------------------+---------------+ +-- | 'KeepAlive' state | timeout (s) | +-- +====================+===============+ +-- | `StClient` | @Just 97@ | +-- +--------------------+---------------+ +-- | `StServer` | @Just 60@ | +-- +--------------------+---------------+ +-- timeLimitsKeepAlive :: ProtocolTimeLimits KeepAlive timeLimitsKeepAlive = ProtocolTimeLimits { timeLimitForState } where diff --git a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/PeerSharing/Codec.hs b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/PeerSharing/Codec.hs index 813eb2ab22a..2c15fb4c3c2 100644 --- a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/PeerSharing/Codec.hs +++ b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/PeerSharing/Codec.hs @@ -6,12 +6,18 @@ {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} -module Ouroboros.Network.Protocol.PeerSharing.Codec where +module Ouroboros.Network.Protocol.PeerSharing.Codec + ( codecPeerSharing + , codecPeerSharingId + , byteLimitsPeerSharing + , timeLimitsPeerSharing + ) where import Control.Monad.Class.MonadST import Control.Monad.Class.MonadTime.SI (DiffTime) import Data.ByteString.Lazy (ByteString) +import Data.Kind (Type) import Codec.CBOR.Decoding qualified as CBOR import Codec.CBOR.Encoding qualified as CBOR @@ -23,14 +29,16 @@ import Network.TypedProtocol.Codec.CBOR import Ouroboros.Network.Protocol.Limits import Ouroboros.Network.Protocol.PeerSharing.Type -codecPeerSharing :: forall m peerAddress. +codecPeerSharing :: forall m (peerAddress :: Type). MonadST m => (peerAddress -> CBOR.Encoding) + -- ^ encode 'peerAddress' -> (forall s . CBOR.Decoder s peerAddress) + -- ^ decode 'peerAddress' -> Codec (PeerSharing peerAddress) - CBOR.DeserialiseFailure - m - ByteString + CBOR.DeserialiseFailure + m + ByteString codecPeerSharing encodeAddress decodeAddress = mkCodecCborLazyBS encodeMsg decodeMsg where encodeMsg :: Message (PeerSharing peerAddress) st st' @@ -85,7 +93,7 @@ codecPeerSharing encodeAddress decodeAddress = mkCodecCborLazyBS encodeMsg decod Just n -> CBOR.decodeSequenceLenN (flip (:)) [] reverse n dec codecPeerSharingId - :: forall peerAddress m. + :: forall (peerAddress :: Type) m. Monad m => Codec (PeerSharing peerAddress) CodecFailure m (AnyMessage (PeerSharing peerAddress)) codecPeerSharingId = Codec encodeMsg decodeMsg @@ -119,8 +127,8 @@ codecPeerSharingId = Codec encodeMsg decodeMsg maxTransmissionUnit :: Word maxTransmissionUnit = 4 * 1440 -byteLimitsPeerSharing :: forall peerAddress bytes. - (bytes -> Word) +byteLimitsPeerSharing :: forall (peerAddress :: Type) bytes. + (bytes -> Word) -- ^ compute size of bytes -> ProtocolSizeLimits (PeerSharing peerAddress) bytes byteLimitsPeerSharing = ProtocolSizeLimits sizeLimitForState where @@ -132,7 +140,17 @@ byteLimitsPeerSharing = ProtocolSizeLimits sizeLimitForState sizeLimitForState a@SingDone = notActiveState a -timeLimitsPeerSharing :: forall peerAddress. ProtocolTimeLimits (PeerSharing peerAddress) +-- | 'PeerSharing' timeouts. +-- +-- +----------------------+---------------+ +-- | 'PeerSharing' state | timeout (s) | +-- +======================+===============+ +-- | `StIdle` | `waitForever` | +-- +----------------------+---------------+ +-- | `StBusy` | `longWait` | +-- +----------------------+---------------+ +-- +timeLimitsPeerSharing :: forall (peerAddress :: Type). ProtocolTimeLimits (PeerSharing peerAddress) timeLimitsPeerSharing = ProtocolTimeLimits { timeLimitForState } where timeLimitForState :: forall (st :: PeerSharing peerAddress). diff --git a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/TxSubmission2/Codec.hs b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/TxSubmission2/Codec.hs index 6302cf1de49..b97528369ad 100644 --- a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/TxSubmission2/Codec.hs +++ b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/TxSubmission2/Codec.hs @@ -10,21 +10,20 @@ module Ouroboros.Network.Protocol.TxSubmission2.Codec ( codecTxSubmission2 , codecTxSubmission2Id - , encodeTxSubmission2 - , decodeTxSubmission2 , byteLimitsTxSubmission2 , timeLimitsTxSubmission2 ) where import Control.Monad.Class.MonadST import Control.Monad.Class.MonadTime.SI +import Data.ByteString.Lazy (ByteString) +import Data.Kind (Type) import Data.List.NonEmpty qualified as NonEmpty +import Text.Printf import Codec.CBOR.Decoding qualified as CBOR import Codec.CBOR.Encoding qualified as CBOR import Codec.CBOR.Read qualified as CBOR -import Data.ByteString.Lazy (ByteString) -import Text.Printf import Network.TypedProtocol.Codec.CBOR @@ -47,13 +46,23 @@ byteLimitsTxSubmission2 = ProtocolSizeLimits stateToLimit stateToLimit a@SingDone = notActiveState a --- | Time Limits. +-- | 'TxSubmission2' time limits. -- --- `SingTxIds SingBlocking` No timeout --- `SingTxIds SingNonBlocking` `shortWait` timeout --- `SingTxs` `shortWait` timeout --- `SingIdle` `shortWait` timeout -timeLimitsTxSubmission2 :: forall txid tx. ProtocolTimeLimits (TxSubmission2 txid tx) +-- +-----------------------------+---------------+ +-- | 'TxSubmission2' state | timeout (s) | +-- +=============================+===============+ +-- | `StInit` | `waitForever` | +-- +-----------------------------+---------------+ +-- | `StIdle` | `waitForever` | +-- +-----------------------------+---------------+ +-- | @'StTxIds' 'StBlocking'@ | `waitForever` | +-- +-----------------------------+---------------+ +-- | @'StTxIds' 'StNonBlocking'@ | `shortWait` | +-- +-----------------------------+---------------+ +-- | `StTxs` | `shortWait` | +-- +-----------------------------+---------------+ +-- +timeLimitsTxSubmission2 :: forall (txid :: Type) (tx :: Type). ProtocolTimeLimits (TxSubmission2 txid tx) timeLimitsTxSubmission2 = ProtocolTimeLimits stateToLimit where stateToLimit :: forall (st :: TxSubmission2 txid tx). @@ -67,12 +76,16 @@ timeLimitsTxSubmission2 = ProtocolTimeLimits stateToLimit codecTxSubmission2 - :: forall txid tx m. + :: forall (txid :: Type) (tx :: Type) m. MonadST m => (txid -> CBOR.Encoding) + -- ^ encode 'txid' -> (forall s . CBOR.Decoder s txid) + -- ^ decode 'txid' -> (tx -> CBOR.Encoding) + -- ^ encode transaction -> (forall s . CBOR.Decoder s tx) + -- ^ decode transaction -> Codec (TxSubmission2 txid tx) CBOR.DeserialiseFailure m ByteString codecTxSubmission2 encodeTxId decodeTxId encodeTx decodeTx = @@ -90,16 +103,17 @@ codecTxSubmission2 encodeTxId decodeTxId decodeTxSubmission2 decodeTxId decodeTx stok len key encodeTxSubmission2 - :: forall txid tx. + :: forall (txid :: Type) (tx :: Type) (st :: TxSubmission2 txid tx) (st' :: TxSubmission2 txid tx). (txid -> CBOR.Encoding) + -- ^ encode 'txid' -> (tx -> CBOR.Encoding) - -> (forall (st :: TxSubmission2 txid tx) (st' :: TxSubmission2 txid tx). - Message (TxSubmission2 txid tx) st st' - -> CBOR.Encoding) + -- ^ encode 'tx' + -> Message (TxSubmission2 txid tx) st st' + -> CBOR.Encoding encodeTxSubmission2 encodeTxId encodeTx = encode where - encode :: forall st st'. - Message (TxSubmission2 txid tx) st st' + encode :: forall st0 st1. + Message (TxSubmission2 txid tx) st0 st1 -> CBOR.Encoding encode MsgInit = CBOR.encodeListLen 1 @@ -148,23 +162,24 @@ encodeTxSubmission2 encodeTxId encodeTx = encode decodeTxSubmission2 - :: forall txid tx. - (forall s . CBOR.Decoder s txid) - -> (forall s . CBOR.Decoder s tx) - -> (forall (st :: TxSubmission2 txid tx) s. - ActiveState st - => StateToken st - -> Int - -> Word - -> CBOR.Decoder s (SomeMessage st)) + :: forall (txid :: Type) (tx :: Type) (st :: TxSubmission2 txid tx) s. + ActiveState st + => (forall s'. CBOR.Decoder s' txid) + -- ^ decode 'txid' + -> (forall s'. CBOR.Decoder s' tx) + -- ^ decode transaction + -> StateToken st + -> Int + -> Word + -> CBOR.Decoder s (SomeMessage st) decodeTxSubmission2 decodeTxId decodeTx = decode where - decode :: forall s (st :: TxSubmission2 txid tx). - ActiveState st - => StateToken st + decode :: forall (st' :: TxSubmission2 txid tx). + ActiveState st' + => StateToken st' -> Int -> Word - -> CBOR.Decoder s (SomeMessage st) + -> CBOR.Decoder s (SomeMessage st') decode stok len key = do case (stok, len, key) of (SingInit, 1, 6) -> diff --git a/ouroboros-network-protocols/testlib/Ouroboros/Network/Protocol/Handshake/Test.hs b/ouroboros-network-protocols/testlib/Ouroboros/Network/Protocol/Handshake/Test.hs index ad66d4dab8d..5fc856f6fa6 100644 --- a/ouroboros-network-protocols/testlib/Ouroboros/Network/Protocol/Handshake/Test.hs +++ b/ouroboros-network-protocols/testlib/Ouroboros/Network/Protocol/Handshake/Test.hs @@ -1342,10 +1342,8 @@ prop_channel_simultaneous_open_sim codec versionDataCodec nullTracer -- (("server",) `contramap` Tracer Debug.traceShowM) fdConn' - let chann = fromChannel - $ muxBearerAsChannel bearer (MiniProtocolNum 0) InitiatorDir - chann' = fromChannel - $ muxBearerAsChannel bearer' (MiniProtocolNum 0) InitiatorDir + let chann = muxBearerAsChannel bearer (MiniProtocolNum 0) InitiatorDir + chann' = muxBearerAsChannel bearer' (MiniProtocolNum 0) InitiatorDir res <- prop_channel_simultaneous_open (pure (chann, chann')) codec diff --git a/ouroboros-network/CHANGELOG.md b/ouroboros-network/CHANGELOG.md index 6410d7b4375..b761225ce03 100644 --- a/ouroboros-network/CHANGELOG.md +++ b/ouroboros-network/CHANGELOG.md @@ -24,6 +24,14 @@ It is used by `outboundConnectionsState` when signaling trust state when syncing in Genesis mode. Default value is provided by the Configuration module. * Using `typed-protocols-0.2.0.0`. +* `Ouroboros.Network.NodeToClient.connectTo` takes + `OuroborosApplicationWithMinimalCtx` which is using `Void` type for responder + protocols. It anyway only accepts `InitiatorMode`, and thus no responder + protocols can be specified, nontheless this might require changing type + signature of the applications passed to it. `connectTo` returns now either + an error or the result of the first terminated mini-protocol. +* `Ouroboros.Network.NodeToNode.connectTo` returns either an error or the + result of the first terminated mini-protocol. ### Non-Breaking changes @@ -41,6 +49,7 @@ a mismatch detected. * Added `defaultDeadlineChurnInterval` and `defaultBulkChurnInterval` to Configuration module. Previously these were hard coded in node. +* Updated tests for `network-mux` changes. ## 0.17.1.1 -- 2024-08-27 diff --git a/ouroboros-network/demo/chain-sync.hs b/ouroboros-network/demo/chain-sync.hs index 5753d232049..402d9c25d67 100644 --- a/ouroboros-network/demo/chain-sync.hs +++ b/ouroboros-network/demo/chain-sync.hs @@ -230,15 +230,17 @@ clientChainSync :: [FilePath] clientChainSync sockPaths maxSlotNo = withIOManager $ \iocp -> forConcurrently_ (zip [0..] sockPaths) $ \(index, sockPath) -> do threadDelay (50000 * index) - connectToNode + void $ connectToNode (localSnocket iocp) makeLocalBearer + ConnectToArgs { + ctaHandshakeCodec = unversionedHandshakeCodec, + ctaHandshakeTimeLimits = noTimeLimitsHandshake, + ctaVersionDataCodec = unversionedProtocolDataCodec, + ctaConnectTracers = nullNetworkConnectTracers, + ctaHandshakeCallbacks = HandshakeCallbacks acceptableVersion queryVersion + } mempty - unversionedHandshakeCodec - noTimeLimitsHandshake - unversionedProtocolDataCodec - nullNetworkConnectTracers - (HandshakeCallbacks acceptableVersion queryVersion) (simpleSingletonVersions UnversionedProtocol UnversionedProtocolData @@ -472,16 +474,18 @@ clientBlockFetch sockAddrs maxSlotNo = withIOManager $ \iocp -> do chainSelection fingerprint' peerAsyncs <- sequence - [ async $ + [ async . void $ connectToNode (localSnocket iocp) makeLocalBearer + ConnectToArgs { + ctaHandshakeCodec = unversionedHandshakeCodec, + ctaHandshakeTimeLimits = noTimeLimitsHandshake, + ctaVersionDataCodec = unversionedProtocolDataCodec, + ctaConnectTracers = nullNetworkConnectTracers, + ctaHandshakeCallbacks = HandshakeCallbacks acceptableVersion queryVersion + } mempty - unversionedHandshakeCodec - noTimeLimitsHandshake - unversionedProtocolDataCodec - nullNetworkConnectTracers - (HandshakeCallbacks acceptableVersion queryVersion) (simpleSingletonVersions UnversionedProtocol UnversionedProtocolData diff --git a/ouroboros-network/io-tests/Test/Ouroboros/Network/Pipe.hs b/ouroboros-network/io-tests/Test/Ouroboros/Network/Pipe.hs index 5f5037ff5d0..50318fa6a7a 100644 --- a/ouroboros-network/io-tests/Test/Ouroboros/Network/Pipe.hs +++ b/ouroboros-network/io-tests/Test/Ouroboros/Network/Pipe.hs @@ -1,6 +1,8 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} @@ -17,6 +19,7 @@ import Control.Monad.Class.MonadAsync import Control.Monad.Class.MonadFork import Control.Monad.Class.MonadTimer.SI import Data.ByteString.Lazy qualified as BL +import Data.Monoid.Synchronisation import Data.Void (Void) import Test.ChainGenerators (TestBlockChainAndUpdates (..)) import Test.QuickCheck @@ -25,9 +28,9 @@ import Test.Tasty.QuickCheck (testProperty) import Control.Tracer +import Network.Mux qualified as Mx import Network.Mux.Bearer qualified as Mx import Network.Mux.Bearer.Pipe qualified as Mx -import Network.Mux.Compat qualified as Mx (muxStart) import Ouroboros.Network.Mux #if defined(mingw32_HOST_OS) @@ -152,9 +155,9 @@ demo chain0 updates = do let chan1 = Mx.pipeChannelFromHandles hndRead1 hndWrite2 chan2 = Mx.pipeChannelFromHandles hndRead2 hndWrite1 #endif - producerVar <- atomically $ newTVar (CPS.initChainProducerState chain0) - consumerVar <- atomically $ newTVar chain0 - done <- atomically newEmptyTMVar + producerVar <- newTVarIO (CPS.initChainProducerState chain0) + consumerVar <- newTVarIO chain0 + done <- newEmptyTMVarIO let Just expectedChain = Chain.applyChainUpdates updates chain0 target = Chain.headPoint expectedChain @@ -193,22 +196,55 @@ demo chain0 updates = do clientBearer <- Mx.getBearer Mx.makePipeChannelBearer (-1) activeTracer chan1 serverBearer <- Mx.getBearer Mx.makePipeChannelBearer (-1) activeTracer chan2 - _ <- async $ - Mx.muxStart - activeTracer - (toApplication - MinimalInitiatorContext { micConnectionId = ConnectionId "producer" "consumer" } - ResponderContext { rcConnectionId = ConnectionId "producer" "consumer" } - producerApp) - clientBearer - _ <- async $ - Mx.muxStart - activeTracer - (toApplication - MinimalInitiatorContext { micConnectionId = ConnectionId "consumer" "producer" } - ResponderContext { rcConnectionId = ConnectionId "consumer" "producer" } - consumerApp) - serverBearer + _ <- async $ do + clientMux <- Mx.newMux (toMiniProtocolInfos consumerApp) + let initCtx = MinimalInitiatorContext (ConnectionId "consumer" "producer") + resOps <- sequence + [ Mx.runMiniProtocol + clientMux + miniProtocolNum + miniProtocolDir + Mx.StartEagerly + (\a -> do + r <- action a + return (r, Nothing) + ) + | MiniProtocol{miniProtocolNum, miniProtocolRun} + <- getOuroborosApplication consumerApp + , (miniProtocolDir, action) <- + case miniProtocolRun of + InitiatorProtocolOnly initiator -> + [(Mx.InitiatorDirectionOnly, void . runMiniProtocolCb initiator initCtx)] + ] + withAsync (Mx.runMux nullTracer clientMux clientBearer) $ \aid -> do + _ <- atomically $ runFirstToFinish $ foldMap FirstToFinish resOps + Mx.stopMux clientMux + wait aid + + _ <- async $ do + serverMux <- Mx.newMux (toMiniProtocolInfos producerApp) + let respCtx = ResponderContext (ConnectionId "consumer" "producer") + resOps <- sequence + [ Mx.runMiniProtocol + serverMux + miniProtocolNum + miniProtocolDir + Mx.StartEagerly + (\a -> do + r <- action a + return (r, Nothing) + ) + | MiniProtocol{miniProtocolNum, miniProtocolRun} + <- getOuroborosApplication producerApp + , (miniProtocolDir, action) <- + case miniProtocolRun of + ResponderProtocolOnly responder -> + [(Mx.ResponderDirectionOnly, void . runMiniProtocolCb responder respCtx)] + ] + withAsync (Mx.runMux nullTracer serverMux serverBearer) $ \aid -> do + _ <- atomically $ runFirstToFinish $ foldMap FirstToFinish resOps + Mx.stopMux serverMux + wait aid void $ forkIO $ sequence_ [ do threadDelay 10e-4 -- 1 milliseconds, just to provide interest diff --git a/ouroboros-network/io-tests/Test/Ouroboros/Network/Socket.hs b/ouroboros-network/io-tests/Test/Ouroboros/Network/Socket.hs index a3a60bd24e4..a317a22bf5e 100644 --- a/ouroboros-network/io-tests/Test/Ouroboros/Network/Socket.hs +++ b/ouroboros-network/io-tests/Test/Ouroboros/Network/Socket.hs @@ -180,12 +180,14 @@ demo chain0 updates = withIOManager $ \iocp -> do (connectToNode (socketSnocket iocp) makeSocketBearer - (flip configureSocket Nothing) - nodeToNodeHandshakeCodec - noTimeLimitsHandshake - (cborTermVersionDataCodec nodeToNodeCodecCBORTerm) - nullNetworkConnectTracers - (HandshakeCallbacks acceptableVersion queryVersion) + ConnectToArgs { + ctaHandshakeCodec = nodeToNodeHandshakeCodec, + ctaHandshakeTimeLimits = noTimeLimitsHandshake, + ctaVersionDataCodec = cborTermVersionDataCodec nodeToNodeCodecCBORTerm, + ctaConnectTracers = nullNetworkConnectTracers, + ctaHandshakeCallbacks = HandshakeCallbacks acceptableVersion queryVersion + } + (`configureSocket` Nothing) (simpleSingletonVersions NodeToNodeV_7 (NodeToNodeVersionData { diff --git a/ouroboros-network/ouroboros-network.cabal b/ouroboros-network/ouroboros-network.cabal index 6d097c65629..3bcd4d816b7 100644 --- a/ouroboros-network/ouroboros-network.cabal +++ b/ouroboros-network/ouroboros-network.cabal @@ -315,6 +315,7 @@ test-suite io-tests bytestring, contra-tracer, io-classes, + monoidal-synchronisation, network, network-mux, ouroboros-network, diff --git a/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Diffusion/Node/MiniProtocols.hs b/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Diffusion/Node/MiniProtocols.hs index 496aec18135..155a24f6cf5 100644 --- a/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Diffusion/Node/MiniProtocols.hs +++ b/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Diffusion/Node/MiniProtocols.hs @@ -33,7 +33,7 @@ import Data.ByteString.Lazy (ByteString) import Data.Functor (($>)) import Data.Maybe (fromMaybe) import Data.Void (Void) -import System.Random (RandomGen, StdGen) +import System.Random (RandomGen, StdGen, mkStdGen) import Codec.CBOR.Read qualified as CBOR import Codec.Serialise qualified as Serialise @@ -137,7 +137,7 @@ data LimitsAndTimeouts header block = LimitsAndTimeouts :: ProtocolSizeLimits (ChainSync header (Point block) (Tip block)) ByteString , chainSyncTimeLimits - :: ProtocolTimeLimits (ChainSync header (Point block) (Tip block)) + :: StdGen -> ProtocolTimeLimits (ChainSync header (Point block) (Tip block)) -- block-fetch , blockFetchLimits @@ -394,8 +394,9 @@ applications debugTracer nodeKernel bracket (registerClientChains nodeKernel (remoteAddress connId)) (\_ -> unregisterClientChains nodeKernel (remoteAddress connId)) (\chainVar -> - runPeerWithLimits + runPeerWithLimitsRnd nullTracer + (mkStdGen 0) -- TODO chainSyncCodec (chainSyncSizeLimits limits) (chainSyncTimeLimits limits) @@ -410,8 +411,9 @@ applications debugTracer nodeKernel :: MiniProtocolCb (ResponderContext NtNAddr) ByteString m () chainSyncResponder = MiniProtocolCb $ \_ctx channel -> do labelThisThread "ChainSyncServer" - runPeerWithLimits + runPeerWithLimitsRnd nullTracer + (mkStdGen 0) chainSyncCodec (chainSyncSizeLimits limits) (chainSyncTimeLimits limits) diff --git a/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Mux.hs b/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Mux.hs index 21e87202d7b..cdd872cf987 100644 --- a/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Mux.hs +++ b/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Mux.hs @@ -1,5 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} @@ -9,6 +10,8 @@ module Test.Ouroboros.Network.Mux (tests) where import Codec.Serialise (Serialise (..)) +import Data.Functor (void) +import Data.Monoid.Synchronisation (FirstToFinish (..)) import Control.Applicative (Alternative) import Control.Concurrent.Class.MonadSTM.Strict @@ -43,9 +46,9 @@ import Ouroboros.Network.Protocol.ChainSync.Server qualified as ChainSync import Ouroboros.Network.Protocol.ChainSync.Type qualified as ChainSync import Ouroboros.Network.Util.ShowProxy +import Network.Mux qualified as Mx import Network.Mux.Bearer qualified as Mx import Network.Mux.Bearer.Queues qualified as Mx -import Network.Mux.Compat qualified as Mx (muxStart) import Ouroboros.Network.Mux as Mx @@ -105,9 +108,9 @@ demo chain0 updates delay = do client_r <- atomically $ newTBQueue 10 let server_w = client_r server_r = client_w - producerVar <- atomically $ newTVar (CPS.initChainProducerState chain0) - consumerVar <- atomically $ newTVar chain0 - done <- atomically newEmptyTMVar + producerVar <- newTVarIO (CPS.initChainProducerState chain0) + consumerVar <- newTVarIO chain0 + done <- newEmptyTMVarIO let Just expectedChain = Chain.applyChainUpdates updates chain0 target = Chain.headPoint expectedChain @@ -160,22 +163,55 @@ demo chain0 updates delay = do Mx.readQueue = server_r } - clientAsync <- async $ - Mx.muxStart - activeTracer - (Mx.toApplication - MinimalInitiatorContext { micConnectionId = ConnectionId "client" "server" } - ResponderContext { rcConnectionId = ConnectionId "client" "server" } - consumerApp) - clientBearer - serverAsync <- async $ - Mx.muxStart - activeTracer - (Mx.toApplication - MinimalInitiatorContext { micConnectionId = ConnectionId "server" "client" } - ResponderContext { rcConnectionId = ConnectionId "server" "client" } - producerApp) - serverBearer + clientAsync <- async $ do + clientMux <- Mx.newMux (toMiniProtocolInfos consumerApp) + let initCtx = MinimalInitiatorContext (ConnectionId "consumer" "producer") + resOps <- sequence + [ Mx.runMiniProtocol + clientMux + miniProtocolNum + miniProtocolDir + Mx.StartEagerly + (\a -> do + r <- action a + return (r, Nothing) + ) + | MiniProtocol{miniProtocolNum, miniProtocolRun} + <- getOuroborosApplication consumerApp + , (miniProtocolDir, action) <- + case miniProtocolRun of + InitiatorProtocolOnly initiator -> + [(Mx.InitiatorDirectionOnly, void . runMiniProtocolCb initiator initCtx)] + ] + withAsync (Mx.runMux nullTracer clientMux clientBearer) $ \aid -> do + _ <- atomically $ runFirstToFinish $ foldMap FirstToFinish resOps + Mx.stopMux clientMux + wait aid + + serverAsync <- async $ do + serverMux <- Mx.newMux (toMiniProtocolInfos producerApp) + let respCtx = ResponderContext (ConnectionId "producer" "consumer") + resOps <- sequence + [ Mx.runMiniProtocol + serverMux + miniProtocolNum + miniProtocolDir + Mx.StartEagerly + (\a -> do + r <- action a + return (r, Nothing) + ) + | MiniProtocol{miniProtocolNum, miniProtocolRun} + <- getOuroborosApplication producerApp + , (miniProtocolDir, action) <- + case miniProtocolRun of + ResponderProtocolOnly responder -> + [(Mx.ResponderDirectionOnly, void . runMiniProtocolCb responder respCtx)] + ] + withAsync (Mx.runMux nullTracer serverMux serverBearer) $ \aid -> do + _ <- atomically $ runFirstToFinish $ foldMap FirstToFinish resOps + Mx.stopMux serverMux + wait aid updateAid <- async $ sequence_ [ do diff --git a/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Testnet.hs b/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Testnet.hs index 324ca9588d4..a17c3b2f0a4 100644 --- a/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Testnet.hs +++ b/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Testnet.hs @@ -314,7 +314,6 @@ unit_cm_valid_transitions = [ ( NodeArgs (-2) InitiatorAndResponderDiffusionMode - (Just 269) (Map.fromList [(RelayAccessAddress "0:71:0:1:0:1:0:1" 65534, DoAdvertisePeer)]) GenesisMode @@ -358,7 +357,6 @@ unit_cm_valid_transitions = , ( NodeArgs 0 InitiatorAndResponderDiffusionMode - (Just 90) Map.empty GenesisMode (Script (DontUseBootstrapPeers :| [])) @@ -871,7 +869,7 @@ unit_4177 = prop_inbound_governor_transitions_coverage absNoAttenuation script script = DiffusionScript (SimArgs 1 10) (singletonTimedScript Map.empty) - [ ( NodeArgs (-6) InitiatorAndResponderDiffusionMode (Just 180) + [ ( NodeArgs (-6) InitiatorAndResponderDiffusionMode (Map.fromList [(RelayAccessDomain "test2" 65535, DoAdvertisePeer)]) PraosMode (Script ((UseBootstrapPeers [RelayAccessDomain "bootstrap" 00000]) :| [])) @@ -904,7 +902,7 @@ unit_4177 = prop_inbound_governor_transitions_coverage absNoAttenuation script ,Reconfigure 4.870967741935 [(1,1,Map.fromList [(RelayAccessDomain "test2" 65535,(DoAdvertisePeer, IsNotTrustable))])] ] ) - , ( NodeArgs (1) InitiatorAndResponderDiffusionMode (Just 135) + , ( NodeArgs (1) InitiatorAndResponderDiffusionMode (Map.fromList [(RelayAccessAddress "0:7:0:7::" 65533, DoAdvertisePeer)]) PraosMode (Script ((UseBootstrapPeers [RelayAccessDomain "bootstrap" 00000]) :| [])) @@ -1465,7 +1463,6 @@ unit_4191 = testWithIOSim prop_diffusion_dns_can_recover 125000 absInfo script [(NodeArgs 16 InitiatorAndResponderDiffusionMode - (Just 224) Map.empty PraosMode (Script ((UseBootstrapPeers [RelayAccessDomain "bootstrap" 00000]) :| [])) @@ -2450,7 +2447,6 @@ async_demotion_network_script = common = NodeArgs { naSeed = 10, naDiffusionMode = InitiatorAndResponderDiffusionMode, - naMbTime = Just 1, naPublicRoots = Map.empty, naConsensusMode = PraosMode, naBootstrapPeers = (Script ((UseBootstrapPeers [RelayAccessDomain "bootstrap" 00000]) :| [])), @@ -2948,7 +2944,7 @@ prop_unit_4258 = diffScript = DiffusionScript (SimArgs 1 10) (singletonTimedScript Map.empty) - [( NodeArgs (-3) InitiatorAndResponderDiffusionMode (Just 224) + [( NodeArgs (-3) InitiatorAndResponderDiffusionMode (Map.fromList []) PraosMode (Script ((UseBootstrapPeers [RelayAccessDomain "bootstrap" 00000]) :| [])) @@ -2983,7 +2979,7 @@ prop_unit_4258 = Reconfigure 4.190476190476 [] ] ), - ( NodeArgs (-5) InitiatorAndResponderDiffusionMode (Just 269) + ( NodeArgs (-5) InitiatorAndResponderDiffusionMode (Map.fromList [(RelayAccessAddress "0.0.0.4" 9, DoAdvertisePeer)]) PraosMode (Script ((UseBootstrapPeers [RelayAccessDomain "bootstrap" 00000]) :| [])) @@ -3056,7 +3052,6 @@ prop_unit_reconnect = [(NodeArgs (-3) InitiatorAndResponderDiffusionMode - (Just 224) Map.empty PraosMode (Script (DontUseBootstrapPeers :| [])) @@ -3088,7 +3083,6 @@ prop_unit_reconnect = , (NodeArgs (-1) InitiatorAndResponderDiffusionMode - (Just 2) Map.empty PraosMode (Script (DontUseBootstrapPeers :| [])) @@ -3506,7 +3500,6 @@ unit_peer_sharing = defaultNodeArgs naConsensusMode = NodeArgs { naSeed = 0, naDiffusionMode = InitiatorAndResponderDiffusionMode, - naMbTime = Nothing, naPublicRoots = mempty, naBootstrapPeers = singletonScript DontUseBootstrapPeers, naAddr = undefined, diff --git a/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Testnet/Simulation/Node.hs b/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Testnet/Simulation/Node.hs index be45de816e6..3578c8c7ec5 100644 --- a/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Testnet/Simulation/Node.hs +++ b/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Testnet/Simulation/Node.hs @@ -90,8 +90,8 @@ import Ouroboros.Network.PeerSelection.PeerStateActions (PeerSelectionActionsTrace) import Ouroboros.Network.Protocol.BlockFetch.Codec (byteLimitsBlockFetch, timeLimitsBlockFetch) -import Ouroboros.Network.Protocol.ChainSync.Codec (ChainSyncTimeout (..), - byteLimitsChainSync, timeLimitsChainSync) +import Ouroboros.Network.Protocol.ChainSync.Codec (byteLimitsChainSync, + timeLimitsChainSync) import Ouroboros.Network.Protocol.Handshake.Version (Accept (Accept)) import Ouroboros.Network.Protocol.KeepAlive.Codec (byteLimitsKeepAlive, timeLimitsKeepAlive) @@ -192,8 +192,6 @@ data NodeArgs = { naSeed :: Int -- ^ 'randomBlockGenerationArgs' seed argument , naDiffusionMode :: DiffusionMode - , naMbTime :: Maybe DiffTime - -- ^ 'LimitsAndTimeouts' argument , naPublicRoots :: Map RelayAccessPoint PeerAdvertise -- ^ 'Interfaces' relays auxiliary value , naConsensusMode :: ConsensusMode @@ -222,7 +220,7 @@ data NodeArgs = } instance Show NodeArgs where - show NodeArgs { naSeed, naDiffusionMode, naMbTime, naBootstrapPeers, naPublicRoots, + show NodeArgs { naSeed, naDiffusionMode, naBootstrapPeers, naPublicRoots, naAddr, naPeerSharing, naLocalRootPeers, naPeerTargets, naDNSTimeoutScript, naDNSLookupDelayScript, naChainSyncExitOnBlockNo, naChainSyncEarlyExit, naFetchModeScript, naConsensusMode } = @@ -230,7 +228,6 @@ instance Show NodeArgs where , "(" ++ show naSeed ++ ")" , show naDiffusionMode , show naConsensusMode - , "(" ++ show naMbTime ++ ")" , "(" ++ show naPublicRoots ++ ")" , "(" ++ show naBootstrapPeers ++ ")" , "(" ++ show naAddr ++ ")" @@ -378,20 +375,6 @@ genNodeArgs relays minConnected localRootPeers relay = flip suchThat hasUpstream , (3, pure InitiatorAndResponderDiffusionMode) ] - -- These values approximately correspond to false positive - -- thresholds for streaks of empty slots with 99% probability, - -- 99.9% probability up to 99.999% probability. - -- t = T_s [log (1-Y) / log (1-f)] - -- Y = [0.99, 0.999...] - -- - -- T_s = slot length of 1s. - -- f = 0.05 - -- The timeout is randomly picked per bearer to avoid all bearers - -- going down at the same time in case of a long streak of empty - -- slots. TODO: workaround until peer selection governor. - -- Taken from ouroboros-consensus/src/Ouroboros/Consensus/Node.hs - mustReplyTimeout <- Just <$> oneof (pure <$> [90, 135, 180, 224, 269]) - -- Make sure our targets for active peers cover the maximum of peers -- one generated SmallTargets deadlineTargets <- resize (length relays * 2) arbitrary @@ -441,7 +424,6 @@ genNodeArgs relays minConnected localRootPeers relay = flip suchThat hasUpstream $ NodeArgs { naSeed = seed , naDiffusionMode = diffusionMode - , naMbTime = mustReplyTimeout , naPublicRoots = publicRoots -- TODO: we haven't been using public root peers so far because we set -- `UseLedgerPeers 0`! @@ -1058,7 +1040,6 @@ diffusionSimulation } NodeArgs { naSeed = seed - , naMbTime = mustReplyTimeout , naPublicRoots = publicRoots , naConsensusMode = consensusMode , naBootstrapPeers = bootstrapPeers @@ -1098,22 +1079,12 @@ diffusionSimulation bgaRng quota - stdChainSyncTimeout :: ChainSyncTimeout - stdChainSyncTimeout = do - ChainSyncTimeout - { canAwaitTimeout = shortWait - , intersectTimeout = shortWait - , mustReplyTimeout - , idleTimeout = Nothing - } - limitsAndTimeouts :: NodeKernel.LimitsAndTimeouts BlockHeader Block limitsAndTimeouts = NodeKernel.LimitsAndTimeouts { NodeKernel.chainSyncLimits = defaultMiniProtocolsLimit , NodeKernel.chainSyncSizeLimits = byteLimitsChainSync (const 0) - , NodeKernel.chainSyncTimeLimits = - timeLimitsChainSync stdChainSyncTimeout + , NodeKernel.chainSyncTimeLimits = timeLimitsChainSync , NodeKernel.blockFetchLimits = defaultMiniProtocolsLimit , NodeKernel.blockFetchSizeLimits = byteLimitsBlockFetch (const 0) , NodeKernel.blockFetchTimeLimits = timeLimitsBlockFetch diff --git a/ouroboros-network/src/Ouroboros/Network/Diffusion/Configuration.hs b/ouroboros-network/src/Ouroboros/Network/Diffusion/Configuration.hs index 9a94ff19ac4..42ceb03a2f9 100644 --- a/ouroboros-network/src/Ouroboros/Network/Diffusion/Configuration.hs +++ b/ouroboros-network/src/Ouroboros/Network/Diffusion/Configuration.hs @@ -11,7 +11,6 @@ module Ouroboros.Network.Diffusion.Configuration , defaultDiffusionMode , defaultPeerSharing , defaultBlockFetchConfiguration - , defaultChainSyncTimeout , defaultDeadlineTargets , defaultSyncTargets , defaultDeadlineChurnInterval @@ -19,7 +18,6 @@ module Ouroboros.Network.Diffusion.Configuration -- re-exports , AcceptedConnectionsLimit (..) , BlockFetchConfiguration (..) - , ChainSyncTimeout (..) , ConsensusModePeerTargets (..) , DiffusionMode (..) , MiniProtocolParameters (..) @@ -42,7 +40,6 @@ module Ouroboros.Network.Diffusion.Configuration ) where import Control.Monad.Class.MonadTime.SI -import System.Random (randomRIO) import Ouroboros.Network.BlockFetch (BlockFetchConfiguration (..)) import Ouroboros.Network.ConnectionManager.Core (defaultProtocolIdleTimeout, @@ -50,8 +47,7 @@ import Ouroboros.Network.ConnectionManager.Core (defaultProtocolIdleTimeout, import Ouroboros.Network.ConsensusMode import Ouroboros.Network.Diffusion (P2P (..)) import Ouroboros.Network.Diffusion.Policies (closeConnectionTimeout, - deactivateTimeout, maxChainSyncTimeout, minChainSyncTimeout, - peerMetricsConfiguration) + deactivateTimeout, peerMetricsConfiguration) import Ouroboros.Network.NodeToNode (DiffusionMode (..), MiniProtocolParameters (..), defaultMiniProtocolParameters) import Ouroboros.Network.PeerSelection.Governor.Types @@ -61,9 +57,7 @@ import Ouroboros.Network.PeerSelection.LedgerPeers.Type import Ouroboros.Network.PeerSelection.PeerSharing (PeerSharing (..)) import Ouroboros.Network.PeerSharing (ps_POLICY_PEER_SHARE_MAX_PEERS, ps_POLICY_PEER_SHARE_STICKY_TIME) -import Ouroboros.Network.Protocol.ChainSync.Codec (ChainSyncTimeout (..)) import Ouroboros.Network.Protocol.Handshake (handshake_QUERY_SHUTDOWN_DELAY) -import Ouroboros.Network.Protocol.Limits (shortWait) import Ouroboros.Network.Server.RateLimiting (AcceptedConnectionsLimit (..)) @@ -147,33 +141,6 @@ defaultBlockFetchConfiguration bfcSalt = bfcDecisionLoopInterval = 0.01, -- 10ms bfcSalt } -defaultChainSyncTimeout :: IO ChainSyncTimeout -defaultChainSyncTimeout = do - -- These values approximately correspond to false positive - -- thresholds for streaks of empty slots with 99% probability, - -- 99.9% probability up to 99.999% probability. - -- t = T_s [log (1-Y) / log (1-f)] - -- Y = [0.99, 0.999...] - -- T_s = slot length of 1s. - -- f = 0.05 - -- The timeout is randomly picked per bearer to avoid all bearers - -- going down at the same time in case of a long streak of empty - -- slots. - -- To avoid global synchronosation the timeout is picked uniformly - -- from the interval 135 - 269, corresponds to the a 99.9% to - -- 99.9999% thresholds. - -- TODO: The timeout should be drawn at random everytime chainsync - -- enters the must reply state. A static per connection timeout - -- leads to selection preassure for connections with a large - -- timeout, see #4244. - mustReplyTimeout <- Just . realToFrac <$> randomRIO ( realToFrac minChainSyncTimeout :: Double - , realToFrac maxChainSyncTimeout :: Double - ) - return ChainSyncTimeout { canAwaitTimeout = shortWait, - intersectTimeout = shortWait, - mustReplyTimeout, - idleTimeout = Just 3673 } - defaultDeadlineChurnInterval :: DiffTime defaultDeadlineChurnInterval = 3300 diff --git a/ouroboros-network/src/Ouroboros/Network/Diffusion/Policies.hs b/ouroboros-network/src/Ouroboros/Network/Diffusion/Policies.hs index d4f6c0a65db..35b68d34f54 100644 --- a/ouroboros-network/src/Ouroboros/Network/Diffusion/Policies.hs +++ b/ouroboros-network/src/Ouroboros/Network/Diffusion/Policies.hs @@ -41,18 +41,6 @@ deactivateTimeout = 300 closeConnectionTimeout :: DiffTime closeConnectionTimeout = 120 - --- | Chain sync `mustReplayTimeout` lower bound. --- -minChainSyncTimeout :: DiffTime -minChainSyncTimeout = 135 - - --- | Chain sync `mustReplayTimeout` upper bound. --- -maxChainSyncTimeout :: DiffTime -maxChainSyncTimeout = 269 - -- | Churn timeouts after 60s trying to establish a connection. -- -- This doesn't mean the connection is terminated after it, just churns moves diff --git a/ouroboros-network/src/Ouroboros/Network/NodeToClient.hs b/ouroboros-network/src/Ouroboros/Network/NodeToClient.hs index 2e52b46a56a..dbae7602edc 100644 --- a/ouroboros-network/src/Ouroboros/Network/NodeToClient.hs +++ b/ouroboros-network/src/Ouroboros/Network/NodeToClient.hs @@ -18,6 +18,7 @@ module Ouroboros.Network.NodeToClient , NetworkConnectTracers (..) , nullNetworkConnectTracers , connectTo + , connectToWithMux , NetworkServerTracers (..) , nullNetworkServerTracers , NetworkMutableState (..) @@ -75,18 +76,20 @@ module Ouroboros.Network.NodeToClient import Cardano.Prelude (FatalError) import Control.Concurrent.Async qualified as Async -import Control.Exception (ErrorCall, IOException) +import Control.Exception (ErrorCall, IOException, SomeException) import Control.Monad (forever) import Control.Monad.Class.MonadTimer.SI import Codec.CBOR.Term qualified as CBOR import Data.ByteString.Lazy qualified as BL +import Data.Functor (void) import Data.Functor.Contravariant (contramap) import Data.Functor.Identity (Identity (..)) import Data.Kind (Type) -import Data.Void (Void) +import Data.Void (Void, absurd) import Network.Mux (WithMuxBearer (..)) +import Network.Mux qualified as Mx import Network.Mux.Types (MuxRuntimeError (..)) import Network.TypedProtocol.Peer.Client import Network.TypedProtocol.Stateful.Peer.Client qualified as Stateful @@ -237,25 +240,80 @@ connectTo -> Versions NodeToClientVersion NodeToClientVersionData (OuroborosApplicationWithMinimalCtx - InitiatorMode LocalAddress BL.ByteString IO a b) + InitiatorMode LocalAddress BL.ByteString IO a Void) -- ^ A dictionary of protocol versions & applications to run on an established -- connection. The application to run will be chosen by initial handshake -- protocol (the highest shared version will be chosen). -> FilePath -- ^ path of the unix socket or named pipe - -> IO () + -> IO (Either SomeException a) connectTo snocket tracers versions path = - connectToNode snocket - makeLocalBearer - mempty - nodeToClientHandshakeCodec - noTimeLimitsHandshake - (cborTermVersionDataCodec nodeToClientCodecCBORTerm) - tracers - (HandshakeCallbacks acceptableVersion queryVersion) - versions - Nothing - (localAddressFromPath path) + fmap fn <$> + connectToNode + snocket + makeLocalBearer + ConnectToArgs { + ctaHandshakeCodec = nodeToClientHandshakeCodec, + ctaHandshakeTimeLimits = noTimeLimitsHandshake, + ctaVersionDataCodec = cborTermVersionDataCodec nodeToClientCodecCBORTerm, + ctaConnectTracers = tracers, + ctaHandshakeCallbacks = HandshakeCallbacks acceptableVersion queryVersion + } + mempty + versions + Nothing + (localAddressFromPath path) + where + fn :: forall x. Either x Void -> x + fn = either id absurd + +-- | A version of `connectTo` which exposes `Mx.Mux` interfaces which allows to +-- run mini-protocols and handle their termination (e.g. restart them when they +-- terminate or error). +-- +connectToWithMux + :: LocalSnocket + -- ^ callback constructed by 'Ouroboros.Network.IOManager.withIOManager' + -> NetworkConnectTracers LocalAddress NodeToClientVersion + -> Versions NodeToClientVersion + NodeToClientVersionData + (OuroborosApplicationWithMinimalCtx + InitiatorMode LocalAddress BL.ByteString IO a b) + -- ^ A dictionary of protocol versions & applications to run on an established + -- connection. The application to run will be chosen by initial handshake + -- protocol (the highest shared version will be chosen). + -> FilePath + -- ^ path of the unix socket or named pipe + -> ( ConnectionId LocalAddress + -> NodeToClientVersion + -> NodeToClientVersionData + -> OuroborosApplicationWithMinimalCtx InitiatorMode LocalAddress BL.ByteString IO a b + -> Mx.Mux InitiatorMode IO + -> Async.Async () + -> IO x) + -- ^ callback which has access to negotiated protocols and mux handle created for + -- that connection. The `Async` is a handle the the thread which runs + -- `Mx.runMux`. The `Mux` handle allows schedule mini-protocols. + -- + -- NOTE: when the callback returns or errors, the mux thread will be killed. + -> IO x +connectToWithMux snocket tracers versions path k = + connectToNodeWithMux + snocket + makeLocalBearer + ConnectToArgs { + ctaHandshakeCodec = nodeToClientHandshakeCodec, + ctaHandshakeTimeLimits = noTimeLimitsHandshake, + ctaVersionDataCodec = cborTermVersionDataCodec nodeToClientCodecCBORTerm, + ctaConnectTracers = tracers, + ctaHandshakeCallbacks = HandshakeCallbacks acceptableVersion queryVersion + } + mempty + versions + Nothing + (localAddressFromPath path) + k + -- | A specialised version of 'Ouroboros.Network.Socket.withServerNode'. @@ -327,14 +385,16 @@ ncSubscriptionWorker nsErrorPolicyTracer networkState subscriptionParams - (connectToNode' + (void . connectToNode' sn makeLocalBearer - nodeToClientHandshakeCodec - noTimeLimitsHandshake - (cborTermVersionDataCodec nodeToClientCodecCBORTerm) - (NetworkConnectTracers nsMuxTracer nsHandshakeTracer) - (HandshakeCallbacks acceptableVersion queryVersion) + ConnectToArgs { + ctaHandshakeCodec = nodeToClientHandshakeCodec, + ctaHandshakeTimeLimits = noTimeLimitsHandshake, + ctaVersionDataCodec = cborTermVersionDataCodec nodeToClientCodecCBORTerm, + ctaConnectTracers = NetworkConnectTracers nsMuxTracer nsHandshakeTracer, + ctaHandshakeCallbacks = HandshakeCallbacks acceptableVersion queryVersion + } versions) -- | 'ErrorPolicies' for client application. Additional rules can be added by diff --git a/ouroboros-network/src/Ouroboros/Network/NodeToNode.hs b/ouroboros-network/src/Ouroboros/Network/NodeToNode.hs index b3a0871ee2d..f293418ea9b 100644 --- a/ouroboros-network/src/Ouroboros/Network/NodeToNode.hs +++ b/ouroboros-network/src/Ouroboros/Network/NodeToNode.hs @@ -107,12 +107,13 @@ module Ouroboros.Network.NodeToNode ) where import Control.Concurrent.Async qualified as Async -import Control.Exception (IOException) +import Control.Exception (IOException, SomeException) import Control.Monad.Class.MonadTime.SI (DiffTime) import Codec.CBOR.Read qualified as CBOR import Codec.CBOR.Term qualified as CBOR import Data.ByteString.Lazy qualified as BL +import Data.Functor (void) import Data.Void (Void) import Data.Word import Network.Mux (WithMuxBearer (..)) @@ -452,11 +453,17 @@ connectTo InitiatorMode Socket.SockAddr BL.ByteString IO a b) -> Maybe Socket.SockAddr -> Socket.SockAddr - -> IO () + -> IO (Either SomeException (Either a b)) connectTo sn tr = - connectToNode sn makeSocketBearer configureOutboundSocket nodeToNodeHandshakeCodec timeLimitsHandshake - (cborTermVersionDataCodec nodeToNodeCodecCBORTerm) - tr (HandshakeCallbacks acceptableVersion queryVersion) + connectToNode sn makeSocketBearer + ConnectToArgs { + ctaHandshakeCodec = nodeToNodeHandshakeCodec, + ctaHandshakeTimeLimits = timeLimitsHandshake, + ctaVersionDataCodec = cborTermVersionDataCodec nodeToNodeCodecCBORTerm, + ctaConnectTracers = tr, + ctaHandshakeCallbacks = HandshakeCallbacks acceptableVersion queryVersion + } + configureOutboundSocket where configureOutboundSocket :: Socket -> IO () configureOutboundSocket sock = do @@ -539,14 +546,16 @@ ipSubscriptionWorker nsErrorPolicyTracer networkState subscriptionParams - (connectToNode' + (void . connectToNode' sn makeSocketBearer - nodeToNodeHandshakeCodec - timeLimitsHandshake - (cborTermVersionDataCodec nodeToNodeCodecCBORTerm) - (NetworkConnectTracers nsMuxTracer nsHandshakeTracer) - (HandshakeCallbacks acceptableVersion queryVersion) + ConnectToArgs { + ctaHandshakeCodec = nodeToNodeHandshakeCodec, + ctaHandshakeTimeLimits = timeLimitsHandshake, + ctaVersionDataCodec = cborTermVersionDataCodec nodeToNodeCodecCBORTerm, + ctaConnectTracers = NetworkConnectTracers nsMuxTracer nsHandshakeTracer, + ctaHandshakeCallbacks = HandshakeCallbacks acceptableVersion queryVersion + } versions) @@ -585,14 +594,16 @@ dnsSubscriptionWorker ndstErrorPolicyTracer networkState subscriptionParams - (connectToNode' + (void . connectToNode' sn makeSocketBearer - nodeToNodeHandshakeCodec - timeLimitsHandshake - (cborTermVersionDataCodec nodeToNodeCodecCBORTerm) - (NetworkConnectTracers ndstMuxTracer ndstHandshakeTracer) - (HandshakeCallbacks acceptableVersion queryVersion) + ConnectToArgs { + ctaHandshakeCodec = nodeToNodeHandshakeCodec, + ctaHandshakeTimeLimits = timeLimitsHandshake, + ctaVersionDataCodec = cborTermVersionDataCodec nodeToNodeCodecCBORTerm, + ctaConnectTracers = NetworkConnectTracers ndstMuxTracer ndstHandshakeTracer, + ctaHandshakeCallbacks = HandshakeCallbacks acceptableVersion queryVersion + } versions)