Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 160 additions & 0 deletions mcp/bm25.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
// Copyright (c) 2023-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.

package mcp

import (
"math"
"sort"
"strings"
"unicode"
)

const (
bm25K1 = 1.2
bm25B = 0.75
)

type BM25Document struct {
ID string
Text string
}

type BM25Result struct {
ID string
Score float64
}

type BM25Index struct {
documents []bm25IndexedDocument
documentCount int
documentFreqs map[string]int
averageDocLength float64
}

type bm25IndexedDocument struct {
id string
termFreqs map[string]int
tokenCount float64
}

func NewBM25Index(docs []BM25Document) *BM25Index {
idx := &BM25Index{
documents: make([]bm25IndexedDocument, 0, len(docs)),
documentCount: len(docs),
documentFreqs: make(map[string]int),
}

var totalDocLength float64
for _, doc := range docs {
tokens := tokenizeBM25Text(doc.Text)
termFreqs := make(map[string]int)
for _, token := range tokens {
termFreqs[token]++
}

for token := range termFreqs {
idx.documentFreqs[token]++
}

tokenCount := float64(len(tokens))
totalDocLength += tokenCount
idx.documents = append(idx.documents, bm25IndexedDocument{
id: doc.ID,
termFreqs: termFreqs,
tokenCount: tokenCount,
})
}

if idx.documentCount > 0 {
idx.averageDocLength = totalDocLength / float64(idx.documentCount)
}

return idx
}

func (idx *BM25Index) Search(query string, limit int) []BM25Result {
if idx == nil || idx.documentCount == 0 || idx.averageDocLength == 0 {
return nil
}

queryTokens := uniqueBM25Tokens(tokenizeBM25Text(query))
if len(queryTokens) == 0 {
return nil
}

results := make([]BM25Result, 0, len(idx.documents))
for _, doc := range idx.documents {
var score float64
for _, token := range queryTokens {
tf := float64(doc.termFreqs[token])
if tf == 0 {
continue
}

df := idx.documentFreqs[token]
idf := math.Log(1 + (float64(idx.documentCount-df)+0.5)/(float64(df)+0.5))
score += idf * (tf * (bm25K1 + 1)) / (tf + bm25K1*(1-bm25B+bm25B*doc.tokenCount/idx.averageDocLength))
}

if score > 0 {
results = append(results, BM25Result{
ID: doc.id,
Score: score,
})
}
}

sort.Slice(results, func(i, j int) bool {
if results[i].Score == results[j].Score {
return results[i].ID < results[j].ID
}
return results[i].Score > results[j].Score
})

if len(results) == 0 {
return nil
}

if limit > 0 && len(results) > limit {
results = results[:limit]
}

return results
}

func tokenizeBM25Text(text string) []string {
var tokens []string
var current strings.Builder

for _, r := range text {
if unicode.IsLetter(r) || unicode.IsNumber(r) {
current.WriteRune(unicode.ToLower(r))
continue
}

if current.Len() > 0 {
tokens = append(tokens, current.String())
current.Reset()
}
}

if current.Len() > 0 {
tokens = append(tokens, current.String())
}

return tokens
}

func uniqueBM25Tokens(tokens []string) []string {
seen := make(map[string]bool, len(tokens))
unique := make([]string, 0, len(tokens))
for _, token := range tokens {
if seen[token] {
continue
}
seen[token] = true
unique = append(unique, token)
}
return unique
}
98 changes: 98 additions & 0 deletions mcp/bm25_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright (c) 2023-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.

package mcp

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestBM25SearchRanksRelevantDocuments(t *testing.T) {
idx := NewBM25Index([]BM25Document{
{ID: "jira__get_issue", Text: "jira__get_issue get issue jira ticket"},
{ID: "github__create_pull_request", Text: "github__create_pull_request create pull request"},
{ID: "mattermost__search_users", Text: "mattermost__search_users search users"},
})

results := idx.Search("jira issue", 10)

require.NotEmpty(t, results)
require.Equal(t, "jira__get_issue", results[0].ID)
}

func TestBM25SearchUsesNameAndDescription(t *testing.T) {
idx := NewBM25Index([]BM25Document{
{ID: "jira__get_issue", Text: "jira__get_issue get_issue"},
{ID: "github__create_pull_request", Text: "opens a collaboration review"},
})

nameResults := idx.Search("get issue", 10)
require.NotEmpty(t, nameResults)
require.Equal(t, "jira__get_issue", nameResults[0].ID)

descriptionResults := idx.Search("collaboration review", 10)
require.NotEmpty(t, descriptionResults)
require.Equal(t, "github__create_pull_request", descriptionResults[0].ID)
}

func TestBM25SearchLimitAndTieBreak(t *testing.T) {
idx := NewBM25Index([]BM25Document{
{ID: "charlie", Text: "shared"},
{ID: "bravo", Text: "shared"},
{ID: "alpha", Text: "shared"},
})

results := idx.Search("shared", 2)

require.Len(t, results, 2)
require.Equal(t, "alpha", results[0].ID)
require.Equal(t, "bravo", results[1].ID)
}

func TestBM25EmptyQueryReturnsNil(t *testing.T) {
idx := NewBM25Index([]BM25Document{
{ID: "jira__get_issue", Text: "jira issue"},
})

require.Nil(t, idx.Search("", 10))
require.Nil(t, idx.Search(" ", 10))
}

func TestBM25NoMatchingTokensReturnsNil(t *testing.T) {
idx := NewBM25Index([]BM25Document{
{ID: "jira__get_issue", Text: "jira issue"},
})

require.Nil(t, idx.Search("github", 10))
}

func TestBM25TokenizeNonLatin(t *testing.T) {
idx := NewBM25Index([]BM25Document{
{ID: "japanese", Text: "検索 ユーザー"},
{ID: "chinese_contiguous", Text: "用户搜索"},
})

japaneseResults := idx.Search("検索", 10)
require.NotEmpty(t, japaneseResults)
require.Equal(t, "japanese", japaneseResults[0].ID)

contiguousResults := idx.Search("用户搜索", 10)
require.NotEmpty(t, contiguousResults)
require.Equal(t, "chinese_contiguous", contiguousResults[0].ID)

// There is no CJK segmentation: a substring query does not match a contiguous token.
require.Nil(t, idx.Search("用户", 10))
}

func TestBM25TokenizeNamespacedSnakeNames(t *testing.T) {
idx := NewBM25Index([]BM25Document{
{ID: "jira__get_issue", Text: "jira__get_issue"},
})

results := idx.Search("get issue", 10)

require.NotEmpty(t, results)
require.Equal(t, "jira__get_issue", results[0].ID)
}
Loading
Loading