Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,22 @@ threat-detect [flags] <artifacts-dir>
- `--model` — Model override for the engine
- `--prompt-template` — Path to custom prompt template
- `--output` — Path to write JSON result (defaults to stdout)
- `--triage` — Enable fast Phase 1 structured-output triage. Default: `true`
- `--reflect-url` — `api-proxy` `/reflect` base URL for structured-output calls. Default: `http://127.0.0.1:8080/reflect`
- `--triage-model` — Model override for Phase 1 `/reflect` triage
- `--triage-max-bytes` — Maximum bytes per artifact to inline during triage
- `--triage-retries` — Retries for malformed structured-output responses
- `--version` — Print version and exit

`--reflect-url` can also be supplied with `THREAT_DETECTION_REFLECT_URL`,
`API_PROXY_REFLECT_URL`, or `REFLECT_URL`. By default, `threat-detect` first
tries a non-agentic `/reflect` call with a strict JSON schema matching the result
contract. An all-false valid triage result exits successfully without the full
detector. Threats, uncertainty, unsupported models, proxy errors, or malformed
responses fail safe into the full detector. The full detector preserves the
existing CLI engine behavior and prefers `/reflect` structured output when a
schema-capable model is available.

**Exit codes:**
- `0` — Safe (no threats detected)
- `1` — Threat detected
Expand Down Expand Up @@ -131,13 +145,13 @@ The extraction staging model is:
- Stage 3: `github/gh-aw` integration

Stage 1 is functionally represented in this repository.
The standalone Go CLI, artifact reader, prompt builder, result parser, engine abstraction, W3C-style specification, unit tests, CI, Dockerfile, and release workflow are present.
The standalone Go CLI, artifact reader, prompt builder, two-phase `/reflect` triage, result parser, engine abstraction, W3C-style specification, unit tests, CI, Dockerfile, and release workflow are present.
Remaining work involves integration with `github/gh-aw` and production hardening of the container runtime in Stage 2/3, not additional JavaScript porting in this repository.

Decisions for the unresolved extraction questions:

- **JavaScript scripts**: detection setup and result parsing are implemented in Go here; the old GitHub Actions JavaScript scripts should not be needed once `gh-aw` switches to the container contract.
- **Engine CLIs**: do not bundle Copilot, Claude, or Codex CLIs into the detector image. The detector invokes the selected engine CLI from `PATH` and forwards the `--model` value. Production `gh-aw` integration should install or provide the selected engine CLI in the detection job, then run the pinned detector binary extracted from the detector image in that same runner/AWF environment. This keeps the image small, avoids runtime installation inside the image, and reuses the existing engine installation/authentication path.
- **Engine CLIs and `/reflect`**: do not bundle Copilot, Claude, or Codex CLIs into the detector image. The detector invokes the selected engine CLI from `PATH` and forwards the `--model` value when full CLI analysis is needed. When `--reflect-url` is configured, the detector can call `api-proxy` directly for structured-output triage and schema-capable full analysis before falling back to CLI behavior. Production `gh-aw` integration should install or provide the selected engine CLI in the detection job, then run the pinned detector binary extracted from the detector image in that same runner/AWF environment. This keeps the image small, avoids runtime installation inside the image, and reuses the existing engine installation/authentication path.
- **Custom steps**: custom `threat-detection.steps` remain orchestrator-owned. They should run before or after the container in the `gh-aw` job rather than being passed into this container as arbitrary scripts.
- **Backward compatibility**: do not ship a long-lived dual-mode compatibility window. Stage 3 should switch `gh-aw` to the pinned detector image path after Stage 4 validation passes; users that need inline detection can pin an older `gh-aw` release. A temporary internal fallback is acceptable during implementation only, but should not become a documented public feature flag unless Stage 4 exposes a blocking compatibility issue.
- **Ollama/LlamaGuard**: keep this as a custom-step pattern unless a dedicated image variant is explicitly required.
Expand Down
140 changes: 124 additions & 16 deletions cmd/threat-detect/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,20 @@ import (
"fmt"
"os"
"os/signal"
"strconv"

"github.com/github/gh-aw-threat-detection/pkg/artifacts"
"github.com/github/gh-aw-threat-detection/pkg/detector"
"github.com/github/gh-aw-threat-detection/pkg/engine"
)

const (
exitSafe = 0
exitThreat = 1
exitError = 2
exitSafe = 0
exitThreat = 1
exitError = 2

fullDetectionCorrectionSummaryFormat = "Your previous response did not contain a valid %s JSON object"
fullDetectionCorrectionInstructionFormat = "Return exactly one corrected result line using the required %s prefix."
)

func main() {
Expand All @@ -42,18 +46,28 @@ func run() int {
defer stop()

var (
engineID string
model string
promptFile string
outputJSON string
version bool
engineID string
model string
promptFile string
outputJSON string
version bool
triage bool
reflectURL string
triageModel string
triageMaxBytes int
triageRetries int
)

flag.StringVar(&engineID, "engine", "", "AI engine to use (copilot, claude, codex)")
flag.StringVar(&model, "model", "", "Model to use for detection")
flag.StringVar(&promptFile, "prompt-template", "", "Path to custom prompt template (defaults to built-in)")
flag.StringVar(&outputJSON, "output", "", "Path to write JSON result (defaults to stdout)")
flag.BoolVar(&version, "version", false, "Print version and exit")
flag.BoolVar(&triage, "triage", envBool("THREAT_DETECTION_TRIAGE", true), "Run Phase 1 structured-output triage before full detection (env: THREAT_DETECTION_TRIAGE)")
flag.StringVar(&reflectURL, "reflect-url", envFirstOrDefault(engine.DefaultReflectURL, "THREAT_DETECTION_REFLECT_URL", "API_PROXY_REFLECT_URL", "REFLECT_URL"), "api-proxy reflect base URL")
flag.StringVar(&triageModel, "triage-model", os.Getenv("THREAT_DETECTION_TRIAGE_MODEL"), "Model to use for reflect triage")
flag.IntVar(&triageMaxBytes, "triage-max-bytes", envInt("THREAT_DETECTION_TRIAGE_MAX_BYTES", detector.DefaultTriageMaxBytes()), "Maximum bytes per artifact to inline for triage")
flag.IntVar(&triageRetries, "triage-retries", envInt("THREAT_DETECTION_TRIAGE_RETRIES", 1), "Retries for malformed structured outputs")
flag.Parse()

if version {
Expand All @@ -77,6 +91,27 @@ func run() int {
return exitError
}

if triage && reflectURL != "" {
triagePrompt, err := detector.BuildTriagePrompt(arts, triageMaxBytes)
if err == nil {
triageResult, err := (&engine.ReflectClient{
BaseURL: reflectURL,
Model: triageModel,
Retries: triageRetries,
}).AnalyzeStructured(ctx, triagePrompt)
if err == nil && triageResult.IsSafe() {
return writeResult(triageResult, outputJSON)
}
if err != nil {
fmt.Fprintf(os.Stderr, "Triage inconclusive, running full detection: %v\n", err)
} else {
fmt.Fprintln(os.Stderr, "Triage found possible threats, running full detection")
}
} else {
fmt.Fprintf(os.Stderr, "Error building triage prompt, running full detection: %v\n", err)
}
}

// Build the prompt
promptTemplate := ""
if promptFile != "" {
Expand All @@ -94,28 +129,59 @@ func run() int {
return exitError
}

if reflectURL != "" {
reflectResult, err := (&engine.ReflectClient{
BaseURL: reflectURL,
Model: firstNonEmpty(model, triageModel),
Retries: triageRetries,
}).AnalyzeStructured(ctx, prompt)
if err == nil {
return writeResult(reflectResult, outputJSON)
}
fmt.Fprintf(os.Stderr, "Structured reflect detection unavailable, using CLI engine: %v\n", err)
}

// Create engine
eng, err := engine.New(engineID, model)
if err != nil {
fmt.Fprintf(os.Stderr, "Error creating engine: %v\n", err)
return exitError
}

// Run detection
rawOutput, err := eng.Analyze(ctx, prompt)
result, err := analyzeWithRetries(ctx, eng, prompt, triageRetries)
if err != nil {
fmt.Fprintf(os.Stderr, "Error running detection: %v\n", err)
return exitError
}

// Parse result
result, err := detector.ParseResult(rawOutput)
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing result: %v\n", err)
return exitError
return writeResult(result, outputJSON)
}

func analyzeWithRetries(ctx context.Context, eng engine.Engine, prompt string, retries int) (*detector.Result, error) {
attempts := retries + 1
if attempts < 1 {
attempts = 1
}
currentPrompt := prompt
var lastErr error
for i := 0; i < attempts; i++ {
rawOutput, err := eng.Analyze(ctx, currentPrompt)
if err != nil {
return nil, err
}
result, err := detector.ParseResult(rawOutput)
if err == nil {
return result, nil
}
lastErr = err
summary := fmt.Sprintf(fullDetectionCorrectionSummaryFormat, detector.ResultPrefix)
instruction := fmt.Sprintf(fullDetectionCorrectionInstructionFormat, detector.ResultPrefix)
currentPrompt = detector.BuildCorrectionPrompt(prompt, summary, err.Error(), instruction)
}
return nil, lastErr
}

// Output result
func writeResult(result *detector.Result, outputJSON string) int {
jsonBytes, err := json.MarshalIndent(result, "", " ")
if err != nil {
fmt.Fprintf(os.Stderr, "Error marshaling result: %v\n", err)
Expand All @@ -137,3 +203,45 @@ func run() int {
}
return exitSafe
}

func envFirstOrDefault(fallback string, keys ...string) string {
for _, key := range keys {
if value := os.Getenv(key); value != "" {
return value
}
}
return fallback
}

func envBool(key string, fallback bool) bool {
value := os.Getenv(key)
if value == "" {
return fallback
}
parsed, err := strconv.ParseBool(value)
if err != nil {
return fallback
}
return parsed
}

func envInt(key string, fallback int) int {
value := os.Getenv(key)
if value == "" {
return fallback
}
parsed, err := strconv.Atoi(value)
if err != nil {
return fallback
}
return parsed
}

func firstNonEmpty(values ...string) string {
for _, value := range values {
if value != "" {
return value
}
}
return ""
}
152 changes: 152 additions & 0 deletions cmd/threat-detect/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
package main

import (
"encoding/json"
"flag"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"sync/atomic"
"testing"
)

func TestRunReflectUnavailableFallsBackToAgenticEngine(t *testing.T) {
var reflectRequests atomic.Int32
reflectServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reflectRequests.Add(1)
http.Error(w, "reflect not implemented", http.StatusNotImplemented)
}))
defer reflectServer.Close()

artifactsDir := t.TempDir()
outputPath := filepath.Join(t.TempDir(), "result.json")
copilotMarker := filepath.Join(t.TempDir(), "copilot-called")
fakeBinDir := writeFakeCopilot(t, copilotMarker, `THREAT_DETECTION_RESULT:{"prompt_injection":true,"secret_leak":false,"malicious_patch":false,"reasons":["agentic fallback"]}`)

code := runWithTestArgs(t, []string{
"threat-detect",
"-reflect-url", reflectServer.URL,
"-output", outputPath,
artifactsDir,
}, map[string]string{
"PATH": fakeBinDir + string(os.PathListSeparator) + os.Getenv("PATH"),
})

if code != exitThreat {
t.Fatalf("run() exit code = %d, want %d", code, exitThreat)
}
if reflectRequests.Load() == 0 {
t.Fatal("expected /reflect to be attempted before fallback")
}
if _, err := os.Stat(copilotMarker); err != nil {
t.Fatalf("expected copilot fallback to run: %v", err)
}
result := readResultFile(t, outputPath)
if !result["prompt_injection"].(bool) {
t.Fatalf("fallback result prompt_injection = false, want true: %#v", result)
}
}

func TestRunReflectSuccessDoesNotInvokeAgenticEngine(t *testing.T) {
var postRequests atomic.Int32
reflectServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"models":[{"id":"schema","provider":"openai","capabilities":{"json_schema":true}}]}`))
case http.MethodPost:
postRequests.Add(1)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"prompt_injection":false,"secret_leak":false,"malicious_patch":false,"reasons":[]}`))
default:
http.Error(w, "unexpected method", http.StatusMethodNotAllowed)
}
}))
defer reflectServer.Close()

artifactsDir := t.TempDir()
outputPath := filepath.Join(t.TempDir(), "result.json")
copilotMarker := filepath.Join(t.TempDir(), "copilot-called")
fakeBinDir := writeFakeCopilot(t, copilotMarker, "copilot should not run")

code := runWithTestArgs(t, []string{
"threat-detect",
"-triage=false",
"-reflect-url", reflectServer.URL,
"-output", outputPath,
artifactsDir,
}, map[string]string{
"PATH": fakeBinDir + string(os.PathListSeparator) + os.Getenv("PATH"),
})

if code != exitSafe {
t.Fatalf("run() exit code = %d, want %d", code, exitSafe)
}
if postRequests.Load() == 0 {
t.Fatal("expected successful structured /reflect detection")
}
if _, err := os.Stat(copilotMarker); !os.IsNotExist(err) {
t.Fatalf("copilot should not run when /reflect succeeds, stat err = %v", err)
}
result := readResultFile(t, outputPath)
if result["prompt_injection"].(bool) || result["secret_leak"].(bool) || result["malicious_patch"].(bool) {
t.Fatalf("reflect result is not safe: %#v", result)
}
}

func runWithTestArgs(t *testing.T, args []string, env map[string]string) int {
t.Helper()

originalArgs := os.Args
originalFlags := flag.CommandLine
t.Cleanup(func() {
os.Args = originalArgs
flag.CommandLine = originalFlags
})
os.Args = args
flag.CommandLine = flag.NewFlagSet(args[0], flag.ContinueOnError)

for key, value := range env {
t.Setenv(key, value)
}

return run()
}

func writeFakeCopilot(t *testing.T, markerPath, output string) string {
t.Helper()

binDir := t.TempDir()
scriptPath := filepath.Join(binDir, "copilot")
script := strings.Join([]string{
"#!/bin/sh",
"cat >/dev/null",
"printf called > " + shellQuote(markerPath),
"printf '%s\\n' " + shellQuote(output),
"",
}, "\n")
if err := os.WriteFile(scriptPath, []byte(script), 0o700); err != nil {
t.Fatalf("writing fake copilot: %v", err)
}
return binDir
}

func readResultFile(t *testing.T, path string) map[string]any {
t.Helper()

data, err := os.ReadFile(path)
if err != nil {
t.Fatalf("reading result file: %v", err)
}
var result map[string]any
if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("parsing result JSON: %v", err)
}
return result
}

func shellQuote(value string) string {
return "'" + strings.ReplaceAll(value, "'", "'\\''") + "'"
}
Loading
Loading