Skip to content

Commit 82fd12a

Browse files
authored
Merge pull request #5 from nnemirovsky/feat/db-watcher
feat(core): auto-reload policy on database changes via SQLite data_version
2 parents 5a7f6e0 + 613b985 commit 82fd12a

3 files changed

Lines changed: 138 additions & 54 deletions

File tree

cmd/sluice/main.go

Lines changed: 54 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -552,71 +552,71 @@ func main() {
552552
sigCh := make(chan os.Signal, 1)
553553
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
554554

555-
sighupCh := make(chan os.Signal, 1)
556-
signal.Notify(sighupCh, syscall.SIGHUP)
555+
// reloadAll reloads policy, bindings, and OAuth index from the database.
556+
// Called by both SIGHUP handler and the SQLite data_version watcher.
557+
reloadAll := func() {
558+
srv.ReloadMu().Lock()
559+
defer srv.ReloadMu().Unlock()
560+
561+
newEng, loadErr := policy.LoadFromStore(db)
562+
if loadErr != nil {
563+
log.Printf("reload policy failed: %v", loadErr)
564+
return
565+
}
557566

558-
go func() {
559-
for range sighupCh {
560-
srv.ReloadMu().Lock()
561-
562-
newEng, loadErr := policy.LoadFromStore(db)
563-
if loadErr != nil {
564-
log.Printf("reload policy failed: %v", loadErr)
565-
drainSignals(sighupCh)
566-
srv.ReloadMu().Unlock()
567-
continue
568-
}
567+
if valErr := newEng.Validate(); valErr != nil {
568+
log.Printf("reload policy validation failed: %v", valErr)
569+
return
570+
}
569571

570-
// Validate the new engine before swapping to catch
571-
// corrupted or incomplete compilation results.
572-
if valErr := newEng.Validate(); valErr != nil {
573-
log.Printf("reload policy validation failed: %v", valErr)
574-
drainSignals(sighupCh)
575-
srv.ReloadMu().Unlock()
576-
continue
572+
log.Printf("reloaded policy: %d allow, %d deny, %d ask rules (default: %s)",
573+
len(newEng.AllowRules), len(newEng.DenyRules), len(newEng.AskRules), newEng.Default)
574+
srv.StoreEngine(newEng)
575+
srv.UpdateInspectRules(newEng)
576+
577+
newBindings, bindErr := readBindings(db)
578+
if bindErr != nil {
579+
log.Printf("reload bindings failed: %v", bindErr)
580+
} else if len(newBindings) > 0 {
581+
newResolver, resolveErr := vault.NewBindingResolver(newBindings)
582+
if resolveErr != nil {
583+
log.Printf("rebuild binding resolver failed: %v", resolveErr)
584+
} else {
585+
srv.StoreResolver(newResolver)
586+
log.Printf("reloaded bindings: %d", len(newBindings))
577587
}
588+
} else if len(newBindings) == 0 {
589+
srv.StoreResolver(nil)
590+
}
578591

579-
log.Printf("reloaded policy: %d allow, %d deny, %d ask rules (default: %s)",
580-
len(newEng.AllowRules), len(newEng.DenyRules), len(newEng.AskRules), newEng.Default)
581-
srv.StoreEngine(newEng)
582-
srv.UpdateInspectRules(newEng)
583-
584-
// Rebuild binding resolver so credential injection picks up
585-
// bindings added via CLI or Telegram since last reload.
586-
newBindings, bindErr := readBindings(db)
587-
if bindErr != nil {
588-
log.Printf("reload bindings failed: %v", bindErr)
589-
} else if len(newBindings) > 0 {
590-
newResolver, resolveErr := vault.NewBindingResolver(newBindings)
591-
if resolveErr != nil {
592-
log.Printf("rebuild binding resolver failed: %v", resolveErr)
593-
} else {
594-
srv.StoreResolver(newResolver)
595-
log.Printf("reloaded bindings: %d", len(newBindings))
596-
}
597-
} else if len(newBindings) == 0 {
598-
srv.StoreResolver(nil)
599-
}
592+
if metas, metaErr := db.ListCredentialMeta(); metaErr == nil {
593+
srv.UpdateOAuthIndex(metas)
594+
} else {
595+
log.Printf("reload oauth index failed: %v", metaErr)
596+
}
600597

601-
// Rebuild OAuth token URL index so response interception picks
602-
// up credentials added via CLI or Telegram since last reload.
603-
if metas, metaErr := db.ListCredentialMeta(); metaErr == nil {
604-
srv.UpdateOAuthIndex(metas)
605-
} else {
606-
log.Printf("reload oauth index failed: %v", metaErr)
607-
}
598+
if broker == nil && (len(newEng.AskRules) > 0 || newEng.Default == policy.Ask) {
599+
log.Printf("WARNING: policy has ask rules but no approval broker is running; ask verdicts will auto-deny")
600+
}
608601

609-
// Warn if the reloaded policy has ask rules but no approval
610-
// broker is running.
611-
if broker == nil && (len(newEng.AskRules) > 0 || newEng.Default == policy.Ask) {
612-
log.Printf("WARNING: policy has ask rules but no approval broker is running; ask verdicts will auto-deny")
613-
}
602+
}
603+
604+
sighupCh := make(chan os.Signal, 1)
605+
signal.Notify(sighupCh, syscall.SIGHUP)
614606

607+
go func() {
608+
for range sighupCh {
609+
reloadAll()
615610
drainSignals(sighupCh)
616-
srv.ReloadMu().Unlock()
617611
}
618612
}()
619613

614+
// Watch for database changes from external connections (CLI commands).
615+
// Triggers the same reload as SIGHUP without requiring manual signals.
616+
dbWatcher := store.NewWatcher(db.DB(), reloadAll)
617+
dbWatcher.Start()
618+
defer dbWatcher.Stop()
619+
620620
errCh := make(chan error, 1)
621621
go func() {
622622
errCh <- srv.ListenAndServe()

internal/store/store.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ type Store struct {
2020
db *sql.DB
2121
}
2222

23+
// DB returns the underlying *sql.DB for use by the data_version watcher.
24+
func (s *Store) DB() *sql.DB {
25+
return s.db
26+
}
27+
2328
// New opens or creates a SQLite database at the given path and runs
2429
// schema migrations. Use ":memory:" for an in-memory database (tests).
2530
func New(path string) (*Store, error) {

internal/store/watcher.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package store
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"log"
7+
"time"
8+
)
9+
10+
// Watcher polls SQLite's PRAGMA data_version to detect changes from external
11+
// connections (e.g. CLI commands). When a change is detected, the onChange
12+
// callback is invoked. This enables hot-reload without signals or IPC.
13+
type Watcher struct {
14+
db *sql.DB
15+
interval time.Duration
16+
onChange func()
17+
cancel context.CancelFunc
18+
}
19+
20+
const defaultWatchInterval = 2 * time.Second
21+
22+
// NewWatcher creates a watcher that polls the database for changes.
23+
// The onChange callback is called when the data_version changes, indicating
24+
// another connection (CLI, API, Telegram) modified the database.
25+
func NewWatcher(db *sql.DB, onChange func(), interval ...time.Duration) *Watcher {
26+
iv := defaultWatchInterval
27+
if len(interval) > 0 && interval[0] > 0 {
28+
iv = interval[0]
29+
}
30+
return &Watcher{
31+
db: db,
32+
interval: iv,
33+
onChange: onChange,
34+
}
35+
}
36+
37+
// Start begins polling in a goroutine. Call Stop to terminate.
38+
func (w *Watcher) Start() {
39+
ctx, cancel := context.WithCancel(context.Background())
40+
w.cancel = cancel
41+
go w.poll(ctx)
42+
}
43+
44+
// Stop terminates the polling goroutine.
45+
func (w *Watcher) Stop() {
46+
if w.cancel != nil {
47+
w.cancel()
48+
}
49+
}
50+
51+
func (w *Watcher) poll(ctx context.Context) {
52+
var lastVersion int64
53+
// Read initial version.
54+
if err := w.db.QueryRow("PRAGMA data_version").Scan(&lastVersion); err != nil {
55+
log.Printf("db watcher: initial data_version read failed: %v", err)
56+
return
57+
}
58+
59+
ticker := time.NewTicker(w.interval)
60+
defer ticker.Stop()
61+
62+
for {
63+
select {
64+
case <-ctx.Done():
65+
return
66+
case <-ticker.C:
67+
var version int64
68+
if err := w.db.QueryRow("PRAGMA data_version").Scan(&version); err != nil {
69+
log.Printf("db watcher: data_version read failed: %v", err)
70+
continue
71+
}
72+
if version != lastVersion {
73+
lastVersion = version
74+
log.Printf("db watcher: change detected (version %d), triggering reload", version)
75+
w.onChange()
76+
}
77+
}
78+
}
79+
}

0 commit comments

Comments
 (0)