diff --git a/cmd/root.go b/cmd/root.go index c8c29174eb..ae2966ab97 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -9,6 +9,7 @@ import ( "os" "os/signal" "strings" + "sync" "time" "github.com/getsentry/sentry-go" @@ -140,6 +141,18 @@ var ( } else { ctx = telemetry.WithService(ctx, service) } + if service != nil { + var stitchOnce sync.Once + utils.OnGotrueID = func(gotrueID string) { + if service.NeedsIdentityStitch() { + stitchOnce.Do(func() { + if err := service.StitchLogin(gotrueID); err != nil { + fmt.Fprintln(utils.GetDebugLogger(), err) + } + }) + } + } + } ctx = telemetry.WithCommandContext(ctx, commandAnalyticsContext(cmd)) cmd.SetContext(ctx) // Setup sentry last to ignore errors from parsing cli flags diff --git a/internal/telemetry/service.go b/internal/telemetry/service.go index 8bf920b8d0..39e0b6c073 100644 --- a/internal/telemetry/service.go +++ b/internal/telemetry/service.go @@ -153,6 +153,10 @@ func (s *Service) ClearDistinctID() error { return SaveState(s.state, s.fsys) } +func (s *Service) NeedsIdentityStitch() bool { + return s != nil && s.state.DistinctID == "" && s.canSend() +} + func (s *Service) GroupIdentify(groupType string, groupKey string, properties map[string]any) error { if !s.canSend() { return nil diff --git a/internal/telemetry/service_test.go b/internal/telemetry/service_test.go index d8620c85ff..c46a793bed 100644 --- a/internal/telemetry/service_test.go +++ b/internal/telemetry/service_test.go @@ -213,6 +213,28 @@ func TestServiceCaptureIncludesLinkedProjectGroups(t *testing.T) { }, analytics.captures[0].groups) } +func TestServiceNeedsIdentityStitch(t *testing.T) { + now := time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC) + t.Setenv("SUPABASE_HOME", "/tmp/supabase-home") + fsys := afero.NewMemMapFs() + analytics := &fakeAnalytics{enabled: true} + + service, err := NewService(fsys, Options{ + Analytics: analytics, + Now: func() time.Time { return now }, + }) + require.NoError(t, err) + + t.Run("true when DistinctID is empty", func(t *testing.T) { + assert.True(t, service.NeedsIdentityStitch()) + }) + + t.Run("false after StitchLogin", func(t *testing.T) { + require.NoError(t, service.StitchLogin("user-123")) + assert.False(t, service.NeedsIdentityStitch()) + }) +} + func TestServiceCaptureHonorsConsentAndEnvOptOut(t *testing.T) { now := time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC) diff --git a/internal/utils/api.go b/internal/utils/api.go index 5b9a96fdee..06f37a8a50 100644 --- a/internal/utils/api.go +++ b/internal/utils/api.go @@ -21,6 +21,8 @@ const ( DNS_OVER_HTTPS = "https" ) +var OnGotrueID func(string) + var ( clientOnce sync.Once apiClient *supabase.ClientWithResponses @@ -123,8 +125,13 @@ func GetSupabase() *supabase.ClientWithResponses { if t, ok := http.DefaultTransport.(*http.Transport); ok { t.DialContext = withFallbackDNS(t.DialContext) } + transport := &identityTransport{ + RoundTripper: http.DefaultTransport, + onGotrueID: &OnGotrueID, + } apiClient, err = supabase.NewClientWithResponses( GetSupabaseAPIHost(), + supabase.WithHTTPClient(&http.Client{Transport: transport}), supabase.WithRequestEditorFn(func(ctx context.Context, req *http.Request) error { req.Header.Set("Authorization", "Bearer "+token) req.Header.Set("User-Agent", "SupabaseCLI/"+Version) diff --git a/internal/utils/identity_transport.go b/internal/utils/identity_transport.go new file mode 100644 index 0000000000..bcf01d97e7 --- /dev/null +++ b/internal/utils/identity_transport.go @@ -0,0 +1,21 @@ +package utils + +import "net/http" + +const HeaderGotrueID = "X-Gotrue-Id" + +type identityTransport struct { + http.RoundTripper + onGotrueID *func(string) +} + +func (t *identityTransport) RoundTrip(req *http.Request) (*http.Response, error) { + resp, err := t.RoundTripper.RoundTrip(req) + if err != nil { + return resp, err + } + if id := resp.Header.Get(HeaderGotrueID); id != "" && t.onGotrueID != nil && *t.onGotrueID != nil { + (*t.onGotrueID)(id) + } + return resp, err +} diff --git a/internal/utils/identity_transport_test.go b/internal/utils/identity_transport_test.go new file mode 100644 index 0000000000..314039d0a9 --- /dev/null +++ b/internal/utils/identity_transport_test.go @@ -0,0 +1,101 @@ +package utils + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIdentityTransport_CapturesGotrueIdHeader(t *testing.T) { + var captured string + cb := func(id string) { captured = id } + transport := &identityTransport{ + RoundTripper: roundTripFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Header: http.Header{"X-Gotrue-Id": []string{"user-abc-123"}}, + }, nil + }), + onGotrueID: &cb, + } + req, _ := http.NewRequest("GET", "https://api.supabase.io/v1/projects", nil) + resp, err := transport.RoundTrip(req) + assert.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + assert.Equal(t, "user-abc-123", captured) +} + +func TestIdentityTransport_IgnoresWhenHeaderMissing(t *testing.T) { + var captured string + cb := func(id string) { captured = id } + transport := &identityTransport{ + RoundTripper: roundTripFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Header: http.Header{}, + }, nil + }), + onGotrueID: &cb, + } + req, _ := http.NewRequest("GET", "https://api.supabase.io/v1/projects", nil) + _, err := transport.RoundTrip(req) + assert.NoError(t, err) + assert.Empty(t, captured) +} + +func TestIdentityTransport_NilCallbackDoesNotPanic(t *testing.T) { + transport := &identityTransport{ + RoundTripper: roundTripFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Header: http.Header{"X-Gotrue-Id": []string{"user-abc-123"}}, + }, nil + }), + onGotrueID: nil, + } + req, _ := http.NewRequest("GET", "https://api.supabase.io/v1/projects", nil) + resp, err := transport.RoundTrip(req) + assert.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) +} + +func TestIdentityTransport_NilFuncBehindPointerDoesNotPanic(t *testing.T) { + var cb func(string) // nil func + transport := &identityTransport{ + RoundTripper: roundTripFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Header: http.Header{"X-Gotrue-Id": []string{"user-abc-123"}}, + }, nil + }), + onGotrueID: &cb, + } + req, _ := http.NewRequest("GET", "https://api.supabase.io/v1/projects", nil) + resp, err := transport.RoundTrip(req) + assert.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) +} + +func TestIdentityTransport_InnerTransportError(t *testing.T) { + var captured string + cb := func(id string) { captured = id } + transport := &identityTransport{ + RoundTripper: roundTripFunc(func(req *http.Request) (*http.Response, error) { + return nil, assert.AnError + }), + onGotrueID: &cb, + } + req, _ := http.NewRequest("GET", "https://api.supabase.io/v1/projects", nil) + resp, err := transport.RoundTrip(req) + assert.Error(t, err) + assert.Nil(t, resp) + assert.Empty(t, captured) +} + +// roundTripFunc is a test helper to create inline RoundTrippers. +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +}