diff --git a/.gitignore b/.gitignore index de588db3..1bd57665 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,7 @@ /.idea /build /dist -/internal/util/version.txt +/internal/version/version.txt /python/cog/cog-* /python/coglet/_version.py /uv.lock diff --git a/cmd/cog/main.go b/cmd/cog/main.go index 26643796..c659947b 100644 --- a/cmd/cog/main.go +++ b/cmd/cog/main.go @@ -16,7 +16,7 @@ import ( "github.com/replicate/cog-runtime/internal/config" "github.com/replicate/cog-runtime/internal/runner" "github.com/replicate/cog-runtime/internal/service" - "github.com/replicate/cog-runtime/internal/util" + "github.com/replicate/cog-runtime/internal/version" ) type ServerCmd struct { @@ -125,7 +125,7 @@ func (s *ServerCmd) Run() error { } addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port) - log.Infow("starting Cog HTTP server", "addr", addr, "version", util.Version(), "pid", os.Getpid()) + log.Infow("starting Cog HTTP server", "addr", addr, "version", version.Version(), "pid", os.Getpid()) // Create service with base logger svc := service.New(cfg, baseLogger) diff --git a/internal/runner/path.go b/internal/runner/path.go index b4edaead..96862065 100644 --- a/internal/runner/path.go +++ b/internal/runner/path.go @@ -15,8 +15,7 @@ import ( "github.com/gabriel-vasile/mimetype" "github.com/getkin/kin-openapi/openapi3" - - "github.com/replicate/cog-runtime/internal/util" + "github.com/replicate/go/httpclient" ) var Base64Regex = regexp.MustCompile(`^data:.*;base64,(?P.*)$`) @@ -210,7 +209,7 @@ type uploader struct { // newUploader creates a new uploader instance func newUploader(uploadURL string) *uploader { return &uploader{ - client: util.HTTPClientWithRetry(), + client: httpclient.ApplyRetryPolicy(http.DefaultClient), uploadURL: uploadURL, } } diff --git a/internal/runner/runner.go b/internal/runner/runner.go index f1dd75d4..511e6a0d 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -2,10 +2,12 @@ package runner import ( "bufio" + "bytes" "context" "encoding/json" "errors" "fmt" + "net/http" "os" "os/exec" "path" @@ -20,8 +22,10 @@ import ( "github.com/getkin/kin-openapi/openapi3" "go.uber.org/zap" + "github.com/replicate/go/httpclient" + "github.com/replicate/cog-runtime/internal/config" - "github.com/replicate/cog-runtime/internal/util" + "github.com/replicate/cog-runtime/internal/version" "github.com/replicate/cog-runtime/internal/webhook" ) @@ -293,7 +297,7 @@ type Runner struct { schema string doc *openapi3.T setupResult SetupResult - logs []string + logs LogsSlice asyncPredict bool maxConcurrency int pending map[string]*PendingPrediction @@ -464,14 +468,27 @@ func (r *Runner) setupLogCapture() error { func (r *Runner) logStdout(line string) { r.captureLogLine(line) - _, _ = fmt.Fprintln(os.Stdout, line) //nolint:forbidigo // mirror log to stdout + // Strip [pid=xxxxx] prefix before mirroring to stdout + mirrorLine := stripPIDPrefix(line) + _, _ = fmt.Fprintln(os.Stdout, mirrorLine) //nolint:forbidigo // mirror log to stdout } // logStderr captures a line from stderr and mirrors to stderr func (r *Runner) logStderr(line string) { r.captureLogLine(line) - _, _ = fmt.Fprintln(os.Stderr, line) //nolint:forbidigo // mirror log to stderr + // Strip [pid=xxxxx] prefix before mirroring to stderr + mirrorLine := stripPIDPrefix(line) + _, _ = fmt.Fprintln(os.Stderr, mirrorLine) //nolint:forbidigo // mirror log to stderr +} + +func stripPIDPrefix(line string) string { + if LogRegex.MatchString(line) { + if m := LogRegex.FindStringSubmatch(line); m != nil { + return m[2] // Extract message without pid prefix + } + } + return line } // captureLogLine handles routing log lines like the old implementation @@ -521,7 +538,7 @@ func (r *Runner) captureLogLine(line string) { } else { // Add to runner logs for crash reporting r.logs = append(r.logs, line) - r.setupResult.Logs = util.JoinLogs(r.logs) + r.setupResult.Logs = r.logs.String() } r.mu.Unlock() default: @@ -566,6 +583,9 @@ func (r *Runner) Config(ctx context.Context) error { // Default to 1 if not set in cog.yaml, regardless whether async predict or not maxConcurrency := max(1, cogYaml.Concurrency.Max) + // Send metrics + go r.sendRunnerMetric(*cogYaml) + // Create config.json for the coglet process configJSON := map[string]any{ "module_name": moduleName, @@ -593,6 +613,36 @@ func (r *Runner) Config(ctx context.Context) error { return nil } +func (r *Runner) sendRunnerMetric(cogYaml CogYaml) { + log := r.logger.Sugar() + // FIXME: wire this up through more than os.getenv + endpoint := os.Getenv("COG_METRICS_ENDPOINT") + if endpoint == "" { + return + } + data := map[string]any{ + "gpu": cogYaml.Build.GPU, + "fast": cogYaml.Build.Fast, + "cog_runtime": cogYaml.Build.CogRuntime, + "version": version.Version(), + } + payload := MetricsPayload{ + Source: "cog-runtime", + Type: "runner", + Data: data, + } + body, err := json.Marshal(payload) + if err != nil { + log.Errorw("failed to marshal payload", "error", err) + return + } + resp, err := httpclient.ApplyRetryPolicy(http.DefaultClient).Post(endpoint, "application/json", bytes.NewBuffer(body)) + if err != nil || resp.StatusCode != http.StatusOK { + log.Errorw("failed to send runner metrics", "error", err) + } + defer resp.Body.Close() +} + func (r *Runner) Stop() error { log := r.logger.Sugar() r.mu.Lock() @@ -913,7 +963,10 @@ func (r *Runner) updateSetupResult() { } // Set logs first (original pattern) - r.setupResult.Logs = util.JoinLogs(logLines) + r.setupResult.Logs = strings.Join(logLines, "\n") + if r.setupResult.Logs != "" { + r.setupResult.Logs += "\n" + } setupResultPath := filepath.Join(r.runnerCtx.workingdir, "setup_result.json") log.Debug("reading setup_result.json", "path", setupResultPath) @@ -954,7 +1007,7 @@ func (r *Runner) rotateLogs() string { r.mu.Lock() defer r.mu.Unlock() - allLogs := util.JoinLogs(r.logs) + allLogs := r.logs.String() r.logs = r.logs[:0] return allLogs } diff --git a/internal/runner/runner_test.go b/internal/runner/runner_test.go index 911b93e4..d4586fcc 100644 --- a/internal/runner/runner_test.go +++ b/internal/runner/runner_test.go @@ -1208,7 +1208,7 @@ func TestPerPredictionWatcher(t *testing.T) { assert.Equal(t, "partial output", pending.response.Output) assert.Equal(t, predictionID, pending.response.ID) assert.Equal(t, map[string]any{"test": "input"}, pending.response.Input) - assert.Equal(t, []string{"existing log"}, pending.response.Logs) // Logs preserved + assert.Equal(t, LogsSlice{"existing log"}, pending.response.Logs) // Logs preserved pending.mu.Unlock() }) diff --git a/internal/runner/types.go b/internal/runner/types.go index d91d5935..4cfa6b02 100644 --- a/internal/runner/types.go +++ b/internal/runner/types.go @@ -16,10 +16,50 @@ import ( "syscall" "time" - "github.com/replicate/cog-runtime/internal/util" "github.com/replicate/cog-runtime/internal/webhook" ) +// LogsSlice is a []string that marshals to/from a newline-joined string in JSON +type LogsSlice []string + +func (l LogsSlice) String() string { + r := strings.Join(l, "\n") + if r != "" { + r += "\n" + } + return r +} + +// MarshalJSON implements custom JSON marshaling to convert logs from []string to string +func (l LogsSlice) MarshalJSON() ([]byte, error) { + result := strings.Join(l, "\n") + if result != "" { + result += "\n" + } + return json.Marshal(result) +} + +// UnmarshalJSON implements custom JSON unmarshaling to convert logs from string to []string +func (l *LogsSlice) UnmarshalJSON(data []byte) error { + var str string + if err := json.Unmarshal(data, &str); err != nil { + return err + } + + if str == "" { + *l = nil + return nil + } + + // Split on newline and remove the trailing empty element if it exists + parts := strings.Split(str, "\n") + if len(parts) > 0 && parts[len(parts)-1] == "" { + parts = parts[:len(parts)-1] + } + *l = LogsSlice(parts) + return nil +} + type Status int const ( @@ -116,72 +156,11 @@ type PredictionResponse struct { Input any `json:"input,omitempty"` Output any `json:"output,omitempty"` Error string `json:"error,omitempty"` - Logs []string `json:"logs,omitempty"` + Logs LogsSlice `json:"logs,omitempty"` Metrics any `json:"metrics,omitempty"` WebhookURL string `json:"webhook,omitempty"` } -// MarshalJSON implements custom JSON marshaling to convert logs from []string to string -func (pr PredictionResponse) MarshalJSON() ([]byte, error) { - return json.Marshal(&struct { - ID string `json:"id"` - Status PredictionStatus `json:"status"` - Input any `json:"input,omitempty"` - Output any `json:"output,omitempty"` - Error string `json:"error,omitempty"` - Logs string `json:"logs,omitempty"` - Metrics any `json:"metrics,omitempty"` - WebhookURL string `json:"webhook,omitempty"` - }{ - ID: pr.ID, - Status: pr.Status, - Input: pr.Input, - Output: pr.Output, - Error: pr.Error, - Logs: util.JoinLogs(pr.Logs), - Metrics: pr.Metrics, - WebhookURL: pr.WebhookURL, - }) -} - -// UnmarshalJSON implements custom JSON unmarshalling to convert logs from string to []string -func (pr *PredictionResponse) UnmarshalJSON(data []byte) error { - aux := &struct { - ID string `json:"id"` - Status PredictionStatus `json:"status"` - Input any `json:"input,omitempty"` - Output any `json:"output,omitempty"` - Error string `json:"error,omitempty"` - Logs string `json:"logs,omitempty"` - Metrics any `json:"metrics,omitempty"` - WebhookURL string `json:"webhook,omitempty"` - }{} - if err := json.Unmarshal(data, aux); err != nil { - return err - } - - pr.ID = aux.ID - pr.Status = aux.Status - pr.Input = aux.Input - pr.Output = aux.Output - pr.Error = aux.Error - pr.Metrics = aux.Metrics - pr.WebhookURL = aux.WebhookURL - - // Convert string logs back to []string - if aux.Logs != "" { - // Split on newline and remove the trailing empty element if it exists - parts := strings.Split(aux.Logs, "\n") - if len(parts) > 0 && parts[len(parts)-1] == "" { - parts = parts[:len(parts)-1] - } - pr.Logs = parts - } else { - pr.Logs = nil - } - return nil -} - // RunnerID is a unique identifier for a runner instance. // Format: 8-character base32 string (no leading zeros) // Example: "k7m3n8p2", "b9q4x2w1" @@ -391,3 +370,9 @@ func (p *PendingPrediction) sendWebhookSync(event webhook.Event) error { _ = p.webhookSender.SendConditional(p.request.Webhook, bytes.NewReader(body), event, p.request.WebhookEventsFilter, &p.lastUpdated) return nil } + +type MetricsPayload struct { + Source string `json:"source,omitempty"` + Type string `json:"type,omitempty"` + Data map[string]any `json:"data,omitempty"` +} diff --git a/internal/runner/types_test.go b/internal/runner/types_test.go index d4b35c8b..cc565bd2 100644 --- a/internal/runner/types_test.go +++ b/internal/runner/types_test.go @@ -139,7 +139,7 @@ func TestPredictionResponse(t *testing.T) { assert.Equal(t, PredictionSucceeded, resp.Status) assert.Equal(t, map[string]any{"result": "success"}, resp.Output) assert.Empty(t, resp.Error) - assert.Equal(t, []string{"log1", "log2"}, resp.Logs) + assert.Equal(t, LogsSlice{"log1", "log2"}, resp.Logs) assert.Equal(t, map[string]any{"duration": 1.5}, resp.Metrics) assert.Equal(t, "http://example.com/webhook", resp.WebhookURL) }) @@ -330,7 +330,7 @@ func TestPredictionResponseUnmarshalFromExternalJSON(t *testing.T) { err := json.Unmarshal([]byte(jsonStr), &response) require.NoError(t, err) - expected := []string{ + expected := LogsSlice{ "starting prediction", "prediction in progress 1/2", "prediction in progress 2/2", diff --git a/internal/server/server.go b/internal/server/server.go index 41c19670..a7fe6f96 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -3,6 +3,7 @@ package server import ( "bytes" "context" + "encoding/base32" "encoding/json" "errors" "fmt" @@ -15,11 +16,15 @@ import ( "go.uber.org/zap" + "github.com/replicate/go/httpclient" + "github.com/replicate/go/uuid" + "github.com/replicate/cog-runtime/internal/config" "github.com/replicate/cog-runtime/internal/runner" - "github.com/replicate/cog-runtime/internal/util" ) +const TimeLayout = "2006-01-02T15:04:05.999999-07:00" + // errAsyncPrediction is a sentinel error used to indicate that a prediction is being served asynchronously, it is not surfaced outside of server var errAsyncPrediction = errors.New("async prediction") @@ -110,8 +115,8 @@ func (h *Handler) healthCheck() (*HealthCheck, error) { hc := HealthCheck{ Status: runnerStatus, Setup: &SetupResult{ - StartedAt: util.FormatTime(h.startedAt), - CompletedAt: util.FormatTime(h.startedAt), + StartedAt: formatTime(h.startedAt), + CompletedAt: formatTime(h.startedAt), Status: runnerSetupResult.Status, Logs: logsStr, }, @@ -231,7 +236,7 @@ func (h *Handler) Predict(w http.ResponseWriter, r *http.Request) { req.ID = id } if req.ID == "" { - req.ID, err = util.PredictionID() + req.ID, err = PredictionID() if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -304,7 +309,7 @@ func (h *Handler) Predict(w http.ResponseWriter, r *http.Request) { log.Debugw("runner result received", "id", runnerResult.ID, "logs_count", len(runnerResult.Logs)) if len(runnerResult.Logs) > 0 { log.Debugw("joining logs", "logs", runnerResult.Logs) - logsStr = util.JoinLogs(runnerResult.Logs) + logsStr = runnerResult.Logs.String() log.Debugw("joined logs result", "logs_str", logsStr) } var metrics map[string]any @@ -390,7 +395,7 @@ func SendWebhook(webhook string, pr *PredictionResponse) error { // Only retry on completed webhooks client := http.DefaultClient if pr.Status.IsCompleted() { - client = util.HTTPClientWithRetry() + client = httpclient.ApplyRetryPolicy(http.DefaultClient) } resp, err := client.Do(req) if err != nil { @@ -438,3 +443,20 @@ func writeReadyFile() error { return nil } + +func PredictionID() (string, error) { + u, err := uuid.NewV7() + if err != nil { + return "", err + } + shuffle := make([]byte, uuid.Size) + for i := 0; i < 4; i++ { + shuffle[i], shuffle[i+4], shuffle[i+8], shuffle[i+12] = u[i+12], u[i+4], u[i], u[i+8] + } + encoding := base32.NewEncoding("0123456789abcdefghjkmnpqrstvwxyz").WithPadding(base32.NoPadding) + return encoding.EncodeToString(shuffle), nil +} + +func formatTime(t time.Time) string { + return t.UTC().Format(TimeLayout) +} diff --git a/internal/tests/async_prediction_test.go b/internal/tests/async_prediction_test.go index fa190f7a..d701976a 100644 --- a/internal/tests/async_prediction_test.go +++ b/internal/tests/async_prediction_test.go @@ -11,7 +11,7 @@ import ( "github.com/stretchr/testify/require" "github.com/replicate/cog-runtime/internal/runner" - "github.com/replicate/cog-runtime/internal/util" + "github.com/replicate/cog-runtime/internal/server" "github.com/replicate/cog-runtime/internal/webhook" ) @@ -62,7 +62,7 @@ func TestAsyncPrediction(t *testing.T) { }) waitForSetupComplete(t, runtimeServer, runner.StatusReady, runner.SetupSucceeded) - predictionID, err := util.PredictionID() + predictionID, err := server.PredictionID() require.NoError(t, err) prediction := runner.PredictionRequest{ Input: map[string]any{"i": 1, "s": "bar"}, @@ -118,7 +118,7 @@ func TestAsyncPredictionCanceled(t *testing.T) { }) waitForSetupComplete(t, runtimeServer, runner.StatusReady, runner.SetupSucceeded) - predictionID, err := util.PredictionID() + predictionID, err := server.PredictionID() require.NoError(t, err) prediction := runner.PredictionRequest{ Input: map[string]any{"i": 60, "s": "bar"}, @@ -195,7 +195,7 @@ func TestAsyncPredictionConcurrency(t *testing.T) { assert.Equal(t, 1, hc.Concurrency.Max) assert.Equal(t, 0, hc.Concurrency.Current) - predictionID, err := util.PredictionID() + predictionID, err := server.PredictionID() require.NoError(t, err) prediction := runner.PredictionRequest{ Input: map[string]any{"i": 1, "s": "bar"}, diff --git a/internal/tests/async_predictor_test.go b/internal/tests/async_predictor_test.go index 7c1a0793..bcdf2cbb 100644 --- a/internal/tests/async_predictor_test.go +++ b/internal/tests/async_predictor_test.go @@ -11,7 +11,7 @@ import ( "github.com/stretchr/testify/require" "github.com/replicate/cog-runtime/internal/runner" - "github.com/replicate/cog-runtime/internal/util" + "github.com/replicate/cog-runtime/internal/server" "github.com/replicate/cog-runtime/internal/webhook" ) @@ -28,9 +28,9 @@ func TestAsyncPredictorConcurrency(t *testing.T) { receiverServer := testHarnessReceiverServer(t) waitForSetupComplete(t, runtimeServer, runner.StatusReady, runner.SetupSucceeded) - barID, err := util.PredictionID() + barID, err := server.PredictionID() require.NoError(t, err) - bazID, err := util.PredictionID() + bazID, err := server.PredictionID() require.NoError(t, err) barReq := httpPredictionRequestWithID(t, runtimeServer, runner.PredictionRequest{ Input: map[string]any{"i": 1, "s": "bar"}, @@ -94,7 +94,7 @@ func TestAsyncPredictorCanceled(t *testing.T) { receiverServer := testHarnessReceiverServer(t) waitForSetupComplete(t, runtimeServer, runner.StatusReady, runner.SetupSucceeded) - barID, err := util.PredictionID() + barID, err := server.PredictionID() require.NoError(t, err) barReq := httpPredictionRequestWithID(t, runtimeServer, runner.PredictionRequest{ Input: map[string]any{"i": 60, "s": "bar"}, diff --git a/internal/tests/filter_test.go b/internal/tests/filter_test.go index bb0b3000..a87bef5b 100644 --- a/internal/tests/filter_test.go +++ b/internal/tests/filter_test.go @@ -10,7 +10,7 @@ import ( "github.com/stretchr/testify/require" "github.com/replicate/cog-runtime/internal/runner" - "github.com/replicate/cog-runtime/internal/util" + "github.com/replicate/cog-runtime/internal/server" "github.com/replicate/cog-runtime/internal/webhook" ) @@ -111,7 +111,7 @@ func TestPredictionWebhookFilter(t *testing.T) { }) waitForSetupComplete(t, runtimeServer, runner.StatusReady, runner.SetupSucceeded) - predictionID, err := util.PredictionID() + predictionID, err := server.PredictionID() require.NoError(t, err) prediction := runner.PredictionRequest{ Input: map[string]any{"i": 2, "s": "bar"}, diff --git a/internal/tests/iterator_test.go b/internal/tests/iterator_test.go index 53c3365d..6de3a877 100644 --- a/internal/tests/iterator_test.go +++ b/internal/tests/iterator_test.go @@ -10,7 +10,6 @@ import ( "github.com/replicate/cog-runtime/internal/runner" "github.com/replicate/cog-runtime/internal/server" - "github.com/replicate/cog-runtime/internal/util" "github.com/replicate/cog-runtime/internal/webhook" ) @@ -92,9 +91,9 @@ func TestPredictionAsyncIteratorConcurrency(t *testing.T) { waitForSetupComplete(t, runtimeServer, runner.StatusReady, runner.SetupSucceeded) - barID, err := util.PredictionID() + barID, err := server.PredictionID() require.NoError(t, err) - bazID, err := util.PredictionID() + bazID, err := server.PredictionID() require.NoError(t, err) barPrediction := runner.PredictionRequest{ Input: map[string]any{"i": 1, "s": "bar"}, diff --git a/internal/tests/path_test.go b/internal/tests/path_test.go index 37436c1c..7f26144c 100644 --- a/internal/tests/path_test.go +++ b/internal/tests/path_test.go @@ -18,7 +18,6 @@ import ( "github.com/replicate/cog-runtime/internal/runner" "github.com/replicate/cog-runtime/internal/server" - "github.com/replicate/cog-runtime/internal/util" "github.com/replicate/cog-runtime/internal/webhook" ) @@ -337,13 +336,13 @@ func TestPredictionPathMimeTypes(t *testing.T) { testDataPrefix := contentServer.URL + "/mimetype/" - gifPredictionID, err := util.PredictionID() + gifPredictionID, err := server.PredictionID() require.NoError(t, err) - jarPredictionID, err := util.PredictionID() + jarPredictionID, err := server.PredictionID() require.NoError(t, err) - tarPredictionID, err := util.PredictionID() + tarPredictionID, err := server.PredictionID() require.NoError(t, err) - webpPredictionID, err := util.PredictionID() + webpPredictionID, err := server.PredictionID() require.NoError(t, err) predictions := []struct { diff --git a/internal/tests/prediction_test.go b/internal/tests/prediction_test.go index f9ecdf1b..fa5313fc 100644 --- a/internal/tests/prediction_test.go +++ b/internal/tests/prediction_test.go @@ -13,7 +13,6 @@ import ( "github.com/replicate/cog-runtime/internal/runner" "github.com/replicate/cog-runtime/internal/server" - "github.com/replicate/cog-runtime/internal/util" "github.com/replicate/cog-runtime/internal/webhook" ) @@ -62,7 +61,7 @@ func TestPredictionWithIdSucceeded(t *testing.T) { waitForSetupComplete(t, runtimeServer, runner.StatusReady, runner.SetupSucceeded) input := map[string]any{"i": 1, "s": "bar"} - predictionID, err := util.PredictionID() + predictionID, err := server.PredictionID() require.NoError(t, err) predictionReq := runner.PredictionRequest{ ID: predictionID, diff --git a/internal/tests/shutdown_test.go b/internal/tests/shutdown_test.go index c1c76ad8..8caf50ab 100644 --- a/internal/tests/shutdown_test.go +++ b/internal/tests/shutdown_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/require" "github.com/replicate/cog-runtime/internal/runner" - "github.com/replicate/cog-runtime/internal/util" + "github.com/replicate/cog-runtime/internal/server" "github.com/replicate/cog-runtime/internal/webhook" ) @@ -83,7 +83,7 @@ func TestShutdownEndpointWaitsForInflightPredictions(t *testing.T) { baseURL := httpTestServer.URL // Start an async prediction - predictionID, err := util.PredictionID() + predictionID, err := server.PredictionID() require.NoError(t, err) prediction := runner.PredictionRequest{ @@ -118,7 +118,7 @@ func TestShutdownEndpointWaitsForInflightPredictions(t *testing.T) { assert.Equal(t, http.StatusOK, shutdownResp.StatusCode) // Verify new predictions are rejected during shutdown with 503 - newPredictionID, err := util.PredictionID() + newPredictionID, err := server.PredictionID() require.NoError(t, err) newPrediction := runner.PredictionRequest{ Input: map[string]any{"i": 1, "s": "should_be_rejected"}, diff --git a/internal/util/metrics.go b/internal/util/metrics.go deleted file mode 100644 index bf06958a..00000000 --- a/internal/util/metrics.go +++ /dev/null @@ -1,45 +0,0 @@ -package util //nolint:revive // FIXME: break up util package and move functions to where they're used - -import ( - "bytes" - "encoding/json" - "net/http" - "os" -) - -type MetricsPayload struct { - Source string `json:"source,omitempty"` - Type string `json:"type,omitempty"` - Data map[string]any `json:"data,omitempty"` -} - -const MetricsEndpointEnv = "COG_METRICS_ENDPOINT" - -func SendRunnerMetric(yaml CogYaml) { - log := logger.Sugar() - endpoint := os.Getenv(MetricsEndpointEnv) - if endpoint == "" { - return - } - data := map[string]any{ - "gpu": yaml.Build.GPU, - "fast": yaml.Build.Fast, - "cog_runtime": yaml.Build.CogRuntime, - "version": Version(), - } - payload := MetricsPayload{ - Source: "cog-runtime", - Type: "runner", - Data: data, - } - body, err := json.Marshal(payload) - if err != nil { - log.Errorw("failed to marshal payload", "error", err) - return - } - resp, err := HTTPClientWithRetry().Post(endpoint, "application/json", bytes.NewBuffer(body)) - if err != nil || resp.StatusCode != http.StatusOK { - log.Errorw("failed to send runner metrics", "error", err) - } - defer resp.Body.Close() -} diff --git a/internal/util/util.go b/internal/util/util.go deleted file mode 100644 index 27c9b5e9..00000000 --- a/internal/util/util.go +++ /dev/null @@ -1,130 +0,0 @@ -package util //nolint:revive // FIXME: break up util package and move functions to where they're used - -import ( - "embed" - "encoding/base32" - "fmt" - "net/http" - "os" - "path/filepath" - "strings" - "time" - - "github.com/replicate/go/httpclient" - "github.com/replicate/go/logging" - "github.com/replicate/go/uuid" - "go.uber.org/zap" - "go.uber.org/zap/zapcore" - "gopkg.in/yaml.v3" -) - -var logger = CreateLogger("cog-util") - -type Build struct { - GPU bool `yaml:"gpu"` - Fast bool `yaml:"fast"` - CogRuntime bool `yaml:"cog_runtime"` -} - -type Concurrency struct { - Max int `yaml:"max"` -} - -type CogYaml struct { - Build Build `yaml:"build"` - Concurrency Concurrency `yaml:"concurrency"` - Predict string `yaml:"predict"` -} - -func ReadCogYaml(dir string) (*CogYaml, error) { - var cogYaml CogYaml - bs, err := os.ReadFile(filepath.Join(dir, "cog.yaml")) //nolint:gosec // expected dynamic path - if err != nil { - return nil, err - } - if err := yaml.Unmarshal(bs, &cogYaml); err != nil { - return nil, err - } - return &cogYaml, nil -} - -func (y *CogYaml) PredictModuleAndPredictor() (string, string, error) { - parts := strings.Split(y.Predict, ":") - if len(parts) != 2 { - return "", "", fmt.Errorf("invalid predict: %s", y.Predict) - } - moduleName := strings.TrimSuffix(parts[0], ".py") - predictorName := parts[1] - return moduleName, predictorName, nil -} - -// api.git: internal/logic/id.go -func PredictionID() (string, error) { - u, err := uuid.NewV7() - if err != nil { - return "", err - } - shuffle := make([]byte, uuid.Size) - for i := 0; i < 4; i++ { - shuffle[i], shuffle[i+4], shuffle[i+8], shuffle[i+12] = u[i+12], u[i+4], u[i], u[i+8] - } - encoding := base32.NewEncoding("0123456789abcdefghjkmnpqrstvwxyz").WithPadding(base32.NoPadding) - return encoding.EncodeToString(shuffle), nil -} - -const TimeLayout = "2006-01-02T15:04:05.999999-07:00" - -func NowIso() string { - // Python: datetime.now(tz=timezone.utc).isoformat() - return time.Now().UTC().Format(TimeLayout) -} - -func FormatTime(t time.Time) string { - return t.UTC().Format(TimeLayout) -} - -func ParseTime(t string) (time.Time, error) { - parsedTime, err := time.Parse(TimeLayout, t) - if err != nil { - return time.Time{}, err - } - return parsedTime, nil -} - -func JoinLogs(logs []string) string { - r := strings.Join(logs, "\n") - if r != "" { - r += "\n" - } - return r -} - -// Wildcard match in case version.txt is not generated yet -// -//go:embed * -var embedFS embed.FS - -func Version() string { - bs, err := embedFS.ReadFile("version.txt") - if err != nil { - return "0.0.0+unknown" - } - return strings.TrimSpace(string(bs)) -} - -func HTTPClientWithRetry() *http.Client { - return httpclient.ApplyRetryPolicy(http.DefaultClient) -} - -func CreateLogger(name string) *zap.Logger { - logLevel := os.Getenv("COG_LOG_LEVEL") - if logLevel == "" { - logLevel = "info" - } - lvl, err := zapcore.ParseLevel(logLevel) - if err != nil { - fmt.Printf("Failed to parse log level \"%s\": %s\n", logLevel, err) //nolint:forbidigo // if the logger cannot be initialized, we should still be able to report the error - lvl = zapcore.InfoLevel - } - return logging.New(name).WithOptions(zap.IncreaseLevel(lvl)) -} diff --git a/internal/version/version.go b/internal/version/version.go new file mode 100644 index 00000000..df24b55e --- /dev/null +++ b/internal/version/version.go @@ -0,0 +1,17 @@ +package version + +import ( + "embed" + "strings" +) + +//go:embed * +var embedFS embed.FS + +func Version() string { + bs, err := embedFS.ReadFile("version.txt") + if err != nil { + return "0.0.0+unknown" + } + return strings.TrimSpace(string(bs)) +} diff --git a/internal/webhook/webhook.go b/internal/webhook/webhook.go index 5fb5a2a0..d114853a 100644 --- a/internal/webhook/webhook.go +++ b/internal/webhook/webhook.go @@ -7,9 +7,8 @@ import ( "slices" "time" + "github.com/replicate/go/httpclient" "go.uber.org/zap" - - "github.com/replicate/cog-runtime/internal/util" ) // Event represents a webhook event - using string to be compatible with any type @@ -41,7 +40,7 @@ type DefaultSender struct { func NewSender(logger *zap.Logger) *DefaultSender { return &DefaultSender{ logger: logger.Named("webhook"), - client: util.HTTPClientWithRetry(), + client: httpclient.ApplyRetryPolicy(http.DefaultClient), } } diff --git a/script/build.sh b/script/build.sh index 15735da8..702e88ca 100755 --- a/script/build.sh +++ b/script/build.sh @@ -15,7 +15,7 @@ rm -rf python/cog/cog-* # Skip Go binaries if building "clet", i.e. coglet without go for pyodide if [ -z "${CLET:-}" ]; then # Export Python version to Go - uv run --with setuptools_scm python3 -m setuptools_scm > internal/util/version.txt + uv run --with setuptools_scm python3 -m setuptools_scm > internal/version/version.txt # Binaries are bundled in Python wheel for os in darwin linux; do for arch in amd64 arm64; do diff --git a/script/test-setuid-cleanup.sh b/script/test-setuid-cleanup.sh index 975042e8..3e7d31ba 100755 --- a/script/test-setuid-cleanup.sh +++ b/script/test-setuid-cleanup.sh @@ -38,7 +38,7 @@ python3 -m cog.server.http \ EOF # Start Docker container -docker run -it --detach \ +docker run -it --rm --detach \ --name "$name" \ --entrypoint /bin/bash \ --publish "$port:$port" \