Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"os"
"os/signal"
"strings"
"sync"
"time"

"github.com/getsentry/sentry-go"
Expand Down Expand Up @@ -140,6 +141,18 @@ var (
} else {
ctx = telemetry.WithService(ctx, service)
}
if service != nil {
var stitchOnce sync.Once
Comment thread
seanoliver marked this conversation as resolved.
utils.OnGotrueID = func(gotrueID string) {
if service.NeedsIdentityStitch() {
stitchOnce.Do(func() {
if err := service.StitchLogin(gotrueID); err != nil {
fmt.Fprintln(utils.GetDebugLogger(), err)
}
})
}
}
}
ctx = telemetry.WithCommandContext(ctx, commandAnalyticsContext(cmd))
cmd.SetContext(ctx)
// Setup sentry last to ignore errors from parsing cli flags
Expand Down
4 changes: 4 additions & 0 deletions internal/telemetry/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ func (s *Service) ClearDistinctID() error {
return SaveState(s.state, s.fsys)
}

func (s *Service) NeedsIdentityStitch() bool {
return s != nil && s.state.DistinctID == "" && s.canSend()
}

func (s *Service) GroupIdentify(groupType string, groupKey string, properties map[string]any) error {
if !s.canSend() {
return nil
Expand Down
22 changes: 22 additions & 0 deletions internal/telemetry/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,28 @@ func TestServiceCaptureIncludesLinkedProjectGroups(t *testing.T) {
}, analytics.captures[0].groups)
}

func TestServiceNeedsIdentityStitch(t *testing.T) {
now := time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC)
t.Setenv("SUPABASE_HOME", "/tmp/supabase-home")
fsys := afero.NewMemMapFs()
analytics := &fakeAnalytics{enabled: true}

service, err := NewService(fsys, Options{
Analytics: analytics,
Now: func() time.Time { return now },
})
require.NoError(t, err)

t.Run("true when DistinctID is empty", func(t *testing.T) {
assert.True(t, service.NeedsIdentityStitch())
})

t.Run("false after StitchLogin", func(t *testing.T) {
require.NoError(t, service.StitchLogin("user-123"))
assert.False(t, service.NeedsIdentityStitch())
})
}

func TestServiceCaptureHonorsConsentAndEnvOptOut(t *testing.T) {
now := time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC)

Expand Down
7 changes: 7 additions & 0 deletions internal/utils/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ const (
DNS_OVER_HTTPS = "https"
)

var OnGotrueID func(string)

Comment thread
seanoliver marked this conversation as resolved.
var (
clientOnce sync.Once
apiClient *supabase.ClientWithResponses
Expand Down Expand Up @@ -123,8 +125,13 @@ func GetSupabase() *supabase.ClientWithResponses {
if t, ok := http.DefaultTransport.(*http.Transport); ok {
t.DialContext = withFallbackDNS(t.DialContext)
}
transport := &identityTransport{
RoundTripper: http.DefaultTransport,
onGotrueID: &OnGotrueID,
}
apiClient, err = supabase.NewClientWithResponses(
GetSupabaseAPIHost(),
supabase.WithHTTPClient(&http.Client{Transport: transport}),
supabase.WithRequestEditorFn(func(ctx context.Context, req *http.Request) error {
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("User-Agent", "SupabaseCLI/"+Version)
Expand Down
21 changes: 21 additions & 0 deletions internal/utils/identity_transport.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package utils

import "net/http"

const HeaderGotrueID = "X-Gotrue-Id"

type identityTransport struct {
http.RoundTripper
onGotrueID *func(string)
}

func (t *identityTransport) RoundTrip(req *http.Request) (*http.Response, error) {
resp, err := t.RoundTripper.RoundTrip(req)
if err != nil {
return resp, err
}
if id := resp.Header.Get(HeaderGotrueID); id != "" && t.onGotrueID != nil && *t.onGotrueID != nil {
(*t.onGotrueID)(id)
}
return resp, err
}
101 changes: 101 additions & 0 deletions internal/utils/identity_transport_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package utils

import (
"net/http"
"testing"

"github.com/stretchr/testify/assert"
)

func TestIdentityTransport_CapturesGotrueIdHeader(t *testing.T) {
var captured string
cb := func(id string) { captured = id }
transport := &identityTransport{
RoundTripper: roundTripFunc(func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Header: http.Header{"X-Gotrue-Id": []string{"user-abc-123"}},
}, nil
}),
onGotrueID: &cb,
}
req, _ := http.NewRequest("GET", "https://api.supabase.io/v1/projects", nil)
resp, err := transport.RoundTrip(req)
assert.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)
assert.Equal(t, "user-abc-123", captured)
}

func TestIdentityTransport_IgnoresWhenHeaderMissing(t *testing.T) {
var captured string
cb := func(id string) { captured = id }
transport := &identityTransport{
RoundTripper: roundTripFunc(func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Header: http.Header{},
}, nil
}),
onGotrueID: &cb,
}
req, _ := http.NewRequest("GET", "https://api.supabase.io/v1/projects", nil)
_, err := transport.RoundTrip(req)
assert.NoError(t, err)
assert.Empty(t, captured)
}

func TestIdentityTransport_NilCallbackDoesNotPanic(t *testing.T) {
transport := &identityTransport{
RoundTripper: roundTripFunc(func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Header: http.Header{"X-Gotrue-Id": []string{"user-abc-123"}},
}, nil
}),
onGotrueID: nil,
}
req, _ := http.NewRequest("GET", "https://api.supabase.io/v1/projects", nil)
resp, err := transport.RoundTrip(req)
assert.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)
}

func TestIdentityTransport_NilFuncBehindPointerDoesNotPanic(t *testing.T) {
var cb func(string) // nil func
transport := &identityTransport{
RoundTripper: roundTripFunc(func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Header: http.Header{"X-Gotrue-Id": []string{"user-abc-123"}},
}, nil
}),
onGotrueID: &cb,
}
req, _ := http.NewRequest("GET", "https://api.supabase.io/v1/projects", nil)
resp, err := transport.RoundTrip(req)
assert.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)
}

func TestIdentityTransport_InnerTransportError(t *testing.T) {
var captured string
cb := func(id string) { captured = id }
transport := &identityTransport{
RoundTripper: roundTripFunc(func(req *http.Request) (*http.Response, error) {
return nil, assert.AnError
}),
onGotrueID: &cb,
}
req, _ := http.NewRequest("GET", "https://api.supabase.io/v1/projects", nil)
resp, err := transport.RoundTrip(req)
assert.Error(t, err)
assert.Nil(t, resp)
assert.Empty(t, captured)
}

// roundTripFunc is a test helper to create inline RoundTrippers.
type roundTripFunc func(*http.Request) (*http.Response, error)

func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
Loading