Skip to content

Commit 21b2aea

Browse files
committed
Fix LuaWeakReference safety
1 parent 139d571 commit 21b2aea

2 files changed

Lines changed: 46 additions & 15 deletions

File tree

src/Laylua/Library/Entities/Reference/LuaWeakReference.cs

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@ internal LuaThread Thread
2424
}
2525

2626
private readonly LuaThread _thread;
27-
private readonly void* _pointer;
27+
private readonly LuaTable? _key;
2828

29-
internal LuaWeakReference(LuaThread thread, void* pointer)
29+
internal LuaWeakReference(LuaThread thread, LuaTable key)
3030
{
3131
_thread = thread;
32-
_pointer = pointer;
32+
_key = key;
3333
}
3434

3535
/// <summary>
@@ -40,7 +40,7 @@ internal LuaWeakReference(LuaThread thread, void* pointer)
4040
/// </returns>
4141
public TReference? GetValue()
4242
{
43-
if (_thread == null || !_thread.Stack.TryEnsureFreeCapacity(2))
43+
if (_thread == null || _key == null || !_thread.Stack.TryEnsureFreeCapacity(2))
4444
{
4545
return default;
4646
}
@@ -54,7 +54,8 @@ internal LuaWeakReference(LuaThread thread, void* pointer)
5454
return default;
5555
}
5656

57-
var type = lua_rawgetp(L, -1, _pointer);
57+
_thread.Stack.Push(_key);
58+
var type = lua_rawget(L, -2);
5859
if (type.IsNoneOrNil())
5960
{
6061
return default;
@@ -72,10 +73,10 @@ internal LuaWeakReference(LuaThread thread, void* pointer)
7273
}
7374
}
7475

75-
[MemberNotNull(nameof(_thread))]
76+
[MemberNotNull(nameof(_thread), nameof(_key))]
7677
private void ThrowIfInvalid()
7778
{
78-
if (_thread == null)
79+
if (_thread == null || _key == null)
7980
{
8081
throw new InvalidOperationException($"This '{GetType().ToTypeString()}' has not been initialized.");
8182
}
@@ -89,22 +90,22 @@ internal static class LuaWeakReference
8990
public static unsafe bool TryCreate<TReference>(LuaThread thread, int stackIndex, out LuaWeakReference<TReference> weakReference)
9091
where TReference : LuaReference
9192
{
92-
if (!TryCreate(thread, stackIndex, out var targetPointer))
93+
if (!TryCreate(thread, stackIndex, out var key))
9394
{
9495
weakReference = default;
9596
return false;
9697
}
9798

98-
weakReference = new LuaWeakReference<TReference>(thread, targetPointer);
99+
weakReference = new LuaWeakReference<TReference>(thread, key);
99100
return true;
100101
}
101102

102-
private static unsafe bool TryCreate(LuaThread thread, int stackIndex, out void* targetPointer)
103+
private static unsafe bool TryCreate(LuaThread thread, int stackIndex, [MaybeNullWhen(false)] out LuaTable key)
103104
{
104105
var L = thread.State.L;
105-
if (!lua_checkstack(L, 4))
106+
if (!lua_checkstack(L, 5))
106107
{
107-
targetPointer = null;
108+
key = null;
108109
return false;
109110
}
110111

@@ -117,15 +118,24 @@ private static unsafe bool TryCreate(LuaThread thread, int stackIndex, out void*
117118
lua_createtable(L, 0, 1);
118119

119120
lua_pushstring(L, LuaMetatableKeysUtf8.__mode);
120-
lua_pushstring(L, "v"u8);
121+
lua_pushstring(L, "kv"u8);
121122
lua_rawset(L, -3);
122123

123124
lua_setmetatable(L, -2);
124125
}
125126

127+
lua_createtable(L, 0, 0);
128+
var keyIndex = lua_gettop(L);
129+
if (!thread.Stack[keyIndex].TryGetValue(out LuaTable? createdKey) || createdKey == null)
130+
{
131+
key = null;
132+
return false;
133+
}
134+
135+
key = createdKey;
136+
lua_pushvalue(L, keyIndex);
126137
lua_pushvalue(L, stackIndex);
127-
targetPointer = lua_topointer(L, -1);
128-
lua_rawsetp(L, -2, targetPointer);
138+
lua_rawset(L, -4);
129139
return true;
130140
}
131141
finally

tests/Laylua.Tests/Tests/Library/Entities/LuaWeakReferenceTests.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,25 @@ public void NoLuaReferencesToTarget_ObjectIsGarbageCollected()
4242
// Assert
4343
Assert.That(function, Is.Null);
4444
}
45+
46+
[Test]
47+
public void WeakLuaReference_UsesIndependentKeyAndDoesNotResolveCollectedTargetToAnotherObject()
48+
{
49+
// Arrange
50+
var staleWeakReference = Lua.Evaluate<LuaWeakReference<LuaTable>>("return {}");
51+
Lua.State.GC.Collect();
52+
53+
using var liveReference = Lua.Evaluate<LuaTable>("return {}")!;
54+
var liveWeakReference = liveReference.CreateWeakReference();
55+
56+
// Act
57+
Lua.State.GC.Collect();
58+
59+
using var staleValue = staleWeakReference.GetValue();
60+
using var liveValue = liveWeakReference.GetValue();
61+
62+
// Assert
63+
Assert.That(staleValue, Is.Null);
64+
Assert.That(liveValue, Is.Not.Null);
65+
}
4566
}

0 commit comments

Comments
 (0)