diff --git a/src/Network/WebSockets/Simple/AckProtocol.hs b/src/Network/WebSockets/Simple/AckProtocol.hs new file mode 100644 index 0000000..e551c73 --- /dev/null +++ b/src/Network/WebSockets/Simple/AckProtocol.hs @@ -0,0 +1,72 @@ +module Network.WebSockets.Simple.AckProtocol (AckProtocol (..), resendTimedoutEvents) where + +import Control.Concurrent (threadDelay) +import Control.Monad (forM_) +import Control.Monad.IO.Class (MonadIO, liftIO) +import Control.Monad.Reader (asks) +import Data.HashMap.Strict qualified as HashMap +import Data.IORef (atomicModifyIORef', readIORef) +import Data.Time.Clock (addUTCTime, getCurrentTime, secondsToNominalDiffTime) +import GHC.Generics (Generic) +import Network.WebSockets.Simple.Session qualified as Session + +-- inspired by https://socket.io/docs/v4/socket-io-protocol/#exchange-protocol +data AckProtocol message + = Send message + | Event Integer message + | EventAck Integer + deriving (Show, Generic) + +instance + ( MonadIO m, + Session.Codec (AckProtocol send), + Session.Codec (AckProtocol receive), + Session.Codec send, + Session.Codec receive + ) => + Session.SessionProtocol m (AckProtocol send) (AckProtocol receive) + where + send (Send msg) = do + timestamp <- liftIO getCurrentTime + ackProtocol <- asks Session.ackProtocol + id_ <- liftIO $ atomicModifyIORef' ackProtocol $ \(current, hashMap) -> + let next = current + 1 + -- inefficient since we're converting to bytestring twice and on each retry + newHashMap = HashMap.insert next (timestamp, Session.toByteString msg) hashMap + in ((next, newHashMap), next) + Session.send $ Event id_ msg + send (Event _ _) = error "send: unexpected Event message" + send (EventAck _) = error "send: unexpected EventAck message" + + receive = do + msg <- Session.receive + case msg of + EventAck id_ -> do + ackProtocol <- asks Session.ackProtocol + _ <- liftIO $ atomicModifyIORef' ackProtocol $ \(current, hashMap) -> + ((current, HashMap.delete id_ hashMap), ()) + return $ EventAck id_ + Event id_ msg2 -> do + Session.send $ EventAck id_ + return $ Event id_ msg2 + Send _ -> error "receive: unexpected Send message" + +resendTimedoutEvents :: + ( MonadIO m, + Session.Codec (AckProtocol send), + Session.Codec (AckProtocol receive), + Session.Codec send, + Session.Codec receive + ) => + Session.Session m (AckProtocol send) (AckProtocol receive) () +resendTimedoutEvents = do + ackProtocol <- asks Session.ackProtocol + (_, hashMap) <- liftIO $ readIORef ackProtocol + currentTime <- liftIO getCurrentTime + let timedout = HashMap.filter (\(msgTimestamp, _) -> addUTCTime (secondsToNominalDiffTime $ fromIntegral interval) msgTimestamp < currentTime) hashMap + forM_ (HashMap.toList timedout) $ \(id_, (_, msg)) -> + Session.send $ Event id_ $ Session.fromByteString msg + liftIO $ threadDelay (fromIntegral interval * 1000 * 1000) + resendTimedoutEvents + where + interval = 10 diff --git a/src/Network/WebSockets/Simple/Client.hs b/src/Network/WebSockets/Simple/Client.hs index 76945ac..77cb62e 100644 --- a/src/Network/WebSockets/Simple/Client.hs +++ b/src/Network/WebSockets/Simple/Client.hs @@ -8,8 +8,10 @@ module Network.WebSockets.Simple.Client ) where +import Control.Monad (when) import Data.ByteString (ByteString) import Data.ByteString.Char8 (unpack) +import Data.Maybe (isJust) import Network.WebSockets qualified as WS import Network.WebSockets.Connection.PingPong qualified as PingPong import Network.WebSockets.Simple.Session qualified as Session @@ -20,7 +22,8 @@ import Wuss qualified data Options = Options { headers :: WS.Headers, messageLimit :: Int, - staminaSettings :: Stamina.RetrySettings + staminaSettings :: Stamina.RetrySettings, + staminaRetry :: Stamina.RetryStatus -> IO () } defaultOptions :: Options @@ -28,13 +31,16 @@ defaultOptions = Options { headers = [], messageLimit = 10000, - staminaSettings = Stamina.defaults + staminaSettings = Stamina.defaults, + staminaRetry = const $ return () } run :: (Session.Codec send, Session.Codec receive) => ByteString -> Options -> Session.Session IO send receive () -> (receive -> Session.Session IO send receive ()) -> IO () run uriBS options app receiveApp = do (isSecure, host, port, path) <- Utils.parseURI uriBS - Stamina.retry (staminaSettings options) $ \retryStatus -> + Stamina.retry (staminaSettings options) $ \retryStatus -> do + when (isJust $ Stamina.lastException retryStatus) $ + staminaRetry options retryStatus if isSecure then Wuss.runSecureClientWith (unpack host) (fromIntegral port) (unpack path) connectionOptions (headers options) (go retryStatus) else WS.runClientWith (unpack host) (fromIntegral port) (unpack path) connectionOptions (headers options) (go retryStatus) diff --git a/src/Network/WebSockets/Simple/Session.hs b/src/Network/WebSockets/Simple/Session.hs index f0c90bf..f3a3e19 100644 --- a/src/Network/WebSockets/Simple/Session.hs +++ b/src/Network/WebSockets/Simple/Session.hs @@ -3,6 +3,7 @@ module Network.WebSockets.Simple.Session run, Session (..), SessionProtocol (..), + ackProtocol, ) where @@ -13,6 +14,9 @@ import Control.Monad.IO.Class (MonadIO, liftIO) import Control.Monad.IO.Unlift (MonadUnliftIO) import Control.Monad.Reader (MonadReader, ReaderT, asks, runReaderT) import Data.ByteString (ByteString, toStrict) +import Data.HashMap.Strict qualified as HashMap +import Data.IORef (IORef, newIORef) +import Data.Time.Clock (UTCTime) import Network.WebSockets qualified as WS -- Allows decoding from ByteString to any format like JSON or CBOR. @@ -23,7 +27,10 @@ class Codec a where -- State for the session data SessionEnv = SessionEnv { sendChan :: Unagi.InChan ByteString, - receiveChan :: Unagi.OutChan ByteString + receiveChan :: Unagi.OutChan ByteString, + -- TODO: ideally we'd implement a way for each WebsocketMonad instance to specify how env is created + -- maybe order by timestamp? + ackProtocol :: IORef (Integer, HashMap.HashMap Integer (UTCTime, ByteString)) } newtype Session m send receive a = Session (ReaderT SessionEnv m a) @@ -36,7 +43,7 @@ class (MonadIO m, Codec send, Codec receive) => SessionProtocol m send receive w send :: send -> Session m send receive () receive :: Session m send receive receive -instance (MonadIO m, Codec send, Codec receive) => SessionProtocol m send receive where +instance {-# OVERLAPPABLE #-} (MonadIO m, Codec send, Codec receive) => SessionProtocol m send receive where send msg = do sendChanWrite <- asks sendChan liftIO $ Unagi.writeChan sendChanWrite $ toByteString msg @@ -50,7 +57,8 @@ run :: (Codec send, Codec receive) => Int -> WS.Connection -> Session IO send re run limit conn sendApp receiveApp = do (sendChanWrite, sendChanRead) <- liftIO $ Unagi.newChan limit (receiveChanWrite, receiveChanRead) <- liftIO $ Unagi.newChan limit - let clientEnv = SessionEnv sendChanWrite receiveChanRead + ackProtocol <- liftIO $ newIORef (0, HashMap.empty) + let clientEnv = SessionEnv sendChanWrite receiveChanRead ackProtocol -- Use async to queue the send and receive channels in parallel sendAsync <- liftIO $ async $ forever $ do diff --git a/websockets-simple.cabal b/websockets-simple.cabal index aa132c0..210f2b9 100644 --- a/websockets-simple.cabal +++ b/websockets-simple.cabal @@ -27,7 +27,9 @@ common common unliftio-core, bytestring, exceptions, - stamina + stamina, + time, + unordered-containers default-extensions: OverloadedStrings @@ -36,6 +38,7 @@ library exposed-modules: Network.WebSockets.Simple.Server Network.WebSockets.Simple.Client + Network.WebSockets.Simple.AckProtocol other-modules: Network.WebSockets.Simple.Session Network.WebSockets.Simple.Utils