Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions internal/difc/pipeline_decisions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package difc

// ShouldBypassCoarseDeny returns true when a coarse-grained deny should still
// proceed to backend execution so Phase 5 can enforce per-item policy.
func ShouldBypassCoarseDeny(operation OperationType) bool {
return operation == OperationRead
}

// ShouldCallLabelResponse returns true when guards should label response data
// for possible fine-grained filtering.
func ShouldCallLabelResponse(operation OperationType, enforcementMode EnforcementMode) bool {
isPureWrite := operation == OperationWrite
return !isPureWrite && (operation != OperationReadWrite || enforcementMode != EnforcementStrict)
}

// ShouldBlockFilteredResponse returns true when filtered items should block the
// whole response instead of returning a partially filtered result.
func ShouldBlockFilteredResponse(enforcementMode EnforcementMode, filteredCount int) bool {
return enforcementMode == EnforcementStrict && filteredCount > 0
}

// ShouldAccumulateReadLabels returns true when read labels should be
// accumulated back into the agent label set.
func ShouldAccumulateReadLabels(operation OperationType, enforcementMode EnforcementMode) bool {
return operation != OperationWrite && enforcementMode == EnforcementPropagate
}
36 changes: 36 additions & 0 deletions internal/difc/pipeline_decisions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package difc

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestShouldBypassCoarseDeny(t *testing.T) {
assert.True(t, ShouldBypassCoarseDeny(OperationRead))
assert.False(t, ShouldBypassCoarseDeny(OperationWrite))
assert.False(t, ShouldBypassCoarseDeny(OperationReadWrite))
}

func TestShouldCallLabelResponse(t *testing.T) {
assert.False(t, ShouldCallLabelResponse(OperationWrite, EnforcementStrict))
assert.False(t, ShouldCallLabelResponse(OperationReadWrite, EnforcementStrict))
assert.True(t, ShouldCallLabelResponse(OperationRead, EnforcementStrict))
assert.True(t, ShouldCallLabelResponse(OperationReadWrite, EnforcementFilter))
assert.True(t, ShouldCallLabelResponse(OperationReadWrite, EnforcementPropagate))
}

func TestShouldBlockFilteredResponse(t *testing.T) {
assert.True(t, ShouldBlockFilteredResponse(EnforcementStrict, 1))
assert.False(t, ShouldBlockFilteredResponse(EnforcementStrict, 0))
assert.False(t, ShouldBlockFilteredResponse(EnforcementFilter, 3))
assert.False(t, ShouldBlockFilteredResponse(EnforcementPropagate, 2))
}

func TestShouldAccumulateReadLabels(t *testing.T) {
assert.True(t, ShouldAccumulateReadLabels(OperationRead, EnforcementPropagate))
assert.True(t, ShouldAccumulateReadLabels(OperationReadWrite, EnforcementPropagate))
assert.False(t, ShouldAccumulateReadLabels(OperationWrite, EnforcementPropagate))
assert.False(t, ShouldAccumulateReadLabels(OperationRead, EnforcementStrict))
assert.False(t, ShouldAccumulateReadLabels(OperationRead, EnforcementFilter))
}
44 changes: 35 additions & 9 deletions internal/logger/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,38 +183,64 @@ import (

// Log-Level Quad-Function Pattern
//
// Three sets of four public functions — one set per logger variant — share an identical
// structural pattern where each function is a one-liner that delegates to an internal
// helper with the appropriate LogLevel constant:
// Three sets of four public functions — one set per logger variant — share an
// identical structure where each exported one-liner delegates to an unexported
// per-level closure created by helper constructors in this file:
//
// func Log<Level>(category, format string, args ...interface{}) {
// <internalHelper>(LogLevel<Level>, category, format, args...)
// log<level>(category, format, args...)
// }
//
// The three sets and their internal helpers are:
// The three sets and their internal dispatch helpers are:
//
// file_logger.go LogInfo / LogWarn / LogError / LogDebug → logWithLevel
// markdown_logger.go LogInfoMd / LogWarnMd / LogErrorMd / LogDebugMd → logWithMarkdown
// server_file_logger.go LogInfoWithServer / ... / LogDebugWithServer → logWithLevelAndServer
//
// This pattern is intentionally kept across the three files because:
// This pattern keeps exported APIs immutable (`func` declarations) while still
// eliminating repeated inline level wiring.
//
// The makeLevelLogger and makeServerLevelLogger helpers are for internal
// delegation only and should not replace exported functions with reassignable
// function variables.
//
// This remains intentionally consistent across the three files because:
// - Each set is a distinct public API with a different signature and set of callers.
// - The one-liner wrappers are trivial and unlikely to diverge.
// - Go lacks the metaprogramming to eliminate them without sacrificing readability.
// - The exported wrappers preserve a stable, non-mutable API surface.
// - Internal closure generation removes repetitive level-binding boilerplate.
//
// The shared logFuncs map below centralises the LogLevel → log-function
// mapping so that the internal helpers (logWithMarkdown, logWithLevelAndServer)
// do not need their own switch-on-level blocks.
//
// When adding a new LogLevel constant (e.g., LogLevelTrace):
// 1. Add a new entry to the logFuncs map below.
// 2. Add a new LogTrace wrapper to each of the three files above.
// 2. Add a new internal per-level closure and exported wrapper in each of the
// three files above.
//
// logFuncs maps each LogLevel to its corresponding global log function.
// This eliminates repeated switch-on-level blocks in logWithMarkdown
// (markdown_logger.go) and logWithLevelAndServer (server_file_logger.go).
// When adding a new LogLevel constant, add a corresponding entry here so
// that all dispatch sites automatically support the new level.
func makeLevelLogger(
dispatch func(level LogLevel, category, format string, args ...interface{}),
level LogLevel,
) func(category, format string, args ...interface{}) {
return func(category, format string, args ...interface{}) {
dispatch(level, category, format, args...)
}
}
Comment on lines +226 to +233
Copy link

Copilot AI Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The large block comment immediately above this helper says the per-level one-liner wrappers are “intentionally kept” and instructs adding new wrappers in each file when adding a LogLevel. With this refactor, that guidance is now outdated/misleading. Please update or remove that section to reflect the new pattern (and, if you keep exported wrappers as funcs, document that makeLevelLogger is meant for internal delegation rather than replacing exported functions with mutable vars).

Copilot uses AI. Check for mistakes.

func makeServerLevelLogger(
dispatch func(serverID string, level LogLevel, category, format string, args ...interface{}),
level LogLevel,
) func(serverID, category, format string, args ...interface{}) {
return func(serverID, category, format string, args ...interface{}) {
dispatch(serverID, level, category, format, args...)
}
}

var logFuncs = map[LogLevel]func(string, string, ...interface{}){
LogLevelInfo: LogInfo,
LogLevelWarn: LogWarn,
Expand Down
23 changes: 15 additions & 8 deletions internal/logger/file_logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,24 +113,31 @@ func logWithLevel(level LogLevel, category, format string, args ...interface{})
})
}

// LogInfo logs an informational message
var (
logInfo = makeLevelLogger(logWithLevel, LogLevelInfo)
logWarn = makeLevelLogger(logWithLevel, LogLevelWarn)
logError = makeLevelLogger(logWithLevel, LogLevelError)
logDebug = makeLevelLogger(logWithLevel, LogLevelDebug)
)

// LogInfo logs an informational message.
func LogInfo(category, format string, args ...interface{}) {
logWithLevel(LogLevelInfo, category, format, args...)
logInfo(category, format, args...)
}

// LogWarn logs a warning message
// LogWarn logs a warning message.
func LogWarn(category, format string, args ...interface{}) {
logWithLevel(LogLevelWarn, category, format, args...)
logWarn(category, format, args...)
}

// LogError logs an error message
// LogError logs an error message.
func LogError(category, format string, args ...interface{}) {
logWithLevel(LogLevelError, category, format, args...)
logError(category, format, args...)
}

// LogDebug logs a debug message
// LogDebug logs a debug message.
func LogDebug(category, format string, args ...interface{}) {
logWithLevel(LogLevelDebug, category, format, args...)
logDebug(category, format, args...)
}

// CloseGlobalLogger closes the global file logger
Expand Down
23 changes: 15 additions & 8 deletions internal/logger/markdown_logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,24 +180,31 @@ func logWithMarkdown(level LogLevel, category, format string, args ...interface{
})
}

// LogInfoMd logs to both regular and markdown loggers
var (
logInfoMd = makeLevelLogger(logWithMarkdown, LogLevelInfo)
logWarnMd = makeLevelLogger(logWithMarkdown, LogLevelWarn)
logErrorMd = makeLevelLogger(logWithMarkdown, LogLevelError)
logDebugMd = makeLevelLogger(logWithMarkdown, LogLevelDebug)
)

// LogInfoMd logs to both regular and markdown loggers.
func LogInfoMd(category, format string, args ...interface{}) {
logWithMarkdown(LogLevelInfo, category, format, args...)
logInfoMd(category, format, args...)
}

// LogWarnMd logs to both regular and markdown loggers
// LogWarnMd logs to both regular and markdown loggers.
func LogWarnMd(category, format string, args ...interface{}) {
logWithMarkdown(LogLevelWarn, category, format, args...)
logWarnMd(category, format, args...)
}

// LogErrorMd logs to both regular and markdown loggers
// LogErrorMd logs to both regular and markdown loggers.
func LogErrorMd(category, format string, args ...interface{}) {
logWithMarkdown(LogLevelError, category, format, args...)
logErrorMd(category, format, args...)
}

// LogDebugMd logs to both regular and markdown loggers
// LogDebugMd logs to both regular and markdown loggers.
func LogDebugMd(category, format string, args ...interface{}) {
logWithMarkdown(LogLevelDebug, category, format, args...)
logDebugMd(category, format, args...)
}

// CloseMarkdownLogger closes the global markdown logger
Expand Down
23 changes: 15 additions & 8 deletions internal/logger/server_file_logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,24 +155,31 @@ func logWithLevelAndServer(serverID string, level LogLevel, category, format str
}
}

// LogInfoWithServer logs an informational message to the server-specific log file
var (
logInfoWithServer = makeServerLevelLogger(logWithLevelAndServer, LogLevelInfo)
logWarnWithServer = makeServerLevelLogger(logWithLevelAndServer, LogLevelWarn)
logErrorWithServer = makeServerLevelLogger(logWithLevelAndServer, LogLevelError)
logDebugWithServer = makeServerLevelLogger(logWithLevelAndServer, LogLevelDebug)
)

// LogInfoWithServer logs an informational message to the server-specific log file.
func LogInfoWithServer(serverID, category, format string, args ...interface{}) {
logWithLevelAndServer(serverID, LogLevelInfo, category, format, args...)
logInfoWithServer(serverID, category, format, args...)
}

// LogWarnWithServer logs a warning message to the server-specific log file
// LogWarnWithServer logs a warning message to the server-specific log file.
func LogWarnWithServer(serverID, category, format string, args ...interface{}) {
logWithLevelAndServer(serverID, LogLevelWarn, category, format, args...)
logWarnWithServer(serverID, category, format, args...)
}

// LogErrorWithServer logs an error message to the server-specific log file
// LogErrorWithServer logs an error message to the server-specific log file.
func LogErrorWithServer(serverID, category, format string, args ...interface{}) {
logWithLevelAndServer(serverID, LogLevelError, category, format, args...)
logErrorWithServer(serverID, category, format, args...)
}

// LogDebugWithServer logs a debug message to the server-specific log file
// LogDebugWithServer logs a debug message to the server-specific log file.
func LogDebugWithServer(serverID, category, format string, args ...interface{}) {
logWithLevelAndServer(serverID, LogLevelDebug, category, format, args...)
logDebugWithServer(serverID, category, format, args...)
}

// CloseServerFileLogger closes the global server file logger
Expand Down
6 changes: 3 additions & 3 deletions internal/proxy/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa
evalResult := s.evaluator.Evaluate(agentLabels.Secrecy, agentLabels.Integrity, resource, operation)

if !evalResult.IsAllowed() {
if operation == difc.OperationRead {
if difc.ShouldBypassCoarseDeny(operation) {
// Read in filter mode: skip coarse block, proceed to fine-grained filtering
logHandler.Printf("[DIFC] Phase 2: coarse check failed for read, proceeding to Phase 3")
} else {
Expand Down Expand Up @@ -266,7 +266,7 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa
}

// Strict mode: block entire response if any item filtered
if s.enforcementMode == difc.EnforcementStrict && filtered.GetFilteredCount() > 0 {
if difc.ShouldBlockFilteredResponse(s.enforcementMode, filtered.GetFilteredCount()) {
logHandler.Printf("[DIFC] STRICT: blocking response — %d filtered items", filtered.GetFilteredCount())
writeDIFCForbidden(w, fmt.Sprintf("DIFC policy violation: %d of %d items not accessible",
filtered.GetFilteredCount(), filtered.TotalCount))
Expand Down Expand Up @@ -318,7 +318,7 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa
}

// **Phase 6: Label accumulation (propagate mode)**
if s.enforcementMode == difc.EnforcementPropagate && labeledData != nil {
if labeledData != nil && difc.ShouldAccumulateReadLabels(operation, s.enforcementMode) {
overall := labeledData.Overall()
agentLabels.AccumulateFromRead(overall)
logHandler.Printf("[DIFC] Phase 6: accumulated labels")
Expand Down
11 changes: 5 additions & 6 deletions internal/server/unified.go
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName
// For read operations in any mode, we skip the coarse-grained block
// and let the request proceed. Fine-grained filtering at Phase 5 will filter
// individual items from the response based on their actual labels from LabelResponse().
isReadOperation := (operation == difc.OperationRead)
isReadOperation := difc.ShouldBypassCoarseDeny(operation)
result := requestEvaluator.Evaluate(agentLabels.Secrecy, agentLabels.Integrity, resource, operation)

if !result.IsAllowed() {
Expand Down Expand Up @@ -603,8 +603,7 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName
// Per spec: LabelResponse() is only called for read operations in all modes,
// and for read-write operations in filter/propagate modes.
// For write operations and read-write in strict mode, skip LabelResponse().
isPureWrite := (operation == difc.OperationWrite)
shouldCallLabelResponse := !isPureWrite && (operation != difc.OperationReadWrite || enforcementMode != difc.EnforcementStrict)
shouldCallLabelResponse := difc.ShouldCallLabelResponse(operation, enforcementMode)

var labeledData difc.LabeledData
if shouldCallLabelResponse {
Expand All @@ -631,7 +630,7 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName
filtered.GetAccessibleCount(), filtered.TotalCount)

// **Strict mode: block entire response if ANY item is filtered**
if enforcementMode == difc.EnforcementStrict && filtered.GetFilteredCount() > 0 {
if difc.ShouldBlockFilteredResponse(enforcementMode, filtered.GetFilteredCount()) {
logger.LogWarn("difc", "STRICT MODE: Blocking entire response - %d/%d items violate DIFC policy",
filtered.GetFilteredCount(), filtered.TotalCount)
blockErr := fmt.Errorf("DIFC policy violation: %d of %d items in response are not accessible to agent %s",
Expand Down Expand Up @@ -664,7 +663,7 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName
// **Phase 6: Accumulate labels from this operation (for reads in PROPAGATE mode only)**
// Label accumulation should only happen when mode is EnforcementPropagate
// Filter mode does NOT accumulate - it just filters what the agent can see
if !isPureWrite && enforcementMode == difc.EnforcementPropagate {
if difc.ShouldAccumulateReadLabels(operation, enforcementMode) {
overall := labeledData.Overall()
agentLabels.AccumulateFromRead(overall)
logUnified.Printf("[DIFC] Agent %s accumulated labels (propagate mode) | Secrecy: %v | Integrity: %v",
Expand All @@ -675,7 +674,7 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName
finalResult = backendResult

// **Phase 6: Accumulate labels from resource (for reads in PROPAGATE mode only)**
if !isPureWrite && enforcementMode == difc.EnforcementPropagate {
if difc.ShouldAccumulateReadLabels(operation, enforcementMode) {
agentLabels.AccumulateFromRead(resource)
logUnified.Printf("[DIFC] Agent %s accumulated labels (propagate mode) | Secrecy: %v | Integrity: %v",
agentID, agentLabels.GetSecrecyTags(), agentLabels.GetIntegrityTags())
Expand Down
Loading