-
-
Notifications
You must be signed in to change notification settings - Fork 37
Expand file tree
/
Copy pathRedisLockExtension.cs
More file actions
117 lines (99 loc) · 3.54 KB
/
RedisLockExtension.cs
File metadata and controls
117 lines (99 loc) · 3.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using StackExchange.Redis;
namespace CacheTower.Extensions.Redis
{
/// <summary>
/// Based on Loris Cro's "RedisMemoLock"
/// https://github.com/kristoff-it/redis-memolock/blob/77da8f82711309b9dd81eafd02cb7ccfb22344c7/csharp/redis-memolock/RedisMemoLock.cs
/// </summary>
public class RedisLockExtension : IRefreshWrapperExtension
{
private ISubscriber Subscriber { get; }
private IDatabaseAsync Database { get; }
private string RedisChannel { get; }
private TimeSpan LockTimeout { get; } = TimeSpan.FromMinutes(1);
private ICacheStack RegisteredStack { get; set; }
private ConcurrentDictionary<string, IEnumerable<TaskCompletionSource<bool>>> LockedOnKeyRefresh { get; }
public RedisLockExtension(ConnectionMultiplexer connection, int databaseIndex = -1, string channelPrefix = "CacheTower", TimeSpan? lockTimeout = default)
{
if (connection == null)
{
throw new ArgumentNullException(nameof(connection));
}
if (channelPrefix == null)
{
throw new ArgumentNullException(nameof(channelPrefix));
}
Database = connection.GetDatabase(databaseIndex);
Subscriber = connection.GetSubscriber();
RedisChannel = $"{channelPrefix}.CacheLock";
if (lockTimeout.HasValue)
{
LockTimeout = lockTimeout.Value;
}
LockedOnKeyRefresh = new ConcurrentDictionary<string, IEnumerable<TaskCompletionSource<bool>>>(StringComparer.Ordinal);
Subscriber.Subscribe(RedisChannel, (channel, value) => UnlockWaitingTasks(value));
}
public void Register(ICacheStack cacheStack)
{
if (RegisteredStack != null)
{
throw new InvalidOperationException($"{nameof(RedisLockExtension)} can only be registered to one {nameof(ICacheStack)}");
}
RegisteredStack = cacheStack;
}
public async ValueTask<CacheEntry<T>> RefreshValueAsync<T>(string cacheKey, Func<ValueTask<CacheEntry<T>>> valueProvider, CacheEntryLifetime settings)
{
var hasLock = await Database.StringSetAsync(cacheKey, RedisValue.EmptyString, expiry: LockTimeout, when: When.NotExists);
if (hasLock)
{
try
{
var cacheEntry = await valueProvider();
await Subscriber.PublishAsync(RedisChannel, cacheKey, CommandFlags.FireAndForget);
return cacheEntry;
}
finally
{
await Database.KeyDeleteAsync(cacheKey, CommandFlags.FireAndForget);
}
}
else
{
return await WaitForResult<T>(cacheKey, settings);
}
}
private async Task<CacheEntry<T>> WaitForResult<T>(string cacheKey, CacheEntryLifetime settings)
{
var delayedResultSource = new TaskCompletionSource<bool>();
var waitList = new[] { delayedResultSource };
LockedOnKeyRefresh.AddOrUpdate(cacheKey, waitList, (key, oldList) => oldList.Concat(waitList));
//Last minute check to confirm whether waiting is required (in case the notification is missed)
var currentEntry = await RegisteredStack.GetAsync<T>(cacheKey);
if (currentEntry != null && currentEntry.GetStaleDate(settings) > DateTime.UtcNow)
{
UnlockWaitingTasks(cacheKey);
return currentEntry;
}
//Lock until we are notified to be unlocked
await delayedResultSource.Task;
//Get the updated value from the cache stack
return await RegisteredStack.GetAsync<T>(cacheKey);
}
private void UnlockWaitingTasks(string cacheKey)
{
if (LockedOnKeyRefresh.TryRemove(cacheKey, out var waitingTasks))
{
foreach (var task in waitingTasks)
{
task.TrySetResult(true);
}
}
}
}
}