diff --git a/README.md b/README.md index 58433237..f29a1c71 100644 --- a/README.md +++ b/README.md @@ -96,6 +96,28 @@ cd kiji-privacy-proxy-X.Y.Z-linux-amd64 ./run.sh ``` +Unix socket listener (optional): +```bash +PROXY_UNIX_SOCKET_PATH="${XDG_RUNTIME_DIR:-/run/kiji-proxy}/kiji-proxy.sock" kiji-proxy +``` + +**`PROXY_UNIX_SOCKET_PATH` behavior** + +When `PROXY_UNIX_SOCKET_PATH` is set, Kiji listens on the given Unix socket path instead of binding the main HTTP API to `PROXY_PORT`. + +- If `PROXY_UNIX_SOCKET_PATH` is unset, Kiji keeps the default TCP listener behavior and binds to `PROXY_PORT`. +- If the socket file already exists, Kiji removes the stale socket before listening. +- The configured path is treated the same as the `UnixSocketPath` config field. +- Socket permissions are controlled by `PROXY_UNIX_SOCKET_ACCESS_MODE` or the config field `UnixSocketAccessMode`. +- Supported access modes are `USER`, `GROUP`, and `ALL`, which map to `0600`, `0660`, and `0666`. + +Example: +```bash +PROXY_UNIX_SOCKET_PATH="${XDG_RUNTIME_DIR:-/run/kiji-proxy}/kiji-proxy.sock" \ +PROXY_UNIX_SOCKET_ACCESS_MODE="GROUP" \ +kiji-proxy +``` + **Test It:** *macOS (with automatic PAC):* diff --git a/src/backend/config/config.development.json b/src/backend/config/config.development.json index 01643d47..17e76d68 100644 --- a/src/backend/config/config.development.json +++ b/src/backend/config/config.development.json @@ -17,6 +17,7 @@ } }, "ProxyPort": ":8080", + "UnixSocketAccessMode": "USER", "Database": { "Enabled": false }, diff --git a/src/backend/config/config.go b/src/backend/config/config.go index 4c932120..6dbaccf4 100644 --- a/src/backend/config/config.go +++ b/src/backend/config/config.go @@ -22,6 +22,10 @@ const DefaultForwardProxyPort = ":8080" // The leading colon is intentional — this is a net.Listen-style address (e.g. ":8081"). const DefaultTransparentProxyPort = ":8081" +// DefaultUnixSocketAccessMode sets the default chmod permissions on the Unix socket, +// but only if config.UnixSocketPath is non-null. Can be "USER", "GROUP" or "ALL". +const DefaultUnixSocketAccessMode = "USER" + // LoggingConfig holds logging configuration options type LoggingConfig struct { LogRequests bool // Log request content @@ -69,16 +73,18 @@ type ProxyConfig struct { // Config holds all configuration for the PII proxy service type Config struct { - Providers ProvidersConfig `json:"providers"` - ProxyPort string - Database DatabaseConfig - Logging LoggingConfig - ONNXModelPath string - TokenizerPath string - ModelVariant string // "trained" (full precision) or "quantized" (INT8). Used to derive ONNXModelDirectory when it isn't set. - ONNXModelDirectory string // Explicit override; takes precedence over ModelVariant. - UIPath string - Proxy ProxyConfig `json:"Proxy"` + Providers ProvidersConfig `json:"providers"` + ProxyPort string + UnixSocketPath string + UnixSocketAccessMode string + Database DatabaseConfig + Logging LoggingConfig + ONNXModelPath string + TokenizerPath string + ModelVariant string // "trained" (full precision) or "quantized" (INT8). Used to derive ONNXModelDirectory when it isn't set. + ONNXModelDirectory string // Explicit override; takes precedence over ModelVariant. + UIPath string + Proxy ProxyConfig `json:"Proxy"` } // ModelVariantTrained is the full-precision model variant. @@ -104,9 +110,15 @@ func (c *Config) ResolveModelDirectory() string { func (c *Config) ValidateConfig() error { var errs []string - // Validate ProxyPort format (":port") - if err := validatePort(c.ProxyPort, "ProxyPort"); err != nil { - errs = append(errs, err.Error()) + if c.UnixSocketPath == "" { + // Validate ProxyPort format (":port") + if err := validatePort(c.ProxyPort, "ProxyPort"); err != nil { + errs = append(errs, err.Error()) + } + } else { + if err := validateSocketAccessMode(c.UnixSocketAccessMode); err != nil { + errs = append(errs, err.Error()) + } } // Validate ProxyConfig fields @@ -137,6 +149,28 @@ func (c *Config) ValidateConfig() error { return nil } +func validateSocketAccessMode(value string) error { + switch value { + case "USER", "GROUP", "ALL": + return nil + default: + return fmt.Errorf("UnixSocketAccessMode: value must be one of USER, GROUP, or ALL (current value: %s)", value) + } +} + +func SocketAccessModeToChmod(value string) (uint32, error) { + switch value { + case "USER": + return 0o600, nil + case "GROUP": + return 0o660, nil + case "ALL": + return 0o666, nil + default: + return 0, fmt.Errorf("UnixSocketAccessMode: value must be one of USER, GROUP, or ALL (current value: %s)", value) + } +} + func validatePort(port string, fieldName string) error { if port == "" { return fmt.Errorf("%s: port cannot be empty", fieldName) @@ -250,12 +284,13 @@ func DefaultConfig() *Config { MistralProviderConfig: defaultMistralProviderConfig, CustomProviderConfig: defaultCustomProviderConfig, }, - ProxyPort: DefaultForwardProxyPort, - ONNXModelPath: "", - TokenizerPath: "", - ModelVariant: ModelVariantTrained, - ONNXModelDirectory: "", - UIPath: "./src/frontend/dist", + ProxyPort: DefaultForwardProxyPort, + UnixSocketAccessMode: DefaultUnixSocketAccessMode, + ONNXModelPath: "", + TokenizerPath: "", + ModelVariant: ModelVariantTrained, + ONNXModelDirectory: "", + UIPath: "./src/frontend/dist", Database: DatabaseConfig{ Path: dbPath, CleanupHours: 24, diff --git a/src/backend/config/config_test.go b/src/backend/config/config_test.go index 6281622b..a1fd5992 100644 --- a/src/backend/config/config_test.go +++ b/src/backend/config/config_test.go @@ -278,6 +278,38 @@ func TestValidateConfig(t *testing.T) { expectErr: true, errString: "Proxy.ProxyPort: port must be in format ':PORT' where PORT is numeric (current value: invalid)", }, + { + name: "unix socket skips main proxy port validation", + config: func() *Config { + c := newDefaultConfig() + c.ProxyPort = "invalid" + c.UnixSocketPath = "/tmp/kiji-proxy.sock" + return c + }(), + expectErr: false, + }, + { + name: "unix socket still validates transparent proxy port", + config: func() *Config { + c := newDefaultConfig() + c.UnixSocketPath = "/tmp/kiji-proxy.sock" + c.Proxy.ProxyPort = "invalid" + return c + }(), + expectErr: true, + errString: "Proxy.ProxyPort: port must be in format ':PORT' where PORT is numeric (current value: invalid)", + }, + { + name: "unix socket rejects invalid access mode", + config: func() *Config { + c := newDefaultConfig() + c.UnixSocketPath = "/tmp/kiji-proxy.sock" + c.UnixSocketAccessMode = "TEAM" + return c + }(), + expectErr: true, + errString: "UnixSocketAccessMode: value must be one of USER, GROUP, or ALL (current value: TEAM)", + }, { name: "invalid openai provider config", config: func() *Config { @@ -330,3 +362,35 @@ func TestValidateConfig(t *testing.T) { func stringContains(s, substr string) bool { return strings.Contains(s, substr) } + +func TestSocketAccessModeToChmod(t *testing.T) { + testCases := []struct { + name string + value string + want uint32 + expectErr bool + }{ + {name: "user", value: "USER", want: 0o600}, + {name: "group", value: "GROUP", want: 0o660}, + {name: "all", value: "ALL", want: 0o666}, + {name: "invalid", value: "TEAM", expectErr: true}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := SocketAccessModeToChmod(tc.value) + if tc.expectErr { + if err == nil { + t.Fatalf("expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if got != tc.want { + t.Fatalf("SocketPermModeToChmod(%q) = %#o, want %#o", tc.value, got, tc.want) + } + }) + } +} diff --git a/src/backend/main.go b/src/backend/main.go index 2a7ff6c5..c8c4c6c7 100644 --- a/src/backend/main.go +++ b/src/backend/main.go @@ -185,6 +185,7 @@ func expandConfigPaths(cfg *config.Config) { cfg.TokenizerPath = expandPath(cfg.TokenizerPath) cfg.ONNXModelDirectory = expandPath(cfg.ONNXModelDirectory) cfg.UIPath = expandPath(cfg.UIPath) + cfg.UnixSocketPath = expandPath(cfg.UnixSocketPath) cfg.Database.Path = expandPath(cfg.Database.Path) cfg.Proxy.CAPath = expandPath(cfg.Proxy.CAPath) cfg.Proxy.KeyPath = expandPath(cfg.Proxy.KeyPath) @@ -216,6 +217,12 @@ func loadApplicationConfig(cfg *config.Config) { if proxyPort := os.Getenv("PROXY_PORT"); proxyPort != "" { cfg.ProxyPort = proxyPort } + if socketPath := os.Getenv("PROXY_UNIX_SOCKET_PATH"); socketPath != "" { + cfg.UnixSocketPath = socketPath + } + if socketAccessMode := os.Getenv("PROXY_UNIX_SOCKET_ACCESS_MODE"); socketAccessMode != "" { + cfg.UnixSocketAccessMode = socketAccessMode + } // Override OpenAI provider config with environment variables if openAIURL := os.Getenv("OPENAI_BASE_URL"); openAIURL != "" { diff --git a/src/backend/main_test.go b/src/backend/main_test.go index 8ba8bab1..93613c80 100644 --- a/src/backend/main_test.go +++ b/src/backend/main_test.go @@ -46,6 +46,7 @@ func TestExpandConfigPaths(t *testing.T) { TokenizerPath: "~/models/tok.json", ONNXModelDirectory: "~/models", UIPath: "./src/frontend/dist", + UnixSocketPath: "~/run/kiji-proxy.sock", Database: config.DatabaseConfig{ Path: "~/.kiji/db.sqlite", }, @@ -62,6 +63,7 @@ func TestExpandConfigPaths(t *testing.T) { "TokenizerPath": filepath.Join(home, "models/tok.json"), "ONNXModelDirectory": filepath.Join(home, "models"), "UIPath": "./src/frontend/dist", + "UnixSocketPath": filepath.Join(home, "run/kiji-proxy.sock"), "Database.Path": filepath.Join(home, ".kiji/db.sqlite"), "Proxy.CAPath": filepath.Join(home, "Library/Application Support/Kiji Privacy Proxy/certs/ca.crt"), "Proxy.KeyPath": "/absolute/keys/ca.key", @@ -72,6 +74,7 @@ func TestExpandConfigPaths(t *testing.T) { "TokenizerPath": cfg.TokenizerPath, "ONNXModelDirectory": cfg.ONNXModelDirectory, "UIPath": cfg.UIPath, + "UnixSocketPath": cfg.UnixSocketPath, "Database.Path": cfg.Database.Path, "Proxy.CAPath": cfg.Proxy.CAPath, "Proxy.KeyPath": cfg.Proxy.KeyPath, @@ -83,3 +86,18 @@ func TestExpandConfigPaths(t *testing.T) { } } } + +func TestLoadApplicationConfigUnixSocket(t *testing.T) { + t.Setenv("PROXY_UNIX_SOCKET_PATH", "/tmp/kiji-proxy.sock") + t.Setenv("PROXY_UNIX_SOCKET_ACCESS_MODE", "GROUP") + + cfg := config.DefaultConfig() + loadApplicationConfig(cfg) + + if cfg.UnixSocketPath != "/tmp/kiji-proxy.sock" { + t.Errorf("UnixSocketPath = %q, want %q", cfg.UnixSocketPath, "/tmp/kiji-proxy.sock") + } + if cfg.UnixSocketAccessMode != "GROUP" { + t.Errorf("SocketPermMode = %q, want %q", cfg.UnixSocketAccessMode, "GROUP") + } +} diff --git a/src/backend/server/server.go b/src/backend/server/server.go index 9f474aaf..206201d9 100644 --- a/src/backend/server/server.go +++ b/src/backend/server/server.go @@ -5,6 +5,7 @@ import ( "fmt" "io/fs" "log" + "net" "net/http" "os" "path/filepath" @@ -180,7 +181,11 @@ func NewServerWithEmbedded(cfg *config.Config, uiFS, modelFS fs.FS, version stri // Start starts the HTTP server func (s *Server) Start() error { - log.Printf("Starting Kiji Privacy Proxy service on port %s", s.config.ProxyPort) + if s.config.UnixSocketPath != "" { + log.Printf("Starting Kiji Privacy Proxy service on Unix socket %s", s.config.UnixSocketPath) + } else { + log.Printf("Starting Kiji Privacy Proxy service on port %s", s.config.ProxyPort) + } log.Printf("Forward OpenAI requests to: %s", s.config.Providers.OpenAIProviderConfig.APIDomain) log.Printf("Forward Anthropic requests to: %s", s.config.Providers.AnthropicProviderConfig.APIDomain) log.Printf("Forward Gemini requests to: %s", s.config.Providers.GeminiProviderConfig.APIDomain) @@ -300,7 +305,36 @@ func (s *Server) Start() error { IdleTimeout: 60 * time.Second, } - return server.ListenAndServe() + return serveHTTP(server, s.config.UnixSocketPath, s.config.UnixSocketAccessMode) +} + +func serveHTTP(server *http.Server, socketPath, socketAccessMode string) error { + if socketPath == "" { + return server.ListenAndServe() + } + + if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("remove stale Unix socket %s: %w", socketPath, err) + } + + listener, err := net.Listen("unix", socketPath) + if err != nil { + return fmt.Errorf("listen on Unix socket %s: %w", socketPath, err) + } + defer os.Remove(socketPath) + + socketMode, err := config.SocketAccessModeToChmod(socketAccessMode) + if err != nil { + _ = listener.Close() + return fmt.Errorf("validate Unix socket permissions for %s: %w", socketPath, err) + } + + if err := os.Chmod(socketPath, os.FileMode(socketMode)); err != nil { + _ = listener.Close() + return fmt.Errorf("set Unix socket permissions on %s: %w", socketPath, err) + } + + return server.Serve(listener) } // startTransparentProxy starts the transparent proxy server