Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions server/cmd/api/api/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package api

import (
"context"
"encoding/json"
"net/http"
"sync/atomic"
"time"

chiMiddleware "github.com/go-chi/chi/v5/middleware"

"github.com/kernel/kernel-images/server/lib/events"
oapi "github.com/kernel/kernel-images/server/lib/oapi"
)

// Per-request scratch shared between the chi-level HTTP middleware and the
// strict-server middleware so the latter can stamp the matched operationId.
type telemetryCtxKey struct{}

type telemetryRequestCtx struct {
operationID string
}

// Process-wide toggle for the api_call middleware. Flipped by
// Enable/DisableTelemetryMiddleware; both middleware layers short-circuit
// to passthroughs when false.
var telemetryMiddlewareEnabled atomic.Bool

// EnableTelemetryMiddleware turns on api_call event emission.
func EnableTelemetryMiddleware() { telemetryMiddlewareEnabled.Store(true) }

// DisableTelemetryMiddleware turns api_call event emission off.
func DisableTelemetryMiddleware() { telemetryMiddlewareEnabled.Store(false) }

// TelemetryMiddlewareEnabled reports the current state.
func TelemetryMiddlewareEnabled() bool { return telemetryMiddlewareEnabled.Load() }

// TelemetryHTTPMiddleware emits a BrowserApiCallEvent per documented operation,
// capturing the final status and wall-clock duration.
func TelemetryHTTPMiddleware(publish func(events.Event)) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !telemetryMiddlewareEnabled.Load() {
next.ServeHTTP(w, r)
return
}
tc := &telemetryRequestCtx{}
ctx := context.WithValue(r.Context(), telemetryCtxKey{}, tc)
ww := chiMiddleware.NewWrapResponseWriter(w, r.ProtoMajor)
start := time.Now()

next.ServeHTTP(ww, r.WithContext(ctx))

if tc.operationID == "" {
return
}
data, _ := json.Marshal(oapi.BrowserApiCallEventData{
RequestId: chiMiddleware.GetReqID(ctx),
OperationId: tc.operationID,
Status: ww.Status(),
DurationMs: float32(time.Since(start).Microseconds()) / 1000.0,
})
Comment thread
cursor[bot] marked this conversation as resolved.
publish(events.Event{
Ts: time.Now().UnixMicro(),
Type: "api_call",
Category: events.Api,
Source: oapi.BrowserEventSource{Kind: oapi.KernelApi},
Data: data,
})
})
}
}

// TelemetryStrictMiddleware records the matched OpenAPI operationId onto the
// per-request scratch so TelemetryHTTPMiddleware can include it in the event.
func TelemetryStrictMiddleware() oapi.StrictMiddlewareFunc {
return func(next oapi.StrictHandlerFunc, operationID string) oapi.StrictHandlerFunc {
return func(ctx context.Context, w http.ResponseWriter, r *http.Request, request any) (any, error) {
if !telemetryMiddlewareEnabled.Load() {
return next(ctx, w, r, request)
}
if tc, ok := ctx.Value(telemetryCtxKey{}).(*telemetryRequestCtx); ok {
tc.operationID = operationID
}
return next(ctx, w, r, request)
}
}
}
149 changes: 149 additions & 0 deletions server/cmd/api/api/middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package api

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"sync"
"testing"

chiMiddleware "github.com/go-chi/chi/v5/middleware"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/kernel/kernel-images/server/lib/events"
oapi "github.com/kernel/kernel-images/server/lib/oapi"
)

// recordingPublisher captures published events for assertion.
type recordingPublisher struct {
mu sync.Mutex
events []events.Event
}

func (rp *recordingPublisher) publish(ev events.Event) {
rp.mu.Lock()
defer rp.mu.Unlock()
rp.events = append(rp.events, ev)
}

func (rp *recordingPublisher) snapshot() []events.Event {
rp.mu.Lock()
defer rp.mu.Unlock()
out := make([]events.Event, len(rp.events))
copy(out, rp.events)
return out
}

// Mirrors the oapi-codegen strict dispatcher: middleware chain -> inner
// handler -> response write.
func fakeStrictHandler(operationID string, status int, mws []oapi.StrictMiddlewareFunc) http.Handler {
inner := oapi.StrictHandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request, request any) (any, error) {
return nil, nil
})
for _, mw := range mws {
inner = mw(inner, operationID)
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = inner(r.Context(), w, r, nil)
w.WriteHeader(status)
})
}

// Flips the package-level toggle on for the test, restoring prior state
// via t.Cleanup.
func withTelemetryMiddlewareEnabled(t *testing.T) {
t.Helper()
prev := TelemetryMiddlewareEnabled()
EnableTelemetryMiddleware()
t.Cleanup(func() {
if prev {
EnableTelemetryMiddleware()
} else {
DisableTelemetryMiddleware()
}
})
}

func TestTelemetryMiddleware_EmitsApiCallEventOnDocumentedRoute(t *testing.T) {
withTelemetryMiddlewareEnabled(t)
rp := &recordingPublisher{}
chain := chiHandler(t, rp.publish, "ProcessExec", http.StatusOK)

req := httptest.NewRequest(http.MethodPost, "/process/exec", nil)
rec := httptest.NewRecorder()
chain.ServeHTTP(rec, req)

captured := rp.snapshot()
require.Len(t, captured, 1)
ev := captured[0]
assert.Equal(t, "api_call", ev.Type)
assert.Equal(t, events.Api, ev.Category)
assert.Equal(t, oapi.KernelApi, ev.Source.Kind)

var data struct {
RequestID string `json:"request_id"`
OperationID string `json:"operation_id"`
Status int `json:"status"`
DurationMs float64 `json:"duration_ms"`
}
require.NoError(t, json.Unmarshal(ev.Data, &data))
assert.NotEmpty(t, data.RequestID, "request_id should be set by chi RequestID middleware")
assert.Equal(t, "ProcessExec", data.OperationID)
assert.Equal(t, http.StatusOK, data.Status)
assert.GreaterOrEqual(t, data.DurationMs, 0.0)
}

func TestTelemetryMiddleware_CapturesNonOKStatus(t *testing.T) {
withTelemetryMiddlewareEnabled(t)
rp := &recordingPublisher{}
chain := chiHandler(t, rp.publish, "ProcessExec", http.StatusInternalServerError)

req := httptest.NewRequest(http.MethodPost, "/process/exec", nil)
rec := httptest.NewRecorder()
chain.ServeHTTP(rec, req)

captured := rp.snapshot()
require.Len(t, captured, 1)
var data struct {
Status int `json:"status"`
}
require.NoError(t, json.Unmarshal(captured[0].Data, &data))
assert.Equal(t, http.StatusInternalServerError, data.Status)
}

func TestTelemetryMiddleware_SkipsUndocumentedRoutes(t *testing.T) {
withTelemetryMiddlewareEnabled(t)
rp := &recordingPublisher{}
mw := TelemetryHTTPMiddleware(rp.publish)
plain := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))

req := httptest.NewRequest(http.MethodGet, "/health", nil)
chiMiddleware.RequestID(plain).ServeHTTP(httptest.NewRecorder(), req)

assert.Empty(t, rp.snapshot(), "no event should be emitted when operationId is unset")
}

func TestTelemetryMiddleware_ShortCircuitsWhenDisabled(t *testing.T) {
DisableTelemetryMiddleware()
rp := &recordingPublisher{}
chain := chiHandler(t, rp.publish, "ProcessExec", http.StatusOK)

req := httptest.NewRequest(http.MethodPost, "/process/exec", nil)
rec := httptest.NewRecorder()
chain.ServeHTTP(rec, req)

assert.Empty(t, rp.snapshot(), "disabled middleware must not emit")
}

// Builds the same middleware stack as main.go: RequestID -> HTTP middleware ->
// strict dispatch -> inner handler.
func chiHandler(t *testing.T, publish func(events.Event), operationID string, status int) http.Handler {
t.Helper()
inner := fakeStrictHandler(operationID, status, []oapi.StrictMiddlewareFunc{TelemetryStrictMiddleware()})
telemetry := TelemetryHTTPMiddleware(publish)(inner)
return chiMiddleware.RequestID(telemetry)
}
44 changes: 37 additions & 7 deletions server/cmd/api/api/telemetry.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package api
import (
"context"

"github.com/nrednav/cuid2"
oapi "github.com/kernel/kernel-images/server/lib/oapi"
"github.com/nrednav/cuid2"

"github.com/kernel/kernel-images/server/lib/events"
"github.com/kernel/kernel-images/server/lib/logger"
Expand All @@ -25,7 +25,7 @@ func (s *ApiService) GetTelemetry(_ context.Context, _ oapi.GetTelemetryRequestO

// PutTelemetry handles PUT /telemetry.
// Sets the telemetry configuration. Returns 201 if not previously configured, 200 if it was.
// Setting all four categories to enabled:false clears the configuration (200).
// Setting all five categories to enabled:false clears the configuration (200).
func (s *ApiService) PutTelemetry(ctx context.Context, req oapi.PutTelemetryRequestObject) (oapi.PutTelemetryResponseObject, error) {
s.monitorMu.Lock()
defer s.monitorMu.Unlock()
Expand All @@ -45,12 +45,14 @@ func (s *ApiService) PutTelemetry(ctx context.Context, req oapi.PutTelemetryRequ
// All categories disabled: clear the configuration.
s.cdpMonitor.Stop()
s.telemetrySession.Stop()
s.applyTelemetryMiddlewareState()
return oapi.PutTelemetry200JSONResponse(oapi.TelemetryState{Config: disabledConfig(), Seq: int64(s.telemetrySession.Seq())}), nil
}

if wasActive {
// Replace config on the running session.
s.telemetrySession.UpdateConfig(cfg)
s.applyTelemetryMiddlewareState()
return oapi.PutTelemetry200JSONResponse(s.buildTelemetryResponse()), nil
}

Expand All @@ -61,16 +63,18 @@ func (s *ApiService) PutTelemetry(ctx context.Context, req oapi.PutTelemetryRequ
if err := s.cdpMonitor.Start(s.lifecycleCtx); err != nil {
// Roll back: clear the session so a retry can succeed.
s.telemetrySession.Stop()
s.applyTelemetryMiddlewareState()
logger.FromContext(ctx).Error("failed to start telemetry monitor", "err", err)
return oapi.PutTelemetry500JSONResponse{InternalErrorJSONResponse: oapi.InternalErrorJSONResponse{Message: "failed to start telemetry"}}, nil
}

s.applyTelemetryMiddlewareState()
return oapi.PutTelemetry201JSONResponse(s.buildTelemetryResponse()), nil
}

// PatchTelemetry handles PATCH /telemetry.
// Partially updates the telemetry configuration. Returns 404 if not configured.
// Setting all four categories to enabled:false clears the configuration (200).
// Setting all five categories to enabled:false clears the configuration (200).
func (s *ApiService) PatchTelemetry(_ context.Context, req oapi.PatchTelemetryRequestObject) (oapi.PatchTelemetryResponseObject, error) {
s.monitorMu.Lock()
defer s.monitorMu.Unlock()
Expand All @@ -88,14 +92,33 @@ func (s *ApiService) PatchTelemetry(_ context.Context, req oapi.PatchTelemetryRe
// All categories disabled: clear the configuration.
s.cdpMonitor.Stop()
s.telemetrySession.Stop()
s.applyTelemetryMiddlewareState()
return oapi.PatchTelemetry200JSONResponse(oapi.TelemetryState{Config: disabledConfig(), Seq: int64(s.telemetrySession.Seq())}), nil
}
s.telemetrySession.UpdateConfig(cfg)
s.applyTelemetryMiddlewareState()
}

return oapi.PatchTelemetry200JSONResponse(s.buildTelemetryResponse()), nil
}

// applyTelemetryMiddlewareState turns the api_call middleware on iff the
// session is active and the api category is enabled. Call after any config
// change.
func (s *ApiService) applyTelemetryMiddlewareState() {
if !s.telemetrySession.Active() {
DisableTelemetryMiddleware()
return
}
for _, c := range s.telemetrySession.Config().Categories {
if c == events.Api {
EnableTelemetryMiddleware()
return
}
}
DisableTelemetryMiddleware()
}

// buildTelemetryResponse constructs a TelemetryState response from the current configuration.
func (s *ApiService) buildTelemetryResponse() oapi.TelemetryState {
resp := oapi.TelemetryState{
Expand Down Expand Up @@ -127,13 +150,14 @@ func telemetryConfigFromOAPI(cfg *oapi.BrowserTelemetryConfig) (telemetry.Teleme
networkOn := isEnabled(b.Network)
pageOn := isEnabled(b.Page)
interactionOn := isEnabled(b.Interaction)
apiOn := isEnabled(b.Api)

allDisabled := !consoleOn && !networkOn && !pageOn && !interactionOn
allDisabled := !consoleOn && !networkOn && !pageOn && !interactionOn && !apiOn
if allDisabled {
return telemetry.TelemetryConfig{}, true, nil
}

cats := make([]oapi.TelemetryEventCategory, 0, 5)
cats := make([]oapi.TelemetryEventCategory, 0, 6)
if consoleOn {
cats = append(cats, events.Console)
}
Expand All @@ -146,6 +170,9 @@ func telemetryConfigFromOAPI(cfg *oapi.BrowserTelemetryConfig) (telemetry.Teleme
if interactionOn {
cats = append(cats, events.Interaction)
}
if apiOn {
cats = append(cats, events.Api)
}
// CategorySystem is always appended by TelemetrySession.Start/UpdateConfig;
// no need to include it here.
return telemetry.TelemetryConfig{Categories: cats}, false, nil
Expand Down Expand Up @@ -177,6 +204,7 @@ func mergeTelemetryConfig(current telemetry.TelemetryConfig, patch *oapi.Browser
override(events.Network, patch.Network)
override(events.Page, patch.Page)
override(events.Interaction, patch.Interaction)
override(events.Api, patch.Api)

// CategorySystem is managed internally by TelemetrySession; exclude from the
// user-facing allDisabled check.
Expand All @@ -185,6 +213,7 @@ func mergeTelemetryConfig(current telemetry.TelemetryConfig, patch *oapi.Browser
events.Network,
events.Page,
events.Interaction,
events.Api,
}
allDisabled := true
for _, c := range userCats {
Expand All @@ -204,7 +233,7 @@ func mergeTelemetryConfig(current telemetry.TelemetryConfig, patch *oapi.Browser
return telemetry.TelemetryConfig{Categories: cats}, false
}

// disabledConfig returns a BrowserTelemetryConfig with all four user-facing categories explicitly disabled.
// disabledConfig returns a BrowserTelemetryConfig with all five user-facing categories explicitly disabled.
func disabledConfig() oapi.BrowserTelemetryConfig {
f := false
cat := &oapi.BrowserTelemetryCategoryConfig{Enabled: &f}
Expand All @@ -214,6 +243,7 @@ func disabledConfig() oapi.BrowserTelemetryConfig {
Network: cat,
Page: cat,
Interaction: cat,
Api: cat,
Comment thread
cursor[bot] marked this conversation as resolved.
Outdated
},
}
}
Expand All @@ -238,7 +268,7 @@ func telemetryConfigToOAPI(cfg telemetry.TelemetryConfig) oapi.BrowserTelemetryC
Network: enabled(events.Network),
Page: enabled(events.Page),
Interaction: enabled(events.Interaction),
Api: enabled(events.Api),
},
}
}

Loading