Sometimes there is a need for a web application to acquire client originating IP address (location dependent content, audit requirements etc.). ASP.NET Core provides HttpContext.Connection.RemoteIpAddress property which provides originating IP address for the connection. In todays web the origin of connection seen by the web server is rarely the client, more likely the last proxy on the path. So, in order to attempt acquiring the client originating IP address developer needs to work a little bit more harder than using a property.

The most common way of preserving the client IP address is the X-Forwarded-For header (it's even more popular than standardized Forwarded header). It should contain a comma separated list of addresses which represent the request path through the network. The first entry on that list should be the client address. The X-Forwarded-For header in ASP.NET Core is supported out of the box by Forwarded Headers Middleware, it just needs to be added to the pipeline.

Unfortunately, the X-Forwarded-For header is often not implemented correctly by proxies (for example they might override the value instead of appending to it). In this context (and fact that Forwarded headers is not picking up as quickly as it should) the global reverse proxy provides like Akamai or CloudFlare have done what typically happens in software for such cases - they have introduced they own headers. Akamai is using True-Client-IP (it's not unique to Akamai to be precise, there are some others who are using it by Akamai is the biggest) and CloudFlare is using CF_CONNECTING_IP.

So, how all of those can be handled within an application if needed? Luckily the Forwarded Headers Middleware can be registered multiple times.

Stacking Forwarded Headers Middleware

The Forwarded Headers Middleware supports multiple headers with different responsibilities. The options allow for providing different names for those headers and choosing which one should be processed. Assuming that part of the stack will be Forwarded Headers Middleware configured to handle its default headers, two more needs to be added. For each one of them an extension method can be created to encapsulate the configuration.

For Akamai the ForwardedForHeaderName property should be set to True-Client-IP and processing limited to XForwardedFor.

public static class ForwardedHeadersExtensions
{
    public static IApplicationBuilder UseAkamaiTrueClientIp(this IApplicationBuilder app)
    {
        if (app == null)
        {
            throw new ArgumentNullException(nameof(app));
        }

        return app.UseForwardedHeaders(new ForwardedHeadersOptions
        {
            ForwardedForHeaderName = "True-Client-IP",
            ForwardedHeaders = ForwardedHeaders.XForwardedFor
        });
    }
}

The only difference in case of CloudFlare configuration is ForwardedForHeaderName property value.

public static class ForwardedHeadersExtensions
{
    ...

    public static IApplicationBuilder UseCloudFlareConnectingIp(this IApplicationBuilder app)
    {
        if (app == null)
        {
            throw new ArgumentNullException(nameof(app));
        }

        return app.UseForwardedHeaders(new ForwardedHeadersOptions
        {
            ForwardedForHeaderName = "CF_CONNECTING_IP",
            ForwardedHeaders = ForwardedHeaders.XForwardedFor
        });
    }
}

This allows to easily stack the middleware.

public void Configure(IApplicationBuilder app, IHostingEnvironment env)
{
    app.UseForwardedHeaders(new ForwardedHeadersOptions
    {
        ForwardedHeaders = ForwardedHeaders.All
    });
    app.UseAkamaiTrueClientIp();
    app.UseCloudFlareConnectingIp();

    ...
}

This works, but you might want something more. One of the issues with forwarded headers is trust - they can be easily spoofed. In case of global providers there are options to protect against that. For simplicity I will focus on CloudFlare from this point.

Preventing CloudFlare Connecting IP spoofing

CloudFlare makes its IP ranges available here. This list can be used to validate if incoming request is allowed to carry CF_CONNECTING_IP header as CloudFlare should be the last intermediary in front of the server. To perform that validation one can start with wrapping the ForwardedHeadersMiddleware.

public class CloudFlareConnectingIpMiddleware
{
    private readonly RequestDelegate _next;
    private readonly ForwardedHeadersMiddleware _forwardedHeadersMiddleware;

    public CloudFlareConnectingIpMiddleware(RequestDelegate next, ILoggerFactory loggerFactory)
    {
        _next = next ?? throw new ArgumentNullException(nameof(next));

        _forwardedHeadersMiddleware = new ForwardedHeadersMiddleware(next, loggerFactory,
            Options.Create(new ForwardedHeadersOptions
        {
            ForwardedForHeaderName = "CF_CONNECTING_IP",
            ForwardedHeaders = ForwardedHeaders.XForwardedFor
        }));
    }

    public Task Invoke(HttpContext context)
    {
        return _forwardedHeadersMiddleware.Invoke(context);
    }
}

To validate the request originating IP address the list provided by CloudFlare must be parsed to some usable form. There are existing parsers available, for example IPAddressRange.

public class CloudFlareConnectingIpMiddleware
{
    private static readonly IPAddressRange[] _cloudFlareIpAddressRanges = new IPAddressRange[]
    {
        IPAddressRange.Parse("103.21.244.0/22"),
        ...
        IPAddressRange.Parse("2a06:98c0::/29")
    };

    ...

    private bool IsCloudFlareIp(IPAddress ipadress)
    {
        bool isCloudFlareIp = false;

        for (int i = 0; i < _cloudFlareIpAddressRanges.Length; i++)
        {
            isCloudFlareIp = _cloudFlareIpAddressRanges[i].Contains(ipadress);
            if (isCloudFlareIp)
            {
                break;
            }
        }

        return isCloudFlareIp;
    }
}

With the ability to check if request is incoming from CloudFlare, the ForwardedHeadersMiddleware can be called only when needed.

public class CloudFlareConnectingIpMiddleware
{
    ...

    public Task Invoke(HttpContext context)
    {
        if (context.Request.Headers.ContainsKey("CF_CONNECTING_IP")
            && IsCloudFlareIp(context.Connection.RemoteIpAddress))
        {
            return _forwardedHeadersMiddleware.Invoke(context);
        }

        return _next(context);
    }

    ...
}

Complete code can be found here.

There is one drawback to this approach. What if IP ranges change? This can be handled as well. CloudFlare provides two endpoints which return text lists - one for IPv4 and one for IPv6. Those two endpoints can be used by IHostedService based background task to update the ranges on startup or periodically. I'll leave this as additional exercise.

I have a demo project on GitHub which accompanies my blog series about Web Push based push notifications in ASP.NET Core. There is one thing in that project which I wanted to "fix" for some time. That thing is requesting delivery of notifications, which is being done inside an action.

public class PushNotificationsApiController : Controller
{
    ...

    [HttpPost("notifications")]
    public async Task<IActionResult> SendNotification([FromBody]PushMessageViewModel message)
    {
        PushMessage pushMessage = new PushMessage(message.Notification)
        {
            Topic = message.Topic,
            Urgency = message.Urgency
        };

        await _subscriptionStore.ForEachSubscriptionAsync((PushSubscription subscription) =>
        {
            _notificationService.SendNotificationAsync(subscription, pushMessage);
        });

        return NoContent();
    }
}

If you have read post about requesting delivery you know it's an expensive operation. Taking into consideration possible high number of subscription this is something which shouldn't be done in context of request. It would be much better to queue it in the background, independent of any request. Back in ASP.NET days this could be done with QueueBackgroundWorkItem method, but it's not available in ASP.NET Core (at least not yet). However, there is a prototype implementation based on IHostedService which can be used as it is or adjusted to specific case. I've decided to go the second path. First step on that path is the queue itself.

Creating the queue

The queue interface should be simple. Only two operations are needed: enqueue and dequeue. The dequeue should be returning Task so the dequeuer can wait for new items. It also should accept a CancellationToken so the dequeuer can be stopped while it's waiting on dequeue.

internal interface IPushNotificationsQueue
{
    void Enqueue(PushMessage message);

    Task<PushMessage> DequeueAsync(CancellationToken cancellationToken);
}

The implementation is based on ConcurrentQueue and SemaphoreSlim. That SemaphoreSlim is where the magic happens. The DequeueAsync should be waiting on that semaphore. When a new message is enqueued the semaphore should be released, which allow the DequeueAsync to continue. If the semaphore will be raised more than once, the next call to DequeueAsync will not wait, just decrement the internal count of the semaphore until it's back at 0 again.

internal class PushNotificationsQueue : IPushNotificationsQueue
{
    private readonly ConcurrentQueue<PushMessage> _messages = new ConcurrentQueue<PushMessage>();
    private readonly SemaphoreSlim _messageEnqueuedSignal = new SemaphoreSlim(0);

    public void Enqueue(PushMessage message)
    {
        if (message == null)
        {
            throw new ArgumentNullException(nameof(message));
        }

        _messages.Enqueue(message);

        _messageEnqueuedSignal.Release();
    }

    public async Task<PushMessage> DequeueAsync(CancellationToken cancellationToken)
    {
        await _messageEnqueuedSignal.WaitAsync(cancellationToken);

        _messages.TryDequeue(out PushMessage message);

        return message;
    }
}

Having the queue, next step is implementing the dequeuer.

Implementing the dequeuer

The dequeuer is an implementation of IHostedService. In general it should be waiting on DequeueAsync and perform the same logic as the action does. But there are two important differences from the code in action here. A services scope needs to be created. The reason is IPushSubscriptionStore. By itself it's transient, so it wouldn't cause any issues, but its Sqlite implementation depends on DbContext which is scoped. Furthermore, the whole processing must support cancellation in order for the host to be able to shutdown graceful.

internal class PushNotificationsDequeuer : IHostedService
{
    private readonly IServiceProvider _serviceProvider;
    private readonly IPushNotificationsQueue _messagesQueue;
    private readonly IPushNotificationService _notificationService;
    private readonly CancellationTokenSource _stopTokenSource = new CancellationTokenSource();

    private Task _dequeueMessagesTask;

    public PushNotificationsDequeuer(IServiceProvider serviceProvider,
        IPushNotificationsQueue messagesQueue, IPushNotificationService notificationService)
    {
        _serviceProvider = serviceProvider;
        _messagesQueue = messagesQueue;
        _notificationService = notificationService;
    }

    public Task StartAsync(CancellationToken cancellationToken)
    {
        _dequeueMessagesTask = Task.Run(DequeueMessagesAsync);

        return Task.CompletedTask;
    }

    public Task StopAsync(CancellationToken cancellationToken)
    {
        _stopTokenSource.Cancel();

        return Task.WhenAny(_dequeueMessagesTask, Task.Delay(Timeout.Infinite, cancellationToken));
    }

    private async Task DequeueMessagesAsync()
    {
        while (!_stopTokenSource.IsCancellationRequested)
        {
            PushMessage message = await _messagesQueue.DequeueAsync(_stopTokenSource.Token);

            if (!_stopTokenSource.IsCancellationRequested)
            {
                using (IServiceScope serviceScope = _serviceProvider.CreateScope())
                {
                    IPushSubscriptionStore subscriptionStore =
                        serviceScope.ServiceProvider.GetRequiredService<IPushSubscriptionStore>();

                    await subscriptionStore.ForEachSubscriptionAsync(
                        (PushSubscription subscription) =>
                        {
                            _notificationService.SendNotificationAsync(subscription, message,
                                _stopTokenSource.Token);
                        },
                        _stopTokenSource.Token
                    );
                }

            }
        }

    }
}

Now the queue and dequeuer just need to be registered (both as singletons).

public static class ServiceCollectionExtensions
{
    ...

    public static IServiceCollection AddPushNotificationsQueue(this IServiceCollection services)
    {
        services.AddSingleton<IPushNotificationsQueue, PushNotificationsQueue>();
        services.AddSingleton&ktlIHostedService, PushNotificationsDequeuer>();

        return services;
    }
}

Queueing requesting delivery

With queue and dequeuer available the action can be changed to pass the message to the background.

public class PushNotificationsApiController : Controller
{
    ...

    [HttpPost("notifications")]
    public IActionResult SendNotification([FromBody]PushMessageViewModel message)
    {
        _pushNotificationsQueue.Enqueue(new PushMessage(message.Notification)
        {
            Topic = message.Topic,
            Urgency = message.Urgency
        });

        return NoContent();
    }
}

It is important to note, that the dequeuer is sequential. If one would want to parallelize there are two ways. One way is to use the dequeuer implementation as a base and register multiple delivered dequeuers. The other way is to introduce parallelization inside the dequeuer. In this approach a single instance would manage multiple reading threads. It's also easy to achieve, just a proper synchronization inside StopAsync method is needed. I prefer the second approach as the first is rather ugly.

ASP.NET Core comes with out-of-the-box support for server side response caching. It's easy to use and (when configured properly) can give you nice performance boost. But it also has some shortcomings.

Under the hood it utilizes in-memory caching which means that the cache has low latency at price of increased memory usage. In high load scenarios this can lead to memory pressure and memory pressure can lead to entries being evicted prior to its expiration. This also means that cache is not durable. If the process goes down for any reason the cache needs to be repopulated. Last but not least, it provides no support for load balancing scenarios - every node has to keep its own full cache.

Load balancing with in-memory response cache

None of those limitations may be a problem for you, but if it is, you might want to trade some cache latency to solve them. One of approaches can be using a distributed cache like Redis. This way the nodes are no longer responsible for holding the cache, the memory usage is lower and when an instance recycles it doesn't have to warm up again.

Load balancing with Redis backed response cache

Implementing Redis backed IResponseCache

The heart of server side response caching in ASP.NET Core is ResponseCachingMiddleware. It orchestrates the entire process and makes other components like IResponseCachingPolicyProvider, IResponseCachingKeyProvider and IResponseCache talk to each other. The component, which needs to be implemented in order to switch caching from in-memory to Redis, is IResponseCache as it represent the storage for entries. It needs to be able to set or get IResponseCacheEntry by a string key. The IResponseCacheEntry doesn't make any assumptions about the shape of an entry (it's an empty interface) so the only thing that can be done with an instance of it is to blindly attempt binary serialization. That might not be a good idea, so it might be better to focus on its implementations: CachedResponse and CachedVaryByRules. They can be stored in Redis by using Hashes. I'm going to focus only on CachedResponse as CachedVaryByRules is simpler and can be done by replicating same approach.

An instance of CachedResponse can't be represented by single hash because it contains headers collection. What can be done is represent it by two separated hashes, which share a key pattern. First, some helper methods which will take care of conversion (I will be using StackExchange.Redis).

internal class RedisResponseCache : IResponseCache
{
    ...

    private HashEntry[] CachedResponseToHashEntryArray(CachedResponse cachedResponse)
    {
        MemoryStream bodyStream = new MemoryStream();
        cachedResponse.Body.CopyTo(bodyStream);

        return new HashEntry[]
        {
            new HashEntry("Type", nameof(CachedResponse)),
            new HashEntry(nameof(cachedResponse.Created), cachedResponse.Created.ToUnixTimeMilliseconds()),
            new HashEntry(nameof(cachedResponse.StatusCode), cachedResponse.StatusCode),
            new HashEntry(nameof(cachedResponse.Body), bodyStream.ToArray())
        };
    }

    private HashEntry[] HeaderDictionaryToHashEntryArray(IHeaderDictionary headerDictionary)
    {
        HashEntry[] headersHashEntries = new HashEntry[headerDictionary.Count];

        int headersHashEntriesIndex = 0;
        foreach (KeyValuePair<string, StringValues> header in headerDictionary)
        {
            headersHashEntries[headersHashEntriesIndex++] = new HashEntry(header.Key, (string)header.Value);
        }

        return headersHashEntries;
    }
}

With the conversion in place the entry in cache can be set (I will show only async version). It is important to set expiration for both hashes.

internal class RedisResponseCache : IResponseCache
{
    private ConnectionMultiplexer _redis;

    public RedisResponseCache(string redisConnectionMultiplexerConfiguration)
    {
        if (String.IsNullOrWhiteSpace(redisConnectionMultiplexerConfiguration))
        {
            throw new ArgumentNullException(nameof(redisConnectionMultiplexerConfiguration));
        }

        _redis = ConnectionMultiplexer.Connect(redisConnectionMultiplexerConfiguration);
    }

    ...

    public async Task SetAsync(string key, IResponseCacheEntry entry, TimeSpan validFor)
    {
        if (entry is CachedResponse cachedResponse)
        {
            string headersKey = key + "_Headers";

            IDatabase redisDatabase = _redis.GetDatabase();

            await redisDatabase.HashSetAsync(key, CachedResponseToHashEntryArray(cachedResponse));
            await redisDatabase.HashSetAsync(headersKey, HeaderDictionaryToHashEntryArray(cachedResponse.Headers));

            await redisDatabase.KeyExpireAsync(headersKey, validFor);
            await redisDatabase.KeyExpireAsync(key, validFor);
        }
        else if (entry is CachedVaryByRules cachedVaryByRules)
        {
            ...
        }
    }

    ...
}

Getting entry from cache is similar. An opposite conversion methods are needed.

internal class RedisResponseCache : IResponseCache
{
    ...

    private CachedResponse CachedResponseFromHashEntryArray(HashEntry[] hashEntries)
    {
        CachedResponse cachedResponse = new CachedResponse();

        foreach (HashEntry hashEntry in hashEntries)
        {
            switch (hashEntry.Name)
            {
                case nameof(cachedResponse.Created):
                    cachedResponse.Created = DateTimeOffset.FromUnixTimeMilliseconds((long)hashEntry.Value);
                    break;
                case nameof(cachedResponse.StatusCode):
                    cachedResponse.StatusCode = (int)hashEntry.Value;
                    break;
                case nameof(cachedResponse.Body):
                    cachedResponse.Body = new MemoryStream(hashEntry.Value);
                    break;
            }
        }

        return cachedResponse;
    }

    private IHeaderDictionary HeaderDictionaryFromHashEntryArray(HashEntry[] headersHashEntries)
    {
        IHeaderDictionary headerDictionary = new HeaderDictionary();

        foreach (HashEntry headersHashEntry in headersHashEntries)
        {
            headerDictionary.Add(headersHashEntry.Name, (string)headersHashEntry.Value);
        }

        return headerDictionary;
    }
}

So the hashes can be retrieved and entry recreated (only async version again).

internal class RedisResponseCache : IResponseCache
{
    ...

    public async Task<IResponseCacheEntry> GetAsync(string key)
    {
        IResponseCacheEntry responseCacheEntry = null;

        IDatabase redisDatabase = _redis.GetDatabase();

        HashEntry[] hashEntries = await redisDatabase.HashGetAllAsync(key);

        string type = hashEntries.First(e => e.Name == "Type").Value;
        if (type == nameof(CachedResponse))
        {
            HashEntry[] headersHashEntries = await redisDatabase.HashGetAllAsync(key + "_Headers");

            if ((headersHashEntries != null) && (headersHashEntries.Length > 0)
                && (hashEntries != null) && (hashEntries.Length > 0))
            {
                CachedResponse cachedResponse = CachedResponseFromHashEntryArray(hashEntries);
                cachedResponse.Headers = HeaderDictionaryFromHashEntryArray(headersHashEntries);

                responseCacheEntry = cachedResponse;
            }
        }
        else if (type == nameof(CachedVaryByRules))
        {
            ...
        }

        return responseCacheEntry;
    }

    ...
}

At this point adding sync versions and code for CachedVaryByRules shouldn't be hard.

Having the implementation is step one, step two is using it.

Using custom IResponseCache

Back in ASP.NET Core 1.1 ResponseCachingMiddleware had a constructor which allowed for providing your own implementations of IResponseCache. This constructor is gone in 2.0 in order to guarantee a limit on memory usage by making ResponseCachingMiddleware use its own private instance of cache. The IResponseCache implementation can still be replaced by (please don't throw rocks at me) using reflection. Yes, reflection is not a perfect solution. It results in less readable and harder to maintain code. But here a very little of it is needed, just enough to gain access to a single field _cache. A custom middleware can be delivered from ResponseCachingMiddleware and expose this field through property.

internal class RedisResponseCachingMiddleware : ResponseCachingMiddleware
{
    private RedisResponseCache Cache
    {
        set
        {
            FieldInfo cacheFieldInfo = typeof(ResponseCachingMiddleware)
                .GetField("_cache", BindingFlags.NonPublic | BindingFlags.Instance);

            cacheFieldInfo.SetValue(this, value);
        }
    }

    public RedisResponseCachingMiddleware(RequestDelegate next, IOptions<RedisResponseCachingOptions> options,
        ILoggerFactory loggerFactory, IResponseCachingPolicyProvider policyProvider,
        IResponseCachingKeyProvider keyProvider)
        : base(next, options, loggerFactory, policyProvider, keyProvider)
    {
        Cache = new RedisResponseCache(options.Value.RedisConnectionMultiplexerConfiguration);
    }
}

The RedisResponseCachingOptions extends ResponseCachingOptions by adding option needed to establish connection to Redis. Now this all can be put together by providing the options and registering custom middleware instead of calling UseResponseCaching method.

public class Startup
{
    public void ConfigureServices(IServiceCollection services)
    {
        services.AddResponseCaching();
        services.Configure((options) =>
        {
            ...
            options.RedisConnectionMultiplexerConfiguration = "localhost";
        });
    }

    public void Configure(IApplicationBuilder app)
    {
        ...

        app.UseMiddleware<RedisResponseCachingMiddleware>();

        ...
    }
}

This is a fully working solution. Redis is used just as an example, you can use a distribute cache of your choosing or even a database if you wish (although I wouldn't recommend that). It's all about implementing IResponseCache and following the pattern for replacing it.

Some time ago I've learned about new security header named Clear-Site-Data, but only recently I've had a chance to try it in action. The goal of the header is to provide a mechanism which allows developers to instruct a browser to clear a site’s data. This can be useful for example upon sign out, to ensure that locally stored data is removed. I wanted to explore that scenario in context of ASP.NET Core.

Before going further it's worth to mention that Clear-Site-Data is not supported yet by all browsers.

The Clear-Site-Data header

The header has a simple structure. Its value should be a comma separated list containing types of data to clear. The specification defines four of them:

  • "cache" which indicates that the server wishes to remove locally cached data
  • "cookies" which indicates that the server wishes to remove cookies
  • "storage" which indicates that the server wishes to remove data from local storages (localStorage, sessionStorage, IndexedDB, etc)
  • "executionContexts" which indicates that the server wishes to neuter and reload execution contexts

Additionally there is a wildcard pseudotype which indicates that the server wishes to remove all data types.

Based on the specification a simple class for generating the header value can be created. There is nothing particularly interesting about this class, so I will not show its implementation (for interested it can be found here). What I will show are helpers, which can be used to set the value on the response.

public static class HttpResponseHeadersExtensions
{
    public static void SetClearSiteData(this HttpResponse response, ClearSiteDataHeaderValue clearSiteData)
    {
        response.SetResponseHeader("Clear-Site-Data", clearSiteData?.ToString());
    }

    public static void SetWildcardClearSiteData(this HttpResponse response)
    {
        response.SetResponseHeader("Clear-Site-Data", "\"*\"");
    }

    internal static void SetResponseHeader(this HttpResponse response, string headerName, string headerValue)
    {
        if (!String.IsNullOrWhiteSpace(headerValue))
        {
            if (response.Headers.ContainsKey(headerName))
            {
                response.Headers[headerName] = headerValue;
            }
            else
            {
                response.Headers.Append(headerName, headerValue);
            }
        }
    }
}

Attaching to sign out in ASP.NET Core

Generating the header value and setting it on response was the easy part. The hard part is plugin this nicely into sign out functionality provided by ASP.NET Core. It is good to have such things as transparent as possible, so they don't have to become the required knowledge for all developers on the team. Plugging at a low level also helps when the operation is triggered by the framework without direct call from the application code.

After reviewing ASP.NET Core code related to authentication, the common call point seems to be SignOutAsync extension method available on HttpContext. The method is very simple.

public static Task SignOutAsync(this HttpContext context, string scheme, AuthenticationProperties properties) =>
    context.RequestServices.GetRequiredService<IAuthenticationService>().SignOutAsync(context, scheme, properties);

This would suggest that heart of everything is IAuthenticationService implementation. Extending that implementation with the additional behavior is an appealing idea. One way to do it is direct inheritance from AuthenticationService available in Microsoft.AspNetCore.Authentication.Core (this is the default implementation), but it would require making the new class aware of other authentication core concepts (like IAuthenticationSchemeProvider, IAuthenticationRequestHandler and IClaimsTransformation) for the sake of passing them to the basic implementation. Alternative approach would be to write a wrapper which takes an instance of class implementing IAuthenticationService as a parameter, passes all calls to it and introduces the new concepts. The added value of this approach is less coupling with implementation details (it only relies on the interface) and in the result ability to work with custom implementations as well.

internal class ClearSiteDataAuthenticationService : IAuthenticationService
{
    private readonly IAuthenticationService _authenticationService;
    private readonly ClearSiteDataHeaderValue _clearSiteDataHeaderValue;

    public ClearSiteDataAuthenticationService(IAuthenticationService authenticationService,
        ClearSiteDataHeaderValue clearSiteDataHeaderValue)
    {
        _authenticationService = authenticationService
                                 ?? throw new ArgumentNullException(nameof(authenticationService));
        _clearSiteDataHeaderValue = clearSiteDataHeaderValue
                                 ?? throw new ArgumentNullException(nameof(clearSiteDataHeaderValue));
    }

    public Task<AuthenticateResult> AuthenticateAsync(HttpContext context, string scheme)
    {
        return _authenticationService.AuthenticateAsync(context, scheme);
    }

    ...

    public async Task SignOutAsync(HttpContext context, string scheme, AuthenticationProperties properties)
    {
        await _authenticationService.SignOutAsync(context, scheme, properties);

        if (context.User?.Identity?.IsAuthenticated ?? false)
        {
            context.Response.SetClearSiteData(_clearSiteDataHeaderValue);
        }
    }
}

The last challenge is replacing the IAuthenticationService registered in services collection with the wrapper. Thankfully the IServiceCollection exposes the underlying ServiceDescriptors collection. This allows for locating currently registered descriptor, registering its implementation type under new one and setting an implementation factory which will take care of creating the wrappers. Following code does exactly that.

public static class ClearSiteDataAuthenticationServiceCollectionExtensions
{
    public static IServiceCollection AddClearSiteDataAuthentication(this IServiceCollection services,
        ClearSiteDataHeaderValue clearSiteDataHeaderValue)
    {
        if (services == null)
        {
            throw new ArgumentNullException(nameof(services));
        }

        ServiceDescriptor authenticationServiceDescriptor = services
            .FirstOrDefault(d => d.ServiceType == typeof(IAuthenticationService));

        if (authenticationServiceDescriptor != null)
        {
            Type authenticationServiceImplementationType = authenticationServiceDescriptor.ImplementationType;
            ServiceLifetime authenticationServiceLifetime = authenticationServiceDescriptor.Lifetime;

            if (authenticationServiceImplementationType != null)
            {
                services.Remove(authenticationServiceDescriptor);

                services.Add(new ServiceDescriptor(
                    authenticationServiceImplementationType,
                    authenticationServiceImplementationType,
                    authenticationServiceLifetime)
                );

                services.Add(new ServiceDescriptor(
                    typeof(IAuthenticationService),
                    (IServiceProvider serviceProvider) =>
                    {
                        IAuthenticationService authenticationService = (IAuthenticationService)serviceProvider
                            .GetRequiredService(authenticationServiceImplementationType);

                        return new ClearSiteDataAuthenticationService(authenticationService, clearSiteDataHeaderValue);
                    },
                    authenticationServiceLifetime
                ));
            }
        }

        return services;
    }
}

Now all that needs to be done is calling AddClearSiteDataAuthentication after calling AddAuthentication.

All the code above can be found on GitHub.

I've written about WebSockets in ASP.NET Core several times before (touching subjects like subprotocol negotiation, Cross-Site WebSocket Hijacking and per-message compression), but I've never written a post about the fundamentals of managing WebSocket lifetime. In this post I intend to fix that.

The documentation does a pretty good job explaining the key aspects and showing the most basic scenario. Unfortunately real-life scenarios are usually not that simple. In most cases we want to have access to the WebSocket outside of middleware responsible for accepting the requests. This leads to handing the WebSocket off to a dedicated service/manager. While doing this, a couple of rules needs to be followed to ensure to correct lifetime handling.

WebSocket "handing off" middleware flow

The most important thing to remember when writing a middleware for accepting WebSocket requests is that the middleware should not terminate before WebSocket is closed. Allowing the middleware to terminate results in request thread being completed and underlying connection closed.

It's often also worth to design the middleware to be the last in the pipeline. This comes from the fact that middleware is bind to a specific ws:// or wss:// URI, so unless the goal is to handle different protocols under same URI anything else than WebSocket request is an error.

The general flow of middleware which accepts WebSocket request and then hands it off is shown on below diagram.

WebSocket 'handing off' middleware flow

There is a number of ways to implement this flow, so next section should be considered my opinionated approach (here you can read about different one which leverages inheritance).

Opinionated implementation

I will first show the code of the middleware and then discuss its details.

public class WebSocketConnectionsMiddleware
{
    private IWebSocketConnectionsService _connectionsService;

    public WebSocketConnectionsMiddleware(RequestDelegate next,
        IWebSocketConnectionsService connectionsService)
    {
        _connectionsService = connectionsService
                              ?? throw new ArgumentNullException(nameof(connectionsService));
    }

    public async Task Invoke(HttpContext context)
    {
        if (context.WebSockets.IsWebSocketRequest)
        {
                WebSocket webSocket = await context.WebSockets.AcceptWebSocketAsync();

                WebSocketConnection webSocketConnection = new WebSocketConnection(webSocket);

                _connectionsService.AddConnection(webSocketConnection);

                await webSocketConnection.ReceiveMessagesUntilCloseAsync();

                await webSocket.CloseAsync(webSocketConnection.CloseStatus.Value,
                    webSocketConnection.CloseStatusDescription, CancellationToken.None);

                _connectionsService.RemoveConnection(webSocketConnection.Id);
        }
        else
        {
            context.Response.StatusCode = StatusCodes.Status400BadRequest;
        }
    }
}

The code can be easily mapped to the diagram, the important parts are IWebSocketConnectionsService and WebSocketConnection.

There is nothing very specific about IWebSocketConnectionsService implementation. It's supposed to store the connections (in case of my demo project it uses ConcurrentDictionary for that purposes) and expose high level operations which are needed by other parts of application.

The WebSocketConnection, which servers as WebSocket abstraction, is worth taking a closer look at. It's public API needs to bridge two words. It must allow for sending and receiving messages to be usable by other parts of applications. It also needs to provide the receiving loop for the middleware to wait on, together with information required to complete the close handshake.

public class WebSocketConnection
{
    private WebSocket _webSocket;

    public Guid Id => Guid.NewGuid();

    public WebSocketCloseStatus? CloseStatus { get; private set; } = null;

    public string CloseStatusDescription { get; private set; } = null;

    public event EventHandler<string> ReceiveText;

    public event EventHandler<byte[]> ReceiveBinary;

    public WebSocketConnection(WebSocket webSocket)
    {
        _webSocket = webSocket ?? throw new ArgumentNullException(nameof(webSocket));
    }

    public Task SendAsync(string message, CancellationToken cancellationToken)
    {
        ...
    }

    public Task SendAsync(byte[] message, CancellationToken cancellationToken)
    {
        ...
    }

    public async Task ReceiveMessagesUntilCloseAsync()
    {
        ...
    }
}

I will skip details of SendAsync methods (typically they check the WebSocket status and call SendAsync on it). What I will focus on is the receiving loop hiding under ReceiveMessagesUntilCloseAsync. It has two tasks. First is to wait for Close message and when it arrives return control back to the waiting middleware (the middleware will need CloseStatus and CloseStatusDescription from that message to complete the handshake). Second task is triggering ReceiveText or ReceiveBinary event when either Text or Binary message arrives.

public class WebSocketConnection
{
    ...
    private int _receivePayloadBufferSize = 4 * 1024;

    ...

    public async Task ReceiveMessagesUntilCloseAsync()
    {
        byte[] receivePayloadBuffer = new byte[_receivePayloadBufferSize];
        WebSocketReceiveResult webSocketReceiveResult =
            await _webSocket.ReceiveAsync(new ArraySegment<byte>(receivePayloadBuffer),
                CancellationToken.None);

        while (webSocketReceiveResult.MessageType != WebSocketMessageType.Close)
        {
            if (webSocketReceiveResult.MessageType == WebSocketMessageType.Binary)
            {
                byte[] webSocketMessage = await ReceiveMessagePayloadAsync(webSocketReceiveResult,
                    receivePayloadBuffer);
                ReceiveBinary?.Invoke(this, webSocketMessage);
            }
            else
            {
                byte[] webSocketMessage = await ReceiveMessagePayloadAsync(webSocketReceiveResult,
                    receivePayloadBuffer);
                ReceiveText?.Invoke(this, Encoding.UTF8.GetString(webSocketMessage));
            }

            webSocketReceiveResult =
                await _webSocket.ReceiveAsync(new ArraySegment<byte>(receivePayloadBuffer),
                    CancellationToken.None);
        }

        CloseStatus = webSocketReceiveResult.CloseStatus.Value;
        CloseStatusDescription = webSocketReceiveResult.CloseStatusDescription;
    }

    private static async Task<byte[]> ReceiveMessagePayloadAsync(
        WebSocketReceiveResult webSocketReceiveResult, byte[] receivePayloadBuffer)
    {
        byte[] messagePayload = null;

        if (webSocketReceiveResult.EndOfMessage)
        {
            messagePayload = new byte[webSocketReceiveResult.Count];
            Array.Copy(receivePayloadBuffer, messagePayload, webSocketReceiveResult.Count);
        }
        else
        {
            using (MemoryStream messagePayloadStream = new MemoryStream())
            {
                messagePayloadStream.Write(receivePayloadBuffer, 0, webSocketReceiveResult.Count);
                while (!webSocketReceiveResult.EndOfMessage)
                {
                    webSocketReceiveResult =
                        await _webSocket.ReceiveAsync(new ArraySegment<byte>(receivePayloadBuffer),
                            CancellationToken.None);
                    messagePayloadStream.Write(receivePayloadBuffer, 0, webSocketReceiveResult.Count);
                }

                messagePayload = messagePayloadStream.ToArray();
            }
        }

        return messagePayload;
    }
}

This will work nicely, until something unexpected happens...

Prematurely closed connection

There is one unexpected situation which should always be expected - client closing connection prematurely. Most of the time this means that client has crashed (the easiest way to simulate this is killing the browser process when it's connected to the application). The connection being closed prematurely manifests itself through WebSocketException (with WebSocketError.ConnectionClosedPrematurely as WebSocketErrorCode value) thrown from the receiving loop. The WebSocketConnection should handle this exception nicely so the middleware can remove the connection from manager and terminate.

public class WebSocketConnection
{
    ...

    public async Task ReceiveMessagesUntilCloseAsync()
    {
        try
        {
            ...
        }
        catch (WebSocketException wsex)
            when (wsex.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely)
        {
            // Perform some logging
            ...
        }
    }

    ...
}

A small change to the middleware is needed as in this case the close handshake shouldn't be completed (it hasn't been started).

public class WebSocketConnectionsMiddleware
{
    ...

    public async Task Invoke(HttpContext context)
    {
        if (context.WebSockets.IsWebSocketRequest)
        {
                ...

                if (webSocketConnection.CloseStatus.HasValue)
                {
                    await webSocket.CloseAsync(webSocketConnection.CloseStatus.Value,
                        webSocketConnection.CloseStatusDescription, CancellationToken.None);
                }

                _connectionsService.RemoveConnection(webSocketConnection.Id);
        }
        else
        {
            context.Response.StatusCode = StatusCodes.Status400BadRequest;
        }
    }
}

This way the prematurely closed connections will be handled correctly.

The demo project available on GitHub contains a modified version of this code as it also shows all the features described in my previous posts.

Older Posts