Skip to content

Multimodal embedding #1193

Description

@LoicDagnas

Description

Taking inspiration from the LlamaEmbedder and the multimodal support which has been added to LlamaInteractExecutor, I have been trying to implement a multimodal embedder. The main idea is to support Qwen2-VL related models specialized in screenshot embedding such as:

IMO, it should works as I did i.e.:

  • building manually the prompt
  • tokenizing separately the prompt before and after the image marker
  • feeding the model with the token before the image, then the image using LlavaWeights.EvalImageEmbed and finally the tokens after the image
  • getting the embedding of the last token <|endoftext|> and normalize it

But:

  • I don't get the same vectors as the one I obtain with python code
  • even worse, embedding twice the same image doesn't give me the same vectors
  • and even using two different context instances, I don't get the same vector

Does it ring a bell to someone?

Here is my class and a dummy unit test comparing two runs of image embedding computation

using System.Numerics.Tensors;
using System.Text;
using LLama.Common;
using LLama.Extensions;
using LLama.Native;

namespace LLama.Unittest;

public sealed class LlamaMultimodalEmbedder : IDisposable
{
    private readonly LLavaWeights _llavaWeights;
    private readonly LLamaContext _context;

    public LlamaMultimodalEmbedder(LLamaContext context, LLavaWeights llavaWeights)
    {
        if (context.Params.UBatchSize != context.Params.BatchSize)
            throw new ArgumentException("For non-causal models, batch size must be equal to ubatch size");

        _llavaWeights = llavaWeights;
        _context = context;

        NativeApi.llama_set_embeddings(_context.NativeHandle, true);
    }

    private bool _disposed;

    public void Dispose()
    {
        if (_disposed)
            return;

        _context.Dispose();
        _llavaWeights.Dispose();
        _disposed = true;
    }

    private const string ImageMarker = "<|image_pad|>";
    private readonly int _imageMarkerSize = ImageMarker.Length;

    private async Task<float[]> GetEmbedding(
        string? text,
        byte[]? image,
        CancellationToken cancellationToken = default)
    {
        // clear previous kv_cache values
        _context.NativeHandle.KvCacheClear();
        _context.NativeHandle.KvCacheRemove( LLamaSeqId.Zero, -1, -1 );

        var hasText = !string.IsNullOrEmpty(text);
        var hasImage = image != null;

        if (!hasText && !hasImage)
            throw new ArgumentException("At least one of text or image must be provided");

        // Even if it implies a loss of genericity, we build manually the prompt for two reasons:
        // * history doesn't handle image content
        // * we aim to support Qwen2-VL like model
        var promptBuilder = new StringBuilder();

        promptBuilder
            .Append("<|im_start|>system\n")
            .Append("You are a helpful assistant.<|im_end|>\n");

        promptBuilder.Append("<|im_start|>user\n");

        if (hasImage)
            promptBuilder.Append("<|vision_start|>").Append(ImageMarker).Append("<|vision_end|>");

        if (hasText)
            promptBuilder.Append(text);

        promptBuilder.Append("<|im_end|>\n");

        promptBuilder
            .Append("<|im_start|>assistant\n")
            .Append("<|endoftext|>");

        var prompt = promptBuilder.ToString();

        // Compute embeddings of the input image to be fed into the model
        using var imageEmbeddingHandle = hasImage ? GetImageEmbeddingHandle(image!) : null;

        var tokens = new List<LLamaToken>();
        var imageTokenIndex = -1;

        if (hasImage)
        {
            var imageIndexInPrompt = prompt.IndexOf(ImageMarker, StringComparison.Ordinal);

            // Tokenize text segment before <|image_pad|> tag
            var promptBeforeImage = prompt[..imageIndexInPrompt];
            var tokensBeforeImage = _context.Tokenize(promptBeforeImage, addBos: true, special: true);

            // Remember the position to add the image embeddings
            imageTokenIndex = tokensBeforeImage.Length;

            // Tokenize text segment after <|image_pad|> tag
            var promptAfterImage = prompt[(imageIndexInPrompt + _imageMarkerSize)..];
            var tokensAfterImage = _context.Tokenize(promptAfterImage, addBos: false, special: true);

            tokens.AddRange(tokensBeforeImage);
            tokens.AddRange(tokensAfterImage);
        }
        else
        {
            tokens.AddRange(_context.Tokenize(prompt, addBos: true, special: true));
        }

        var tokensCount = tokens.Count;

        if (tokensCount > _context.ContextSize)
            throw new ArgumentException(
                $"Embedding prompt is longer than the context window ({tokensCount} > {_context.ContextSize})");

        // Check if we should cancel the work, just before doing anything expensive (encode/decode)
        cancellationToken.ThrowIfCancellationRequested();

        // Evaluate prompt in batch-size chunks
        var batch = new LLamaBatch();
        var nPast = 0;

        var decodeResponse = await _context
            .DecodeAsync(tokens.GetRange(0, hasImage ? imageTokenIndex : tokensCount), LLamaSeqId.Zero, batch, nPast)
            .ConfigureAwait(false);

        nPast = decodeResponse.Item3;

        if (hasImage)
        {
            _llavaWeights.EvalImageEmbed(_context, imageEmbeddingHandle!, ref nPast);

            decodeResponse = await _context
                .DecodeAsync(tokens.GetRange(imageTokenIndex, tokensCount - imageTokenIndex), LLamaSeqId.Zero, batch,
                    nPast)
                .ConfigureAwait(false);

            nPast = decodeResponse.Item3;
        }

        var poolingType = _context.NativeHandle.PoolingType;

        if (poolingType != LLamaPoolingType.None)
            throw new NotSupportedException("Unsupported pooling type");

        var positions = batch.GetLogitPositions();

        if (positions == null)
            throw new InvalidOperationException("GetLogitPositions returned null");

        var embedding = _context.NativeHandle.GetEmbeddingsIth(positions[^1].Item2).ToArray();

        embedding.EuclideanNormalization();

        return embedding;
    }

    private SafeLlavaImageEmbedHandle GetImageEmbeddingHandle(byte[] imageBytes)
    {
        if (_llavaWeights == null)
            throw new InvalidOperationException("LLavaWeights is not loaded.");

        var embeddingsHandle = _llavaWeights.CreateImageEmbeddings(imageBytes);

        if (embeddingsHandle.IsInvalid)
            throw new InvalidOperationException(
                "Failed to create embedding handle, make sure that the image is a valid base 64 encoded string.");

        return embeddingsHandle;
    }

    public async Task<float[]> GetTextEmbedding(string text, CancellationToken cancellationToken) =>
        await GetEmbedding(text, null, cancellationToken).ConfigureAwait(false);

    public async Task<float[]> GetImageEmbedding(byte[] imageBytes, CancellationToken cancellationToken) =>
        await GetEmbedding(null, imageBytes, cancellationToken).ConfigureAwait(false);
}

public sealed class LLamaMultimodalEmbedderTests
{
    private const string ModelPath = "path\to\model.gguf";
    private const string MmprojPath = "path\to\mmproj.gguf";
    private const string ImagePath = "path\to\image.png";
    
    [Fact]
    public async Task TestBasic()
    {
        var parameters = new ModelParams(ModelPath)
        {
            GpuLayerCount = 5
        };

        var model = await LLamaWeights.LoadFromFileAsync(parameters);
        var llavaWeights = await LLavaWeights.LoadFromFileAsync(MmprojPath);
        var context = model.CreateContext(parameters);

        var multimodalEmbedder = new LlamaMultimodalEmbedder(context, llavaWeights);

        var embedding1 = await multimodalEmbedder.GetImageEmbedding(
            await File.ReadAllBytesAsync(ImagePath),
            CancellationToken.None);
        
        var embedding2 = await multimodalEmbedder.GetImageEmbedding(
            await File.ReadAllBytesAsync(ImagePath),
            CancellationToken.None);
        
        var diff = TensorPrimitives.Norm(
            embedding1.Zip(embedding2, (a, b) => a - b).ToArray());
        
        Assert.True(diff < 10e-1);
    }
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    do not closeProtect this issue from auto closingstaleStale issue will be autoclosed soon

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions