Skip to content

Commit 3f01ece

Browse files
committed
Fix Orleans identity authorization review findings
1 parent 2d8e093 commit 3f01ece

8 files changed

Lines changed: 289 additions & 99 deletions

File tree

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
using System.Threading.Tasks;
22
using Microsoft.AspNetCore.Mvc.Filters;
3-
using Orleans.Runtime;
4-
using ManagedCode.Orleans.Identity.Core.Constants;
53

64
namespace ManagedCode.Orleans.Identity.Client.Filters;
75

86
public sealed class OrleansAuthorizationActionFilter : IAsyncActionFilter
97
{
10-
public Task OnActionExecutionAsync(ActionExecutingContext context, ActionExecutionDelegate next)
8+
public async Task OnActionExecutionAsync(ActionExecutingContext context, ActionExecutionDelegate next)
119
{
12-
RequestContext.Set(OrleansIdentityConstants.USER_CLAIMS, context.HttpContext.User);
13-
return next();
10+
using var requestContextScope = new OrleansRequestContextScope(context.HttpContext.User);
11+
await next();
1412
}
15-
}
13+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
using System;
2+
using System.Security.Claims;
3+
using ManagedCode.Orleans.Identity.Core.Constants;
4+
using Orleans.Runtime;
5+
6+
namespace ManagedCode.Orleans.Identity.Client.Filters;
7+
8+
internal readonly struct OrleansRequestContextScope : IDisposable
9+
{
10+
private readonly object? previousUser;
11+
private readonly bool hasPreviousUser;
12+
13+
public OrleansRequestContextScope(ClaimsPrincipal? user)
14+
{
15+
previousUser = RequestContext.Get(OrleansIdentityConstants.USER_CLAIMS);
16+
hasPreviousUser = previousUser is not null;
17+
18+
SetUser(user);
19+
}
20+
21+
public void Dispose()
22+
{
23+
if (hasPreviousUser)
24+
{
25+
RequestContext.Set(OrleansIdentityConstants.USER_CLAIMS, previousUser!);
26+
return;
27+
}
28+
29+
RequestContext.Remove(OrleansIdentityConstants.USER_CLAIMS);
30+
}
31+
32+
private static void SetUser(ClaimsPrincipal? user)
33+
{
34+
if (user is null)
35+
{
36+
RequestContext.Remove(OrleansIdentityConstants.USER_CLAIMS);
37+
return;
38+
}
39+
40+
RequestContext.Set(OrleansIdentityConstants.USER_CLAIMS, user);
41+
}
42+
}

ManagedCode.Orleans.Identity.Client/Filters/SignalRAuthorizationFilter.cs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
using System;
2-
using System.Security.Claims;
32
using System.Threading.Tasks;
4-
using ManagedCode.Orleans.Identity.Core.Constants;
53
using Microsoft.AspNetCore.SignalR;
6-
using Orleans.Runtime;
74

85
namespace ManagedCode.Orleans.Identity.Client.Filters;
96

@@ -13,8 +10,15 @@ public sealed class SignalRAuthorizationFilter : IHubFilter
1310
HubInvocationContext invocationContext,
1411
Func<HubInvocationContext, ValueTask<object?>> next)
1512
{
16-
RequestContext.Set(OrleansIdentityConstants.USER_CLAIMS, invocationContext.Context.User!); // can be null
17-
return next(invocationContext);
13+
return InvokeMethodWithRequestContextAsync(invocationContext, next);
14+
}
15+
16+
private static async ValueTask<object?> InvokeMethodWithRequestContextAsync(
17+
HubInvocationContext invocationContext,
18+
Func<HubInvocationContext, ValueTask<object?>> next)
19+
{
20+
using var requestContextScope = new OrleansRequestContextScope(invocationContext.Context.User);
21+
return await next(invocationContext);
1822
}
1923

2024
public Task OnConnectedAsync(HubLifetimeContext context, Func<HubLifetimeContext, Task> next)
@@ -27,4 +31,4 @@ public Task OnDisconnectedAsync(HubLifetimeContext context, Exception? exception
2731
{
2832
return next(context, exception);
2933
}
30-
}
34+
}

ManagedCode.Orleans.Identity.Server/GrainCallFilter/GrainAuthorizationIncomingFilter.cs

Lines changed: 64 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,54 +4,36 @@
44
using System.Reflection;
55
using System.Security.Claims;
66
using System.Threading.Tasks;
7+
using ManagedCode.Orleans.Identity.Core.Constants;
78
using Microsoft.AspNetCore.Authorization;
8-
using Microsoft.Extensions.Logging;
9-
using Orleans;
109
using Orleans.Runtime;
11-
using ManagedCode.Orleans.Identity.Core.Constants;
1210

1311
namespace ManagedCode.Orleans.Identity.Server.GrainCallFilter;
1412

1513
public class GrainAuthorizationIncomingFilter : IIncomingGrainCallFilter
1614
{
15+
private const int EmptyAttributeCount = 0;
16+
private const char RoleSeparator = ',';
17+
private const string AccessDeniedNotAuthenticated = "Access denied. User is not authenticated.";
18+
private const string AccessDeniedMissingRoles = "Access denied. User does not have required roles.";
19+
1720
public async Task Invoke(IIncomingGrainCallContext context)
1821
{
19-
// Check both interface method and implementation method
20-
if (IsGrainAuthorized(context.ImplementationMethod, out var attributes))
22+
if (IsGrainAuthorized(context, out var attributes))
2123
{
2224
var user = GetUserFromRequestContext();
23-
25+
2426
if (user == null || user.Identity?.IsAuthenticated != true)
2527
{
26-
throw new UnauthorizedAccessException("Access denied. User is not authenticated.");
28+
throw new UnauthorizedAccessException(AccessDeniedNotAuthenticated);
2729
}
2830

29-
// Check if any attribute requires specific roles
30-
var rolesRequired = attributes.Any(attr => !string.IsNullOrWhiteSpace(attr.Roles));
31-
32-
if (rolesRequired)
31+
if (!HasRequiredRoles(attributes, user))
3332
{
34-
var userRoles = user.FindAll(ClaimTypes.Role).Select(c => c.Value).ToHashSet();
35-
36-
// Check if user has any of the required roles from any attribute
37-
var hasRequiredRole = attributes.Any(attribute =>
38-
{
39-
if (string.IsNullOrWhiteSpace(attribute.Roles))
40-
return true; // No specific role required by this attribute
41-
42-
var requiredRoles = attribute.Roles.Split(',', StringSplitOptions.RemoveEmptyEntries)
43-
.Select(r => r.Trim());
44-
45-
return requiredRoles.Any(role => userRoles.Contains(role));
46-
});
47-
48-
if (!hasRequiredRole)
49-
{
50-
throw new UnauthorizedAccessException("Access denied. User does not have required roles.");
51-
}
33+
throw new UnauthorizedAccessException(AccessDeniedMissingRoles);
5234
}
5335
}
54-
36+
5537
await context.Invoke();
5638
}
5739

@@ -61,28 +43,71 @@ public async Task Invoke(IIncomingGrainCallContext context)
6143
return requestContext as ClaimsPrincipal;
6244
}
6345

64-
private static bool IsGrainAuthorized(MemberInfo methodInfo, out List<AuthorizeAttribute> attributes)
46+
private static bool IsGrainAuthorized(IIncomingGrainCallContext context, out List<AuthorizeAttribute> attributes)
6547
{
6648
attributes = [];
49+
var members = GetAuthorizationMembers(context);
6750

68-
if (Attribute.IsDefined(methodInfo, typeof(AllowAnonymousAttribute)))
51+
if (members.Any(HasAllowAnonymousAttribute))
6952
{
7053
return false;
7154
}
7255

73-
if (methodInfo.DeclaringType != null && Attribute.IsDefined(methodInfo.DeclaringType, typeof(AuthorizeAttribute)))
56+
attributes.AddRange(members.SelectMany(GetAuthorizeAttributes));
57+
return attributes.Count != EmptyAttributeCount;
58+
}
59+
60+
private static IReadOnlyList<MemberInfo> GetAuthorizationMembers(IIncomingGrainCallContext context)
61+
{
62+
var members = new List<MemberInfo>();
63+
64+
if (context.InterfaceMethod.DeclaringType is { } interfaceType)
65+
{
66+
members.Add(interfaceType);
67+
}
68+
69+
if (context.ImplementationMethod.DeclaringType is { } implementationType)
7470
{
75-
attributes.AddRange(Attribute.GetCustomAttributes(methodInfo.DeclaringType, typeof(AuthorizeAttribute))
76-
.Cast<AuthorizeAttribute>());
71+
members.Add(implementationType);
7772
}
7873

79-
if (Attribute.IsDefined(methodInfo, typeof(AuthorizeAttribute)))
74+
members.Add(context.InterfaceMethod);
75+
members.Add(context.ImplementationMethod);
76+
77+
return members;
78+
}
79+
80+
private static bool HasAllowAnonymousAttribute(MemberInfo memberInfo)
81+
{
82+
return Attribute.IsDefined(memberInfo, typeof(AllowAnonymousAttribute), inherit: true);
83+
}
84+
85+
private static IEnumerable<AuthorizeAttribute> GetAuthorizeAttributes(MemberInfo memberInfo)
86+
{
87+
return Attribute
88+
.GetCustomAttributes(memberInfo, typeof(AuthorizeAttribute), inherit: true)
89+
.Cast<AuthorizeAttribute>();
90+
}
91+
92+
private static bool HasRequiredRoles(IEnumerable<AuthorizeAttribute> attributes, ClaimsPrincipal user)
93+
{
94+
return attributes
95+
.Where(attribute => !string.IsNullOrWhiteSpace(attribute.Roles))
96+
.All(attribute => HasAnyRequiredRole(attribute, user));
97+
}
98+
99+
private static bool HasAnyRequiredRole(AuthorizeAttribute attribute, ClaimsPrincipal user)
100+
{
101+
if (string.IsNullOrWhiteSpace(attribute.Roles))
80102
{
81-
attributes.AddRange(Attribute.GetCustomAttributes(methodInfo, typeof(AuthorizeAttribute))
82-
.Cast<AuthorizeAttribute>());
83103
return true;
84104
}
85105

86-
return attributes.Count != 0;
106+
var roles = attribute.Roles.Split(
107+
RoleSeparator,
108+
StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries
109+
);
110+
111+
return roles.Any(user.IsInRole);
87112
}
88113
}
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
using System.Security.Claims;
2+
using ManagedCode.Orleans.Identity.Client.Filters;
3+
using ManagedCode.Orleans.Identity.Core.Constants;
4+
using Microsoft.AspNetCore.Http;
5+
using Microsoft.AspNetCore.Mvc;
6+
using Microsoft.AspNetCore.Mvc.Abstractions;
7+
using Microsoft.AspNetCore.Mvc.Filters;
8+
using Microsoft.AspNetCore.Routing;
9+
using Orleans.Runtime;
10+
using Shouldly;
11+
using Xunit;
12+
13+
namespace ManagedCode.Orleans.Identity.Tests;
14+
15+
public class ClientFilterTests
16+
{
17+
private const string AuthenticationType = "Test";
18+
private const string CurrentUserName = "current-user";
19+
private const string PreviousUserName = "previous-user";
20+
21+
[Fact]
22+
public async Task OrleansAuthorizationActionFilter_RestoresPreviousRequestContext()
23+
{
24+
try
25+
{
26+
var previousUser = CreatePrincipal(PreviousUserName);
27+
var currentUser = CreatePrincipal(CurrentUserName);
28+
var context = CreateActionExecutingContext(currentUser);
29+
var filter = new OrleansAuthorizationActionFilter();
30+
31+
RequestContext.Set(OrleansIdentityConstants.USER_CLAIMS, previousUser);
32+
33+
await filter.OnActionExecutionAsync(
34+
context,
35+
() =>
36+
{
37+
RequestContext.Get(OrleansIdentityConstants.USER_CLAIMS).ShouldBeSameAs(currentUser);
38+
return Task.FromResult(CreateActionExecutedContext(context));
39+
}
40+
);
41+
42+
RequestContext.Get(OrleansIdentityConstants.USER_CLAIMS).ShouldBeSameAs(previousUser);
43+
}
44+
finally
45+
{
46+
RequestContext.Clear();
47+
}
48+
}
49+
50+
[Fact]
51+
public async Task OrleansAuthorizationActionFilter_ClearsRequestContextWhenNoPreviousValue()
52+
{
53+
try
54+
{
55+
var currentUser = CreatePrincipal(CurrentUserName);
56+
var context = CreateActionExecutingContext(currentUser);
57+
var filter = new OrleansAuthorizationActionFilter();
58+
59+
RequestContext.Remove(OrleansIdentityConstants.USER_CLAIMS);
60+
61+
await filter.OnActionExecutionAsync(
62+
context,
63+
() =>
64+
{
65+
RequestContext.Get(OrleansIdentityConstants.USER_CLAIMS).ShouldBeSameAs(currentUser);
66+
return Task.FromResult(CreateActionExecutedContext(context));
67+
}
68+
);
69+
70+
RequestContext.Get(OrleansIdentityConstants.USER_CLAIMS).ShouldBeNull();
71+
}
72+
finally
73+
{
74+
RequestContext.Clear();
75+
}
76+
}
77+
78+
private static ClaimsPrincipal CreatePrincipal(string userName)
79+
{
80+
return new ClaimsPrincipal(new ClaimsIdentity([new Claim(ClaimTypes.Name, userName)], AuthenticationType));
81+
}
82+
83+
private static ActionExecutingContext CreateActionExecutingContext(ClaimsPrincipal user)
84+
{
85+
var httpContext = new DefaultHttpContext { User = user };
86+
var actionContext = new ActionContext(httpContext, new RouteData(), new ActionDescriptor());
87+
88+
return new ActionExecutingContext(actionContext, [], new Dictionary<string, object?>(), new object());
89+
}
90+
91+
private static ActionExecutedContext CreateActionExecutedContext(ActionExecutingContext context)
92+
{
93+
return new ActionExecutedContext(context, [], context.Controller);
94+
}
95+
}

ManagedCode.Orleans.Identity.Tests/Cluster/Grains/IUserGrain.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
using ManagedCode.Orleans.Identity.Tests.Constants;
2+
using Microsoft.AspNetCore.Authorization;
3+
14
namespace ManagedCode.Orleans.Identity.Tests.Cluster.Grains;
25

36
public interface IUserGrain : IGrainWithStringKey
@@ -8,4 +11,7 @@ public interface IUserGrain : IGrainWithStringKey
811
Task<string> GetPublicInfo();
912
Task<string> ModifyUser();
1013
Task<string> AddToList();
11-
}
14+
15+
[Authorize(Roles = TestRoles.ADMIN)]
16+
Task<string> GetInterfaceAdminInfo();
17+
}

0 commit comments

Comments
 (0)