Skip to content

Commit 9b0965c

Browse files
authored
[C#] Remove WaitableSet and its global. Tidy <T> in codegen (#1552)
* Remove WaitableSet and its global. Tidy <T> in codegen * address feedback, comment use of unsafe, refactor repeated code.
1 parent 2c1b579 commit 9b0965c

File tree

3 files changed

+73
-74
lines changed

3 files changed

+73
-74
lines changed

crates/csharp/src/AsyncSupport.cs

Lines changed: 70 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -21,25 +21,13 @@ public enum CallbackCode : uint
2121
//#define TEST_CALLBACK_CODE_WAIT(set) (2 | (set << 4))
2222
}
2323

24-
public class WaitableSet(int handle) : IDisposable
24+
// The context that we will create in unmanaged memory and pass to context_set.
25+
// TODO: C has world specific types for these pointers, perhaps C# would benefit from those also.
26+
[StructLayout(LayoutKind.Sequential)]
27+
public struct ContextTask
2528
{
26-
public int Handle { get; } = handle;
27-
28-
void Dispose(bool _disposing)
29-
{
30-
AsyncSupport.WaitableSetDrop(handle);
31-
}
32-
33-
public void Dispose()
34-
{
35-
Dispose(true);
36-
GC.SuppressFinalize(this);
37-
}
38-
39-
~WaitableSet()
40-
{
41-
Dispose(false);
42-
}
29+
public int WaitableSetHandle;
30+
public int FutureHandle;
4331
}
4432

4533
public static class AsyncSupport
@@ -51,9 +39,6 @@ internal static class PollWasmInterop
5139
internal static extern void wasmImportPoll(nint p0, int p1, nint p2);
5240
}
5341

54-
// TODO: How do we allow multiple waitable sets?
55-
internal static WaitableSet WaitableSet;
56-
5742
private static class Interop
5843
{
5944
[global::System.Runtime.InteropServices.DllImport("$root", EntryPoint = "[waitable-set-new]"), global::System.Runtime.InteropServices.WasmImportLinkageAttribute]
@@ -69,7 +54,7 @@ private static class Interop
6954
internal static unsafe extern uint WaitableSetPoll(int waitable, uint* waitableHandlePtr);
7055

7156
[global::System.Runtime.InteropServices.DllImport("$root", EntryPoint = "[waitable-set-drop]"), global::System.Runtime.InteropServices.WasmImportLinkageAttribute]
72-
internal static unsafe extern void WaitableSetDrop(int waitable);
57+
internal static extern void WaitableSetDrop(int waitable);
7358

7459
[global::System.Runtime.InteropServices.DllImport("$root", EntryPoint = "[context-set-0]"), global::System.Runtime.InteropServices.WasmImportLinkageAttribute]
7560
internal static unsafe extern void ContextSet(ContextTask* waitable);
@@ -78,13 +63,14 @@ private static class Interop
7863
internal static unsafe extern ContextTask* ContextGet();
7964
}
8065

81-
public static WaitableSet WaitableSetNew()
66+
public static int WaitableSetNew()
8267
{
8368
var waitableSet = Interop.WaitableSetNew();
8469
Console.WriteLine($"WaitableSet created with number {waitableSet}");
85-
return new WaitableSet(waitableSet);
70+
return waitableSet;
8671
}
8772

73+
// unsafe because we are using pointers.
8874
public static unsafe void WaitableSetPoll(int waitableHandle)
8975
{
9076
var error = Interop.WaitableSetPoll(waitableHandle, null);
@@ -94,16 +80,16 @@ public static unsafe void WaitableSetPoll(int waitableHandle)
9480
}
9581
}
9682

97-
internal static void Join(SubtaskStatus subtask, WaitableSet set, WaitableInfoState waitableInfoState)
83+
internal static void Join(SubtaskStatus subtask, int waitableSetHandle, WaitableInfoState waitableInfoState)
9884
{
99-
AddTaskToWaitables(set.Handle, subtask.Handle, waitableInfoState);
100-
Interop.WaitableJoin(subtask.Handle, set.Handle);
85+
AddTaskToWaitables(waitableSetHandle, subtask.Handle, waitableInfoState);
86+
Interop.WaitableJoin(subtask.Handle, waitableSetHandle);
10187
}
10288

103-
internal static void Join(int readerWriterHandle, WaitableSet set, WaitableInfoState waitableInfoState)
89+
internal static void Join(int readerWriterHandle, int waitableHandle, WaitableInfoState waitableInfoState)
10490
{
105-
AddTaskToWaitables(set.Handle, readerWriterHandle, waitableInfoState);
106-
Interop.WaitableJoin(readerWriterHandle, set.Handle);
91+
AddTaskToWaitables(waitableHandle, readerWriterHandle, waitableInfoState);
92+
Interop.WaitableJoin(readerWriterHandle, waitableHandle);
10793
}
10894

10995
// TODO: Revisit this to see if we can remove it.
@@ -120,10 +106,11 @@ private static void AddTaskToWaitables(int waitableSetHandle, int waitableHandle
120106
waitableSetOfTasks[waitableHandle] = waitableInfoState;
121107
}
122108

123-
public unsafe static EventWaitable WaitableSetWait(WaitableSet set)
109+
// unsafe because we use a fixed size buffer.
110+
public static unsafe EventWaitable WaitableSetWait(int waitableSetHandle)
124111
{
125112
uint* buffer = stackalloc uint[2];
126-
var eventCode = (EventCode)Interop.WaitableSetWait(set.Handle, buffer);
113+
var eventCode = (EventCode)Interop.WaitableSetWait(waitableSetHandle, buffer);
127114
return new EventWaitable(eventCode, buffer[0], buffer[1]);
128115
}
129116

@@ -132,34 +119,25 @@ public static void WaitableSetDrop(int handle)
132119
Interop.WaitableSetDrop(handle);
133120
}
134121

135-
// The context that we will create in unmanaged memory and pass to context_set.
136-
// TODO: C has world specific types for these pointers, perhaps C# would benefit from those also.
137-
[StructLayout(LayoutKind.Sequential)]
138-
public struct ContextTask
139-
{
140-
public int Set;
141-
public int FutureHandle;
142-
}
143-
122+
// unsafe because we are using pointers.
144123
public static unsafe void ContextSet(ContextTask* contextTask)
145124
{
146125
Interop.ContextSet(contextTask);
147126
}
148127

128+
// unsafe because we are using pointers.
149129
public static unsafe ContextTask* ContextGet()
150130
{
151-
ContextTask* contextTaskPtr = Interop.ContextGet();
152-
if(contextTaskPtr == null)
153-
{
154-
throw new Exception("null context returned.");
155-
}
156-
return contextTaskPtr;
131+
return Interop.ContextGet();
157132
}
158133

134+
// unsafe because we are using pointers.
159135
public static unsafe uint Callback(EventWaitable e, ContextTask* contextPtr, Action taskReturn)
160136
{
161137
Console.WriteLine($"Callback Event code {e.EventCode} Code {e.Code} Waitable {e.Waitable} Waitable Status {e.WaitableStatus.State}, Count {e.WaitableCount}");
162-
var waitables = pendingTasks[WaitableSet.Handle];
138+
ContextTask* contextTaskPtr = ContextGet();
139+
140+
var waitables = pendingTasks[contextTaskPtr->WaitableSetHandle];
163141
var waitableInfoState = waitables[e.Waitable];
164142

165143
if (e.IsDropped)
@@ -195,32 +173,36 @@ public static unsafe uint Callback(EventWaitable e, ContextTask* contextPtr, Act
195173

196174
if (waitables.Count == 0)
197175
{
198-
Console.WriteLine($"No more waitables for waitable {e.Waitable} in set {WaitableSet.Handle}");
176+
Console.WriteLine($"No more waitables for waitable {e.Waitable} in set {contextTaskPtr->WaitableSetHandle}");
199177
taskReturn();
178+
ContextSet(null);
179+
Marshal.FreeHGlobal((IntPtr)contextTaskPtr);
200180
return (uint)CallbackCode.Exit;
201181
}
202182

203183
Console.WriteLine("More waitables in the set.");
204-
return (uint)CallbackCode.Wait | (uint)(WaitableSet.Handle << 4);
184+
return (uint)CallbackCode.Wait | (uint)(contextTaskPtr->WaitableSetHandle << 4);
205185
}
206186

207-
throw new NotImplementedException($"WaitableStatus not implemented {e.WaitableStatus.State} in set {WaitableSet.Handle}");
187+
throw new NotImplementedException($"WaitableStatus not implemented {e.WaitableStatus.State} in set {contextTaskPtr->WaitableSetHandle}");
208188
}
209189

210-
public static Task TaskFromStatus(uint status)
190+
// This method is unsafe because we are using unmanaged memory to store the context.
191+
internal static unsafe Task TaskFromStatus(uint status)
211192
{
212193
var subtaskStatus = new SubtaskStatus(status);
213194
status = status & 0xF;
214195

215196
if (subtaskStatus.IsSubtaskStarting || subtaskStatus.IsSubtaskStarted)
216197
{
217-
if (WaitableSet == null) {
218-
WaitableSet = WaitableSetNew();
219-
Console.WriteLine($"TaskFromStatus creating WaitableSet {WaitableSet.Handle}");
198+
ContextTask* contextTaskPtr = ContextGet();
199+
if (contextTaskPtr == null) {
200+
contextTaskPtr = AllocateAndSetNewContext();
201+
Console.WriteLine($"TaskFromStatus creating WaitableSet {contextTaskPtr->WaitableSetHandle}");
220202
}
221203

222204
TaskCompletionSource tcs = new TaskCompletionSource();
223-
AsyncSupport.Join(subtaskStatus, WaitableSet, new WaitableInfoState(tcs));
205+
Join(subtaskStatus, contextTaskPtr->WaitableSetHandle, new WaitableInfoState(tcs));
224206
return tcs.Task;
225207
}
226208
else if (subtaskStatus.IsSubtaskReturned)
@@ -233,7 +215,8 @@ public static Task TaskFromStatus(uint status)
233215
}
234216
}
235217

236-
public static Task<T> TaskFromStatus<T>(uint status, Func<T> liftFunc)
218+
// unsafe because we are using pointers.
219+
public static unsafe Task<T> TaskFromStatus<T>(uint status, Func<T> liftFunc)
237220
{
238221
var subtaskStatus = new SubtaskStatus(status);
239222
status = status & 0xF;
@@ -242,9 +225,12 @@ public static Task<T> TaskFromStatus<T>(uint status, Func<T> liftFunc)
242225
var tcs = new TaskCompletionSource<T>();
243226
if (subtaskStatus.IsSubtaskStarting || subtaskStatus.IsSubtaskStarted)
244227
{
245-
if (WaitableSet == null) {
228+
ContextTask* contextTaskPtr = ContextGet();
229+
if (contextTaskPtr == null) {
230+
contextTaskPtr = (ContextTask *)Marshal.AllocHGlobal(Marshal.SizeOf<ContextTask>());
246231
Console.WriteLine("TaskFromStatus<T> creating WaitableSet");
247-
WaitableSet = AsyncSupport.WaitableSetNew();
232+
contextTaskPtr->WaitableSetHandle = WaitableSetNew();
233+
ContextSet(contextTaskPtr);
248234
}
249235

250236
return tcs.Task;
@@ -259,6 +245,15 @@ public static Task<T> TaskFromStatus<T>(uint status, Func<T> liftFunc)
259245
throw new Exception($"unexpected subtask status: {status}");
260246
}
261247
}
248+
249+
// unsafe because we are working with native memory.
250+
internal static unsafe ContextTask* AllocateAndSetNewContext()
251+
{
252+
var contextTaskPtr = (ContextTask *)Marshal.AllocHGlobal(Marshal.SizeOf<ContextTask>());
253+
contextTaskPtr->WaitableSetHandle = AsyncSupport.WaitableSetNew();
254+
AsyncSupport.ContextSet(contextTaskPtr);
255+
return contextTaskPtr;
256+
}
262257
}
263258

264259

@@ -371,6 +366,7 @@ internal int TakeHandle()
371366

372367
internal abstract uint VTableRead(IntPtr bufferPtr, int length);
373368

369+
// unsafe as we are working with pointers.
374370
internal unsafe Task<int> ReadInternal(Func<GCHandle?> liftBuffer, int length)
375371
{
376372
if (Handle == 0)
@@ -389,14 +385,15 @@ internal unsafe Task<int> ReadInternal(Func<GCHandle?> liftBuffer, int length)
389385
{
390386
Console.WriteLine("Read Blocked");
391387
var tcs = new TaskCompletionSource<int>();
392-
if(AsyncSupport.WaitableSet == null)
388+
ContextTask* contextTaskPtr = AsyncSupport.ContextGet();
389+
if(contextTaskPtr == null)
393390
{
394391
Console.WriteLine("FutureReader Read Blocked creating WaitableSet");
395-
AsyncSupport.WaitableSet = AsyncSupport.WaitableSetNew();
392+
contextTaskPtr = AsyncSupport.AllocateAndSetNewContext();
396393
}
397394
Console.WriteLine("blocked read before join");
398395

399-
AsyncSupport.Join(Handle, AsyncSupport.WaitableSet, new WaitableInfoState(tcs, this));
396+
AsyncSupport.Join(Handle, contextTaskPtr->WaitableSetHandle, new WaitableInfoState(tcs, this));
400397
Console.WriteLine("blocked read after join");
401398
return tcs.Task;
402399
}
@@ -470,7 +467,7 @@ public class FutureReader<T>(int handle, FutureVTable vTable) : ReaderBase(handl
470467
{
471468
public FutureVTable VTable { get; private set; } = vTable;
472469

473-
private GCHandle LiftBuffer<T>(T buffer)
470+
private GCHandle LiftBuffer(T buffer)
474471
{
475472
if(typeof(T) == typeof(byte))
476473
{
@@ -483,7 +480,7 @@ private GCHandle LiftBuffer<T>(T buffer)
483480
}
484481
}
485482

486-
public unsafe Task Read<T>(T buffer)
483+
public Task Read(T buffer)
487484
{
488485
return ReadInternal(() => LiftBuffer(buffer), 1);
489486
}
@@ -508,7 +505,7 @@ public StreamReader(int handle, StreamVTable vTable) : base(handle)
508505

509506
public StreamVTable VTable { get; private set; }
510507

511-
public unsafe Task Read(int length)
508+
public Task Read(int length)
512509
{
513510
return ReadInternal(() => null, length);
514511
}
@@ -528,7 +525,7 @@ public class StreamReader<T>(int handle, StreamVTable vTable) : ReaderBase(hand
528525
{
529526
public StreamVTable VTable { get; private set; } = vTable;
530527

531-
private GCHandle LiftBuffer<T>(T[] buffer)
528+
private GCHandle LiftBuffer(T[] buffer)
532529
{
533530
if(typeof(T) == typeof(byte))
534531
{
@@ -541,7 +538,7 @@ private GCHandle LiftBuffer<T>(T[] buffer)
541538
}
542539
}
543540

544-
public unsafe Task<int> Read<T>(T[] buffer)
541+
public Task<int> Read(T[] buffer)
545542
{
546543
return ReadInternal(() => LiftBuffer(buffer), buffer.Length);
547544
}
@@ -582,6 +579,7 @@ internal int TakeHandle()
582579

583580
internal abstract uint VTableWrite(IntPtr bufferPtr, int length);
584581

582+
// unsafe as we are working with pointers.
585583
internal unsafe Task<int> WriteInternal(Func<GCHandle?> lowerPayload, int length)
586584
{
587585
if (Handle == 0)
@@ -600,12 +598,13 @@ internal unsafe Task<int> WriteInternal(Func<GCHandle?> lowerPayload, int length
600598
{
601599
Console.WriteLine("blocked write");
602600
var tcs = new TaskCompletionSource<int>();
603-
if(AsyncSupport.WaitableSet == null)
601+
ContextTask* contextTaskPtr = AsyncSupport.ContextGet();
602+
if(contextTaskPtr == null)
604603
{
605-
AsyncSupport.WaitableSet = AsyncSupport.WaitableSetNew();
604+
contextTaskPtr = AsyncSupport.AllocateAndSetNewContext();
606605
}
607606
Console.WriteLine("blocked write before join");
608-
AsyncSupport.Join(Handle, AsyncSupport.WaitableSet, new WaitableInfoState(tcs, this));
607+
AsyncSupport.Join(Handle, contextTaskPtr->WaitableSetHandle, new WaitableInfoState(tcs, this));
609608
Console.WriteLine("blocked write after join");
610609
return tcs.Task;
611610
}
@@ -679,7 +678,6 @@ public class FutureWriter<T>(int handle, FutureVTable vTable) : WriterBase(handl
679678
// TODO: Generate per type for this instrinsic.
680679
public Task Write()
681680
{
682-
// TODO: Lower T
683681
return WriteInternal(() => null, 1);
684682
}
685683

@@ -719,7 +717,7 @@ public class StreamWriter<T>(int handle, StreamVTable vTable) : WriterBase(handl
719717
private GCHandle bufferHandle;
720718
public StreamVTable VTable { get; private set; } = vTable;
721719

722-
private GCHandle LowerPayload<T>(T[] payload)
720+
private GCHandle LowerPayload(T[] payload)
723721
{
724722
if (VTable.Lower == null)
725723
{

crates/csharp/src/function.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1163,7 +1163,8 @@ impl Bindgen for FunctionBindgen<'_, '_> {
11631163
}});
11641164
11651165
// TODO: Defer dropping borrowed resources until a result is returned.
1166-
return (uint)CallbackCode.Wait | (uint)(AsyncSupport.WaitableSet.Handle << 4);
1166+
ContextTask* contextTaskPtr = AsyncSupport.ContextGet();
1167+
return (uint)CallbackCode.Wait | (uint)(contextTaskPtr->WaitableSetHandle << 4);
11671168
"#);
11681169
}
11691170

crates/csharp/src/interface.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,7 @@ var {async_status_var} = {raw_name}({wasm_params});
880880
uwriteln!(
881881
self.csharp_interop_src,
882882
r#"
883-
return (uint)AsyncSupport.Callback(e, (AsyncSupport.ContextTask *)IntPtr.Zero, () => {camel_name}TaskReturn());
883+
return (uint)AsyncSupport.Callback(e, (ContextTask *)IntPtr.Zero, () => {camel_name}TaskReturn());
884884
}}
885885
"#
886886
);

0 commit comments

Comments
 (0)