-
-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathP2PKernelSerializer.cs
More file actions
151 lines (135 loc) · 5.83 KB
/
Copy pathP2PKernelSerializer.cs
File metadata and controls
151 lines (135 loc) · 5.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
using System.Reflection;
using System.Text.Json;
namespace SpawnDev.ILGPU.P2P;
/// <summary>
/// Serializes kernel references for P2P dispatch with security restrictions.
///
/// SECURITY: Only methods on explicitly registered types can be resolved.
/// This prevents arbitrary code execution via malicious dispatch requests.
/// Both coordinator and worker must register the same kernel types.
/// </summary>
public static class P2PKernelSerializer
{
private static readonly System.Collections.Concurrent.ConcurrentDictionary<Type, bool> _allowedTypes = new();
private static readonly System.Collections.Concurrent.ConcurrentDictionary<string, MethodInfo?> _resolveCache = new();
/// <summary>
/// Register a type whose static methods can be dispatched via P2P.
/// Both coordinator and worker must register the same types.
/// </summary>
public static void RegisterKernelType(Type type)
{
_allowedTypes[type] = true;
// Clear resolve cache - a previously-rejected type may now be allowed
_resolveCache.Clear();
}
/// <summary>
/// Register multiple kernel types at once.
/// </summary>
public static void RegisterKernelTypes(params Type[] types)
{
foreach (var type in types)
_allowedTypes[type] = true;
}
/// <summary>
/// Check if a type is registered for P2P dispatch.
/// </summary>
public static bool IsTypeAllowed(Type type) => _allowedTypes.ContainsKey(type);
/// <summary>
/// Clear the allowlist (allows all types again). For testing only.
/// </summary>
public static void ClearAllowlist()
{
_allowedTypes.Clear();
_resolveCache.Clear();
}
/// <summary>
/// Number of registered kernel types.
/// </summary>
public static int AllowedTypeCount => _allowedTypes.Count;
/// <summary>
/// Serialize a kernel method reference to a dispatch request.
/// The method's declaring type must be registered via RegisterKernelType.
/// </summary>
/// <param name="scalarValues">
/// Optional scalar kernel parameter values keyed by parameter index.
/// For a kernel signature <c>(Index1D, ArrayView<float>, ArrayView<float>, float scalar)</c>
/// pass <c>new Dictionary<int, object> { [3] = 7.5f }</c>. Buffer parameters are NOT included here
/// (they use the <see cref="KernelDispatchRequest.Buffers"/> binding mechanism).
/// Supports all ILGPU primitive types: float, double, int, uint, long, ulong, short,
/// ushort, byte, sbyte, bool, System.Half, plus any JSON-serializable struct.
/// </param>
public static KernelDispatchRequest CreateDispatch(
MethodInfo kernelMethod,
long gridDimX, long gridDimY = 1, long gridDimZ = 1,
int groupDimX = 256, int groupDimY = 1, int groupDimZ = 1,
IReadOnlyDictionary<int, object>? scalarValues = null,
string? f64Mode = null)
{
var request = new KernelDispatchRequest
{
KernelType = kernelMethod.DeclaringType?.AssemblyQualifiedName ?? "",
KernelMethod = kernelMethod.Name,
GridDimX = gridDimX,
GridDimY = gridDimY,
GridDimZ = gridDimZ,
GroupDimX = groupDimX,
GroupDimY = groupDimY,
GroupDimZ = groupDimZ,
F64Mode = f64Mode,
};
if (scalarValues != null && scalarValues.Count > 0)
{
// Encode scalar parameters as a JSON object keyed by parameter index (stringified).
// System.Half is not natively supported by System.Text.Json; promote to float for transmission
// and the worker will narrow it back to Half per the kernel signature.
var encoded = new Dictionary<string, object?>(scalarValues.Count);
foreach (var kvp in scalarValues)
{
encoded[kvp.Key.ToString()] = kvp.Value is System.Half h ? (float)h : kvp.Value;
}
request.ScalarParams = JsonSerializer.SerializeToUtf8Bytes(encoded);
}
return request;
}
/// <summary>
/// Resolve a kernel method on the worker side from the dispatch request.
/// SECURITY: Only resolves methods on registered types.
/// Returns null if the type is not in the allowlist.
/// </summary>
public static MethodInfo? ResolveKernel(KernelDispatchRequest request)
{
var cacheKey = $"{request.KernelType}::{request.KernelMethod}";
// Check cache first - only successful resolutions are cached
if (_resolveCache.TryGetValue(cacheKey, out var cached))
return cached;
var type = Type.GetType(request.KernelType);
if (type == null)
{
foreach (var assembly in AppDomain.CurrentDomain.GetAssemblies())
{
type = assembly.GetType(request.KernelType);
if (type != null) break;
}
}
if (type == null) return null;
// SECURITY: Only allow registered types. Fail-closed - if no types
// are registered, reject ALL dispatches. Never allow arbitrary code execution.
if (!_allowedTypes.ContainsKey(type))
return null;
// Include NonPublic so private/internal static kernel methods (the C# default when
// no access modifier is specified) resolve correctly. Type allowlist is the
// security boundary, not method visibility.
var method = type.GetMethod(request.KernelMethod,
BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static);
if (method != null)
_resolveCache[cacheKey] = method;
return method;
}
/// <summary>
/// Check if this worker can execute the requested kernel.
/// </summary>
public static bool CanExecute(KernelDispatchRequest request)
{
return ResolveKernel(request) != null;
}
}