In one of projects which I'm involved in we had to configure the limit for concurrent requests being processed. The project is quite specific, processing of a request is quite heavy and requires considerable resources. Because of that, when there was a high number of concurrent requests to process, the server was becoming unresponsive. I went on some research for best practices in this area. The "stress your server and see" seems to be the best advice but number of possible different approaches got me interested. In order to satisfy my curiosity I've decided to implement couple of them myself.

General remark

Limiting concurrent requests is a server responsibility, there shouldn't be a need for application to handle it. This is one of reasons for suggestion from the ASP.NET Team to run Kestrel behind fully featured web server like IIS or nginx (some limits support has come in 2.0). Despite that here I'm implementing a middleware. This is not that bad, registering such middleware at the beginning of the pipeline should prevent heavy processing early enough. Still I'm doing this just for fun, don't do it in production unless you have an unquestionable reason to do so.

Testing scenario

For the testing purposes I've decided to set up couple integration tests using XUnit and Microsoft.AspNetCore.TestHost. The general setup is very well described in documentation. As I intended to spawn multiple requests I wanted to have some rough timings captured for them, so I've prepared following extension method.

internal class HttpResponseMessageWithTiming
{
    internal HttpResponseMessage Response { get; set; }

    internal TimeSpan Timing { get; set; }
}

internal static class HttpClientExtensions
{
    internal static async Task<HttpResponseMessageWithTiming> GetWithTimingAsync(this HttpClient client,
        string requestUri)
    {
        Stopwatch stopwatch = Stopwatch.StartNew();

        HttpResponseMessage response = await client.GetAsync(requestUri);
        TimeSpan timing = stopwatch.Elapsed;

        stopwatch.Stop();

        return new HttpResponseMessageWithTiming
        {
            Response = response,
            Timing = timing
        };
    }
}

This is not perfect (the timings have risk of not being accurate) but should be good enough. I also wanted to have easy access to status code and timing during debugging so I've introduced another intermediate representation.

private struct HttpResponseInformation
{
    public HttpStatusCode StatusCode { get; set; }

    public TimeSpan Timing { get; set; }

    public override string ToString()
    {
        return $"StatusCode: {StatusCode} | Timing {Timing}";
    }
}

I've also created a prepare SUT method for instantiating TestServer.

private TestServer PrepareTestServer(IEnumerable<KeyValuePair<string, string>> configuration = null)
{
    IWebHostBuilder webHostBuilder = new WebHostBuilder()
        .UseStartup<Startup>();

    if (configuration != null)
    {
        ConfigurationBuilder configurationBuilder = new ConfigurationBuilder();
        configurationBuilder.AddInMemoryCollection(configuration);
        IConfiguration buildedConfiguration = configurationBuilder.Build();

        webHostBuilder.UseConfiguration(buildedConfiguration);
        webHostBuilder.ConfigureServices((services) =>
        {
            services.Configure<MaxConcurrentRequestsOptions>(options =>
                buildedConfiguration.GetSection("MaxConcurrentRequestsOptions").Bind(options)
            );
        });
    }

    return new TestServer(webHostBuilder);
}

The MaxConcurrentRequestsOptions class (empty at this point) will be used for controlling the behavior of the middleware. The Startup looks like this (to simulate long request processing):

public class Startup
{
    public void Configure(IApplicationBuilder app)
    {
        app.Run(async (context) =>
        {
            await Task.Delay(500);

            await context.Response.WriteAsync("-- Demo.AspNetCore.MaxConcurrentConnections --");
        });
    }
}

With all those elements in place I've created a general method to be reused by all tests.

private HttpResponseInformation[] GetResponseInformation(Dictionary<string, string> configuration,
    int concurrentRequestsCount)
{
    HttpResponseInformation[] responseInformation;

    using (TestServer server = PrepareTestServer(configuration))
    {
        List<HttpClient> clients = new List<HttpClient>();
        for (int i = 0; i < concurrentRequestsCount; i++)
        {
            clients.Add(server.CreateClient());
        }

        List<Task<HttpResponseMessageWithTiming>> responsesWithTimingsTasks =
            new List<Task<HttpResponseMessageWithTiming>>();
        foreach (HttpClient client in clients)
        {
            responsesWithTimingsTasks.Add(Task.Run(async () => {
                return await client.GetWithTimingAsync("/");
            }));
        }
        Task.WaitAll(responsesWithTimingsTasks.ToArray());

        clients.ForEach(client => client.Dispose());

        responseInformation = responsesWithTimingsTasks.Select(task => new HttpResponseInformation
        {
            StatusCode = task.Result.Response.StatusCode,
            Timing = task.Result.Timing
        }).ToArray();
    }

    return responseInformation;
}

This (by providing different configuration and concurrentRequestsCount) will allow me to test different approaches I'm going to play with.

Limiting concurrent requests through middleware

The first and most important thing which middleware needs to support is a hard limit. The functionality is very simple in theory, if the application is processing maximum number of concurrent request every incoming request should immediately result in 503 Service Unavailable.

public class MaxConcurrentRequestsMiddleware
{
    private int _concurrentRequestsCount;

    private readonly RequestDelegate _next;
    private readonly MaxConcurrentRequestsOptions _options;

    public MaxConcurrentRequestsMiddleware(RequestDelegate next,
        IOptions<MaxConcurrentRequestsOptions> options)
    {
        _concurrentRequestsCount = 0;

        _next = next ?? throw new ArgumentNullException(nameof(next));
        _options = options?.Value ?? throw new ArgumentNullException(nameof(options));
    }

    public async Task Invoke(HttpContext context)
    {
        if (CheckLimitExceeded())
        {
            IHttpResponseFeature responseFeature = context.Features.Get<IHttpResponseFeature>();

            responseFeature.StatusCode = StatusCodes.Status503ServiceUnavailable;
            responseFeature.ReasonPhrase = "Concurrent request limit exceeded.";
        }
        else
        {
            await _next(context);

            // TODO: Decrement concurrent requests count
        }
    }

    private bool CheckLimitExceeded()
    {
        bool limitExceeded = false;

        // TODO: Check and increment concurrent requests count

        return limitExceeded;
    }
}

The challenge hides in maintaining the count of concurrent requests. It must be incremented and decremented in a thread safe way while affecting performance as little as possible. The Interlocked class with its atomic operations for shared variables seems perfect for the job. I've decided to use the Interlocked.CompareExchange "do-while pattern" for check and increment which should ensure that value will not be exceeded.

public class MaxConcurrentRequestsMiddleware
{
    ...

    public async Task Invoke(HttpContext context)
    {
        if (CheckLimitExceeded())
        {
            ...
        }
        else
        {
            await _next(context);

            Interlocked.Decrement(ref _concurrentRequestsCount);
        }
    }

    private bool CheckLimitExceeded()
    {
        bool limitExceeded;

        int initialConcurrentRequestsCount, incrementedConcurrentRequestsCount;
        do
        {
            limitExceeded = true;

            initialConcurrentRequestsCount = _concurrentRequestsCount;
            if (initialConcurrentRequestsCount >= _options.Limit)
            {
                break;
            }

            limitExceeded = false;
            incrementedConcurrentRequestsCount = initialConcurrentRequestsCount + 1;
        }
        while (initialConcurrentRequestsCount != Interlocked.CompareExchange(
            ref _concurrentRequestsCount, incrementedConcurrentRequestsCount, initialConcurrentRequestsCount));

        return limitExceeded;
    }
}

After adding the middleware to the pipeline I've set up a test with 30 concurrent requests and 10 as MaxConcurrentRequestsOptions.Limit. The result is best observed through responseInformation temporary value.

Visual Studio Locals Window - Content of responseInformation array for hard limit scenario

There are 10 requests which resulted in 200 OK and 20 requests which resulted in 503 Service Unavailable (red frames) - the desired effect.

Queueing additional requests

The hard limit approach gets the job done, but there are more flexible approaches. In general they are based on assumption that the application can cheaply store more requests in memory than it's currently processing and client can wait for the response a little bit longer. Additional requests can wait in queue (typically a FIFO one) for resources to be available. The queue should have a size limit, otherwise the application might end up processing only the queued requests with constantly growing latency.

Because of the size limit requirement this can't be simply implemented with ConcurrentQueue, a custom (again thread safe) solution is needed. The middleware should also be able to await the queue. On a high level a class looking like the one below should provide the desired interface.

private class MaxConcurrentRequestsEnqueuer
{
    private static readonly Task<bool> _enqueueFailedTask = Task.FromResult(false);
    private readonly int _maxQueueLength;

    public MaxConcurrentRequestsEnqueuer(int maxQueueLength)
    {
        _maxQueueLength = maxQueueLength;
    }

    public Task<bool> EnqueueAsync()
    {
        Task<bool> enqueueTask = _enqueueFailedTask;

        // TODO: Check the size and enqueue

        return enqueueTask;
    }

    public bool Dequeue()
    {
        bool dequeued = false;

        // TODO: Dequeue

        return dequeued;
    }
}

The EnqueueAsync returns a Task<bool> so middleware can await and use the Result as indicator if the request should be processed or not.

Internally the class will maintain a Queue<TaskCompletionSource<bool>>.

internal class MaxConcurrentRequestsEnqueuer
{
    ...
    private readonly object _lock = new object();
    private readonly Queue<TaskCompletionSource<bool>> _queue = new Queue<TaskCompletionSource<bool>>();

    public Task<bool> EnqueueAsync()
    {
        Task<bool> enqueueTask = _enqueueFailedTask;

        if (_maxQueueLength > 0)
        {
            lock (_lock)
            {
                if (_queue.Count < _maxQueueLength)
                {
                    TaskCompletionSource<bool> enqueueTaskCompletionSource =
                        new TaskCompletionSource<bool>();

                    _queue.Enqueue(enqueueTaskCompletionSource);

                    enqueueTask = enqueueTaskCompletionSource.Task;
                }
            }
        }

        return enqueueTask;
    }

    public bool Dequeue()
    {
        bool dequeued = false;

        lock (_lock)
        {
            if (_queue.Count > 0)
            {
                _queue.Dequeue().SetResult(true);

                dequeued = true;
            }
        }

        return dequeued;
    }
}

Yes this is locking, not perfect but the request which ends up here is going to wait anyway so it can be considered as acceptable. Some consolation can be the fact that this class can be nicely used from the middleware.

public class MaxConcurrentRequestsMiddleware
{
    ...
    private readonly MaxConcurrentRequestsEnqueuer _enqueuer;

    public MaxConcurrentRequestsMiddleware(RequestDelegate next,
        IOptions<MaxConcurrentRequestsOptions> options)
    {
        ...

        if (_options.LimitExceededPolicy != MaxConcurrentRequestsLimitExceededPolicy.Drop)
        {
            _enqueuer = new MaxConcurrentRequestsEnqueuer(_options.MaxQueueLength);
        }
    }

    public async Task Invoke(HttpContext context)
    {
        if (CheckLimitExceeded() && !(await TryWaitInQueue()))
        {
            ...
        }
        else
        {
            await _next(context);

            if (ShouldDecrementConcurrentRequestsCount())
            {
                Interlocked.Decrement(ref _concurrentRequestsCount);
            }
        }
    }

    ...

    private async Task<bool> TryWaitInQueue()
    {
        return (_enqueuer != null) && (await _enqueuer.EnqueueAsync());
    }

    private bool ShouldDecrementConcurrentRequestsCount()
    {
        return (_enqueuer == null) || !_enqueuer.Dequeue();
    }
}

Below is the result of another test. The number of concurrent request is again 30 and the max number of concurrent requests is 10. The queue size has been also set to 10. The requests which have waited in queue are the ones in orange.

Visual Studio Locals Window - Content of responseInformation array for hard limit with FIFO queue scenario

This approach is called drop tail and has an alternative in form of drop head. In case of drop head when a new requests arrives (and the queue is full) the first request in the queue is being dropped. This can often result in better latency for waiting requests at price of dropped requests having to wait as well.

The MaxConcurrentRequestsEnqueuer can be easily modified to support drop head.

internal class MaxConcurrentRequestsEnqueuer
{
    public enum DropMode
    {
        Tail = MaxConcurrentRequestsLimitExceededPolicy.FifoQueueDropTail,
        Head = MaxConcurrentRequestsLimitExceededPolicy.FifoQueueDropHead
    }

    ...
    private readonly DropMode _dropMode;

    public MaxConcurrentRequestsEnqueuer(int maxQueueLength, DropMode dropMode)
    {
        ...
        _dropMode = dropMode;
    }

    public Task<bool> EnqueueAsync()
    {
        Task<bool> enqueueTask = _enqueueFailedTask;

        if (_maxQueueLength > 0)
        {
            lock (_lock)
            {
                if (_queue.Count < _maxQueueLength)
                {
                    enqueueTask = InternalEnqueueAsync();
                }
                else if (_dropMode == DropMode.Head)
                {
                    _queue.Dequeue().SetResult(false);

                    enqueueTask = InternalEnqueueAsync();
                }
            }
        }

        return enqueueTask;
    }

    ...

    private Task<bool> InternalEnqueueAsync()
    {
        TaskCompletionSource<bool> enqueueTaskCompletionSource = new TaskCompletionSource<bool>();

        _queue.Enqueue(enqueueTaskCompletionSource);

        return enqueueTaskCompletionSource.Task;
    }
}

My test scenario is not sophisticated enough (all requests are arriving roughly at the same time) to show the difference between drop tail and drop head, but this article has some nice comparison.

Further improvements

There are two more things which would be "nice to have" for the queue. First is support for request aborting. Right now if the request gets cancelled by client it will still wait in queue taking spot which one of incoming requests could take.

In order to implement the support for request aborting the internals of MaxConcurrentRequestsEnqueuer needs to be change in a way which enables removing specific request. To achieve that the Queue<TaskCompletionSource<bool>> can no longer be used, it will be replaced with LinkedList<TaskCompletionSource<bool>>. This allows enqueue and dequeue to remain a O(1) operations and adds possibility of O(n) removal. The O(n) is a pessimistic value, in reality the aborted request will be typically at the beginning of the list so it will be found quickly. Only a couple of changes to MaxConcurrentRequestsEnqueuer is needed.

internal class MaxConcurrentRequestsEnqueuer
{
    ...
    private readonly LinkedList<TaskCompletionSource<bool>> _queue =
        new LinkedList<TaskCompletionSource<bool>>();

    ...

    public Task<bool> EnqueueAsync()
    {
        Task<bool> enqueueTask = _enqueueFailedTask;

        if (_maxQueueLength > 0)
        {
            lock (_lock)
            {
                if (_queue.Count < _maxQueueLength)
                {
                    enqueueTask = InternalEnqueueAsync();
                }
                else if (_dropMode == DropMode.Head)
                {
                    InternalDequeue(false);

                    enqueueTask = InternalEnqueueAsync();
                }
            }
        }

        return enqueueTask;
    }

    public bool Dequeue()
    {
        bool dequeued = false;

        lock (_lock)
        {
            if (_queue.Count > 0)
            {
                InternalDequeue(true);

                dequeued = true;
            }
        }

        return dequeued;
    }

    private Task<bool> InternalEnqueueAsync()
    {
        TaskCompletionSource<bool> enqueueTaskCompletionSource = new TaskCompletionSource<bool>();

        _queue.AddLast(enqueueTaskCompletionSource);

        return enqueueTaskCompletionSource.Task;
    }

    private void InternalDequeue(bool result)
    {
        TaskCompletionSource<bool> enqueueTaskCompletionSource = _queue.First.Value;

        _queue.RemoveFirst();

        enqueueTaskCompletionSource.SetResult(result);
    }
}

Now the support for CancellationToken (which will carry the information about request being aborted) can be added.

internal class MaxConcurrentRequestsEnqueuer
{
    ...

    public Task<bool> EnqueueAsync(CancellationToken cancellationToken)
    {
        Task<bool> enqueueTask = _enqueueFailedTask;

        if (_maxQueueLength > 0)
        {
            lock (_lock)
            {
                if (_queue.Count < _maxQueueLength)
                {
                    enqueueTask = InternalEnqueueAsync(cancellationToken);
                }
                else if (_dropMode == DropMode.Head)
                {
                    InternalDequeue(false);

                    enqueueTask = InternalEnqueueAsync(cancellationToken);
                }
            }
        }

        return enqueueTask;
    }

    ...

    private Task<bool> InternalEnqueueAsync(CancellationToken cancellationToken)
    {
        Task<bool> enqueueTask = _enqueueFailedTask;

        TaskCompletionSource <bool> enqueueTaskCompletionSource = new TaskCompletionSource<bool>();

        cancellationToken.Register(CancelEnqueue, enqueueTaskCompletionSource);

        if (!cancellationToken.IsCancellationRequested)
        {
            _queue.AddLast(enqueueTaskCompletionSource);
            enqueueTask = enqueueTaskCompletionSource.Task;
        }

        return enqueueTask;
    }

    private void CancelEnqueue(object state)
    {
        bool removed = false;

        TaskCompletionSource<bool> enqueueTaskCompletionSource = ((TaskCompletionSource<bool>)state);
        lock (_lock)
        {
            removed = _queue.Remove(enqueueTaskCompletionSource);
        }

        if (removed)
        {
            enqueueTaskCompletionSource.SetResult(false);
        }
    }

    ...
}

The middleware can pass HttpContext.RequestAborted as the CancellationToken (it is also a good idea to wrapp the response status setting code in IsCancellationRequested check), which should ensure that all aborted requests will be removed from queue quickly.

Second useful thing is limitation of time which request can spend in queue. This is just additional protection in cases when queue gets stuck and clients doesn't implement their own timeouts. As the cancellation is already supported this can be nicely introduced with help of CancellationTokenSource.CancelAfter and CancellationTokenSource.CreateLinkedTokenSource.

internal class MaxConcurrentRequestsEnqueuer
{
    ...
    private readonly int _maxTimeInQueue;

    public MaxConcurrentRequestsEnqueuer(int maxQueueLength, DropMode dropMode, int maxTimeInQueue)
    {
        ...
        _maxTimeInQueue = maxTimeInQueue;
    }

    ...

    private Task<bool> InternalEnqueueAsync(CancellationToken cancellationToken)
    {
        Task<bool> enqueueTask = _enqueueFailedTask;

        TaskCompletionSource <bool> enqueueTaskCompletionSource = new TaskCompletionSource<bool>();

        CancellationToken enqueueCancellationToken =
            GetEnqueueCancellationToken(enqueueTaskCompletionSource, cancellationToken);

        if (!enqueueCancellationToken.IsCancellationRequested)
        {
            _queue.AddLast(enqueueTaskCompletionSource);
            enqueueTask = enqueueTaskCompletionSource.Task;
        }

        return enqueueTask;
    }

    private CancellationToken GetEnqueueCancellationToken(
        TaskCompletionSource<bool> enqueueTaskCompletionSource, CancellationToken cancellationToken)
    {
        CancellationToken enqueueCancellationToken = CancellationTokenSource.CreateLinkedTokenSource(
            cancellationToken,
            GetTimeoutToken(enqueueTaskCompletionSource)
        ).Token;

        enqueueCancellationToken.Register(CancelEnqueue, enqueueTaskCompletionSource);

        return enqueueCancellationToken;
    }

    private CancellationToken GetTimeoutToken(TaskCompletionSource<bool> enqueueTaskCompletionSource)
    {
        CancellationToken timeoutToken = CancellationToken.None;

        if (_maxTimeInQueue != MaxConcurrentRequestsOptions.MaxTimeInQueueUnlimited)
        {
            CancellationTokenSource timeoutTokenSource = new CancellationTokenSource();

            timeoutToken = timeoutTokenSource.Token;
            timeoutToken.Register(CancelEnqueue, enqueueTaskCompletionSource);

            timeoutTokenSource.CancelAfter(_maxTimeInQueue);
        }

        return timeoutToken;
    }

    ...
}

The timeout can be clearly shown through test by setting max time in queue to value lower than processing time (for example 300ms).

Visual Studio Locals Window - Content of responseInformation array for hard limit with FIFO queue and max time in queue scenario

As expected there are two groups of dropped requests, the ones which were dropped immediately and ones which were dropped after max time in queue.

There is more

This (quite long) post doesn't fully explore the subject. There are other approaches, for example based on LIFO queue. If this post got you interested in the subject I suggest further research.

More "polished" version of MaxConcurrentRequestsEnqueuer and MaxConcurrentRequestsMiddleware with demo project and tests can be found on GitHub.

I've received a question from a user of my Server-Sent Events Middleware. In general the user was asking if it can be used together with Response Compression Middleware because he had hard time making it work. As there shouldn't be any technical limitation to such scenario I've decided to quickly test it in my simple demo.

public class Startup
{
    public void ConfigureServices(IServiceCollection services)
    {
        services.AddResponseCompression(options =>
        {
            options.MimeTypes = ResponseCompressionDefaults.MimeTypes.Concat(new[]
            {
                "text/event-stream"
            });
        });

        services.AddServerSentEvents();
        ...
    }

    public void Configure(IApplicationBuilder app)
    {
        app.UseResponseCompression()
            .MapServerSentEvents("/see-heartbeat")
            ...;

        ...
    }
}

It worked without any issues.

Chrome Developer Tools Network Tab - Server-Sent Events with gzip

I've quickly written a response including my snippet. Unfortunately the user was doing exactly same thing and while for me it was working perfectly for him it seemed to be doing nothing (there was no error or any other side effect, just events not arriving to the client). We have exchanged couple emails until we have finally discovered a difference in our scenarios - my target framework was netcoreapp1.0 while his was net451. I've switched my project to net451 and I could observe the same behavior.

Difference between .NET Framework and .NET Core

I've started looking for the root cause. It was obviously somehow related to the response compression but I didn't had an idea what could that be. Until I've found below fragment inside of GzipCompressionProvider.

public bool SupportsFlush
{
    get
    {
#if NET451
        return false;
#elif NETSTANDARD1_3
        return true;
#else
        // Not implemented, compiler break
#endif
    }
}

This clearly shows that GZipStream (which is being internally used by GzipCompressionProvider) doesn't support flushing in .NET Framework (and my Server Sent Events implementation is flushing the events when they are completely written). This difference is not documented, in fact the documentation states that GZipStream.Flush has not functionality regardless of implementation. I was able to find this issue which shed some light on how GZipStream has started actually flushing. The bottom line is that when used over .NET Framework the Response Compression Middleware is also buffering the response.

Response Buffering in ASP.NET Core

In ASP.NET Core component which provides response buffering capabilities should implement IHttpBufferingFeature. The Response Compression Middleware does it through BodyWrapperStream in which it wraps the original body stream. This means that it should be possible to disable the buffering with following code.

private void DisableResponseBuffering(HttpContext context)
{
    IHttpBufferingFeature bufferingFeature = context.Features.Get<IHttpBufferingFeature>();
    if (bufferingFeature != null)
    {
        bufferingFeature.DisableResponseBuffering();
    }
}

I've added this to ServerSentEventsMiddleware.

public class ServerSentEventsMiddleware
{
    ...

    public async Task Invoke(HttpContext context)
    {
        if (context.Request.Headers[Constants.ACCEPT_HTTP_HEADER] == Constants.SSE_CONTENT_TYPE)
        {
            DisableResponseBuffering(context);

            ...
        }
        else
        {
            await _next(context);
        }
    }

    ...
}

After this change running the demo application with net451 as target framework resulted in events correctly reaching the client. The difference was that the response wasn't compressed.

Chrome Developer Tools Network Tab - Server-Sent Events no compression for .NET Framework

This is how Response Compression Middleware handles the DisableResponseBuffering method. If compression provider which is supposed to be used doesn't support flushing (the SupportsFlush property above) it disables the compression.

This is an interesting and worth to remember difference in behavior between .NET Framework and .NET Core.

This is my third post about WebSocket protocol in ASP.NET Core. Previously I've written about subprotocol negotiation and Cross-Site WebSocket Hijacking. The subject I'm focusing on here is the per-message compression which is out of the box supported by Chrome, FireFox and other browsers.

WebSocket Extensions

The WebSocket protocol has a concept of extensions, which can provide new capabilities. An extension can define any additional functionality which is able to work on top of the WebSocket framing layer. The specification reserves three bits of header (RSV1, RSV2 and RSV3) and all opcodes from 3 to 7 and 11 to 15 to be used by extensions (it also allows for using the reserved bits in order to create additional opcodes or even using some of the payload data for that purpose). The extensions (similarly to subprotocol) are being negotiated through dedicated header (Sec-WebSocket-Extensions) as part of the handshake. A client can advertise the supported extensions by putting the list into the header and server can accept one or more in the exactly same way.

There are two extensions I have heard of: A Multiplexing Extension for WebSockets and Compression Extensions for WebSocket. The first one has never gone beyond draft but the second has become a standard and got adopted by several browsers.

WebSocket Per-Message Compression Extensions

The Compression Extensions for WebSocket standard defines two things. First is a framework for adding compression functionality to the WebSocket protocol. The framework is really simple, it states only two things:

  • Per-Message Compression Extension operates only on message data (so compression takes place before spliting into frames and decompression takes place after all frames have been received).
  • Per-Message Compression Extension allocates the RSV1 bit and calls it the Per-Message Compressed bit. The bit is supposed to be set to 1 on first frame of compressed message.

The challenging part is the allocation of the RSV1 bit. It makes it impossible to implement support for per-message compression on top of the WebSocket stack available in ASP.NET Core. Because of that I've decided to roll my own implementation for IHttpWebSocketFeature. It is very similar to one provided by Microsoft.AspNetCore.WebSockets and the underlying WebSocket implementation is based on ManagedWebSocket so closely that needed changes can be described in its context (the key difference is that my implementation is stripped from the client specific logic as it is not needed).

From the public API perspective there must be way to set and get the information that message is compressed. First can be achieved with overload of SendAsync method (or more like extending the current SendAsync with one more parameter and providing overload which doesn't need it).

internal class CompressionWebSocket : WebSocket
{
    public override Task SendAsync(ArraySegment<byte> buffer, WebSocketMessageType messageType,
        bool endOfMessage, CancellationToken cancellationToken)
    {
        return SendAsync(buffer, messageType, false, endOfMessage, cancellationToken);
    }

    public Task SendAsync(ArraySegment<byte> buffer, WebSocketMessageType messageType, bool compressed,
        bool endOfMessage, CancellationToken cancellationToken)
    {
        ...;
    }
}

The information about a received message can be exposed through a delivered WebSocketReceiveResult.

public class CompressionWebSocketReceiveResult : WebSocketReceiveResult
{
    public bool Compressed { get; }

    public CompressionWebSocketReceiveResult(int count, WebSocketMessageType messageType,
        bool compressed, bool endOfMessage)
        : base(count, messageType, endOfMessage)
    {
        Compressed = compressed;
    }

    public CompressionWebSocketReceiveResult(int count, WebSocketMessageType messageType,
        bool endOfMessage, WebSocketCloseStatus? closeStatus, string closeStatusDescription)
        : base(count, messageType, endOfMessage, closeStatus, closeStatusDescription)
    {
        Compressed = false;
    }
}

Next step is adjusting the internals of the WebScoket implementation to properly write and read the RSV1 bit. The writing part is being handled by the WriteHeader method. This method needs to be changed in a way that it sets the RSV1 bit when the messages is compressed and current frame is not a continuation.

private static int WriteHeader(WebSocketMessageOpcode opcode, byte[] sendBuffer,
    ArraySegment<byte> payload, bool compressed, bool endOfMessage)
{
    sendBuffer[0] = (byte)opcode;

    if (compressed && (opcode != WebSocketMessageOpcode.Continuation))
    {
        sendBuffer[0] |= 0x40;
    }

    if (endOfMessage)
    {
        sendBuffer[0] |= 0x80;
    }

    ...
}

After this change all the paths leading to WriteHeader method must be changed to either provide (passed down) value of compressed parameter from SendAsync or false.

The receiving flow has a corresponding method TryParseMessageHeaderFromReceiveBuffer which fills out a MessageHeader struct. A different version of that struct is needed.

[StructLayout(LayoutKind.Auto)]
internal struct CompressionWebSocketMessageHeader
{
    internal WebSocketMessageOpcode Opcode { get; set; }

    internal bool Compressed { get; set; }

    internal bool Fin { get; set; }

    internal long PayloadLength { get; set; }

    internal int Mask { get; set; }
}

The TryParseMessageHeaderFromReceiveBuffer method will require two changes. One will take care of reading the RSV1 bit and the second will change the validation of all RSV bits values (per protocol specification invalid combination of RSV bits must fail the connection).

private bool TryParseMessageHeaderFromReceiveBuffer(out CompressionWebSocketMessageHeader resultHeader)
{
    var header = new CompressionWebSocketMessageHeader();

    header.Opcode = (WebSocketMessageOpcode)(_receiveBuffer[_receiveBufferOffset] & 0xF);
    header.Compressed = (_receiveBuffer[_receiveBufferOffset] & 0x40) != 0;
    header.Fin = (_receiveBuffer[_receiveBufferOffset] & 0x80) != 0;

    bool reservedSet = (_receiveBuffer[_receiveBufferOffset] & 0x70) != 0;
    bool reservedExceptCompressedSet = (_receiveBuffer[_receiveBufferOffset] & 0x30) != 0;

    ...

    bool shouldFail = (!header.Compressed && reservedSet) || reservedExceptCompressedSet;

    ...
}

The last step is to modify InternalReceiveAsync method so it skips UTF-8 validation for compressed messages and properly creates CompressionWebSocketReceiveResult.

private async Task<WebSocketReceiveResult> InternalReceiveAsync(ArraySegment<byte> payloadBuffer,
    CancellationToken cancellationToken)
{
    ...

    try
    {
        while (true)
        {
            ...

            if ((header.Opcode == WebSocketMessageOpcode.Text) && !header.Compressed
                && !TryValidateUtf8(
                    new ArraySegment<byte>(payloadBuffer.Array, payloadBuffer.Offset, bytesToCopy),
                    header.Fin, _utf8TextState))
            {
                await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.InvalidPayloadData,
                    WebSocketError.Faulted, cancellationToken).ConfigureAwait(false);
            }

            _lastReceiveHeader = header;
            return new CompressionWebSocketReceiveResult(
                bytesToCopy,
                header.Opcode == WebSocketMessageOpcode.Text ?
                    WebSocketMessageType.Text : WebSocketMessageType.Binary,
                header.Compressed,
                bytesToCopy == 0 || (header.Fin && header.PayloadLength == 0));
        }
    }
    catch (Exception ex)
    {
        ...
    }
    finally
    {
        ...
    }
}

With those changes in place the WebSocket implementation has support for per-message compression framework. A support for specific compression extension can be implemented on top of that.

Deflate based PMCE

The second thing which Compression Extensions for WebSocket standard defines is permessage-deflate compression extension. This extension specifies a way of compressesing message payload using the DEFLATE algorithm with help of the byte boundary alignment method. But first it is worth to implement concepts which are shared by all potential compression extensions - receiving and sending the message payload. Methods responsible for handling those operations should be able to properly concatenate or split the message into frames.

public abstract class WebSocketCompressionProviderBase
{
    private readonly int? _sendSegmentSize;

    ...

    protected async Task SendMessageAsync(WebSocket webSocket, byte[] message,
        WebSocketMessageType messageType, bool compressed, CancellationToken cancellationToken)
    {
        if (webSocket.State == WebSocketState.Open)
        {
            if (_sendSegmentSize.HasValue && (_sendSegmentSize.Value < message.Length))
            {
                int messageOffset = 0;
                int messageBytesToSend = message.Length;

                while (messageBytesToSend > 0)
                {
                    int messageSegmentSize = Math.Min(_sendSegmentSize.Value, messageBytesToSend);
                    ArraySegment<byte> messageSegment = new ArraySegment<byte>(message, messageOffset,
                        messageSegmentSize);

                    messageOffset += messageSegmentSize;
                    messageBytesToSend -= messageSegmentSize;

                    await SendAsync(webSocket, messageSegment, messageType, compressed,
                        (messageBytesToSend == 0), cancellationToken);
                }
            }
            else
            {
                ArraySegment<byte> messageSegment = new ArraySegment<byte>(message, 0, message.Length);

                await SendAsync(webSocket, messageSegment, messageType, compressed, true,
                    cancellationToken);
            }
        }
    }

    private Task SendAsync(WebSocket webSocket, ArraySegment<byte> messageSegment,
        WebSocketMessageType messageType, bool compressed, bool endOfMessage,
        CancellationToken cancellationToken)
    {
        if (compressed)
        {
            CompressionWebSocket compressionWebSocket = (webSocket as CompressionWebSocket)
            ?? throw new InvalidOperationException($"Used WebSocket must be CompressionWebSocket.");

            return compressionWebSocket.SendAsync(messageSegment, messageType, true, endOfMessage,
                cancellationToken);
        }
        else
        {
            return webSocket.SendAsync(messageSegment, messageType, endOfMessage, cancellationToken);
        }
    }

    protected async Task<byte[]> ReceiveMessagePayloadAsync(WebSocket webSocket,
        WebSocketReceiveResult webSocketReceiveResult, byte[] receivePayloadBuffer)
    {
        byte[] messagePayload = null;

        if (webSocketReceiveResult.EndOfMessage)
        {
            messagePayload = new byte[webSocketReceiveResult.Count];
            Array.Copy(receivePayloadBuffer, messagePayload, webSocketReceiveResult.Count);
        }
        else
        {
            IEnumerable<byte> webSocketReceivedBytesEnumerable = Enumerable.Empty<byte>();
            webSocketReceivedBytesEnumerable = webSocketReceivedBytesEnumerable
                .Concat(receivePayloadBuffer);

            while (!webSocketReceiveResult.EndOfMessage)
            {
                webSocketReceiveResult = await webSocket.ReceiveAsync(
                    new ArraySegment<byte>(receivePayloadBuffer), CancellationToken.None);
                webSocketReceivedBytesEnumerable = webSocketReceivedBytesEnumerable
                    .Concat(receivePayloadBuffer.Take(webSocketReceiveResult.Count));
            }

            messagePayload = webSocketReceivedBytesEnumerable.ToArray();
        }

        return messagePayload;
    }
}

With this base the permessage-deflate specifics can be implemented. Let's start with the byte boundary alignment method. In practice it boils down to two operations:

  • In case of compression operation the compressed data should end with empty deflate block and last four octets of that block removed.
  • In case of decompression operation last four octets of empty deflate block should be appended to the received payload before decompression.

It looks that in case of compressing with DeflateStream provided by .NET the empty deflate block is always there, so the above can be implemented with two helper methods.

public sealed class WebSocketDeflateCompressionProvider : WebSocketCompressionProviderBase
{
    private static readonly byte[] LAST_FOUR_OCTETS = new byte[] { 0x00, 0x00, 0xFF, 0xFF };
    ...

    private byte[] TrimLastFourOctetsOfEmptyNonCompressedDeflateBlock(byte[] compressedMessagePayload)
    {
        int lastFourOctetsOfEmptyNonCompressedDeflateBlockPosition = 0;
        for (int position = compressedMessagePayload.Length - 1; position >= 4; position--)
        {
            if ((compressedMessagePayload[position - 3] == LAST_FOUR_OCTETS[0])
                && (compressedMessagePayload[position - 2] == LAST_FOUR_OCTETS[1])
                && (compressedMessagePayload[position - 1] == LAST_FOUR_OCTETS[2])
                && (compressedMessagePayload[position] == LAST_FOUR_OCTETS[3]))
            {
                lastFourOctetsOfEmptyNonCompressedDeflateBlockPosition = position - 3;
                break;
            }
        }
        Array.Resize(ref compressedMessagePayload, lastFourOctetsOfEmptyNonCompressedDeflateBlockPosition);

        return compressedMessagePayload;
    }

    private byte[] AppendLastFourOctetsOfEmptyNonCompressedDeflateBlock(byte[] compressedMessagePayload)
    {
        Array.Resize(ref compressedMessagePayload, compressedMessagePayload.Length + 4);

        compressedMessagePayload[compressedMessagePayload.Length - 4] = LAST_FOUR_OCTETS[0];
        compressedMessagePayload[compressedMessagePayload.Length - 3] = LAST_FOUR_OCTETS[1];
        compressedMessagePayload[compressedMessagePayload.Length - 2] = LAST_FOUR_OCTETS[2];
        compressedMessagePayload[compressedMessagePayload.Length - 1] = LAST_FOUR_OCTETS[3];

        return compressedMessagePayload;
    }
}

The second set of helper methods which will be needed is the actual compression and decompression. For simplicity purposes only text messages will be considered from this point.

public sealed class WebSocketDeflateCompressionProvider : WebSocketCompressionProviderBase
{
    ...
    private static readonly Encoding UTF8_WITHOUT_BOM = new UTF8Encoding(false);

    ...

    private async Task<byte[]> CompressTextWithDeflateAsync(string message)
    {
        byte[] compressedMessagePayload = null;

        using (MemoryStream compressedMessagePayloadStream = new MemoryStream())
        {
            using (DeflateStream compressedMessagePayloadCompressStream =
                new DeflateStream(compressedMessagePayloadStream, CompressionMode.Compress))
            {
                using (StreamWriter compressedMessagePayloadCompressWriter =
                    new StreamWriter(compressedMessagePayloadCompressStream, UTF8_WITHOUT_BOM))
                {
                    await compressedMessagePayloadCompressWriter.WriteAsync(message);
                }
            }

            compressedMessagePayload = compressedMessagePayloadStream.ToArray();
        }

        return compressedMessagePayload;
    }

    private async Task<string> DecompressTextWithDeflateAsync(byte[] compressedMessagePayload)
    {
        string message = null;

        using (MemoryStream compressedMessagePayloadStream = new MemoryStream(compressedMessagePayload))
        {
            using (DeflateStream compressedMessagePayloadDecompressStream =
                new DeflateStream(compressedMessagePayloadStream, CompressionMode.Decompress))
            {
                using (StreamReader compressedMessagePayloadDecompressReader =
                    new StreamReader(compressedMessagePayloadDecompressStream, UTF8_WITHOUT_BOM))
                {
                    message = await compressedMessagePayloadDecompressReader.ReadToEndAsync();
                }
            }
        }

        return message;
    }
}

Now the public API can be exposed.

public interface IWebSocketCompressionProvider
{
    Task CompressTextMessageAsync(WebSocket webSocket, string message,
        CancellationToken cancellationToken);

    Task<string> DecompressTextMessageAsync(WebSocket webSocket,
        WebSocketReceiveResult webSocketReceiveResult, byte[] receivePayloadBuffer);
}

public sealed class WebSocketDeflateCompressionProvider :
    WebSocketCompressionProviderBase, IWebSocketCompressionProvider
{
    ...

    public override async Task CompressTextMessageAsync(WebSocket webSocket, string message,
        CancellationToken cancellationToken)
    {
        byte[] compressedMessagePayload = await CompressTextWithDeflateAsync(message);

        compressedMessagePayload =
            TrimLastFourOctetsOfEmptyNonCompressedDeflateBlock(compressedMessagePayload);

        await SendMessageAsync(webSocket, compressedMessagePayload, WebSocketMessageType.Text, true,
            cancellationToken);
    }

    public override async Task<string> DecompressTextMessageAsync(WebSocket webSocket,
        WebSocketReceiveResult webSocketReceiveResult, byte[] receivePayloadBuffer)
    {
        string message = null;

        CompressionWebSocketReceiveResult compressionWebSocketReceiveResult =
            webSocketReceiveResult as CompressionWebSocketReceiveResult;

        if ((compressionWebSocketReceiveResult != null) && compressionWebSocketReceiveResult.Compressed)
        {
            byte[] compressedMessagePayload =
                await ReceiveMessagePayloadAsync(webSocket, webSocketReceiveResult, receivePayloadBuffer);

            compressedMessagePayload =
                AppendLastFourOctetsOfEmptyNonCompressedDeflateBlock(compressedMessagePayload);

            message = await DecompressTextWithDeflateAsync(compressedMessagePayload);
        }
        else
        {
            byte[] messagePayload =
                await ReceiveMessagePayloadAsync(webSocket, webSocketReceiveResult, receivePayloadBuffer);

            message = Encoding.UTF8.GetString(messagePayload);
        }

        return message;
    }

    ...
}

This API makes it easy to plug in a compression provider into typical (SendAsync and ReceiveAsync based) flow for WebSocket by replacing calls to SendAsync with calls to CompressTextMessageAsync and calling DecompressTextMessageAsync whenever the WebSocketReceiveResult acquired from ReceiveAsync indicates a text message. But before this can be done the permessage-deflate extension must be properly negotiated.

Context takeover and LZ77 window size

An important part of compression extension negotiation are parameters. The permessage-deflate defines four of them: server_no_context_takeover, client_no_context_takeover, server_max_window_bits and client_max_window_bits. First two define if server and/or client can reuse the same context (LZ77 sliding window) for subsequent messages. The remaining two allow for limiting the LZ77 sliding window size. The sad truth is that the above implementation is not able to handle most of this parameters properly, so the negotiation process needs to make sure that the acceptable values are beign used or negotiation fails (failing of negotiation doesn't fail the connection, the handshake response simply doesn't contain the extension as accepted one). So what are the acceptable values for this implementation?

The DeflateStream doesn't provide control over LZ77 sliding window size which means that the negotiation must be failed if offer contains server_max_window_bits parameter as it can't be handled. At the same time the presence of client_max_window_bits should be ignored as this is just a hint that client can support this parameter.

When it comes to the context reuse the above implementation creates new DeflateStream for every message which means it always work in "no context takeover" mode. Because of that the negiotation response must prevent client from reusing the context - the client_no_context_takeover must always be included in the response. This also means that server_no_context_takeover send by client in the offer can always be accepted.

I'm skipping the code which handles the negotiation here. Despite being based on NameValueWithParametersHeaderValue class which handles all the parsing it is still quite lengthy (it also must validate the parameters for "technical correctness"). For anyone who is interested the implementation is split between WebSocketCompressionService and WebSocketDeflateCompressionOptions classes which can be found at GitHub (link below).

Trying this out

There was an issue in Microsoft.AspNetCore.WebSockets repository for implementing per-message compression. It's currently closed, so I've made this implementation available independently on GitHub and NuGet. It is also part of my WebSocket demo project if somebody is looking for ready to use playground.

In my previous post I've written about subprotocols in WebSocket protocol. This time I wanted to focus on Cross-Site WebSocket Hijacking vulnerability.

Cross-Site WebSocket Hijacking

The WebSocket protocol is not a subject to same-origin policy. The specification states that "Servers that are not intended to process input from any web page but only for certain sites SHOULD verify the |Origin| field is an origin they expect.". This means that browser will allow any page to open a WebSocket connection.

Let's imagine a scenario in which the application is sending sensitive data over a WebSocket and authentication is based on cookies which are being send as a part of the initial handshake. In such case if user visits a malicious page while being logged to the application that page can open an authenticated WebSocket connection because browser will automatically send all the cookies. This is quite common and (if not protected) dangerous scenario. There are also more "interesting" scenarios possible like this case of remote code execution.

Protecting against CSWSH

Protection against CSWSH is easy to implement. As the Origin header is required part of initial handshake the application should check its value against list of acceptable origins and if it's not there respond with 403 Forbidden status code. The sample from my previous post was using WebSocketConnectionsMiddleware for handling the connections which makes it a perfect place to add this check.

public class WebSocketConnectionsOptions
{
    public HashSet<string> AllowedOrigins { get; set; }
}

public class WebSocketConnectionsMiddleware
{
    private WebSocketConnectionsOptions _options;
    ...

    public WebSocketConnectionsMiddleware(RequestDelegate next, WebSocketConnectionsOptions options, ...)
    {
        _options = options ?? throw new ArgumentNullException(nameof(options));
        ...
    }

    public async Task Invoke(HttpContext context)
    {
        if (context.WebSockets.IsWebSocketRequest)
        {
            if (ValidateOrigin(context))
            {
                ...
            }
            else
            {
                context.Response.StatusCode = StatusCodes.Status403Forbidden;
            }
        }
        else
        {
            context.Response.StatusCode = StatusCodes.Status400BadRequest;
        }
    }

    private bool ValidateOrigin(HttpContext context)
    {
        return (_options.AllowedOrigins == null)
            || (_options.AllowedOrigins.Count == 0)
            || (_options.AllowedOrigins.Contains(context.Request.Headers["Origin"].ToString()));
    }

    ...
}

Now the list of acceptable origins can be passed during the middleware registration.

public class Startup
{
    ...

    public void Configure(IApplicationBuilder app)
    {
        ...

        WebSocketConnectionsOptions webSocketConnectionsOptions = new WebSocketConnectionsOptions
        {
            AllowedOrigins = new HashSet<string> { "http://localhost:63290" }
        };

        ...
        app.UseWebSockets();
        app.Map("/socket", branchedApp =>
        {
            branchedApp.UseMiddleware<WebSocketConnectionsMiddleware>(webSocketConnectionsOptions);
        });
        ...
    }
}

I've extended the demo available at GitHub with this functionality.

WebSocket is the closest API to a network socket available in browser. This makes it probably the most flexible transport which a web application can use. That flexibility comes at price. From WebSocket perspective the message content is opaque, it only provides distinction between text and binary data. There is also no ready to use mechanism for communicating additional metadata. This means that client and server must agree on application subprotocol. This isn't something problematic as long as the scenario is simple, but the moment there are clients which are not in our control and we want to evolve the subprotocol a problem rises. WebSocket provides a solution for this problem in form of simple subprotocol negotiation mechanism and the Microsoft.AspNetCore.WebSockets package, which provides low-level WebSocket support for ASP.NET Core, fully supports it.

Sample scenario

The sample scenario will be a very simple web application, which regularly receives plain text messages over WebSocket and displays them to users. The relevant part of client side code is the below snippet.

var handleWebSocketPlainTextData = function(data) {
    ...
};

var webSocket = new WebSocket('ws://example.com/socket');

webSocket.onmessage = function(message) {
    handleWebSocketPlainTextData(message.data);
};

On server side there is simple middleware which manages WebSocket connections.

public class WebSocketConnectionsMiddleware
{
    private IWebSocketConnectionsService _connectionsService;


    public WebSocketSubprotocolsMiddleware(RequestDelegate next,
        IWebSocketConnectionsService connectionsService)
    {
        _connectionsService = connectionsService ??
            throw new ArgumentNullException(nameof(connectionsService));
    }

    public async Task Invoke(HttpContext context)
    {
        if (context.WebSockets.IsWebSocketRequest)
        {
            WebSocket webSocket = await context.WebSockets.AcceptWebSocketAsync();

            WebSocketConnection webSocketConnection = new WebSocketConnection(webSocket);

            _connectionsService.AddConnection(webSocketConnection);

            byte[] webSocketBuffer = new byte[1024 * 4];
            WebSocketReceiveResult webSocketReceiveResult = await webSocket.ReceiveAsync(
                new ArraySegment<byte>(webSocketBuffer), CancellationToken.None);
            if (webSocketReceiveResult.MessageType != WebSocketMessageType.Close)
            {
                ...
            }
            await webSocket.CloseAsync(webSocketReceiveResult.CloseStatus.Value,
                webSocketReceiveResult.CloseStatusDescription, CancellationToken.None);

            _connectionsService.RemoveConnection(webSocketConnection.Id);
        }
        else
        {
            context.Response.StatusCode = 400;
        }
    }
}

The IWebSocketConnectionsService implementation is managing connections with help of ConcurrentDictionary and WebSocketConnection is a wrapper around WebSocket class which abstracts the low-level aspects of the API.

public class WebSocketConnection
{
    private WebSocket _webSocket;

    public Guid Id => Guid.NewGuid();

    public WebSocketConnection(WebSocket webSocket)
    {
        _webSocket = webSocket ?? throw new ArgumentNullException(nameof(webSocket));
    }

    public async Task SendAsync(string message, CancellationToken cancellationToken)
    {
        if (_webSocket.State == WebSocketState.Open)
        {
            ArraySegment<byte> buffer = new ArraySegment<byte>(Encoding.ASCII.GetBytes(message),
                0, message.Length);

            await _webSocket.SendAsync(buffer, WebSocketMessageType.Text, true, cancellationToken);
        }
    }

    ...
}

The goal is to introduce new (JSON based) subprotocol which will allow sending additional metadata, but the backward compatibility is also required.

Abstracting the subprotocol

First an abstraction of subprotocol is needed. The abstraction needs to provide the name of the subprotocol and methods for sending/receiving. In general the application will still be sending text based messages so following interface should be sufficient.

public interface ITextWebSocketSubprotocol
{
    string SubProtocol { get; }

    Task SendAsync(string message, WebSocket webSocket, CancellationToken cancellationToken);

    ...
}

The implementation for the plain text version can be extracted from WebSocketConnection.

public class PlainTextWebSocketSubprotocol : ITextWebSocketSubprotocol
{
    public string SubProtocol => "aspnetcore-ws.plaintext";

    public async Task SendAsync(string message, WebSocket webSocket,
        CancellationToken cancellationToken)
    {
        if (webSocket.State == WebSocketState.Open)
        {
            ArraySegment<byte> buffer = new ArraySegment<byte>(Encoding.ASCII.GetBytes(message),
                0, message.Length);

            await webSocket.SendAsync(buffer, WebSocketMessageType.Text, true, cancellationToken);
        }
    }

    ...
}

This means that WebSocketConnection should now be dependent on the subprotocol abstraction.

public class WebSocketConnection
{
    private WebSocket _webSocket;
    private ITextWebSocketSubprotocol _subProtocol;

    public Guid Id => Guid.NewGuid();

    public WebSocketConnection(WebSocket webSocket, ITextWebSocketSubprotocol subProtocol)
    {
        _webSocket = webSocket ?? throw new ArgumentNullException(nameof(webSocket));
        _subProtocol = subProtocol ?? throw new ArgumentNullException(nameof(subProtocol));
    }

    public Task SendAsync(string message, CancellationToken cancellationToken)
    {
        return _subProtocol.SendAsync(message, _webSocket, cancellationToken);
    }

    ...
}

Also a small adjustion to the middleware is needed.

public class WebSocketConnectionsMiddleware
{
    private readonly ITextWebSocketSubprotocol _defaultSubProtocol;
    private IWebSocketConnectionsService _connectionsService;


    public WebSocketSubprotocolsMiddleware(RequestDelegate next,
        IWebSocketConnectionsService connectionsService)
    {
        _defaultSubProtocol = new PlainTextWebSocketSubprotocol();
        _connectionsService = connectionsService ??
            throw new ArgumentNullException(nameof(connectionsService));
    }

    public async Task Invoke(HttpContext context)
    {
        if (context.WebSockets.IsWebSocketRequest)
        {
            WebSocket webSocket = await context.WebSockets.AcceptWebSocketAsync();

            WebSocketConnection webSocketConnection = new WebSocketConnection(webSocket,
                _defaultSubProtocol);

            ...
        }
        else
        {
            context.Response.StatusCode = 400;
        }
    }
}

Now the infrastructure needed for introducing a second subprotocol is in place. It will be a JSON based subprotocol which in addition to the message provides a timestamp.

public class JsonWebSocketSubprotocol : ITextWebSocketSubprotocol
{
    public string SubProtocol => "aspnetcore-ws.json";

    public async Task SendAsync(string message, WebSocket webSocket,
        CancellationToken cancellationToken)
    {
        if (webSocket.State == WebSocketState.Open)
        {
            string jsonMessage = JsonConvert.SerializeObject(new {
                message,
                timestamp = DateTime.UtcNow
            });

            ArraySegment<byte> buffer = new ArraySegment<byte>(Encoding.ASCII.GetBytes(jsonMessage),
                0, jsonMessage.Length);

            await webSocket.SendAsync(buffer, WebSocketMessageType.Text, true, cancellationToken);
        }
    }
}

Subprotocol negotiation

The subprotocol negotiation starts on the client side. As part of the WebSocket object constructor an array of supported subprotocols can be provided. If the negotiation succeeds the information about chosen subprotocol is available through protocol attribute of WebSocket instance.

var handleWebSocketPlainTextData = function(data) {
    ...
};

var handleWebSocketJsonData = function(data) {
    ...
};

var webSocket = new WebSocket('ws://example.com/socket',
    ['aspnetcore-ws.plaintext', 'aspnetcore-ws.json']);

webSocket.onmessage = function(message) {
    if (webSocket.protocol == 'aspnetcore-ws.json') {
        handleWebSocketJsonData(message.data);
    } else {
        handleWebSocketPlainTextData(message.data);
    }
};

The advertised subprotocols are transferred to the server as part of the connection handshake in Sec-WebSocket-Protocol header. On the server side this list is available through HttpContext.WebSockets.WebSocketRequestedProtocols property. The server completes handshake by providing the name of selected subprotocol as parameter to HttpContext.WebSockets.AcceptWebSocketAsync method.

The rules of subprotocol negotiation are simple. If the client has advertised a list of subprotocols the server must choose one of them. If the client hasn't advertised any subprotocols the server can't provide a subprotocol name as part of handshake. The best place to implement those rules seems to be the middleware.

There is also no way for client to specify a preference between subprotocols, the choice is entirely up to the server. In the below code the available subprotocols are being kept as a list and order on that list represents preference.

public class WebSocketConnectionsMiddleware
{
    private readonly ITextWebSocketSubprotocol _defaultSubProtocol;
    private readonly IList<ITextWebSocketSubprotocol> _supportedSubProtocols;
    private IWebSocketConnectionsService _connectionsService;


    public WebSocketSubprotocolsMiddleware(RequestDelegate next,
        IWebSocketConnectionsService connectionsService)
    {
        _defaultSubProtocol = new PlainTextWebSocketSubprotocol();
        _supportedSubProtocols = new List<ITextWebSocketSubprotocol>
        {
            new JsonWebSocketSubprotocol(),
            _defaultSubProtocol
        }
        _connectionsService = connectionsService ??
            throw new ArgumentNullException(nameof(connectionsService));
    }

    public async Task Invoke(HttpContext context)
    {
        if (context.WebSockets.IsWebSocketRequest)
        {
            ITextWebSocketSubprotocol subProtocol =
                NegotiateSubProtocol(context.WebSockets.WebSocketRequestedProtocols);

            WebSocket webSocket =
                await context.WebSockets.AcceptWebSocketAsync(subProtocol?.SubProtocol);

            WebSocketConnection webSocketConnection =
                new WebSocketConnection(webSocket, subProtocol ?? _defaultSubProtocol);

            ...
        }
        else
        {
            context.Response.StatusCode = 400;
        }
    }

    private ITextWebSocketSubprotocol NegotiateSubProtocol(IList<string> requestedSubProtocols)
    {
        ITextWebSocketSubprotocol subProtocol = null;

        foreach (ITextWebSocketSubprotocol supportedSubProtocol in _options.SupportedSubProtocols)
        {
            if (requestedSubProtocols.Contains(supportedSubProtocol.SubProtocol))
            {
                subProtocol = supportedSubProtocol;
                break;
            }
        }

        return subProtocol;
    }
}

With this implementation all the possible clients which are not supporting the new subprotocol will continue to work without any changes, they will not advertise any subprotocols and server will use the default one without providing the name as part of the handshake. Meantime our application, which advertises support for new protocol, will use it because it is the one preferred by the server.

It is also worth to mention that subprotocols are not meant only for the internal purposes. There is a number of well known subprotocols which registry is available here.

The demo project is available on GitHub, it should be a good starting point for playing with WebSocket subprotocol negotiation.

Older Posts