Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions client/adaptive_ratelimit.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,10 +387,12 @@ func (a *AdaptiveRateLimitProvider) checkAndWait(ctx context.Context) error {
if delay > 0 && delay <= a.config.MaxDelay {
a.throttleCount++
a.mu.Unlock()
timer := time.NewTimer(delay)
select {
case <-ctx.Done():
timer.Stop()
return ctx.Err()
case <-time.After(delay):
case <-timer.C:
}
a.mu.Lock()
now = time.Now()
Expand Down Expand Up @@ -421,10 +423,12 @@ func (a *AdaptiveRateLimitProvider) checkAndWait(ctx context.Context) error {
if delay > 0 && delay <= a.config.MaxDelay {
a.throttleCount++
a.mu.Unlock()
timer := time.NewTimer(delay)
select {
case <-ctx.Done():
timer.Stop()
return ctx.Err()
case <-time.After(delay):
case <-timer.C:
}
a.mu.Lock()
now = time.Now()
Expand Down
61 changes: 54 additions & 7 deletions client/guardrails.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ type GuardrailViolation struct {
Rule GuardrailRule `json:"rule"`
MatchedText string `json:"matched_text"`
RedactedResult string `json:"redacted_result,omitempty"`
matchStart int // byte offset of the match in the original response
matchEnd int // byte offset one past the last matched byte
}

// GuardrailError is returned when a guardrail blocks a response.
Expand Down Expand Up @@ -91,6 +93,10 @@ func NewGuardrails(rules ...GuardrailRule) *Guardrails {
}

// AddRule registers a guardrail rule. It panics if the pattern is invalid.
// This follows the regexp.MustCompile convention for programmatic rules
// where an invalid pattern indicates a programmer error. For rules that
// may originate from untrusted sources (config files, user input), use
// AddRuleSafe instead.
func (g *Guardrails) AddRule(r GuardrailRule) {
if r.Pattern != "" {
compiled, err := regexp.Compile(r.Pattern)
Expand All @@ -104,6 +110,37 @@ func (g *Guardrails) AddRule(r GuardrailRule) {
g.rules = append(g.rules, r)
}

// AddRuleSafe registers a guardrail rule and returns an error if the pattern
// is invalid, instead of panicking. Use this when rules may come from
// untrusted sources (config files, user input).
func (g *Guardrails) AddRuleSafe(r GuardrailRule) error {
if r.Pattern != "" {
compiled, err := regexp.Compile(r.Pattern)
if err != nil {
return fmt.Errorf("eyrie: guardrails: invalid regex %q in rule %q: %w", r.Pattern, r.Name, err)
}
r.compiled = compiled
}
g.mu.Lock()
defer g.mu.Unlock()
g.rules = append(g.rules, r)
return nil
}

// NewGuardrailsSafe creates a Guardrails instance and returns an error if any
// rule has an invalid pattern. Use this when rules may come from untrusted
// sources; use NewGuardrails for programmatic rules where invalid patterns
// indicate a programmer error (matching regexp.MustCompile convention).
func NewGuardrailsSafe(rules ...GuardrailRule) (*Guardrails, error) {
g := &Guardrails{}
for _, r := range rules {
if err := g.AddRuleSafe(r); err != nil {
return nil, err
}
}
return g, nil
}

// Rules returns a snapshot of the currently registered rules.
func (g *Guardrails) Rules() []GuardrailRule {
g.mu.RLock()
Expand Down Expand Up @@ -134,17 +171,19 @@ func (g *Guardrails) Check(ctx context.Context, response string) ([]GuardrailVio
if rule.compiled == nil {
continue
}
matches := rule.compiled.FindAllString(response, -1)
matches := rule.compiled.FindAllStringIndex(response, -1)
if len(matches) == 0 {
continue
}
for _, match := range matches {
v := GuardrailViolation{
Rule: rule,
MatchedText: match,
MatchedText: response[match[0]:match[1]],
matchStart: match[0],
matchEnd: match[1],
}
if rule.Action == GuardrailRedact {
v.RedactedResult = strings.Repeat("*", len(match))
v.RedactedResult = strings.Repeat("*", len(v.MatchedText))
}
violations = append(violations, v)
if rule.Action == GuardrailBlock {
Expand All @@ -165,7 +204,9 @@ func (g *Guardrails) Check(ctx context.Context, response string) ([]GuardrailVio

// ApplyRedactions takes the response text and violations, replacing redacted
// matches with their redaction markers. Non-redact violations are left intact.
// Matches are applied positionally to handle overlapping patterns correctly.
// Match positions are used directly from the violations (captured during Check)
// so the correct instance of each match is redacted even when the matched text
// appears multiple times in the response.
func ApplyRedactions(response string, violations []GuardrailViolation) string {
type replacement struct {
start int
Expand All @@ -177,11 +218,17 @@ func ApplyRedactions(response string, violations []GuardrailViolation) string {
if v.Rule.Action != GuardrailRedact {
continue
}
idx := strings.Index(response, v.MatchedText)
if idx < 0 {
if v.matchEnd == 0 && v.matchStart == 0 && v.MatchedText != "" {
// Fallback: violation came from outside Check (e.g. constructed
// manually). Search for the first occurrence.
idx := strings.Index(response, v.MatchedText)
if idx < 0 {
continue
}
reps = append(reps, replacement{start: idx, end: idx + len(v.MatchedText), text: v.RedactedResult})
continue
}
reps = append(reps, replacement{start: idx, end: idx + len(v.MatchedText), text: v.RedactedResult})
reps = append(reps, replacement{start: v.matchStart, end: v.matchEnd, text: v.RedactedResult})
}
if len(reps) == 0 {
return response
Expand Down
36 changes: 36 additions & 0 deletions client/guardrails_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,42 @@ func TestGuardrails_InvalidPatternPanics(t *testing.T) {
})
}

func TestGuardrails_InvalidPatternSafeReturnsError(t *testing.T) {
_, err := NewGuardrailsSafe(GuardrailRule{
Type: GuardrailCustom,
Name: "bad_regex",
Pattern: `[invalid`,
Action: GuardrailBlock,
})
if err == nil {
t.Fatal("expected error for invalid regex in NewGuardrailsSafe")
}
}

func TestGuardrails_AddRuleSafe(t *testing.T) {
g := NewGuardrails()
if err := g.AddRuleSafe(GuardrailRule{
Type: GuardrailCustom,
Name: "dynamic_rule",
Pattern: `dynamic_pattern`,
Action: GuardrailWarn,
}); err != nil {
t.Fatalf("expected no error, got: %v", err)
}
if len(g.Rules()) != 1 {
t.Fatalf("expected 1 rule after AddRuleSafe, got %d", len(g.Rules()))
}

if err := g.AddRuleSafe(GuardrailRule{
Type: GuardrailCustom,
Name: "bad_regex",
Pattern: `[invalid`,
Action: GuardrailBlock,
}); err == nil {
t.Fatal("expected error for invalid regex in AddRuleSafe")
}
}

func TestGuardrails_AddRule(t *testing.T) {
g := NewGuardrails()
g.AddRule(GuardrailRule{
Expand Down
11 changes: 10 additions & 1 deletion client/provider_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"log/slog"
"time"

"github.com/GrayCodeAI/eyrie/config"
"github.com/GrayCodeAI/eyrie/credentials"
Expand Down Expand Up @@ -280,5 +281,13 @@ func ResolveProviderModelEnvOverride(provider string) string {
}

func resolveEnvSecret(envKey string) string {
return credentials.LookupSecret(context.Background(), envKey)
// Bound the lookup to prevent indefinite stalls when the OS keyring
// is unresponsive (e.g., locked keychain on macOS, D-Bus failure on
// Linux). The keyring itself has a 30s timeout, but resolveEnvSecret
// is called multiple times in sequence during provider construction
// (up to 6 calls for AWS Bedrock), so a per-call cap keeps the total
// stall bounded.
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
return credentials.LookupSecret(ctx, envKey)
}
8 changes: 6 additions & 2 deletions client/ratelimit.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,12 @@ func (b *tokenBucket) wait(ctx context.Context) error {
if since < b.minInterval {
wait := b.minInterval - since
b.mu.Unlock()
timer := time.NewTimer(wait)
select {
case <-ctx.Done():
timer.Stop()
return fmt.Errorf("eyrie: rate limiter: %w", ctx.Err())
case <-time.After(wait):
case <-timer.C:
}
b.mu.Lock()
}
Expand All @@ -100,10 +102,12 @@ func (b *tokenBucket) wait(ctx context.Context) error {
waitDur := time.Duration(needed / b.refillRate)
b.mu.Unlock()

timer := time.NewTimer(waitDur)
select {
case <-ctx.Done():
timer.Stop()
return fmt.Errorf("eyrie: rate limiter: %w", ctx.Err())
case <-time.After(waitDur):
case <-timer.C:
}
}
}
Expand Down
8 changes: 7 additions & 1 deletion client/retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,16 @@ func doWithRetry(ctx context.Context, httpClient *http.Client, req *http.Request
"attempt", attempt, "max", rc.MaxRetries,
"delay", delay, "url", req.URL.String(),
)
// Use time.NewTimer + Stop instead of time.After to avoid leaking
// the timer in the runtime when ctx is cancelled before the delay
// elapses. time.After allocates a timer that lives until it fires,
// even if the caller has already moved on.
timer := time.NewTimer(delay)
select {
case <-ctx.Done():
timer.Stop()
return nil, ctx.Err()
case <-time.After(delay):
case <-timer.C:
}
}

Expand Down
19 changes: 16 additions & 3 deletions client/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"io"
"log/slog"
"sort"
"strings"
"time"
)
Expand Down Expand Up @@ -428,9 +429,21 @@ func processOpenAIStreamWithOpts(ctx context.Context, sseEvents <-chan SSEEvent,
return
}
toolsEmitted = true
for _, t := range tools {
var args map[string]interface{}
_ = json.Unmarshal([]byte(t.argsBuf.String()), &args)
// Sort by index so tool calls are emitted in the order the
// model produced them, not in random map-iteration order.
indices := make([]int, 0, len(tools))
for idx := range tools {
indices = append(indices, idx)
}
sort.Ints(indices)
for _, idx := range indices {
t := tools[idx]
args := map[string]interface{}{}
if err := json.Unmarshal([]byte(t.argsBuf.String()), &args); err != nil {
// On parse failure, pass the raw string as _raw so
// the caller sees something rather than a nil map.
args = map[string]interface{}{"_raw": t.argsBuf.String()}
}
emit(ctx, ch, EyrieStreamEvent{
Type: "tool_call",
ToolCall: &ToolCall{ID: t.id, Name: t.name, Arguments: args},
Expand Down
17 changes: 13 additions & 4 deletions config/provider_env.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package config
import (
"encoding/json"
"fmt"
"net/url"
"os"
"path/filepath"
"strings"
Expand Down Expand Up @@ -286,13 +287,21 @@ func ValidateAPIKey(apiKey, providerName string) string {
return ""
}

// ValidateBaseURL validates a base URL.
// ValidateBaseURL validates a base URL. Returns an error message if the URL
// is syntactically invalid (unparseable or missing a scheme), or empty if valid.
func ValidateBaseURL(baseURL string) string {
if baseURL == "" {
return ""
}
if _, err := os.Stat(baseURL); err == nil {
return "Invalid base URL: " + baseURL
u, err := url.Parse(baseURL)
if err != nil {
return "Invalid base URL: " + baseURL + " (" + err.Error() + ")"
}
if u.Scheme != "http" && u.Scheme != "https" {
return "Invalid base URL: " + baseURL + " (must be http or https)"
}
if u.Host == "" {
return "Invalid base URL: " + baseURL + " (missing host)"
}
return ""
}
Expand Down Expand Up @@ -360,7 +369,7 @@ func SaveProviderConfig(config *ProviderConfig, path string) error {
if path == "" {
path = GetProviderConfigPath()
}
_ = os.MkdirAll(filepath.Dir(path), 0o755)
_ = os.MkdirAll(filepath.Dir(path), 0o700)
data, err := json.MarshalIndent(config, "", " ")
if err != nil {
return err
Expand Down
8 changes: 7 additions & 1 deletion config/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"os"
"strings"
"time"

"github.com/GrayCodeAI/eyrie/credentials"
)
Expand Down Expand Up @@ -43,8 +44,13 @@ func envValue(key string) string {
if key == "" {
return ""
}
// Bound the keyring lookup to prevent indefinite stalls when the OS
// keychain is unresponsive. The keyring itself has a 30s timeout,
// but envValue is called many times during provider construction.
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Always check the credential store first for ALL keys.
if v := credentials.LookupSecret(context.Background(), key); v != "" {
if v := credentials.LookupSecret(ctx, key); v != "" {
return v
}
// Fall back to process environment only when the credential store has
Expand Down
15 changes: 15 additions & 0 deletions conversation/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ func (e *Engine) DeleteNode(ctx context.Context, id string) error {
const defaultGroupBudgetMultiplier = 4

func (e *Engine) streamAndSave(ctx context.Context, parentNode *storage.Node, messages []client.EyrieMessage, opts PromptOpts) (<-chan Event, error) {
if e.provider == nil {
return nil, fmt.Errorf("conversation: engine has no provider")
}
_, span := tracer.Start(
ctx, "conversation.streamAndSave",
trace.WithAttributes(
Expand Down Expand Up @@ -248,6 +251,18 @@ func (e *Engine) streamAndSave(ctx context.Context, parentNode *storage.Node, me
currentStream = sr.Events
)

// Ensure the last StreamResult is always closed on exit. The
// closure captures currentSR by reference, so it closes whatever
// stream is active when the goroutine returns — including the
// initial stream on early exits and the final continuation stream
// on normal completion. Previous streams are closed explicitly
// during the continuation loop (currentSR.Close() before reassign).
defer func() {
if currentSR != nil {
currentSR.Close()
}
}()

for {
var fullTextBuilder strings.Builder
var usage *client.EyrieUsage
Expand Down
11 changes: 10 additions & 1 deletion credentials/keyring_platform.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,16 @@ func (k *keyringStore) Get(ctx context.Context, account string) (string, error)
return "", err
}
if r.err != nil {
return "", ErrNotFound
// Only map "not found" errors to ErrNotFound; pass through real
// backend errors (locked keychain, permission denied, D-Bus
// failure, etc.) so callers can distinguish "no secret stored"
// from "lookup failed". Without this, LookupSecret would log a
// real backend failure at Debug level ("no secret stored")
// instead of Warn level ("secret lookup failed").
if isNotFound(r.err) {
return "", ErrNotFound
}
return "", r.err
}
if strings.TrimSpace(r.val) == "" {
return "", ErrNotFound
Expand Down
Loading
Loading