Recently I needed to add support for SSL Acceleration (Offloading) to one of projects I'm working on. In ASP.NET MVC this usually meant custom RequireHttpsAttribute, URL generator and IsHttps method. Whole team needed to be aware that custom components must be used instead of the ones provided by framework, otherwise the things would break. This is no longer case for ASP.NET Core, thanks to low level APIs like request features there is a more elegant way.

SSL Acceleration (Offloading)

SSL Acceleration is a process of using a hardware accelerator for performing SSL decryption and/or decryption. The process usually takes place on a load balancer or firewall, in which case it's called SSL Offloading. There are two flavors off SSL Offloading: SSL Bridging and SSL Termination. SSL Bridging usually doesn't require anything specific from application, but SSL Termination does. In case of SSL Termination the SSL connection doesn't go beyond the SSL Accelerator. The are two main benefits from SSL Termination:

  • Improved performance (the web servers don't have to use resources for SSL processing)
  • Simplified certificate management (the certificates are managed on a single device instead of every web server in cluster)

The drawback is that HTTPS traffic is not reaching the application. In this context the performance benefit can be questioned. The application is no longer able to fully utilize some of HTTP/2 features (for example Server Push) while the resources gain might not be that significant as modern CPUs have good support for encryption/decryption.

Despite the fact that SSL is being terminated, the application still must be able to verify if the original request was made over HTTPS (otherwise it could lower the application security). Typically the SSL Accelerators are providing information about the original protocol through dedicated HTTP header (one quite popular is X-Forwarded-Proto) which application needs to properly interpret.

Making ASP.NET Core understand SSL Acceleration

The "properly interpret" means that application needs to detect the presence of the header and if the value indicates that original request was over HTTPS it should be treated as such. In case of ASP.NET Core the perfect behavior would be for HttpContext.Request.IsHttps to return true. This would automatically make RequireHttpsAttribute and AddRedirectToHttps from URL Rewriting Middleware behave correctly. Also any other code which depends on that property will keep working as expected.

Luckily the value of HttpContext.Request.IsHttps is based on IHttpRequestFeature.Scheme property which value can be changed by application. Assuming that the header name is X-Forwarded-Proto and its value is the original scheme in lower case, following snippet is exactly what is needed.

if (!context.Request.IsHttps)
{
    if (context.Request.Headers.ContainsKey("X-Forwarded-Proto")
        && context.Request.Headers["X-Forwarded-Proto"].Equals("https"))
    {
        IHttpRequestFeature httpRequestFeature = context.Features.Get<IHttpRequestFeature>();
        httpRequestFeature.Scheme = "https";
    }
}

This snippet can be easily wrapped inside a reusable and parametrized middleware like one here.

This scenario is a nice example of how ASP.NET Core is layered and how much power gives the access to the low level building blocks.

In one of projects which I'm involved in we had to configure the limit for concurrent requests being processed. The project is quite specific, processing of a request is quite heavy and requires considerable resources. Because of that, when there was a high number of concurrent requests to process, the server was becoming unresponsive. I went on some research for best practices in this area. The "stress your server and see" seems to be the best advice but number of possible different approaches got me interested. In order to satisfy my curiosity I've decided to implement couple of them myself.

General remark

Limiting concurrent requests is a server responsibility, there shouldn't be a need for application to handle it. This is one of reasons for suggestion from the ASP.NET Team to run Kestrel behind fully featured web server like IIS or nginx (some limits support has come in 2.0). Despite that here I'm implementing a middleware. This is not that bad, registering such middleware at the beginning of the pipeline should prevent heavy processing early enough. Still I'm doing this just for fun, don't do it in production unless you have an unquestionable reason to do so.

Testing scenario

For the testing purposes I've decided to set up couple integration tests using XUnit and Microsoft.AspNetCore.TestHost. The general setup is very well described in documentation. As I intended to spawn multiple requests I wanted to have some rough timings captured for them, so I've prepared following extension method.

internal class HttpResponseMessageWithTiming
{
    internal HttpResponseMessage Response { get; set; }

    internal TimeSpan Timing { get; set; }
}

internal static class HttpClientExtensions
{
    internal static async Task<HttpResponseMessageWithTiming> GetWithTimingAsync(this HttpClient client,
        string requestUri)
    {
        Stopwatch stopwatch = Stopwatch.StartNew();

        HttpResponseMessage response = await client.GetAsync(requestUri);
        TimeSpan timing = stopwatch.Elapsed;

        stopwatch.Stop();

        return new HttpResponseMessageWithTiming
        {
            Response = response,
            Timing = timing
        };
    }
}

This is not perfect (the timings have risk of not being accurate) but should be good enough. I also wanted to have easy access to status code and timing during debugging so I've introduced another intermediate representation.

private struct HttpResponseInformation
{
    public HttpStatusCode StatusCode { get; set; }

    public TimeSpan Timing { get; set; }

    public override string ToString()
    {
        return $"StatusCode: {StatusCode} | Timing {Timing}";
    }
}

I've also created a prepare SUT method for instantiating TestServer.

private TestServer PrepareTestServer(IEnumerable<KeyValuePair<string, string>> configuration = null)
{
    IWebHostBuilder webHostBuilder = new WebHostBuilder()
        .UseStartup<Startup>();

    if (configuration != null)
    {
        ConfigurationBuilder configurationBuilder = new ConfigurationBuilder();
        configurationBuilder.AddInMemoryCollection(configuration);
        IConfiguration buildedConfiguration = configurationBuilder.Build();

        webHostBuilder.UseConfiguration(buildedConfiguration);
        webHostBuilder.ConfigureServices((services) =>
        {
            services.Configure<MaxConcurrentRequestsOptions>(options =>
                buildedConfiguration.GetSection("MaxConcurrentRequestsOptions").Bind(options)
            );
        });
    }

    return new TestServer(webHostBuilder);
}

The MaxConcurrentRequestsOptions class (empty at this point) will be used for controlling the behavior of the middleware. The Startup looks like this (to simulate long request processing):

public class Startup
{
    public void Configure(IApplicationBuilder app)
    {
        app.Run(async (context) =>
        {
            await Task.Delay(500);

            await context.Response.WriteAsync("-- Demo.AspNetCore.MaxConcurrentConnections --");
        });
    }
}

With all those elements in place I've created a general method to be reused by all tests.

private HttpResponseInformation[] GetResponseInformation(Dictionary<string, string> configuration,
    int concurrentRequestsCount)
{
    HttpResponseInformation[] responseInformation;

    using (TestServer server = PrepareTestServer(configuration))
    {
        List<HttpClient> clients = new List<HttpClient>();
        for (int i = 0; i < concurrentRequestsCount; i++)
        {
            clients.Add(server.CreateClient());
        }

        List<Task<HttpResponseMessageWithTiming>> responsesWithTimingsTasks =
            new List<Task<HttpResponseMessageWithTiming>>();
        foreach (HttpClient client in clients)
        {
            responsesWithTimingsTasks.Add(Task.Run(async () => {
                return await client.GetWithTimingAsync("/");
            }));
        }
        Task.WaitAll(responsesWithTimingsTasks.ToArray());

        clients.ForEach(client => client.Dispose());

        responseInformation = responsesWithTimingsTasks.Select(task => new HttpResponseInformation
        {
            StatusCode = task.Result.Response.StatusCode,
            Timing = task.Result.Timing
        }).ToArray();
    }

    return responseInformation;
}

This (by providing different configuration and concurrentRequestsCount) will allow me to test different approaches I'm going to play with.

Limiting concurrent requests through middleware

The first and most important thing which middleware needs to support is a hard limit. The functionality is very simple in theory, if the application is processing maximum number of concurrent request every incoming request should immediately result in 503 Service Unavailable.

public class MaxConcurrentRequestsMiddleware
{
    private int _concurrentRequestsCount;

    private readonly RequestDelegate _next;
    private readonly MaxConcurrentRequestsOptions _options;

    public MaxConcurrentRequestsMiddleware(RequestDelegate next,
        IOptions<MaxConcurrentRequestsOptions> options)
    {
        _concurrentRequestsCount = 0;

        _next = next ?? throw new ArgumentNullException(nameof(next));
        _options = options?.Value ?? throw new ArgumentNullException(nameof(options));
    }

    public async Task Invoke(HttpContext context)
    {
        if (CheckLimitExceeded())
        {
            IHttpResponseFeature responseFeature = context.Features.Get<IHttpResponseFeature>();

            responseFeature.StatusCode = StatusCodes.Status503ServiceUnavailable;
            responseFeature.ReasonPhrase = "Concurrent request limit exceeded.";
        }
        else
        {
            await _next(context);

            // TODO: Decrement concurrent requests count
        }
    }

    private bool CheckLimitExceeded()
    {
        bool limitExceeded = false;

        // TODO: Check and increment concurrent requests count

        return limitExceeded;
    }
}

The challenge hides in maintaining the count of concurrent requests. It must be incremented and decremented in a thread safe way while affecting performance as little as possible. The Interlocked class with its atomic operations for shared variables seems perfect for the job. I've decided to use the Interlocked.CompareExchange "do-while pattern" for check and increment which should ensure that value will not be exceeded.

public class MaxConcurrentRequestsMiddleware
{
    ...

    public async Task Invoke(HttpContext context)
    {
        if (CheckLimitExceeded())
        {
            ...
        }
        else
        {
            await _next(context);

            Interlocked.Decrement(ref _concurrentRequestsCount);
        }
    }

    private bool CheckLimitExceeded()
    {
        bool limitExceeded;

        int initialConcurrentRequestsCount, incrementedConcurrentRequestsCount;
        do
        {
            limitExceeded = true;

            initialConcurrentRequestsCount = _concurrentRequestsCount;
            if (initialConcurrentRequestsCount >= _options.Limit)
            {
                break;
            }

            limitExceeded = false;
            incrementedConcurrentRequestsCount = initialConcurrentRequestsCount + 1;
        }
        while (initialConcurrentRequestsCount != Interlocked.CompareExchange(
            ref _concurrentRequestsCount, incrementedConcurrentRequestsCount, initialConcurrentRequestsCount));

        return limitExceeded;
    }
}

After adding the middleware to the pipeline I've set up a test with 30 concurrent requests and 10 as MaxConcurrentRequestsOptions.Limit. The result is best observed through responseInformation temporary value.

Visual Studio Locals Window - Content of responseInformation array for hard limit scenario

There are 10 requests which resulted in 200 OK and 20 requests which resulted in 503 Service Unavailable (red frames) - the desired effect.

Queueing additional requests

The hard limit approach gets the job done, but there are more flexible approaches. In general they are based on assumption that the application can cheaply store more requests in memory than it's currently processing and client can wait for the response a little bit longer. Additional requests can wait in queue (typically a FIFO one) for resources to be available. The queue should have a size limit, otherwise the application might end up processing only the queued requests with constantly growing latency.

Because of the size limit requirement this can't be simply implemented with ConcurrentQueue, a custom (again thread safe) solution is needed. The middleware should also be able to await the queue. On a high level a class looking like the one below should provide the desired interface.

private class MaxConcurrentRequestsEnqueuer
{
    private static readonly Task<bool> _enqueueFailedTask = Task.FromResult(false);
    private readonly int _maxQueueLength;

    public MaxConcurrentRequestsEnqueuer(int maxQueueLength)
    {
        _maxQueueLength = maxQueueLength;
    }

    public Task<bool> EnqueueAsync()
    {
        Task<bool> enqueueTask = _enqueueFailedTask;

        // TODO: Check the size and enqueue

        return enqueueTask;
    }

    public bool Dequeue()
    {
        bool dequeued = false;

        // TODO: Dequeue

        return dequeued;
    }
}

The EnqueueAsync returns a Task<bool> so middleware can await and use the Result as indicator if the request should be processed or not.

Internally the class will maintain a Queue<TaskCompletionSource<bool>>.

internal class MaxConcurrentRequestsEnqueuer
{
    ...
    private readonly object _lock = new object();
    private readonly Queue<TaskCompletionSource<bool>> _queue = new Queue<TaskCompletionSource<bool>>();

    public Task<bool> EnqueueAsync()
    {
        Task<bool> enqueueTask = _enqueueFailedTask;

        if (_maxQueueLength > 0)
        {
            lock (_lock)
            {
                if (_queue.Count < _maxQueueLength)
                {
                    TaskCompletionSource<bool> enqueueTaskCompletionSource =
                        new TaskCompletionSource<bool>();

                    _queue.Enqueue(enqueueTaskCompletionSource);

                    enqueueTask = enqueueTaskCompletionSource.Task;
                }
            }
        }

        return enqueueTask;
    }

    public bool Dequeue()
    {
        bool dequeued = false;

        lock (_lock)
        {
            if (_queue.Count > 0)
            {
                _queue.Dequeue().SetResult(true);

                dequeued = true;
            }
        }

        return dequeued;
    }
}

Yes this is locking, not perfect but the request which ends up here is going to wait anyway so it can be considered as acceptable. Some consolation can be the fact that this class can be nicely used from the middleware.

public class MaxConcurrentRequestsMiddleware
{
    ...
    private readonly MaxConcurrentRequestsEnqueuer _enqueuer;

    public MaxConcurrentRequestsMiddleware(RequestDelegate next,
        IOptions<MaxConcurrentRequestsOptions> options)
    {
        ...

        if (_options.LimitExceededPolicy != MaxConcurrentRequestsLimitExceededPolicy.Drop)
        {
            _enqueuer = new MaxConcurrentRequestsEnqueuer(_options.MaxQueueLength);
        }
    }

    public async Task Invoke(HttpContext context)
    {
        if (CheckLimitExceeded() && !(await TryWaitInQueue()))
        {
            ...
        }
        else
        {
            await _next(context);

            if (ShouldDecrementConcurrentRequestsCount())
            {
                Interlocked.Decrement(ref _concurrentRequestsCount);
            }
        }
    }

    ...

    private async Task<bool> TryWaitInQueue()
    {
        return (_enqueuer != null) && (await _enqueuer.EnqueueAsync());
    }

    private bool ShouldDecrementConcurrentRequestsCount()
    {
        return (_enqueuer == null) || !_enqueuer.Dequeue();
    }
}

Below is the result of another test. The number of concurrent request is again 30 and the max number of concurrent requests is 10. The queue size has been also set to 10. The requests which have waited in queue are the ones in orange.

Visual Studio Locals Window - Content of responseInformation array for hard limit with FIFO queue scenario

This approach is called drop tail and has an alternative in form of drop head. In case of drop head when a new requests arrives (and the queue is full) the first request in the queue is being dropped. This can often result in better latency for waiting requests at price of dropped requests having to wait as well.

The MaxConcurrentRequestsEnqueuer can be easily modified to support drop head.

internal class MaxConcurrentRequestsEnqueuer
{
    public enum DropMode
    {
        Tail = MaxConcurrentRequestsLimitExceededPolicy.FifoQueueDropTail,
        Head = MaxConcurrentRequestsLimitExceededPolicy.FifoQueueDropHead
    }

    ...
    private readonly DropMode _dropMode;

    public MaxConcurrentRequestsEnqueuer(int maxQueueLength, DropMode dropMode)
    {
        ...
        _dropMode = dropMode;
    }

    public Task<bool> EnqueueAsync()
    {
        Task<bool> enqueueTask = _enqueueFailedTask;

        if (_maxQueueLength > 0)
        {
            lock (_lock)
            {
                if (_queue.Count < _maxQueueLength)
                {
                    enqueueTask = InternalEnqueueAsync();
                }
                else if (_dropMode == DropMode.Head)
                {
                    _queue.Dequeue().SetResult(false);

                    enqueueTask = InternalEnqueueAsync();
                }
            }
        }

        return enqueueTask;
    }

    ...

    private Task<bool> InternalEnqueueAsync()
    {
        TaskCompletionSource<bool> enqueueTaskCompletionSource = new TaskCompletionSource<bool>();

        _queue.Enqueue(enqueueTaskCompletionSource);

        return enqueueTaskCompletionSource.Task;
    }
}

My test scenario is not sophisticated enough (all requests are arriving roughly at the same time) to show the difference between drop tail and drop head, but this article has some nice comparison.

Further improvements

There are two more things which would be "nice to have" for the queue. First is support for request aborting. Right now if the request gets cancelled by client it will still wait in queue taking spot which one of incoming requests could take.

In order to implement the support for request aborting the internals of MaxConcurrentRequestsEnqueuer needs to be change in a way which enables removing specific request. To achieve that the Queue<TaskCompletionSource<bool>> can no longer be used, it will be replaced with LinkedList<TaskCompletionSource<bool>>. This allows enqueue and dequeue to remain a O(1) operations and adds possibility of O(n) removal. The O(n) is a pessimistic value, in reality the aborted request will be typically at the beginning of the list so it will be found quickly. Only a couple of changes to MaxConcurrentRequestsEnqueuer is needed.

internal class MaxConcurrentRequestsEnqueuer
{
    ...
    private readonly LinkedList<TaskCompletionSource<bool>> _queue =
        new LinkedList<TaskCompletionSource<bool>>();

    ...

    public Task<bool> EnqueueAsync()
    {
        Task<bool> enqueueTask = _enqueueFailedTask;

        if (_maxQueueLength > 0)
        {
            lock (_lock)
            {
                if (_queue.Count < _maxQueueLength)
                {
                    enqueueTask = InternalEnqueueAsync();
                }
                else if (_dropMode == DropMode.Head)
                {
                    InternalDequeue(false);

                    enqueueTask = InternalEnqueueAsync();
                }
            }
        }

        return enqueueTask;
    }

    public bool Dequeue()
    {
        bool dequeued = false;

        lock (_lock)
        {
            if (_queue.Count > 0)
            {
                InternalDequeue(true);

                dequeued = true;
            }
        }

        return dequeued;
    }

    private Task<bool> InternalEnqueueAsync()
    {
        TaskCompletionSource<bool> enqueueTaskCompletionSource = new TaskCompletionSource<bool>();

        _queue.AddLast(enqueueTaskCompletionSource);

        return enqueueTaskCompletionSource.Task;
    }

    private void InternalDequeue(bool result)
    {
        TaskCompletionSource<bool> enqueueTaskCompletionSource = _queue.First.Value;

        _queue.RemoveFirst();

        enqueueTaskCompletionSource.SetResult(result);
    }
}

Now the support for CancellationToken (which will carry the information about request being aborted) can be added.

internal class MaxConcurrentRequestsEnqueuer
{
    ...

    public Task<bool> EnqueueAsync(CancellationToken cancellationToken)
    {
        Task<bool> enqueueTask = _enqueueFailedTask;

        if (_maxQueueLength > 0)
        {
            lock (_lock)
            {
                if (_queue.Count < _maxQueueLength)
                {
                    enqueueTask = InternalEnqueueAsync(cancellationToken);
                }
                else if (_dropMode == DropMode.Head)
                {
                    InternalDequeue(false);

                    enqueueTask = InternalEnqueueAsync(cancellationToken);
                }
            }
        }

        return enqueueTask;
    }

    ...

    private Task<bool> InternalEnqueueAsync(CancellationToken cancellationToken)
    {
        Task<bool> enqueueTask = _enqueueFailedTask;

        TaskCompletionSource <bool> enqueueTaskCompletionSource = new TaskCompletionSource<bool>();

        cancellationToken.Register(CancelEnqueue, enqueueTaskCompletionSource);

        if (!cancellationToken.IsCancellationRequested)
        {
            _queue.AddLast(enqueueTaskCompletionSource);
            enqueueTask = enqueueTaskCompletionSource.Task;
        }

        return enqueueTask;
    }

    private void CancelEnqueue(object state)
    {
        bool removed = false;

        TaskCompletionSource<bool> enqueueTaskCompletionSource = ((TaskCompletionSource<bool>)state);
        lock (_lock)
        {
            removed = _queue.Remove(enqueueTaskCompletionSource);
        }

        if (removed)
        {
            enqueueTaskCompletionSource.SetResult(false);
        }
    }

    ...
}

The middleware can pass HttpContext.RequestAborted as the CancellationToken (it is also a good idea to wrapp the response status setting code in IsCancellationRequested check), which should ensure that all aborted requests will be removed from queue quickly.

Second useful thing is limitation of time which request can spend in queue. This is just additional protection in cases when queue gets stuck and clients doesn't implement their own timeouts. As the cancellation is already supported this can be nicely introduced with help of CancellationTokenSource.CancelAfter and CancellationTokenSource.CreateLinkedTokenSource.

internal class MaxConcurrentRequestsEnqueuer
{
    ...
    private readonly int _maxTimeInQueue;

    public MaxConcurrentRequestsEnqueuer(int maxQueueLength, DropMode dropMode, int maxTimeInQueue)
    {
        ...
        _maxTimeInQueue = maxTimeInQueue;
    }

    ...

    private Task<bool> InternalEnqueueAsync(CancellationToken cancellationToken)
    {
        Task<bool> enqueueTask = _enqueueFailedTask;

        TaskCompletionSource <bool> enqueueTaskCompletionSource = new TaskCompletionSource<bool>();

        CancellationToken enqueueCancellationToken =
            GetEnqueueCancellationToken(enqueueTaskCompletionSource, cancellationToken);

        if (!enqueueCancellationToken.IsCancellationRequested)
        {
            _queue.AddLast(enqueueTaskCompletionSource);
            enqueueTask = enqueueTaskCompletionSource.Task;
        }

        return enqueueTask;
    }

    private CancellationToken GetEnqueueCancellationToken(
        TaskCompletionSource<bool> enqueueTaskCompletionSource, CancellationToken cancellationToken)
    {
        CancellationToken enqueueCancellationToken = CancellationTokenSource.CreateLinkedTokenSource(
            cancellationToken,
            GetTimeoutToken(enqueueTaskCompletionSource)
        ).Token;

        enqueueCancellationToken.Register(CancelEnqueue, enqueueTaskCompletionSource);

        return enqueueCancellationToken;
    }

    private CancellationToken GetTimeoutToken(TaskCompletionSource<bool> enqueueTaskCompletionSource)
    {
        CancellationToken timeoutToken = CancellationToken.None;

        if (_maxTimeInQueue != MaxConcurrentRequestsOptions.MaxTimeInQueueUnlimited)
        {
            CancellationTokenSource timeoutTokenSource = new CancellationTokenSource();

            timeoutToken = timeoutTokenSource.Token;
            timeoutToken.Register(CancelEnqueue, enqueueTaskCompletionSource);

            timeoutTokenSource.CancelAfter(_maxTimeInQueue);
        }

        return timeoutToken;
    }

    ...
}

The timeout can be clearly shown through test by setting max time in queue to value lower than processing time (for example 300ms).

Visual Studio Locals Window - Content of responseInformation array for hard limit with FIFO queue and max time in queue scenario

As expected there are two groups of dropped requests, the ones which were dropped immediately and ones which were dropped after max time in queue.

There is more

This (quite long) post doesn't fully explore the subject. There are other approaches, for example based on LIFO queue. If this post got you interested in the subject I suggest further research.

More "polished" version of MaxConcurrentRequestsEnqueuer and MaxConcurrentRequestsMiddleware with demo project and tests can be found on GitHub.

I've received a question from a user of my Server-Sent Events Middleware. In general the user was asking if it can be used together with Response Compression Middleware because he had hard time making it work. As there shouldn't be any technical limitation to such scenario I've decided to quickly test it in my simple demo.

public class Startup
{
    public void ConfigureServices(IServiceCollection services)
    {
        services.AddResponseCompression(options =>
        {
            options.MimeTypes = ResponseCompressionDefaults.MimeTypes.Concat(new[]
            {
                "text/event-stream"
            });
        });

        services.AddServerSentEvents();
        ...
    }

    public void Configure(IApplicationBuilder app)
    {
        app.UseResponseCompression()
            .MapServerSentEvents("/see-heartbeat")
            ...;

        ...
    }
}

It worked without any issues.

Chrome Developer Tools Network Tab - Server-Sent Events with gzip

I've quickly written a response including my snippet. Unfortunately the user was doing exactly same thing and while for me it was working perfectly for him it seemed to be doing nothing (there was no error or any other side effect, just events not arriving to the client). We have exchanged couple emails until we have finally discovered a difference in our scenarios - my target framework was netcoreapp1.0 while his was net451. I've switched my project to net451 and I could observe the same behavior.

Difference between .NET Framework and .NET Core

I've started looking for the root cause. It was obviously somehow related to the response compression but I didn't had an idea what could that be. Until I've found below fragment inside of GzipCompressionProvider.

public bool SupportsFlush
{
    get
    {
#if NET451
        return false;
#elif NETSTANDARD1_3
        return true;
#else
        // Not implemented, compiler break
#endif
    }
}

This clearly shows that GZipStream (which is being internally used by GzipCompressionProvider) doesn't support flushing in .NET Framework (and my Server Sent Events implementation is flushing the events when they are completely written). This difference is not documented, in fact the documentation states that GZipStream.Flush has not functionality regardless of implementation. I was able to find this issue which shed some light on how GZipStream has started actually flushing. The bottom line is that when used over .NET Framework the Response Compression Middleware is also buffering the response.

Response Buffering in ASP.NET Core

In ASP.NET Core component which provides response buffering capabilities should implement IHttpBufferingFeature. The Response Compression Middleware does it through BodyWrapperStream in which it wraps the original body stream. This means that it should be possible to disable the buffering with following code.

private void DisableResponseBuffering(HttpContext context)
{
    IHttpBufferingFeature bufferingFeature = context.Features.Get<IHttpBufferingFeature>();
    if (bufferingFeature != null)
    {
        bufferingFeature.DisableResponseBuffering();
    }
}

I've added this to ServerSentEventsMiddleware.

public class ServerSentEventsMiddleware
{
    ...

    public async Task Invoke(HttpContext context)
    {
        if (context.Request.Headers[Constants.ACCEPT_HTTP_HEADER] == Constants.SSE_CONTENT_TYPE)
        {
            DisableResponseBuffering(context);

            ...
        }
        else
        {
            await _next(context);
        }
    }

    ...
}

After this change running the demo application with net451 as target framework resulted in events correctly reaching the client. The difference was that the response wasn't compressed.

Chrome Developer Tools Network Tab - Server-Sent Events no compression for .NET Framework

This is how Response Compression Middleware handles the DisableResponseBuffering method. If compression provider which is supposed to be used doesn't support flushing (the SupportsFlush property above) it disables the compression.

This is an interesting and worth to remember difference in behavior between .NET Framework and .NET Core.

This is my third post about WebSocket protocol in ASP.NET Core. Previously I've written about subprotocol negotiation and Cross-Site WebSocket Hijacking. The subject I'm focusing on here is the per-message compression which is out of the box supported by Chrome, FireFox and other browsers.

WebSocket Extensions

The WebSocket protocol has a concept of extensions, which can provide new capabilities. An extension can define any additional functionality which is able to work on top of the WebSocket framing layer. The specification reserves three bits of header (RSV1, RSV2 and RSV3) and all opcodes from 3 to 7 and 11 to 15 to be used by extensions (it also allows for using the reserved bits in order to create additional opcodes or even using some of the payload data for that purpose). The extensions (similarly to subprotocol) are being negotiated through dedicated header (Sec-WebSocket-Extensions) as part of the handshake. A client can advertise the supported extensions by putting the list into the header and server can accept one or more in the exactly same way.

There are two extensions I have heard of: A Multiplexing Extension for WebSockets and Compression Extensions for WebSocket. The first one has never gone beyond draft but the second has become a standard and got adopted by several browsers.

WebSocket Per-Message Compression Extensions

The Compression Extensions for WebSocket standard defines two things. First is a framework for adding compression functionality to the WebSocket protocol. The framework is really simple, it states only two things:

  • Per-Message Compression Extension operates only on message data (so compression takes place before spliting into frames and decompression takes place after all frames have been received).
  • Per-Message Compression Extension allocates the RSV1 bit and calls it the Per-Message Compressed bit. The bit is supposed to be set to 1 on first frame of compressed message.

The challenging part is the allocation of the RSV1 bit. It makes it impossible to implement support for per-message compression on top of the WebSocket stack available in ASP.NET Core. Because of that I've decided to roll my own implementation for IHttpWebSocketFeature. It is very similar to one provided by Microsoft.AspNetCore.WebSockets and the underlying WebSocket implementation is based on ManagedWebSocket so closely that needed changes can be described in its context (the key difference is that my implementation is stripped from the client specific logic as it is not needed).

From the public API perspective there must be way to set and get the information that message is compressed. First can be achieved with overload of SendAsync method (or more like extending the current SendAsync with one more parameter and providing overload which doesn't need it).

internal class CompressionWebSocket : WebSocket
{
    public override Task SendAsync(ArraySegment<byte> buffer, WebSocketMessageType messageType,
        bool endOfMessage, CancellationToken cancellationToken)
    {
        return SendAsync(buffer, messageType, false, endOfMessage, cancellationToken);
    }

    public Task SendAsync(ArraySegment<byte> buffer, WebSocketMessageType messageType, bool compressed,
        bool endOfMessage, CancellationToken cancellationToken)
    {
        ...;
    }
}

The information about a received message can be exposed through a delivered WebSocketReceiveResult.

public class CompressionWebSocketReceiveResult : WebSocketReceiveResult
{
    public bool Compressed { get; }

    public CompressionWebSocketReceiveResult(int count, WebSocketMessageType messageType,
        bool compressed, bool endOfMessage)
        : base(count, messageType, endOfMessage)
    {
        Compressed = compressed;
    }

    public CompressionWebSocketReceiveResult(int count, WebSocketMessageType messageType,
        bool endOfMessage, WebSocketCloseStatus? closeStatus, string closeStatusDescription)
        : base(count, messageType, endOfMessage, closeStatus, closeStatusDescription)
    {
        Compressed = false;
    }
}

Next step is adjusting the internals of the WebScoket implementation to properly write and read the RSV1 bit. The writing part is being handled by the WriteHeader method. This method needs to be changed in a way that it sets the RSV1 bit when the messages is compressed and current frame is not a continuation.

private static int WriteHeader(WebSocketMessageOpcode opcode, byte[] sendBuffer,
    ArraySegment<byte> payload, bool compressed, bool endOfMessage)
{
    sendBuffer[0] = (byte)opcode;

    if (compressed && (opcode != WebSocketMessageOpcode.Continuation))
    {
        sendBuffer[0] |= 0x40;
    }

    if (endOfMessage)
    {
        sendBuffer[0] |= 0x80;
    }

    ...
}

After this change all the paths leading to WriteHeader method must be changed to either provide (passed down) value of compressed parameter from SendAsync or false.

The receiving flow has a corresponding method TryParseMessageHeaderFromReceiveBuffer which fills out a MessageHeader struct. A different version of that struct is needed.

[StructLayout(LayoutKind.Auto)]
internal struct CompressionWebSocketMessageHeader
{
    internal WebSocketMessageOpcode Opcode { get; set; }

    internal bool Compressed { get; set; }

    internal bool Fin { get; set; }

    internal long PayloadLength { get; set; }

    internal int Mask { get; set; }
}

The TryParseMessageHeaderFromReceiveBuffer method will require two changes. One will take care of reading the RSV1 bit and the second will change the validation of all RSV bits values (per protocol specification invalid combination of RSV bits must fail the connection).

private bool TryParseMessageHeaderFromReceiveBuffer(out CompressionWebSocketMessageHeader resultHeader)
{
    var header = new CompressionWebSocketMessageHeader();

    header.Opcode = (WebSocketMessageOpcode)(_receiveBuffer[_receiveBufferOffset] & 0xF);
    header.Compressed = (_receiveBuffer[_receiveBufferOffset] & 0x40) != 0;
    header.Fin = (_receiveBuffer[_receiveBufferOffset] & 0x80) != 0;

    bool reservedSet = (_receiveBuffer[_receiveBufferOffset] & 0x70) != 0;
    bool reservedExceptCompressedSet = (_receiveBuffer[_receiveBufferOffset] & 0x30) != 0;

    ...

    bool shouldFail = (!header.Compressed && reservedSet) || reservedExceptCompressedSet;

    ...
}

The last step is to modify InternalReceiveAsync method so it skips UTF-8 validation for compressed messages and properly creates CompressionWebSocketReceiveResult.

private async Task<WebSocketReceiveResult> InternalReceiveAsync(ArraySegment<byte> payloadBuffer,
    CancellationToken cancellationToken)
{
    ...

    try
    {
        while (true)
        {
            ...

            if ((header.Opcode == WebSocketMessageOpcode.Text) && !header.Compressed
                && !TryValidateUtf8(
                    new ArraySegment<byte>(payloadBuffer.Array, payloadBuffer.Offset, bytesToCopy),
                    header.Fin, _utf8TextState))
            {
                await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.InvalidPayloadData,
                    WebSocketError.Faulted, cancellationToken).ConfigureAwait(false);
            }

            _lastReceiveHeader = header;
            return new CompressionWebSocketReceiveResult(
                bytesToCopy,
                header.Opcode == WebSocketMessageOpcode.Text ?
                    WebSocketMessageType.Text : WebSocketMessageType.Binary,
                header.Compressed,
                bytesToCopy == 0 || (header.Fin && header.PayloadLength == 0));
        }
    }
    catch (Exception ex)
    {
        ...
    }
    finally
    {
        ...
    }
}

With those changes in place the WebSocket implementation has support for per-message compression framework. A support for specific compression extension can be implemented on top of that.

Deflate based PMCE

The second thing which Compression Extensions for WebSocket standard defines is permessage-deflate compression extension. This extension specifies a way of compressesing message payload using the DEFLATE algorithm with help of the byte boundary alignment method. But first it is worth to implement concepts which are shared by all potential compression extensions - receiving and sending the message payload. Methods responsible for handling those operations should be able to properly concatenate or split the message into frames.

public abstract class WebSocketCompressionProviderBase
{
    private readonly int? _sendSegmentSize;

    ...

    protected async Task SendMessageAsync(WebSocket webSocket, byte[] message,
        WebSocketMessageType messageType, bool compressed, CancellationToken cancellationToken)
    {
        if (webSocket.State == WebSocketState.Open)
        {
            if (_sendSegmentSize.HasValue && (_sendSegmentSize.Value < message.Length))
            {
                int messageOffset = 0;
                int messageBytesToSend = message.Length;

                while (messageBytesToSend > 0)
                {
                    int messageSegmentSize = Math.Min(_sendSegmentSize.Value, messageBytesToSend);
                    ArraySegment<byte> messageSegment = new ArraySegment<byte>(message, messageOffset,
                        messageSegmentSize);

                    messageOffset += messageSegmentSize;
                    messageBytesToSend -= messageSegmentSize;

                    await SendAsync(webSocket, messageSegment, messageType, compressed,
                        (messageBytesToSend == 0), cancellationToken);
                }
            }
            else
            {
                ArraySegment<byte> messageSegment = new ArraySegment<byte>(message, 0, message.Length);

                await SendAsync(webSocket, messageSegment, messageType, compressed, true,
                    cancellationToken);
            }
        }
    }

    private Task SendAsync(WebSocket webSocket, ArraySegment<byte> messageSegment,
        WebSocketMessageType messageType, bool compressed, bool endOfMessage,
        CancellationToken cancellationToken)
    {
        if (compressed)
        {
            CompressionWebSocket compressionWebSocket = (webSocket as CompressionWebSocket)
            ?? throw new InvalidOperationException($"Used WebSocket must be CompressionWebSocket.");

            return compressionWebSocket.SendAsync(messageSegment, messageType, true, endOfMessage,
                cancellationToken);
        }
        else
        {
            return webSocket.SendAsync(messageSegment, messageType, endOfMessage, cancellationToken);
        }
    }

    protected async Task<byte[]> ReceiveMessagePayloadAsync(WebSocket webSocket,
        WebSocketReceiveResult webSocketReceiveResult, byte[] receivePayloadBuffer)
    {
        byte[] messagePayload = null;

        if (webSocketReceiveResult.EndOfMessage)
        {
            messagePayload = new byte[webSocketReceiveResult.Count];
            Array.Copy(receivePayloadBuffer, messagePayload, webSocketReceiveResult.Count);
        }
        else
        {
            IEnumerable<byte> webSocketReceivedBytesEnumerable = Enumerable.Empty<byte>();
            webSocketReceivedBytesEnumerable = webSocketReceivedBytesEnumerable
                .Concat(receivePayloadBuffer);

            while (!webSocketReceiveResult.EndOfMessage)
            {
                webSocketReceiveResult = await webSocket.ReceiveAsync(
                    new ArraySegment<byte>(receivePayloadBuffer), CancellationToken.None);
                webSocketReceivedBytesEnumerable = webSocketReceivedBytesEnumerable
                    .Concat(receivePayloadBuffer.Take(webSocketReceiveResult.Count));
            }

            messagePayload = webSocketReceivedBytesEnumerable.ToArray();
        }

        return messagePayload;
    }
}

With this base the permessage-deflate specifics can be implemented. Let's start with the byte boundary alignment method. In practice it boils down to two operations:

  • In case of compression operation the compressed data should end with empty deflate block and last four octets of that block removed.
  • In case of decompression operation last four octets of empty deflate block should be appended to the received payload before decompression.

It looks that in case of compressing with DeflateStream provided by .NET the empty deflate block is always there, so the above can be implemented with two helper methods.

public sealed class WebSocketDeflateCompressionProvider : WebSocketCompressionProviderBase
{
    private static readonly byte[] LAST_FOUR_OCTETS = new byte[] { 0x00, 0x00, 0xFF, 0xFF };
    ...

    private byte[] TrimLastFourOctetsOfEmptyNonCompressedDeflateBlock(byte[] compressedMessagePayload)
    {
        int lastFourOctetsOfEmptyNonCompressedDeflateBlockPosition = 0;
        for (int position = compressedMessagePayload.Length - 1; position >= 4; position--)
        {
            if ((compressedMessagePayload[position - 3] == LAST_FOUR_OCTETS[0])
                && (compressedMessagePayload[position - 2] == LAST_FOUR_OCTETS[1])
                && (compressedMessagePayload[position - 1] == LAST_FOUR_OCTETS[2])
                && (compressedMessagePayload[position] == LAST_FOUR_OCTETS[3]))
            {
                lastFourOctetsOfEmptyNonCompressedDeflateBlockPosition = position - 3;
                break;
            }
        }
        Array.Resize(ref compressedMessagePayload, lastFourOctetsOfEmptyNonCompressedDeflateBlockPosition);

        return compressedMessagePayload;
    }

    private byte[] AppendLastFourOctetsOfEmptyNonCompressedDeflateBlock(byte[] compressedMessagePayload)
    {
        Array.Resize(ref compressedMessagePayload, compressedMessagePayload.Length + 4);

        compressedMessagePayload[compressedMessagePayload.Length - 4] = LAST_FOUR_OCTETS[0];
        compressedMessagePayload[compressedMessagePayload.Length - 3] = LAST_FOUR_OCTETS[1];
        compressedMessagePayload[compressedMessagePayload.Length - 2] = LAST_FOUR_OCTETS[2];
        compressedMessagePayload[compressedMessagePayload.Length - 1] = LAST_FOUR_OCTETS[3];

        return compressedMessagePayload;
    }
}

The second set of helper methods which will be needed is the actual compression and decompression. For simplicity purposes only text messages will be considered from this point.

public sealed class WebSocketDeflateCompressionProvider : WebSocketCompressionProviderBase
{
    ...
    private static readonly Encoding UTF8_WITHOUT_BOM = new UTF8Encoding(false);

    ...

    private async Task<byte[]> CompressTextWithDeflateAsync(string message)
    {
        byte[] compressedMessagePayload = null;

        using (MemoryStream compressedMessagePayloadStream = new MemoryStream())
        {
            using (DeflateStream compressedMessagePayloadCompressStream =
                new DeflateStream(compressedMessagePayloadStream, CompressionMode.Compress))
            {
                using (StreamWriter compressedMessagePayloadCompressWriter =
                    new StreamWriter(compressedMessagePayloadCompressStream, UTF8_WITHOUT_BOM))
                {
                    await compressedMessagePayloadCompressWriter.WriteAsync(message);
                }
            }

            compressedMessagePayload = compressedMessagePayloadStream.ToArray();
        }

        return compressedMessagePayload;
    }

    private async Task<string> DecompressTextWithDeflateAsync(byte[] compressedMessagePayload)
    {
        string message = null;

        using (MemoryStream compressedMessagePayloadStream = new MemoryStream(compressedMessagePayload))
        {
            using (DeflateStream compressedMessagePayloadDecompressStream =
                new DeflateStream(compressedMessagePayloadStream, CompressionMode.Decompress))
            {
                using (StreamReader compressedMessagePayloadDecompressReader =
                    new StreamReader(compressedMessagePayloadDecompressStream, UTF8_WITHOUT_BOM))
                {
                    message = await compressedMessagePayloadDecompressReader.ReadToEndAsync();
                }
            }
        }

        return message;
    }
}

Now the public API can be exposed.

public interface IWebSocketCompressionProvider
{
    Task CompressTextMessageAsync(WebSocket webSocket, string message,
        CancellationToken cancellationToken);

    Task<string> DecompressTextMessageAsync(WebSocket webSocket,
        WebSocketReceiveResult webSocketReceiveResult, byte[] receivePayloadBuffer);
}

public sealed class WebSocketDeflateCompressionProvider :
    WebSocketCompressionProviderBase, IWebSocketCompressionProvider
{
    ...

    public override async Task CompressTextMessageAsync(WebSocket webSocket, string message,
        CancellationToken cancellationToken)
    {
        byte[] compressedMessagePayload = await CompressTextWithDeflateAsync(message);

        compressedMessagePayload =
            TrimLastFourOctetsOfEmptyNonCompressedDeflateBlock(compressedMessagePayload);

        await SendMessageAsync(webSocket, compressedMessagePayload, WebSocketMessageType.Text, true,
            cancellationToken);
    }

    public override async Task<string> DecompressTextMessageAsync(WebSocket webSocket,
        WebSocketReceiveResult webSocketReceiveResult, byte[] receivePayloadBuffer)
    {
        string message = null;

        CompressionWebSocketReceiveResult compressionWebSocketReceiveResult =
            webSocketReceiveResult as CompressionWebSocketReceiveResult;

        if ((compressionWebSocketReceiveResult != null) && compressionWebSocketReceiveResult.Compressed)
        {
            byte[] compressedMessagePayload =
                await ReceiveMessagePayloadAsync(webSocket, webSocketReceiveResult, receivePayloadBuffer);

            compressedMessagePayload =
                AppendLastFourOctetsOfEmptyNonCompressedDeflateBlock(compressedMessagePayload);

            message = await DecompressTextWithDeflateAsync(compressedMessagePayload);
        }
        else
        {
            byte[] messagePayload =
                await ReceiveMessagePayloadAsync(webSocket, webSocketReceiveResult, receivePayloadBuffer);

            message = Encoding.UTF8.GetString(messagePayload);
        }

        return message;
    }

    ...
}

This API makes it easy to plug in a compression provider into typical (SendAsync and ReceiveAsync based) flow for WebSocket by replacing calls to SendAsync with calls to CompressTextMessageAsync and calling DecompressTextMessageAsync whenever the WebSocketReceiveResult acquired from ReceiveAsync indicates a text message. But before this can be done the permessage-deflate extension must be properly negotiated.

Context takeover and LZ77 window size

An important part of compression extension negotiation are parameters. The permessage-deflate defines four of them: server_no_context_takeover, client_no_context_takeover, server_max_window_bits and client_max_window_bits. First two define if server and/or client can reuse the same context (LZ77 sliding window) for subsequent messages. The remaining two allow for limiting the LZ77 sliding window size. The sad truth is that the above implementation is not able to handle most of this parameters properly, so the negotiation process needs to make sure that the acceptable values are beign used or negotiation fails (failing of negotiation doesn't fail the connection, the handshake response simply doesn't contain the extension as accepted one). So what are the acceptable values for this implementation?

The DeflateStream doesn't provide control over LZ77 sliding window size which means that the negotiation must be failed if offer contains server_max_window_bits parameter as it can't be handled. At the same time the presence of client_max_window_bits should be ignored as this is just a hint that client can support this parameter.

When it comes to the context reuse the above implementation creates new DeflateStream for every message which means it always work in "no context takeover" mode. Because of that the negiotation response must prevent client from reusing the context - the client_no_context_takeover must always be included in the response. This also means that server_no_context_takeover send by client in the offer can always be accepted.

I'm skipping the code which handles the negotiation here. Despite being based on NameValueWithParametersHeaderValue class which handles all the parsing it is still quite lengthy (it also must validate the parameters for "technical correctness"). For anyone who is interested the implementation is split between WebSocketCompressionService and WebSocketDeflateCompressionOptions classes which can be found at GitHub (link below).

Trying this out

There was an issue in Microsoft.AspNetCore.WebSockets repository for implementing per-message compression. It's currently closed, so I've made this implementation available independently on GitHub and NuGet. It is also part of my WebSocket demo project if somebody is looking for ready to use playground.

In my previous post I've written about subprotocols in WebSocket protocol. This time I wanted to focus on Cross-Site WebSocket Hijacking vulnerability.

Cross-Site WebSocket Hijacking

The WebSocket protocol is not a subject to same-origin policy. The specification states that "Servers that are not intended to process input from any web page but only for certain sites SHOULD verify the |Origin| field is an origin they expect.". This means that browser will allow any page to open a WebSocket connection.

Let's imagine a scenario in which the application is sending sensitive data over a WebSocket and authentication is based on cookies which are being send as a part of the initial handshake. In such case if user visits a malicious page while being logged to the application that page can open an authenticated WebSocket connection because browser will automatically send all the cookies. This is quite common and (if not protected) dangerous scenario. There are also more "interesting" scenarios possible like this case of remote code execution.

Protecting against CSWSH

Protection against CSWSH is easy to implement. As the Origin header is required part of initial handshake the application should check its value against list of acceptable origins and if it's not there respond with 403 Forbidden status code. The sample from my previous post was using WebSocketConnectionsMiddleware for handling the connections which makes it a perfect place to add this check.

public class WebSocketConnectionsOptions
{
    public HashSet<string> AllowedOrigins { get; set; }
}

public class WebSocketConnectionsMiddleware
{
    private WebSocketConnectionsOptions _options;
    ...

    public WebSocketConnectionsMiddleware(RequestDelegate next, WebSocketConnectionsOptions options, ...)
    {
        _options = options ?? throw new ArgumentNullException(nameof(options));
        ...
    }

    public async Task Invoke(HttpContext context)
    {
        if (context.WebSockets.IsWebSocketRequest)
        {
            if (ValidateOrigin(context))
            {
                ...
            }
            else
            {
                context.Response.StatusCode = StatusCodes.Status403Forbidden;
            }
        }
        else
        {
            context.Response.StatusCode = StatusCodes.Status400BadRequest;
        }
    }

    private bool ValidateOrigin(HttpContext context)
    {
        return (_options.AllowedOrigins == null)
            || (_options.AllowedOrigins.Count == 0)
            || (_options.AllowedOrigins.Contains(context.Request.Headers["Origin"].ToString()));
    }

    ...
}

Now the list of acceptable origins can be passed during the middleware registration.

public class Startup
{
    ...

    public void Configure(IApplicationBuilder app)
    {
        ...

        WebSocketConnectionsOptions webSocketConnectionsOptions = new WebSocketConnectionsOptions
        {
            AllowedOrigins = new HashSet<string> { "http://localhost:63290" }
        };

        ...
        app.UseWebSockets();
        app.Map("/socket", branchedApp =>
        {
            branchedApp.UseMiddleware<WebSocketConnectionsMiddleware>(webSocketConnectionsOptions);
        });
        ...
    }
}

I've extended the demo available at GitHub with this functionality.

Older Posts