Back to WebSockets fundamentals in ASP.NET Core - Lifetime & Prematurely closed connections

I've written about WebSockets in ASP.NET Core several times before (touching subjects like subprotocol negotiation, Cross-Site WebSocket Hijacking and per-message compression), but I've never written a post about the fundamentals of managing WebSocket lifetime. In this post I intend to fix that.

The documentation does a pretty good job explaining the key aspects and showing the most basic scenario. Unfortunately real-life scenarios are usually not that simple. In most cases we want to have access to the WebSocket outside of middleware responsible for accepting the requests. This leads to handing the WebSocket off to a dedicated service/manager. While doing this, a couple of rules needs to be followed to ensure to correct lifetime handling.

WebSocket "handing off" middleware flow

The most important thing to remember when writing a middleware for accepting WebSocket requests is that the middleware should not terminate before WebSocket is closed. Allowing the middleware to terminate results in request thread being completed and underlying connection closed.

It's often also worth to design the middleware to be the last in the pipeline. This comes from the fact that middleware is bind to a specific ws:// or wss:// URI, so unless the goal is to handle different protocols under same URI anything else than WebSocket request is an error.

The general flow of middleware which accepts WebSocket request and then hands it off is shown on below diagram.

WebSocket 'handing off' middleware flow

There is a number of ways to implement this flow, so next section should be considered my opinionated approach (here you can read about different one which leverages inheritance).

Opinionated implementation

I will first show the code of the middleware and then discuss its details.

public class WebSocketConnectionsMiddleware
{
    private IWebSocketConnectionsService _connectionsService;

    public WebSocketConnectionsMiddleware(RequestDelegate next,
        IWebSocketConnectionsService connectionsService)
    {
        _connectionsService = connectionsService
                              ?? throw new ArgumentNullException(nameof(connectionsService));
    }

    public async Task Invoke(HttpContext context)
    {
        if (context.WebSockets.IsWebSocketRequest)
        {
                WebSocket webSocket = await context.WebSockets.AcceptWebSocketAsync();

                WebSocketConnection webSocketConnection = new WebSocketConnection(webSocket);

                _connectionsService.AddConnection(webSocketConnection);

                await webSocketConnection.ReceiveMessagesUntilCloseAsync();

                await webSocket.CloseAsync(webSocketConnection.CloseStatus.Value,
                    webSocketConnection.CloseStatusDescription, CancellationToken.None);

                _connectionsService.RemoveConnection(webSocketConnection.Id);
        }
        else
        {
            context.Response.StatusCode = StatusCodes.Status400BadRequest;
        }
    }
}

The code can be easily mapped to the diagram, the important parts are IWebSocketConnectionsService and WebSocketConnection.

There is nothing very specific about IWebSocketConnectionsService implementation. It's supposed to store the connections (in case of my demo project it uses ConcurrentDictionary for that purposes) and expose high level operations which are needed by other parts of application.

The WebSocketConnection, which servers as WebSocket abstraction, is worth taking a closer look at. It's public API needs to bridge two words. It must allow for sending and receiving messages to be usable by other parts of applications. It also needs to provide the receiving loop for the middleware to wait on, together with information required to complete the close handshake.

public class WebSocketConnection
{
    private WebSocket _webSocket;

    public Guid Id => Guid.NewGuid();

    public WebSocketCloseStatus? CloseStatus { get; private set; } = null;

    public string CloseStatusDescription { get; private set; } = null;

    public event EventHandler<string> ReceiveText;

    public event EventHandler<byte[]> ReceiveBinary;

    public WebSocketConnection(WebSocket webSocket)
    {
        _webSocket = webSocket ?? throw new ArgumentNullException(nameof(webSocket));
    }

    public Task SendAsync(string message, CancellationToken cancellationToken)
    {
        ...
    }

    public Task SendAsync(byte[] message, CancellationToken cancellationToken)
    {
        ...
    }

    public async Task ReceiveMessagesUntilCloseAsync()
    {
        ...
    }
}

I will skip details of SendAsync methods (typically they check the WebSocket status and call SendAsync on it). What I will focus on is the receiving loop hiding under ReceiveMessagesUntilCloseAsync. It has two tasks. First is to wait for Close message and when it arrives return control back to the waiting middleware (the middleware will need CloseStatus and CloseStatusDescription from that message to complete the handshake). Second task is triggering ReceiveText or ReceiveBinary event when either Text or Binary message arrives.

public class WebSocketConnection
{
    ...
    private int _receivePayloadBufferSize = 4 * 1024;

    ...

    public async Task ReceiveMessagesUntilCloseAsync()
    {
        byte[] receivePayloadBuffer = new byte[_receivePayloadBufferSize];
        WebSocketReceiveResult webSocketReceiveResult =
            await _webSocket.ReceiveAsync(new ArraySegment<byte>(receivePayloadBuffer),
                CancellationToken.None);

        while (webSocketReceiveResult.MessageType != WebSocketMessageType.Close)
        {
            if (webSocketReceiveResult.MessageType == WebSocketMessageType.Binary)
            {
                byte[] webSocketMessage = await ReceiveMessagePayloadAsync(webSocketReceiveResult,
                    receivePayloadBuffer);
                ReceiveBinary?.Invoke(this, webSocketMessage);
            }
            else
            {
                byte[] webSocketMessage = await ReceiveMessagePayloadAsync(webSocketReceiveResult,
                    receivePayloadBuffer);
                ReceiveText?.Invoke(this, Encoding.UTF8.GetString(webSocketMessage));
            }

            webSocketReceiveResult =
                await _webSocket.ReceiveAsync(new ArraySegment<byte>(receivePayloadBuffer),
                    CancellationToken.None);
        }

        CloseStatus = webSocketReceiveResult.CloseStatus.Value;
        CloseStatusDescription = webSocketReceiveResult.CloseStatusDescription;
    }

    private static async Task<byte[]> ReceiveMessagePayloadAsync(
        WebSocketReceiveResult webSocketReceiveResult, byte[] receivePayloadBuffer)
    {
        byte[] messagePayload = null;

        if (webSocketReceiveResult.EndOfMessage)
        {
            messagePayload = new byte[webSocketReceiveResult.Count];
            Array.Copy(receivePayloadBuffer, messagePayload, webSocketReceiveResult.Count);
        }
        else
        {
            using (MemoryStream messagePayloadStream = new MemoryStream())
            {
                messagePayloadStream.Write(receivePayloadBuffer, 0, webSocketReceiveResult.Count);
                while (!webSocketReceiveResult.EndOfMessage)
                {
                    webSocketReceiveResult =
                        await _webSocket.ReceiveAsync(new ArraySegment<byte>(receivePayloadBuffer),
                            CancellationToken.None);
                    messagePayloadStream.Write(receivePayloadBuffer, 0, webSocketReceiveResult.Count);
                }

                messagePayload = messagePayloadStream.ToArray();
            }
        }

        return messagePayload;
    }
}

This will work nicely, until something unexpected happens...

Prematurely closed connection

There is one unexpected situation which should always be expected - client closing connection prematurely. Most of the time this means that client has crashed (the easiest way to simulate this is killing the browser process when it's connected to the application). The connection being closed prematurely manifests itself through WebSocketException (with WebSocketError.ConnectionClosedPrematurely as WebSocketErrorCode value) thrown from the receiving loop. The WebSocketConnection should handle this exception nicely so the middleware can remove the connection from manager and terminate.

public class WebSocketConnection
{
    ...

    public async Task ReceiveMessagesUntilCloseAsync()
    {
        try
        {
            ...
        }
        catch (WebSocketException wsex)
            when (wsex.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely)
        {
            // Perform some logging
            ...
        }
    }

    ...
}

A small change to the middleware is needed as in this case the close handshake shouldn't be completed (it hasn't been started).

public class WebSocketConnectionsMiddleware
{
    ...

    public async Task Invoke(HttpContext context)
    {
        if (context.WebSockets.IsWebSocketRequest)
        {
                ...

                if (webSocketConnection.CloseStatus.HasValue)
                {
                    await webSocket.CloseAsync(webSocketConnection.CloseStatus.Value,
                        webSocketConnection.CloseStatusDescription, CancellationToken.None);
                }

                _connectionsService.RemoveConnection(webSocketConnection.Id);
        }
        else
        {
            context.Response.StatusCode = StatusCodes.Status400BadRequest;
        }
    }
}

This way the prematurely closed connections will be handled correctly.

The demo project available on GitHub contains a modified version of this code as it also shows all the features described in my previous posts.