Skip to content

Commit 2eed441

Browse files
committed
Safe implementation of GetData
1 parent 4386a66 commit 2eed441

File tree

3 files changed

+111
-16
lines changed

3 files changed

+111
-16
lines changed

LLama.Unittest/MtmdWeightsTests.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,10 @@ private SafeMtmdInputChunks TokenizeWithEmbed(Func<SafeMtmdEmbed> loadEmbed)
5353
Assert.True(embed.Nx > 0);
5454
Assert.True(embed.Ny > 0);
5555
Assert.False(embed.IsAudio);
56-
Assert.True(embed.GetDataSpan().Length > 0);
56+
57+
Assert.True(embed.ByteCount > 0);
58+
using var mem = embed.GetData();
59+
Assert.True(mem.Data.Length > 0);
5760

5861
var status = _mtmdWeights.Tokenize(_mediaMarker, addSpecial: true, parseSpecial: true, out var chunks);
5962
Assert.Equal(0, status);

LLama/Native/NativeApi.Mtmd.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ internal struct mtmd_context_params
7575
internal static extern uint mtmd_bitmap_get_ny(SafeMtmdEmbed bitmap);
7676

7777
[DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_get_data", CallingConvention = CallingConvention.Cdecl)]
78-
internal static extern IntPtr mtmd_bitmap_get_data(SafeMtmdEmbed bitmap);
78+
internal static extern unsafe byte* mtmd_bitmap_get_data(SafeMtmdEmbed bitmap);
7979

8080
[DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_get_n_bytes", CallingConvention = CallingConvention.Cdecl)]
8181
internal static extern UIntPtr mtmd_bitmap_get_n_bytes(SafeMtmdEmbed bitmap);

LLama/Native/SafeMtmdEmbed.cs

Lines changed: 106 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.IO;
3+
using System.Threading;
34

45
namespace LLama.Native
56
{
@@ -28,6 +29,7 @@ private SafeMtmdEmbed(IntPtr ptr)
2829
throw new InvalidOperationException("Failed to create MTMD bitmap.");
2930
}
3031

32+
#region Create Embed
3133
/// <summary>
3234
/// Create an embedding from raw RGB bytes.
3335
/// </summary>
@@ -101,17 +103,21 @@ private SafeMtmdEmbed(IntPtr ptr)
101103
/// <returns>Managed wrapper when decoding succeeds; otherwise <c>null</c>.</returns>
102104
/// <exception cref="ArgumentNullException">The context is null.</exception>
103105
/// <exception cref="ArgumentException">The buffer is empty.</exception>
104-
public static unsafe SafeMtmdEmbed? FromMediaBuffer(SafeMtmdModelHandle mtmdContext, ReadOnlySpan<byte> data)
106+
public static SafeMtmdEmbed? FromMediaBuffer(SafeMtmdModelHandle mtmdContext, ReadOnlySpan<byte> data)
105107
{
106108
if (data.IsEmpty)
107109
throw new ArgumentException("Media buffer must not be empty.", nameof(data));
108110

109-
fixed (byte* bufferPtr = data)
111+
unsafe
110112
{
111-
var native = NativeApi.mtmd_helper_bitmap_init_from_buf(mtmdContext, bufferPtr, (nuint)data.Length);
112-
return native == IntPtr.Zero ? null : new SafeMtmdEmbed(native);
113+
fixed (byte* bufferPtr = data)
114+
{
115+
var native = NativeApi.mtmd_helper_bitmap_init_from_buf(mtmdContext, bufferPtr, (nuint)data.Length);
116+
return native == IntPtr.Zero ? null : new SafeMtmdEmbed(native);
117+
}
113118
}
114119
}
120+
#endregion
115121

116122
/// <summary>
117123
/// Width of the bitmap in pixels (or number of samples for audio embeddings).
@@ -128,6 +134,11 @@ private SafeMtmdEmbed(IntPtr ptr)
128134
/// </summary>
129135
public bool IsAudio => NativeApi.mtmd_bitmap_is_audio(this);
130136

137+
/// <summary>
138+
/// Get the byte count of the raw bitmap/audio data in this embed
139+
/// </summary>
140+
public ulong ByteCount => NativeApi.mtmd_bitmap_get_n_bytes(this).ToUInt64();
141+
131142
/// <summary>
132143
/// Optional identifier assigned to this embedding.
133144
/// </summary>
@@ -137,21 +148,102 @@ public string? Id
137148
set => NativeApi.mtmd_bitmap_set_id(this, value);
138149
}
139150

151+
#region GetData
152+
/// <summary>
153+
/// Provides safe zero-copy access to the underlying bitmap bytes.
154+
/// </summary>
155+
/// <returns>The data access is guaranteed to remain valid until this object is disposed.</returns>
156+
public IEmbedData GetData()
157+
{
158+
// Increment the reference count on this embed. When the "lifetime" is disposed the refcount will be decremented.
159+
var success = false;
160+
DangerousAddRef(ref success);
161+
162+
try
163+
{
164+
unsafe
165+
{
166+
return new EmbedDataLifetime(this, NativeApi.mtmd_bitmap_get_data(this), checked((int)ByteCount));
167+
}
168+
}
169+
catch
170+
{
171+
DangerousRelease();
172+
throw;
173+
}
174+
}
175+
140176
/// <summary>
141-
/// Zero-copy access to the underlying bitmap bytes. The span remains valid while this wrapper is alive.
177+
/// Accessor for the raw data of a <see cref="SafeMtmdEmbed"/>
142178
/// </summary>
143-
/// <returns>Read-only span exposing the native data buffer.</returns>
144-
/// <exception cref="ObjectDisposedException">The embedding has been disposed.</exception>
145-
public unsafe ReadOnlySpan<byte> GetDataSpan()
179+
public interface IEmbedData
180+
: IDisposable
146181
{
147-
EnsureNotDisposed();
182+
/// <summary>
183+
/// Get the raw data. Access to this span is only guaranteed to be valid until this accessor is disposed.
184+
/// </summary>
185+
/// <exception cref="ObjectDisposedException">Thrown if this accessor has been disposed</exception>
186+
ReadOnlySpan<byte> Data { get; }
148187

149-
var dataPtr = (byte*)NativeApi.mtmd_bitmap_get_data(this);
150-
var length = checked((int)NativeApi.mtmd_bitmap_get_n_bytes(this).ToUInt64());
151-
return dataPtr == null || length == 0
152-
? ReadOnlySpan<byte>.Empty
153-
: new ReadOnlySpan<byte>(dataPtr, length);
188+
/// <summary>
189+
/// Indicates if this accessor is still valid (i.e. not disposed)
190+
/// </summary>
191+
bool IsValid { get; }
192+
}
193+
194+
private sealed class EmbedDataLifetime
195+
: IEmbedData
196+
{
197+
private int _valid;
198+
private readonly SafeMtmdEmbed _embed;
199+
private readonly unsafe byte* _dataPtr;
200+
private readonly int _dataLength;
201+
202+
public ReadOnlySpan<byte> Data
203+
{
204+
get
205+
{
206+
unsafe
207+
{
208+
if (!IsValid)
209+
throw new ObjectDisposedException("Cannot access Embed data, accessor has been disposed");
210+
return new ReadOnlySpan<byte>(_dataPtr, _dataLength);
211+
}
212+
}
213+
}
214+
215+
public bool IsValid => _valid != 0;
216+
217+
public unsafe EmbedDataLifetime(SafeMtmdEmbed embed, byte* dataPtr, int dataLength)
218+
{
219+
_embed = embed;
220+
221+
_dataPtr = dataPtr;
222+
_dataLength = dataLength;
223+
224+
_valid = 1;
225+
}
226+
227+
~EmbedDataLifetime()
228+
{
229+
Dispose(false);
230+
}
231+
232+
public void Dispose()
233+
{
234+
Dispose(true);
235+
}
236+
237+
private void Dispose(bool disposing)
238+
{
239+
if (Interlocked.Exchange(ref _valid, 0) == 1)
240+
_embed.DangerousRelease();
241+
242+
if (disposing)
243+
GC.SuppressFinalize(this);
244+
}
154245
}
246+
#endregion
155247

156248
/// <summary>
157249
/// Release the underlying native bitmap.

0 commit comments

Comments
 (0)